diff --git a/README.md b/README.md
index 939ef10..85f300f 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
## Base
-* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM)
+* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/)
* Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index)
## HOW-TO
diff --git a/notebooks/classification-gemma.ipynb b/notebooks/classification.ipynb
similarity index 92%
rename from notebooks/classification-gemma.ipynb
rename to notebooks/classification.ipynb
index 88a7fae..8d8d5b8 100644
--- a/notebooks/classification-gemma.ipynb
+++ b/notebooks/classification.ipynb
@@ -73,7 +73,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "283eecad62ac4c468217952044f683a9",
+ "model_id": "d3ecaba97ac945b89d7851f2780f4faa",
"version_major": 2,
"version_minor": 0
},
@@ -87,7 +87,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "ba35d6b0db1d492783b8b7fa41d9532f",
+ "model_id": "8eeaea87085c477482edbcf859cb48ab",
"version_major": 2,
"version_minor": 0
},
@@ -100,7 +100,7 @@
}
],
"source": [
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
"\n",
"dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
"dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
@@ -123,7 +123,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "0598ac059eb74a14aecaea95eda91ee7",
+ "model_id": "808a16b1185243e6bfc8c83f09f5f9ba",
"version_major": 2,
"version_minor": 0
},
@@ -152,7 +152,7 @@
],
"source": [
"lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n",
- "model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, device_map=\"auto\", num_labels=NUM_CLASSES)\n",
+ "model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_CLASSES, trust_remote_code=True)\n",
"lora_model = get_peft_model(model, lora_config)\n",
"lora_model.print_trainable_parameters()"
]
@@ -188,7 +188,7 @@
"
\n",
" \n",
"
\n",
- " [10/10 00:17, Epoch 0/1]\n",
+ " [10/10 00:16, Epoch 0/1]\n",
"
\n",
" \n",
" \n",
@@ -202,15 +202,15 @@
" \n",
" \n",
" 5 | \n",
- " 0.977000 | \n",
- " 1.173903 | \n",
- " 0.560000 | \n",
+ " 1.423400 | \n",
+ " 1.191649 | \n",
+ " 0.510000 | \n",
"
\n",
" \n",
" 10 | \n",
- " 1.450300 | \n",
- " 1.155836 | \n",
- " 0.560000 | \n",
+ " 1.164800 | \n",
+ " 1.191105 | \n",
+ " 0.520000 | \n",
"
\n",
" \n",
"
"
@@ -273,7 +273,7 @@
"
\n",
" \n",
"
\n",
- " [10/10 00:17, Epoch 0/1]\n",
+ " [10/10 00:16, Epoch 0/1]\n",
"
\n",
" \n",
" \n",
@@ -287,15 +287,15 @@
" \n",
" \n",
" 5 | \n",
- " 0.827100 | \n",
- " 1.106255 | \n",
- " 0.550000 | \n",
+ " 1.295200 | \n",
+ " 1.138854 | \n",
+ " 0.530000 | \n",
"
\n",
" \n",
" 10 | \n",
- " 1.360100 | \n",
- " 1.089411 | \n",
- " 0.540000 | \n",
+ " 1.100300 | \n",
+ " 1.137775 | \n",
+ " 0.530000 | \n",
"
\n",
" \n",
"
"
@@ -359,7 +359,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "3a6911922f574bc29d938985bb1cf2f5",
+ "model_id": "ef2e165bd6c74fb2b962021549e619bd",
"version_major": 2,
"version_minor": 0
},
@@ -377,7 +377,7 @@
"Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n",
- "100%|██████████| 100/100 [02:30<00:00, 1.50s/it]\n"
+ "100%|██████████| 100/100 [02:23<00:00, 1.43s/it]\n"
]
}
],
@@ -410,12 +410,12 @@
"text": [
" precision recall f1-score support\n",
"\n",
- " 0 0.58 0.91 0.71 56\n",
- " 1 0.58 0.16 0.25 44\n",
+ " 0 0.60 0.90 0.72 59\n",
+ " 1 0.45 0.12 0.19 41\n",
"\n",
" accuracy 0.58 100\n",
- " macro avg 0.58 0.53 0.48 100\n",
- "weighted avg 0.58 0.58 0.51 100\n",
+ " macro avg 0.53 0.51 0.45 100\n",
+ "weighted avg 0.54 0.58 0.50 100\n",
"\n"
]
}
diff --git a/notebooks/instruction.ipynb b/notebooks/instruction.ipynb
new file mode 100644
index 0000000..dff7752
--- /dev/null
+++ b/notebooks/instruction.ipynb
@@ -0,0 +1,387 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Instruction Tuning \n",
+ "\n",
+ "In this notebook, I will give you an example of LLM's fine-tuning with instruction. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"../src\")\n",
+ "from utils import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_ID = \"google/gemma-2b\"\n",
+ "\n",
+ "SEED = 42\n",
+ "NUM_CLASSES = 2\n",
+ "\n",
+ "EPOCHS = 2\n",
+ "\n",
+ "LORA_R = 8\n",
+ "LORA_ALPHA = 16\n",
+ "LORA_DROPOUT = 0.1\n",
+ "BATCH_SIZE = 1\n",
+ "\n",
+ "set_seed(SEED)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "data source: https://github.com/tatsu-lab/stanford_alpaca\n",
+ "'''\n",
+ "\n",
+ "dataset = load_dataset(\"json\", data_files=\"../data/alpaca/data.json\")[\"train\"]\n",
+ "dataset = dataset.shuffle().train_test_split(0.2)\n",
+ "trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
+ "dataset = validset.train_test_split(0.5)\n",
+ "validset, testset = dataset[\"train\"], dataset[\"test\"]\n",
+ "\n",
+ "# sampling, as it is just a example file\n",
+ "trainset = trainset.select(range(5)).to_pandas()\n",
+ "validset = validset.select(range(5)).to_pandas()\n",
+ "testset = testset.select(range(5)).to_pandas()\n",
+ "\n",
+ "trainset[\"prompt\"] = \"user \" + trainset[\"instruction\"] + \"\\n\\n\" + \"model \" + trainset[\"output\"] + \"\\n\"\n",
+ "validset[\"prompt\"] = \"user \" + validset[\"instruction\"] + \"\\n\\n\" + \"model \" + validset[\"output\"] + \"\\n\"\n",
+ "testset[\"prompt\"] = \"user \" + testset[\"instruction\"] + \"\\n\\n\" + \"model \"\n",
+ "hello = testset\n",
+ "\n",
+ "trainset, validset, testset = Dataset.from_pandas(trainset[[\"prompt\"]]), Dataset.from_pandas(validset[[\"prompt\"]]), Dataset.from_pandas(testset[[\"prompt\"]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d14f410e8838476ebc4cc027bf37f7ba",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/5 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4416e436cf5f44f094c90a11243893a3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/5 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
+ "\n",
+ "trainset = trainset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
+ "validset = validset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
+ "#testset = testset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
+ "\n",
+ "dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Gemma's activation function should be approximate GeLU and not exact GeLU.\n",
+ "Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "54159c44cb59468a84da3a99d8fc61f1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493\n"
+ ]
+ }
+ ],
+ "source": [
+ "lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = \"CAUSAL_LM\", target_modules = \"all-linear\")\n",
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=\"auto\")\n",
+ "lora_model = get_peft_model(model, lora_config)\n",
+ "lora_model.print_trainable_parameters()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "max_steps is given, it will override any value given in num_train_epochs\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 2/20 : < :, Epoch 0.20/4]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# if using `Trainer`: \n",
+ "trainer = Trainer(\n",
+ " model=lora_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=\"../experiment/instruction/Trainer/\",\n",
+ " learning_rate=2e-5,\n",
+ " num_train_epochs=1,\n",
+ " weight_decay=0.01,\n",
+ " logging_steps=5,\n",
+ " max_steps=20,\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " evaluation_strategy=\"steps\",\n",
+ " save_strategy=\"steps\",\n",
+ " save_steps=5,\n",
+ " load_best_model_at_end=True,\n",
+ " report_to=\"none\"\n",
+ " ),\n",
+ " train_dataset=dataset[\"train\"],\n",
+ " eval_dataset=dataset[\"valid\"],\n",
+ " tokenizer=tokenizer,\n",
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n",
+ ")\n",
+ "\n",
+ "trainer.train()\n",
+ "trainer.save_state()\n",
+ "trainer.save_model(output_dir=\"../experiment/instruction/Trainer/model/\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'generated_text': 'user Name two major world religions.\\n\\nmodel inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol'}\n",
+ "Two major world religions are Christianity and Islam.\n"
+ ]
+ }
+ ],
+ "source": [
+ "ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n",
+ "output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n",
+ "print(output[0])\n",
+ "print(hello.loc[0, \"output\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "max_steps is given, it will override any value given in num_train_epochs\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 4/10 00:00 < 00:01, 3.36 it/s, Epoch 0.60/2]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# if using `SFTTrainer`\n",
+ "trainer = SFTTrainer(\n",
+ " model=lora_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=\"../experiment/instruction/SFTTrainer/\",\n",
+ " learning_rate=2e-5,\n",
+ " num_train_epochs=1,\n",
+ " weight_decay=0.01,\n",
+ " logging_steps=5,\n",
+ " max_steps=10,\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " evaluation_strategy=\"steps\",\n",
+ " save_strategy=\"steps\",\n",
+ " save_steps=5,\n",
+ " load_best_model_at_end=True,\n",
+ " report_to=\"none\"\n",
+ " ),\n",
+ " train_dataset=dataset[\"train\"],\n",
+ " eval_dataset=dataset[\"valid\"],\n",
+ " dataset_text_field=\"prompt\",\n",
+ " tokenizer=tokenizer,\n",
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n",
+ ")\n",
+ "\n",
+ "trainer.train()\n",
+ "trainer.save_state()\n",
+ "trainer.save_model(output_dir=\"../experiment/instruction/SFTTrainer/model/\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'generated_text': 'user Name two major world religions.\\n\\nmodel Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora '}\n",
+ "Two major world religions are Christianity and Islam.\n"
+ ]
+ }
+ ],
+ "source": [
+ "ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n",
+ "output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n",
+ "print(output[0])\n",
+ "print(hello.loc[0, \"output\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "languageM",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/requirements.txt b/requirements.txt
index 94bdb32..a4d1213 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,4 +13,5 @@ datasets
accelerate
modelscope
bitsandbytes
-transformers
\ No newline at end of file
+transformers
+sentencepiece
\ No newline at end of file
diff --git a/src/utils.py b/src/utils.py
index 6e88fe1..9d007c2 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -7,7 +7,7 @@ import pandas as pd
from tqdm import tqdm
from random import randint
-from datasets import load_dataset, DatasetDict
+from datasets import load_dataset, DatasetDict, Dataset
import transformers
from trl import SFTTrainer