update
This commit is contained in:
parent
92d5bda932
commit
a663caaa4c
|
@ -1,5 +1,14 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Text Classification\n",
|
||||
"\n",
|
||||
"In this notebook, I will give you an example of LLM's fine-tuning for text classification. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
|
@ -26,6 +35,12 @@
|
|||
"\n",
|
||||
"dataset = dataset.shuffle().train_test_split(0.2)\n",
|
||||
"trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
|
||||
"\n",
|
||||
"# sampling, as it is just a example file\n",
|
||||
"trainset = trainset.select(range(100))\n",
|
||||
"validset = validset.select(range(100))\n",
|
||||
"testset = testset.select(range(100))\n",
|
||||
"\n",
|
||||
"dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
|
||||
]
|
||||
},
|
||||
|
@ -35,7 +50,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"MODEL_ID = \"google/gemma-7b\"\n",
|
||||
"MODEL_ID = \"google/gemma-2b\"\n",
|
||||
"\n",
|
||||
"SEED = 42\n",
|
||||
"NUM_CLASSES = 2\n",
|
||||
|
@ -58,12 +73,12 @@
|
|||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "592e2aefe76641e3a6fc22affb1781c7",
|
||||
"model_id": "283eecad62ac4c468217952044f683a9",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map: 0%| | 0/13440 [00:00<?, ? examples/s]"
|
||||
"Map: 0%| | 0/100 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
@ -72,12 +87,12 @@
|
|||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c0227aeeb1194900a3eebd50230363bd",
|
||||
"model_id": "ba35d6b0db1d492783b8b7fa41d9532f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map: 0%| | 0/3360 [00:00<?, ? examples/s]"
|
||||
"Map: 0%| | 0/100 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
@ -108,7 +123,7 @@
|
|||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1adbb0ec04254d7dadd0552b61449ba9",
|
||||
"model_id": "0598ac059eb74a14aecaea95eda91ee7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
|
@ -156,7 +171,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -173,7 +188,7 @@
|
|||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [10/10 07:32, Epoch 0/1]\n",
|
||||
" [10/10 00:17, Epoch 0/1]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
|
@ -187,15 +202,15 @@
|
|||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0.352300</td>\n",
|
||||
" <td>1.110698</td>\n",
|
||||
" <td>0.502083</td>\n",
|
||||
" <td>0.977000</td>\n",
|
||||
" <td>1.173903</td>\n",
|
||||
" <td>0.560000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>0.861100</td>\n",
|
||||
" <td>1.112852</td>\n",
|
||||
" <td>0.501786</td>\n",
|
||||
" <td>1.450300</td>\n",
|
||||
" <td>1.155836</td>\n",
|
||||
" <td>0.560000</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
|
@ -206,24 +221,14 @@
|
|||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TrainOutput(global_step=10, training_loss=0.6066924571990967, metrics={'train_runtime': 452.6593, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.6066924571990967, 'epoch': 0.001488095238095238})"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# if using `Trainer`\n",
|
||||
"# if using `Trainer`: \n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=lora_model,\n",
|
||||
" args=TrainingArguments(\n",
|
||||
" output_dir=\"../experiment/\",\n",
|
||||
" output_dir=\"../experiment/classification/Trainer/\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" num_train_epochs=1,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
|
@ -233,6 +238,7 @@
|
|||
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
||||
" evaluation_strategy=\"steps\",\n",
|
||||
" save_strategy=\"steps\",\n",
|
||||
" save_steps=5,\n",
|
||||
" load_best_model_at_end=True,\n",
|
||||
" report_to=\"none\"\n",
|
||||
" ),\n",
|
||||
|
@ -243,7 +249,106 @@
|
|||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"trainer.train()"
|
||||
"trainer.train()\n",
|
||||
"trainer.save_state()\n",
|
||||
"trainer.save_model(output_dir=\"../experiment/classification/Trainer/model/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"max_steps is given, it will override any value given in num_train_epochs\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [10/10 00:17, Epoch 0/1]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Step</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" <th>Validation Loss</th>\n",
|
||||
" <th>Accuracy</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0.827100</td>\n",
|
||||
" <td>1.106255</td>\n",
|
||||
" <td>0.550000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>1.360100</td>\n",
|
||||
" <td>1.089411</td>\n",
|
||||
" <td>0.540000</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# if using `SFTTrainer`\n",
|
||||
"trainer = SFTTrainer(\n",
|
||||
" model=lora_model,\n",
|
||||
" args=TrainingArguments(\n",
|
||||
" output_dir=\"../experiment/classification/SFTTrainer/\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" num_train_epochs=1,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" logging_steps=5,\n",
|
||||
" max_steps=10,\n",
|
||||
" per_device_train_batch_size=BATCH_SIZE,\n",
|
||||
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
||||
" evaluation_strategy=\"steps\",\n",
|
||||
" save_strategy=\"steps\",\n",
|
||||
" save_steps=5,\n",
|
||||
" load_best_model_at_end=True,\n",
|
||||
" report_to=\"none\"\n",
|
||||
" ),\n",
|
||||
" train_dataset=dataset[\"train\"],\n",
|
||||
" eval_dataset=dataset[\"valid\"],\n",
|
||||
" dataset_text_field=\"text\", # main different between SFTTrainer and Trainer\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"trainer.train()\n",
|
||||
"trainer.save_state()\n",
|
||||
"trainer.save_model(output_dir=\"../experiment/classification/SFTTrainer/model/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"del model\n",
|
||||
"del trainer\n",
|
||||
"torch.cuda.empty_cache()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -251,93 +356,82 @@
|
|||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"max_steps is given, it will override any value given in num_train_epochs\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [10/10 07:32, Epoch 0/1]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Step</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" <th>Validation Loss</th>\n",
|
||||
" <th>Accuracy</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0.280700</td>\n",
|
||||
" <td>1.126119</td>\n",
|
||||
" <td>0.497024</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>0.803900</td>\n",
|
||||
" <td>1.128193</td>\n",
|
||||
" <td>0.497024</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
],
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3a6911922f574bc29d938985bb1cf2f5",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TrainOutput(global_step=10, training_loss=0.542291009426117, metrics={'train_runtime': 452.9047, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.542291009426117, 'epoch': 0.001488095238095238})"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
||||
"The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n",
|
||||
"100%|██████████| 100/100 [02:30<00:00, 1.50s/it]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# if using `SFTTrainer`\n",
|
||||
"trainer = SFTTrainer(\n",
|
||||
" model=lora_model,\n",
|
||||
" args=TrainingArguments(\n",
|
||||
" output_dir=\"../experiment/\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" num_train_epochs=1,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" logging_steps=5,\n",
|
||||
" max_steps=10,\n",
|
||||
" per_device_train_batch_size=BATCH_SIZE,\n",
|
||||
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
||||
" evaluation_strategy=\"steps\",\n",
|
||||
" save_strategy=\"steps\",\n",
|
||||
" load_best_model_at_end=True,\n",
|
||||
" report_to=\"none\"\n",
|
||||
" ),\n",
|
||||
" train_dataset=dataset[\"train\"],\n",
|
||||
" eval_dataset=dataset[\"valid\"],\n",
|
||||
" dataset_text_field=\"text\",\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"## load the tuned model and use it in testset\n",
|
||||
"base_model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_CLASSES)\n",
|
||||
"merged_model = PeftModel.from_pretrained(base_model, \"../experiment/classification/Trainer/model/\")\n",
|
||||
"trained_model = merged_model.merge_and_unload()\n",
|
||||
"\n",
|
||||
"trainer.train()"
|
||||
"trained_tokenizer = AutoTokenizer.from_pretrained(\"../experiment/classification/Trainer/model/\")\n",
|
||||
"\n",
|
||||
"cls_pipeline = pipeline(\"text-classification\", merged_model, tokenizer=tokenizer)\n",
|
||||
"\n",
|
||||
"predictions = []\n",
|
||||
"for text in tqdm(dataset[\"test\"][\"text\"]):\n",
|
||||
" prediction = cls_pipeline(text)\n",
|
||||
" prediction = int(prediction[0]['label'].split('_')[1])\n",
|
||||
" predictions.append(prediction)\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.58 0.91 0.71 56\n",
|
||||
" 1 0.58 0.16 0.25 44\n",
|
||||
"\n",
|
||||
" accuracy 0.58 100\n",
|
||||
" macro avg 0.58 0.53 0.48 100\n",
|
||||
"weighted avg 0.58 0.58 0.51 100\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"\n",
|
||||
"print(classification_report(pd.DataFrame(validset)[\"label\"], predictions))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from tqdm import tqdm
|
||||
from random import randint
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
|
|
Loading…
Reference in New Issue