This commit is contained in:
TerenceLiu98 2024-04-25 01:02:04 +00:00
parent 92d5bda932
commit a663caaa4c
2 changed files with 195 additions and 100 deletions

View File

@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Text Classification\n",
"\n",
"In this notebook, I will give you an example of LLM's fine-tuning for text classification. "
]
},
{
"cell_type": "code",
"execution_count": 1,
@ -26,6 +35,12 @@
"\n",
"dataset = dataset.shuffle().train_test_split(0.2)\n",
"trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
"\n",
"# sampling, as it is just a example file\n",
"trainset = trainset.select(range(100))\n",
"validset = validset.select(range(100))\n",
"testset = testset.select(range(100))\n",
"\n",
"dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
]
},
@ -35,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
"MODEL_ID = \"google/gemma-7b\"\n",
"MODEL_ID = \"google/gemma-2b\"\n",
"\n",
"SEED = 42\n",
"NUM_CLASSES = 2\n",
@ -58,12 +73,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "592e2aefe76641e3a6fc22affb1781c7",
"model_id": "283eecad62ac4c468217952044f683a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/13440 [00:00<?, ? examples/s]"
"Map: 0%| | 0/100 [00:00<?, ? examples/s]"
]
},
"metadata": {},
@ -72,12 +87,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0227aeeb1194900a3eebd50230363bd",
"model_id": "ba35d6b0db1d492783b8b7fa41d9532f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/3360 [00:00<?, ? examples/s]"
"Map: 0%| | 0/100 [00:00<?, ? examples/s]"
]
},
"metadata": {},
@ -108,7 +123,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1adbb0ec04254d7dadd0552b61449ba9",
"model_id": "0598ac059eb74a14aecaea95eda91ee7",
"version_major": 2,
"version_minor": 0
},
@ -156,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@ -173,7 +188,7 @@
" <div>\n",
" \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 07:32, Epoch 0/1]\n",
" [10/10 00:17, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
@ -187,15 +202,15 @@
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.352300</td>\n",
" <td>1.110698</td>\n",
" <td>0.502083</td>\n",
" <td>0.977000</td>\n",
" <td>1.173903</td>\n",
" <td>0.560000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.861100</td>\n",
" <td>1.112852</td>\n",
" <td>0.501786</td>\n",
" <td>1.450300</td>\n",
" <td>1.155836</td>\n",
" <td>0.560000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
@ -206,24 +221,14 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=10, training_loss=0.6066924571990967, metrics={'train_runtime': 452.6593, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.6066924571990967, 'epoch': 0.001488095238095238})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# if using `Trainer`\n",
"# if using `Trainer`: \n",
"trainer = Trainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/\",\n",
" output_dir=\"../experiment/classification/Trainer/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
@ -233,6 +238,7 @@
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" save_steps=5,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
@ -243,7 +249,106 @@
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"trainer.train()"
"trainer.train()\n",
"trainer.save_state()\n",
"trainer.save_model(output_dir=\"../experiment/classification/Trainer/model/\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 00:17, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.827100</td>\n",
" <td>1.106255</td>\n",
" <td>0.550000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>1.360100</td>\n",
" <td>1.089411</td>\n",
" <td>0.540000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# if using `SFTTrainer`\n",
"trainer = SFTTrainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/classification/SFTTrainer/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=10,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" save_steps=5,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" dataset_text_field=\"text\", # main different between SFTTrainer and Trainer\n",
" tokenizer=tokenizer,\n",
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"trainer.train()\n",
"trainer.save_state()\n",
"trainer.save_model(output_dir=\"../experiment/classification/SFTTrainer/model/\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"del model\n",
"del trainer\n",
"torch.cuda.empty_cache()"
]
},
{
@ -251,93 +356,82 @@
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 07:32, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.280700</td>\n",
" <td>1.126119</td>\n",
" <td>0.497024</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.803900</td>\n",
" <td>1.128193</td>\n",
" <td>0.497024</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a6911922f574bc29d938985bb1cf2f5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"<IPython.core.display.HTML object>"
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=10, training_loss=0.542291009426117, metrics={'train_runtime': 452.9047, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.542291009426117, 'epoch': 0.001488095238095238})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n",
"100%|██████████| 100/100 [02:30<00:00, 1.50s/it]\n"
]
}
],
"source": [
"# if using `SFTTrainer`\n",
"trainer = SFTTrainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=10,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" dataset_text_field=\"text\",\n",
" tokenizer=tokenizer,\n",
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
" compute_metrics=compute_metrics,\n",
")\n",
"## load the tuned model and use it in testset\n",
"base_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_CLASSES)\n",
"merged_model = PeftModel.from_pretrained(base_model, \"../experiment/classification/Trainer/model/\")\n",
"trained_model = merged_model.merge_and_unload()\n",
"\n",
"trainer.train()"
"trained_tokenizer = AutoTokenizer.from_pretrained(\"../experiment/classification/Trainer/model/\")\n",
"\n",
"cls_pipeline = pipeline(\"text-classification\", merged_model, tokenizer=tokenizer)\n",
"\n",
"predictions = []\n",
"for text in tqdm(dataset[\"test\"][\"text\"]):\n",
" prediction = cls_pipeline(text)\n",
" prediction = int(prediction[0]['label'].split('_')[1])\n",
" predictions.append(prediction)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.58 0.91 0.71 56\n",
" 1 0.58 0.16 0.25 44\n",
"\n",
" accuracy 0.58 100\n",
" macro avg 0.58 0.53 0.48 100\n",
"weighted avg 0.58 0.58 0.51 100\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(pd.DataFrame(validset)[\"label\"], predictions))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@ -5,6 +5,7 @@ import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from random import randint
from datasets import load_dataset, DatasetDict