diff --git a/notebooks/classification-gemma.ipynb b/notebooks/classification-gemma.ipynb index 9c39565..88a7fae 100644 --- a/notebooks/classification-gemma.ipynb +++ b/notebooks/classification-gemma.ipynb @@ -1,5 +1,14 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Text Classification\n", + "\n", + "In this notebook, I will give you an example of LLM's fine-tuning for text classification. " + ] + }, { "cell_type": "code", "execution_count": 1, @@ -26,6 +35,12 @@ "\n", "dataset = dataset.shuffle().train_test_split(0.2)\n", "trainset, validset = dataset[\"train\"], dataset[\"test\"]\n", + "\n", + "# sampling, as it is just a example file\n", + "trainset = trainset.select(range(100))\n", + "validset = validset.select(range(100))\n", + "testset = testset.select(range(100))\n", + "\n", "dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})" ] }, @@ -35,7 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_ID = \"google/gemma-7b\"\n", + "MODEL_ID = \"google/gemma-2b\"\n", "\n", "SEED = 42\n", "NUM_CLASSES = 2\n", @@ -58,12 +73,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "592e2aefe76641e3a6fc22affb1781c7", + "model_id": "283eecad62ac4c468217952044f683a9", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Map: 0%| | 0/13440 [00:00\n", " \n", " \n", - " [10/10 07:32, Epoch 0/1]\n", + " [10/10 00:17, Epoch 0/1]\n", " \n", " \n", " \n", @@ -187,15 +202,15 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", "
50.3523001.1106980.5020830.9770001.1739030.560000
100.8611001.1128520.5017861.4503001.1558360.560000

" @@ -206,24 +221,14 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "TrainOutput(global_step=10, training_loss=0.6066924571990967, metrics={'train_runtime': 452.6593, 'train_samples_per_second': 0.044, 'train_steps_per_second': 0.022, 'total_flos': 6787820038656.0, 'train_loss': 0.6066924571990967, 'epoch': 0.001488095238095238})" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "# if using `Trainer`\n", + "# if using `Trainer`: \n", "trainer = Trainer(\n", " model=lora_model,\n", " args=TrainingArguments(\n", - " output_dir=\"../experiment/\",\n", + " output_dir=\"../experiment/classification/Trainer/\",\n", " learning_rate=2e-5,\n", " num_train_epochs=1,\n", " weight_decay=0.01,\n", @@ -233,6 +238,7 @@ " per_device_eval_batch_size=BATCH_SIZE,\n", " evaluation_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", + " save_steps=5,\n", " load_best_model_at_end=True,\n", " report_to=\"none\"\n", " ),\n", @@ -243,7 +249,106 @@ " compute_metrics=compute_metrics,\n", ")\n", "\n", - "trainer.train()" + "trainer.train()\n", + "trainer.save_state()\n", + "trainer.save_model(output_dir=\"../experiment/classification/Trainer/model/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [10/10 00:17, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossAccuracy
50.8271001.1062550.550000
101.3601001.0894110.540000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# if using `SFTTrainer`\n", + "trainer = SFTTrainer(\n", + " model=lora_model,\n", + " args=TrainingArguments(\n", + " output_dir=\"../experiment/classification/SFTTrainer/\",\n", + " learning_rate=2e-5,\n", + " num_train_epochs=1,\n", + " weight_decay=0.01,\n", + " logging_steps=5,\n", + " max_steps=10,\n", + " per_device_train_batch_size=BATCH_SIZE,\n", + " per_device_eval_batch_size=BATCH_SIZE,\n", + " evaluation_strategy=\"steps\",\n", + " save_strategy=\"steps\",\n", + " save_steps=5,\n", + " load_best_model_at_end=True,\n", + " report_to=\"none\"\n", + " ),\n", + " train_dataset=dataset[\"train\"],\n", + " eval_dataset=dataset[\"valid\"],\n", + " dataset_text_field=\"text\", # main different between SFTTrainer and Trainer\n", + " tokenizer=tokenizer,\n", + " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n", + " compute_metrics=compute_metrics,\n", + ")\n", + "\n", + "trainer.train()\n", + "trainer.save_state()\n", + "trainer.save_model(output_dir=\"../experiment/classification/SFTTrainer/model/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "del model\n", + "del trainer\n", + "torch.cuda.empty_cache()" ] }, { @@ -251,93 +356,82 @@ "execution_count": 10, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "max_steps is given, it will override any value given in num_train_epochs\n" - ] - }, { "data": { - "text/html": [ - "\n", - "

\n", - " \n", - " \n", - " [10/10 07:32, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining LossValidation LossAccuracy
50.2807001.1261190.497024
100.8039001.1281930.497024

" - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "3a6911922f574bc29d938985bb1cf2f5", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "" + "Loading checkpoint shards: 0%| | 0/2 [00:00