diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e40c27b --- /dev/null +++ b/.gitignore @@ -0,0 +1,199 @@ +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python + +### others ### +data/ +experiment/ + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python \ No newline at end of file diff --git a/README.md b/README.md index a57a2df..939ef10 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,12 @@ # languageM -*Learn how to fit LLM(Large Language Model)* +*Learn how to fit LLMs (Large Language Models)* ## Base -* Model: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM) \ No newline at end of file +* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM) +* Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index) + +## HOW-TO + +1. Preparation: `python -m pip install -r requirements.txt` \ No newline at end of file diff --git a/notebooks/classification-gemma.ipynb b/notebooks/classification-gemma.ipynb new file mode 100644 index 0000000..9c39565 --- /dev/null +++ b/notebooks/classification-gemma.ipynb @@ -0,0 +1,364 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../src\")\n", + "from utils import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "'''\n", + "data source: https://www.kaggle.com/competitions/ml-olympiad-toxic-language-ptbr-detection\n", + "'''\n", + "\n", + "dataset = load_dataset(\"csv\", data_files=\"../data/PTBR/train.csv\")[\"train\"]\n", + "testset = load_dataset(\"csv\", data_files=\"../data/PTBR/test.csv\")[\"train\"]\n", + "\n", + "dataset = dataset.shuffle().train_test_split(0.2)\n", + "trainset, validset = dataset[\"train\"], dataset[\"test\"]\n", + "dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_ID = \"google/gemma-7b\"\n", + "\n", + "SEED = 42\n", + "NUM_CLASSES = 2\n", + "\n", + "EPOCHS = 2\n", + "\n", + "LORA_R = 8\n", + "LORA_ALPHA = 16\n", + "LORA_DROPOUT = 0.1\n", + "BATCH_SIZE = 2\n", + "\n", + "set_seed(SEED)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "592e2aefe76641e3a6fc22affb1781c7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/13440 [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c0227aeeb1194900a3eebd50230363bd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/3360 [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n", + "\n", + "dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n", + "dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n", + "dataset[\"valid\"] = dataset[\"valid\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Gemma's activation function should be approximate GeLU and not exact GeLU.\n", + "Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1adbb0ec04254d7dadd0552b61449ba9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable params: 9,826,320 || all params: 2,516,019,232 || trainable%: 0.3905502738223902\n" + ] + } + ], + "source": [ + "lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n", + "model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, device_map=\"auto\", num_labels=NUM_CLASSES)\n", + "lora_model = get_peft_model(model, lora_config)\n", + "lora_model.print_trainable_parameters()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(eval_pred):\n", + " predictions, labels = eval_pred\n", + " predictions = np.argmax(predictions, axis=1)\n", + " return {\"accuracy\": (predictions == labels).mean()}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
Step | \n", + "Training Loss | \n", + "Validation Loss | \n", + "Accuracy | \n", + "
---|---|---|---|
5 | \n", + "0.352300 | \n", + "1.110698 | \n", + "0.502083 | \n", + "
10 | \n", + "0.861100 | \n", + "1.112852 | \n", + "0.501786 | \n", + "
"
+ ],
+ "text/plain": [
+ " "
+ ],
+ "text/plain": [
+ "\n",
+ " \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step \n",
+ " Training Loss \n",
+ " Validation Loss \n",
+ " Accuracy \n",
+ " \n",
+ " \n",
+ " 5 \n",
+ " 0.280700 \n",
+ " 1.126119 \n",
+ " 0.497024 \n",
+ " \n",
+ " \n",
+ " \n",
+ "10 \n",
+ " 0.803900 \n",
+ " 1.128193 \n",
+ " 0.497024 \n",
+ "