add classification
This commit is contained in:
parent
91e070e623
commit
92d5bda932
|
@ -0,0 +1,199 @@
|
||||||
|
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
|
||||||
|
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python
|
||||||
|
|
||||||
|
### others ###
|
||||||
|
data/
|
||||||
|
experiment/
|
||||||
|
|
||||||
|
### Python ###
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
### Python Patch ###
|
||||||
|
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||||
|
poetry.toml
|
||||||
|
|
||||||
|
# ruff
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# LSP config files
|
||||||
|
pyrightconfig.json
|
||||||
|
|
||||||
|
### VisualStudioCode ###
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/settings.json
|
||||||
|
!.vscode/tasks.json
|
||||||
|
!.vscode/launch.json
|
||||||
|
!.vscode/extensions.json
|
||||||
|
!.vscode/*.code-snippets
|
||||||
|
|
||||||
|
# Local History for Visual Studio Code
|
||||||
|
.history/
|
||||||
|
|
||||||
|
# Built Visual Studio Code Extensions
|
||||||
|
*.vsix
|
||||||
|
|
||||||
|
### VisualStudioCode Patch ###
|
||||||
|
# Ignore all local history of files
|
||||||
|
.history
|
||||||
|
.ionide
|
||||||
|
|
||||||
|
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
|
|
@ -1,7 +1,12 @@
|
||||||
# languageM
|
# languageM
|
||||||
|
|
||||||
*Learn how to fit LLM(Large Language Model)*
|
*Learn how to fit LLMs (Large Language Models)*
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
|
|
||||||
* Model: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM)
|
* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM)
|
||||||
|
* Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index)
|
||||||
|
|
||||||
|
## HOW-TO
|
||||||
|
|
||||||
|
1. Preparation: `python -m pip install -r requirements.txt`
|
|
@ -0,0 +1,364 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"sys.path.append(\"../src\")\n",
|
||||||
|
"from utils import *"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"'''\n",
|
||||||
|
"data source: https://www.kaggle.com/competitions/ml-olympiad-toxic-language-ptbr-detection\n",
|
||||||
|
"'''\n",
|
||||||
|
"\n",
|
||||||
|
"dataset = load_dataset(\"csv\", data_files=\"../data/PTBR/train.csv\")[\"train\"]\n",
|
||||||
|
"testset = load_dataset(\"csv\", data_files=\"../data/PTBR/test.csv\")[\"train\"]\n",
|
||||||
|
"\n",
|
||||||
|
"dataset = dataset.shuffle().train_test_split(0.2)\n",
|
||||||
|
"trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
|
||||||
|
"dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"MODEL_ID = \"google/gemma-7b\"\n",
|
||||||
|
"\n",
|
||||||
|
"SEED = 42\n",
|
||||||
|
"NUM_CLASSES = 2\n",
|
||||||
|
"\n",
|
||||||
|
"EPOCHS = 2\n",
|
||||||
|
"\n",
|
||||||
|
"LORA_R = 8\n",
|
||||||
|
"LORA_ALPHA = 16\n",
|
||||||
|
"LORA_DROPOUT = 0.1\n",
|
||||||
|
"BATCH_SIZE = 2\n",
|
||||||
|
"\n",
|
||||||
|
"set_seed(SEED)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "592e2aefe76641e3a6fc22affb1781c7",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Map: 0%| | 0/13440 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "c0227aeeb1194900a3eebd50230363bd",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Map: 0%| | 0/3360 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
|
||||||
|
"\n",
|
||||||
|
"dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
|
||||||
|
"dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
|
||||||
|
"dataset[\"valid\"] = dataset[\"valid\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n",
|
||||||
|
"Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "1adbb0ec04254d7dadd0552b61449ba9",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
|
||||||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"trainable params: 9,826,320 || all params: 2,516,019,232 || trainable%: 0.3905502738223902\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n",
|
||||||
|
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, device_map=\"auto\", num_labels=NUM_CLASSES)\n",
|
||||||
|
"lora_model = get_peft_model(model, lora_config)\n",
|
||||||
|
"lora_model.print_trainable_parameters()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def compute_metrics(eval_pred):\n",
|
||||||
|
" predictions, labels = eval_pred\n",
|
||||||
|
" predictions = np.argmax(predictions, axis=1)\n",
|
||||||
|
" return {\"accuracy\": (predictions == labels).mean()}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"max_steps is given, it will override any value given in num_train_epochs\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"\n",
|
||||||
|
" <div>\n",
|
||||||
|
" \n",
|
||||||
|
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||||
|
" [10/10 07:32, Epoch 0/1]\n",
|
||||||
|
" </div>\n",
|
||||||
|
" <table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: left;\">\n",
|
||||||
|
" <th>Step</th>\n",
|
||||||
|
" <th>Training Loss</th>\n",
|
||||||
|
" <th>Validation Loss</th>\n",
|
||||||
|
" <th>Accuracy</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>5</td>\n",
|
||||||
|
" <td>0.352300</td>\n",
|
||||||
|
" <td>1.110698</td>\n",
|
||||||
|
" <td>0.502083</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>10</td>\n",
|
||||||
|
" <td>0.861100</td>\n",
|
||||||
|
" <td>1.112852</td>\n",
|
||||||
|
" <td>0.501786</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table><p>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"TrainOutput(global_step=10, training_loss=0.6066924571990967, metrics={'train_runtime': 452.6593, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.6066924571990967, 'epoch': 0.001488095238095238})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# if using `Trainer`\n",
|
||||||
|
"trainer = Trainer(\n",
|
||||||
|
" model=lora_model,\n",
|
||||||
|
" args=TrainingArguments(\n",
|
||||||
|
" output_dir=\"../experiment/\",\n",
|
||||||
|
" learning_rate=2e-5,\n",
|
||||||
|
" num_train_epochs=1,\n",
|
||||||
|
" weight_decay=0.01,\n",
|
||||||
|
" logging_steps=5,\n",
|
||||||
|
" max_steps=10,\n",
|
||||||
|
" per_device_train_batch_size=BATCH_SIZE,\n",
|
||||||
|
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
||||||
|
" evaluation_strategy=\"steps\",\n",
|
||||||
|
" save_strategy=\"steps\",\n",
|
||||||
|
" load_best_model_at_end=True,\n",
|
||||||
|
" report_to=\"none\"\n",
|
||||||
|
" ),\n",
|
||||||
|
" train_dataset=dataset[\"train\"],\n",
|
||||||
|
" eval_dataset=dataset[\"valid\"],\n",
|
||||||
|
" tokenizer=tokenizer,\n",
|
||||||
|
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
|
||||||
|
" compute_metrics=compute_metrics,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"trainer.train()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"max_steps is given, it will override any value given in num_train_epochs\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"\n",
|
||||||
|
" <div>\n",
|
||||||
|
" \n",
|
||||||
|
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||||
|
" [10/10 07:32, Epoch 0/1]\n",
|
||||||
|
" </div>\n",
|
||||||
|
" <table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: left;\">\n",
|
||||||
|
" <th>Step</th>\n",
|
||||||
|
" <th>Training Loss</th>\n",
|
||||||
|
" <th>Validation Loss</th>\n",
|
||||||
|
" <th>Accuracy</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>5</td>\n",
|
||||||
|
" <td>0.280700</td>\n",
|
||||||
|
" <td>1.126119</td>\n",
|
||||||
|
" <td>0.497024</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>10</td>\n",
|
||||||
|
" <td>0.803900</td>\n",
|
||||||
|
" <td>1.128193</td>\n",
|
||||||
|
" <td>0.497024</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table><p>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"TrainOutput(global_step=10, training_loss=0.542291009426117, metrics={'train_runtime': 452.9047, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.542291009426117, 'epoch': 0.001488095238095238})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# if using `SFTTrainer`\n",
|
||||||
|
"trainer = SFTTrainer(\n",
|
||||||
|
" model=lora_model,\n",
|
||||||
|
" args=TrainingArguments(\n",
|
||||||
|
" output_dir=\"../experiment/\",\n",
|
||||||
|
" learning_rate=2e-5,\n",
|
||||||
|
" num_train_epochs=1,\n",
|
||||||
|
" weight_decay=0.01,\n",
|
||||||
|
" logging_steps=5,\n",
|
||||||
|
" max_steps=10,\n",
|
||||||
|
" per_device_train_batch_size=BATCH_SIZE,\n",
|
||||||
|
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
||||||
|
" evaluation_strategy=\"steps\",\n",
|
||||||
|
" save_strategy=\"steps\",\n",
|
||||||
|
" load_best_model_at_end=True,\n",
|
||||||
|
" report_to=\"none\"\n",
|
||||||
|
" ),\n",
|
||||||
|
" train_dataset=dataset[\"train\"],\n",
|
||||||
|
" eval_dataset=dataset[\"valid\"],\n",
|
||||||
|
" dataset_text_field=\"text\",\n",
|
||||||
|
" tokenizer=tokenizer,\n",
|
||||||
|
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
|
||||||
|
" compute_metrics=compute_metrics,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"trainer.train()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "languageM",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
tqdm
|
||||||
|
numpy
|
||||||
|
pandas
|
||||||
|
ipykernel
|
||||||
|
ipywidgets
|
||||||
|
scikit-learn
|
||||||
|
trl
|
||||||
|
peft
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
torchaudio
|
||||||
|
datasets
|
||||||
|
accelerate
|
||||||
|
modelscope
|
||||||
|
bitsandbytes
|
||||||
|
transformers
|
|
@ -0,0 +1,17 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from random import randint
|
||||||
|
from datasets import load_dataset, DatasetDict
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from trl import SFTTrainer
|
||||||
|
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model, TaskType
|
||||||
|
from transformers import set_seed, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline, BitsAndBytesConfig, DataCollatorWithPadding, Trainer, TrainingArguments
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
Loading…
Reference in New Issue