add classification

This commit is contained in:
TerenceLiu98 2024-04-24 23:30:18 +00:00
parent 91e070e623
commit 92d5bda932
5 changed files with 603 additions and 2 deletions

199
.gitignore vendored Normal file
View File

@ -0,0 +1,199 @@
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python
### others ###
data/
experiment/
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python

View File

@ -1,7 +1,12 @@
# languageM # languageM
*Learn how to fit LLM(Large Language Model)* *Learn how to fit LLMs (Large Language Models)*
## Base ## Base
* Model: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM) * Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM)
* Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index)
## HOW-TO
1. Preparation: `python -m pip install -r requirements.txt`

View File

@ -0,0 +1,364 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../src\")\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"'''\n",
"data source: https://www.kaggle.com/competitions/ml-olympiad-toxic-language-ptbr-detection\n",
"'''\n",
"\n",
"dataset = load_dataset(\"csv\", data_files=\"../data/PTBR/train.csv\")[\"train\"]\n",
"testset = load_dataset(\"csv\", data_files=\"../data/PTBR/test.csv\")[\"train\"]\n",
"\n",
"dataset = dataset.shuffle().train_test_split(0.2)\n",
"trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
"dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"MODEL_ID = \"google/gemma-7b\"\n",
"\n",
"SEED = 42\n",
"NUM_CLASSES = 2\n",
"\n",
"EPOCHS = 2\n",
"\n",
"LORA_R = 8\n",
"LORA_ALPHA = 16\n",
"LORA_DROPOUT = 0.1\n",
"BATCH_SIZE = 2\n",
"\n",
"set_seed(SEED)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "592e2aefe76641e3a6fc22affb1781c7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/13440 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0227aeeb1194900a3eebd50230363bd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/3360 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
"\n",
"dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
"dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
"dataset[\"valid\"] = dataset[\"valid\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n",
"Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1adbb0ec04254d7dadd0552b61449ba9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 9,826,320 || all params: 2,516,019,232 || trainable%: 0.3905502738223902\n"
]
}
],
"source": [
"lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n",
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, device_map=\"auto\", num_labels=NUM_CLASSES)\n",
"lora_model = get_peft_model(model, lora_config)\n",
"lora_model.print_trainable_parameters()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" predictions = np.argmax(predictions, axis=1)\n",
" return {\"accuracy\": (predictions == labels).mean()}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 07:32, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.352300</td>\n",
" <td>1.110698</td>\n",
" <td>0.502083</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.861100</td>\n",
" <td>1.112852</td>\n",
" <td>0.501786</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=10, training_loss=0.6066924571990967, metrics={'train_runtime': 452.6593, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.6066924571990967, 'epoch': 0.001488095238095238})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# if using `Trainer`\n",
"trainer = Trainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=10,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" tokenizer=tokenizer,\n",
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 07:32, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.280700</td>\n",
" <td>1.126119</td>\n",
" <td>0.497024</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.803900</td>\n",
" <td>1.128193</td>\n",
" <td>0.497024</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=10, training_loss=0.542291009426117, metrics={'train_runtime': 452.9047, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.542291009426117, 'epoch': 0.001488095238095238})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# if using `SFTTrainer`\n",
"trainer = SFTTrainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=10,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" dataset_text_field=\"text\",\n",
" tokenizer=tokenizer,\n",
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"trainer.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "languageM",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

16
requirements.txt Normal file
View File

@ -0,0 +1,16 @@
tqdm
numpy
pandas
ipykernel
ipywidgets
scikit-learn
trl
peft
torch
torchvision
torchaudio
datasets
accelerate
modelscope
bitsandbytes
transformers

17
src/utils.py Normal file
View File

@ -0,0 +1,17 @@
import os
import json
import torch
import numpy as np
import pandas as pd
from random import randint
from datasets import load_dataset, DatasetDict
import transformers
from trl import SFTTrainer
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model, TaskType
from transformers import set_seed, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline, BitsAndBytesConfig, DataCollatorWithPadding, Trainer, TrainingArguments
import warnings
warnings.filterwarnings("ignore")