diff --git a/notebooks/classification-gemma.ipynb b/notebooks/classification-gemma.ipynb
index 9c39565..88a7fae 100644
--- a/notebooks/classification-gemma.ipynb
+++ b/notebooks/classification-gemma.ipynb
@@ -1,5 +1,14 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Text Classification\n",
+ "\n",
+ "In this notebook, I will give you an example of LLM's fine-tuning for text classification. "
+ ]
+ },
{
"cell_type": "code",
"execution_count": 1,
@@ -26,6 +35,12 @@
"\n",
"dataset = dataset.shuffle().train_test_split(0.2)\n",
"trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
+ "\n",
+ "# sampling, as it is just a example file\n",
+ "trainset = trainset.select(range(100))\n",
+ "validset = validset.select(range(100))\n",
+ "testset = testset.select(range(100))\n",
+ "\n",
"dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
]
},
@@ -35,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
- "MODEL_ID = \"google/gemma-7b\"\n",
+ "MODEL_ID = \"google/gemma-2b\"\n",
"\n",
"SEED = 42\n",
"NUM_CLASSES = 2\n",
@@ -58,12 +73,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "592e2aefe76641e3a6fc22affb1781c7",
+ "model_id": "283eecad62ac4c468217952044f683a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
- "Map: 0%| | 0/13440 [00:00, ? examples/s]"
+ "Map: 0%| | 0/100 [00:00, ? examples/s]"
]
},
"metadata": {},
@@ -72,12 +87,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "c0227aeeb1194900a3eebd50230363bd",
+ "model_id": "ba35d6b0db1d492783b8b7fa41d9532f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
- "Map: 0%| | 0/3360 [00:00, ? examples/s]"
+ "Map: 0%| | 0/100 [00:00, ? examples/s]"
]
},
"metadata": {},
@@ -108,7 +123,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "1adbb0ec04254d7dadd0552b61449ba9",
+ "model_id": "0598ac059eb74a14aecaea95eda91ee7",
"version_major": 2,
"version_minor": 0
},
@@ -156,7 +171,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -173,7 +188,7 @@
"
\n",
" \n",
"
\n",
- " [10/10 07:32, Epoch 0/1]\n",
+ " [10/10 00:17, Epoch 0/1]\n",
"
\n",
" \n",
" \n",
@@ -187,15 +202,15 @@
" \n",
" \n",
" 5 | \n",
- " 0.352300 | \n",
- " 1.110698 | \n",
- " 0.502083 | \n",
+ " 0.977000 | \n",
+ " 1.173903 | \n",
+ " 0.560000 | \n",
"
\n",
" \n",
" 10 | \n",
- " 0.861100 | \n",
- " 1.112852 | \n",
- " 0.501786 | \n",
+ " 1.450300 | \n",
+ " 1.155836 | \n",
+ " 0.560000 | \n",
"
\n",
" \n",
"
"
@@ -206,24 +221,14 @@
},
"metadata": {},
"output_type": "display_data"
- },
- {
- "data": {
- "text/plain": [
- "TrainOutput(global_step=10, training_loss=0.6066924571990967, metrics={'train_runtime': 452.6593, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.6066924571990967, 'epoch': 0.001488095238095238})"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
- "# if using `Trainer`\n",
+ "# if using `Trainer`: \n",
"trainer = Trainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
- " output_dir=\"../experiment/\",\n",
+ " output_dir=\"../experiment/classification/Trainer/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
@@ -233,6 +238,7 @@
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
+ " save_steps=5,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
@@ -243,7 +249,106 @@
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
- "trainer.train()"
+ "trainer.train()\n",
+ "trainer.save_state()\n",
+ "trainer.save_model(output_dir=\"../experiment/classification/Trainer/model/\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "max_steps is given, it will override any value given in num_train_epochs\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " [10/10 00:17, Epoch 0/1]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ " Accuracy | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 5 | \n",
+ " 0.827100 | \n",
+ " 1.106255 | \n",
+ " 0.550000 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 1.360100 | \n",
+ " 1.089411 | \n",
+ " 0.540000 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# if using `SFTTrainer`\n",
+ "trainer = SFTTrainer(\n",
+ " model=lora_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=\"../experiment/classification/SFTTrainer/\",\n",
+ " learning_rate=2e-5,\n",
+ " num_train_epochs=1,\n",
+ " weight_decay=0.01,\n",
+ " logging_steps=5,\n",
+ " max_steps=10,\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " evaluation_strategy=\"steps\",\n",
+ " save_strategy=\"steps\",\n",
+ " save_steps=5,\n",
+ " load_best_model_at_end=True,\n",
+ " report_to=\"none\"\n",
+ " ),\n",
+ " train_dataset=dataset[\"train\"],\n",
+ " eval_dataset=dataset[\"valid\"],\n",
+ " dataset_text_field=\"text\", # main different between SFTTrainer and Trainer\n",
+ " tokenizer=tokenizer,\n",
+ " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
+ " compute_metrics=compute_metrics,\n",
+ ")\n",
+ "\n",
+ "trainer.train()\n",
+ "trainer.save_state()\n",
+ "trainer.save_model(output_dir=\"../experiment/classification/SFTTrainer/model/\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "del model\n",
+ "del trainer\n",
+ "torch.cuda.empty_cache()"
]
},
{
@@ -251,93 +356,82 @@
"execution_count": 10,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "max_steps is given, it will override any value given in num_train_epochs\n"
- ]
- },
{
"data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- " [10/10 07:32, Epoch 0/1]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step | \n",
- " Training Loss | \n",
- " Validation Loss | \n",
- " Accuracy | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 5 | \n",
- " 0.280700 | \n",
- " 1.126119 | \n",
- " 0.497024 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 0.803900 | \n",
- " 1.128193 | \n",
- " 0.497024 | \n",
- "
\n",
- " \n",
- "
"
- ],
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3a6911922f574bc29d938985bb1cf2f5",
+ "version_major": 2,
+ "version_minor": 0
+ },
"text/plain": [
- ""
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "data": {
- "text/plain": [
- "TrainOutput(global_step=10, training_loss=0.542291009426117, metrics={'train_runtime': 452.9047, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.542291009426117, 'epoch': 0.001488095238095238})"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
+ "The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n",
+ "100%|██████████| 100/100 [02:30<00:00, 1.50s/it]\n"
+ ]
}
],
"source": [
- "# if using `SFTTrainer`\n",
- "trainer = SFTTrainer(\n",
- " model=lora_model,\n",
- " args=TrainingArguments(\n",
- " output_dir=\"../experiment/\",\n",
- " learning_rate=2e-5,\n",
- " num_train_epochs=1,\n",
- " weight_decay=0.01,\n",
- " logging_steps=5,\n",
- " max_steps=10,\n",
- " per_device_train_batch_size=BATCH_SIZE,\n",
- " per_device_eval_batch_size=BATCH_SIZE,\n",
- " evaluation_strategy=\"steps\",\n",
- " save_strategy=\"steps\",\n",
- " load_best_model_at_end=True,\n",
- " report_to=\"none\"\n",
- " ),\n",
- " train_dataset=dataset[\"train\"],\n",
- " eval_dataset=dataset[\"valid\"],\n",
- " dataset_text_field=\"text\",\n",
- " tokenizer=tokenizer,\n",
- " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
- " compute_metrics=compute_metrics,\n",
- ")\n",
+ "## load the tuned model and use it in testset\n",
+ "base_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_CLASSES)\n",
+ "merged_model = PeftModel.from_pretrained(base_model, \"../experiment/classification/Trainer/model/\")\n",
+ "trained_model = merged_model.merge_and_unload()\n",
"\n",
- "trainer.train()"
+ "trained_tokenizer = AutoTokenizer.from_pretrained(\"../experiment/classification/Trainer/model/\")\n",
+ "\n",
+ "cls_pipeline = pipeline(\"text-classification\", merged_model, tokenizer=tokenizer)\n",
+ "\n",
+ "predictions = []\n",
+ "for text in tqdm(dataset[\"test\"][\"text\"]):\n",
+ " prediction = cls_pipeline(text)\n",
+ " prediction = int(prediction[0]['label'].split('_')[1])\n",
+ " predictions.append(prediction)\n",
+ " "
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.58 0.91 0.71 56\n",
+ " 1 0.58 0.16 0.25 44\n",
+ "\n",
+ " accuracy 0.58 100\n",
+ " macro avg 0.58 0.53 0.48 100\n",
+ "weighted avg 0.58 0.58 0.51 100\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.metrics import classification_report\n",
+ "\n",
+ "print(classification_report(pd.DataFrame(validset)[\"label\"], predictions))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/src/utils.py b/src/utils.py
index b182b13..6e88fe1 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -5,6 +5,7 @@ import torch
import numpy as np
import pandas as pd
+from tqdm import tqdm
from random import randint
from datasets import load_dataset, DatasetDict