From 92d5bda9329d9d2af20b09d8219791ba03e4f6bf Mon Sep 17 00:00:00 2001 From: TerenceLiu98 Date: Wed, 24 Apr 2024 23:30:18 +0000 Subject: [PATCH] add classification --- .gitignore | 199 +++++++++++++++ README.md | 9 +- notebooks/classification-gemma.ipynb | 364 +++++++++++++++++++++++++++ requirements.txt | 16 ++ src/utils.py | 17 ++ 5 files changed, 603 insertions(+), 2 deletions(-) create mode 100644 .gitignore create mode 100644 notebooks/classification-gemma.ipynb create mode 100644 requirements.txt create mode 100644 src/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e40c27b --- /dev/null +++ b/.gitignore @@ -0,0 +1,199 @@ +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python + +### others ### +data/ +experiment/ + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python \ No newline at end of file diff --git a/README.md b/README.md index a57a2df..939ef10 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,12 @@ # languageM -*Learn how to fit LLM(Large Language Model)* +*Learn how to fit LLMs (Large Language Models)* ## Base -* Model: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM) \ No newline at end of file +* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM) +* Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index) + +## HOW-TO + +1. Preparation: `python -m pip install -r requirements.txt` \ No newline at end of file diff --git a/notebooks/classification-gemma.ipynb b/notebooks/classification-gemma.ipynb new file mode 100644 index 0000000..9c39565 --- /dev/null +++ b/notebooks/classification-gemma.ipynb @@ -0,0 +1,364 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../src\")\n", + "from utils import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "'''\n", + "data source: https://www.kaggle.com/competitions/ml-olympiad-toxic-language-ptbr-detection\n", + "'''\n", + "\n", + "dataset = load_dataset(\"csv\", data_files=\"../data/PTBR/train.csv\")[\"train\"]\n", + "testset = load_dataset(\"csv\", data_files=\"../data/PTBR/test.csv\")[\"train\"]\n", + "\n", + "dataset = dataset.shuffle().train_test_split(0.2)\n", + "trainset, validset = dataset[\"train\"], dataset[\"test\"]\n", + "dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_ID = \"google/gemma-7b\"\n", + "\n", + "SEED = 42\n", + "NUM_CLASSES = 2\n", + "\n", + "EPOCHS = 2\n", + "\n", + "LORA_R = 8\n", + "LORA_ALPHA = 16\n", + "LORA_DROPOUT = 0.1\n", + "BATCH_SIZE = 2\n", + "\n", + "set_seed(SEED)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "592e2aefe76641e3a6fc22affb1781c7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/13440 [00:00\n", + " \n", + " \n", + " [10/10 07:32, Epoch 0/1]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossAccuracy
50.3523001.1106980.502083
100.8611001.1128520.501786

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=10, training_loss=0.6066924571990967, metrics={'train_runtime': 452.6593, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.6066924571990967, 'epoch': 0.001488095238095238})" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if using `Trainer`\n", + "trainer = Trainer(\n", + " model=lora_model,\n", + " args=TrainingArguments(\n", + " output_dir=\"../experiment/\",\n", + " learning_rate=2e-5,\n", + " num_train_epochs=1,\n", + " weight_decay=0.01,\n", + " logging_steps=5,\n", + " max_steps=10,\n", + " per_device_train_batch_size=BATCH_SIZE,\n", + " per_device_eval_batch_size=BATCH_SIZE,\n", + " evaluation_strategy=\"steps\",\n", + " save_strategy=\"steps\",\n", + " load_best_model_at_end=True,\n", + " report_to=\"none\"\n", + " ),\n", + " train_dataset=dataset[\"train\"],\n", + " eval_dataset=dataset[\"valid\"],\n", + " tokenizer=tokenizer,\n", + " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n", + " compute_metrics=compute_metrics,\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [10/10 07:32, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossAccuracy
50.2807001.1261190.497024
100.8039001.1281930.497024

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=10, training_loss=0.542291009426117, metrics={'train_runtime': 452.9047, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.542291009426117, 'epoch': 0.001488095238095238})" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if using `SFTTrainer`\n", + "trainer = SFTTrainer(\n", + " model=lora_model,\n", + " args=TrainingArguments(\n", + " output_dir=\"../experiment/\",\n", + " learning_rate=2e-5,\n", + " num_train_epochs=1,\n", + " weight_decay=0.01,\n", + " logging_steps=5,\n", + " max_steps=10,\n", + " per_device_train_batch_size=BATCH_SIZE,\n", + " per_device_eval_batch_size=BATCH_SIZE,\n", + " evaluation_strategy=\"steps\",\n", + " save_strategy=\"steps\",\n", + " load_best_model_at_end=True,\n", + " report_to=\"none\"\n", + " ),\n", + " train_dataset=dataset[\"train\"],\n", + " eval_dataset=dataset[\"valid\"],\n", + " dataset_text_field=\"text\",\n", + " tokenizer=tokenizer,\n", + " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n", + " compute_metrics=compute_metrics,\n", + ")\n", + "\n", + "trainer.train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "languageM", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..94bdb32 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +tqdm +numpy +pandas +ipykernel +ipywidgets +scikit-learn +trl +peft +torch +torchvision +torchaudio +datasets +accelerate +modelscope +bitsandbytes +transformers \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..b182b13 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,17 @@ +import os +import json +import torch + +import numpy as np +import pandas as pd + +from random import randint +from datasets import load_dataset, DatasetDict + +import transformers +from trl import SFTTrainer +from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model, TaskType +from transformers import set_seed, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline, BitsAndBytesConfig, DataCollatorWithPadding, Trainer, TrainingArguments + +import warnings +warnings.filterwarnings("ignore") \ No newline at end of file