add instruction

This commit is contained in:
TerenceLiu98 2024-04-25 12:42:03 +00:00
parent a663caaa4c
commit 7b33da1135
5 changed files with 416 additions and 28 deletions

View File

@ -4,7 +4,7 @@
## Base ## Base
* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM) * Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/)
* Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index) * Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index)
## HOW-TO ## HOW-TO

View File

@ -73,7 +73,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "283eecad62ac4c468217952044f683a9", "model_id": "d3ecaba97ac945b89d7851f2780f4faa",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -87,7 +87,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "ba35d6b0db1d492783b8b7fa41d9532f", "model_id": "8eeaea87085c477482edbcf859cb48ab",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -100,7 +100,7 @@
} }
], ],
"source": [ "source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
"\n", "\n",
"dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n", "dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
"dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n", "dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
@ -123,7 +123,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "0598ac059eb74a14aecaea95eda91ee7", "model_id": "808a16b1185243e6bfc8c83f09f5f9ba",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -152,7 +152,7 @@
], ],
"source": [ "source": [
"lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n", "lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n",
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, device_map=\"auto\", num_labels=NUM_CLASSES)\n", "model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_CLASSES, trust_remote_code=True)\n",
"lora_model = get_peft_model(model, lora_config)\n", "lora_model = get_peft_model(model, lora_config)\n",
"lora_model.print_trainable_parameters()" "lora_model.print_trainable_parameters()"
] ]
@ -188,7 +188,7 @@
" <div>\n", " <div>\n",
" \n", " \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 00:17, Epoch 0/1]\n", " [10/10 00:16, Epoch 0/1]\n",
" </div>\n", " </div>\n",
" <table border=\"1\" class=\"dataframe\">\n", " <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n", " <thead>\n",
@ -202,15 +202,15 @@
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <td>5</td>\n", " <td>5</td>\n",
" <td>0.977000</td>\n", " <td>1.423400</td>\n",
" <td>1.173903</td>\n", " <td>1.191649</td>\n",
" <td>0.560000</td>\n", " <td>0.510000</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>10</td>\n", " <td>10</td>\n",
" <td>1.450300</td>\n", " <td>1.164800</td>\n",
" <td>1.155836</td>\n", " <td>1.191105</td>\n",
" <td>0.560000</td>\n", " <td>0.520000</td>\n",
" </tr>\n", " </tr>\n",
" </tbody>\n", " </tbody>\n",
"</table><p>" "</table><p>"
@ -273,7 +273,7 @@
" <div>\n", " <div>\n",
" \n", " \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 00:17, Epoch 0/1]\n", " [10/10 00:16, Epoch 0/1]\n",
" </div>\n", " </div>\n",
" <table border=\"1\" class=\"dataframe\">\n", " <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n", " <thead>\n",
@ -287,15 +287,15 @@
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <td>5</td>\n", " <td>5</td>\n",
" <td>0.827100</td>\n", " <td>1.295200</td>\n",
" <td>1.106255</td>\n", " <td>1.138854</td>\n",
" <td>0.550000</td>\n", " <td>0.530000</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>10</td>\n", " <td>10</td>\n",
" <td>1.360100</td>\n", " <td>1.100300</td>\n",
" <td>1.089411</td>\n", " <td>1.137775</td>\n",
" <td>0.540000</td>\n", " <td>0.530000</td>\n",
" </tr>\n", " </tr>\n",
" </tbody>\n", " </tbody>\n",
"</table><p>" "</table><p>"
@ -359,7 +359,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "3a6911922f574bc29d938985bb1cf2f5", "model_id": "ef2e165bd6c74fb2b962021549e619bd",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -377,7 +377,7 @@
"Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n", "Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n", "The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n",
"100%|██████████| 100/100 [02:30<00:00, 1.50s/it]\n" "100%|██████████| 100/100 [02:23<00:00, 1.43s/it]\n"
] ]
} }
], ],
@ -410,12 +410,12 @@
"text": [ "text": [
" precision recall f1-score support\n", " precision recall f1-score support\n",
"\n", "\n",
" 0 0.58 0.91 0.71 56\n", " 0 0.60 0.90 0.72 59\n",
" 1 0.58 0.16 0.25 44\n", " 1 0.45 0.12 0.19 41\n",
"\n", "\n",
" accuracy 0.58 100\n", " accuracy 0.58 100\n",
" macro avg 0.58 0.53 0.48 100\n", " macro avg 0.53 0.51 0.45 100\n",
"weighted avg 0.58 0.58 0.51 100\n", "weighted avg 0.54 0.58 0.50 100\n",
"\n" "\n"
] ]
} }

387
notebooks/instruction.ipynb Normal file
View File

@ -0,0 +1,387 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instruction Tuning \n",
"\n",
"In this notebook, I will give you an example of LLM's fine-tuning with instruction. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../src\")\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"MODEL_ID = \"google/gemma-2b\"\n",
"\n",
"SEED = 42\n",
"NUM_CLASSES = 2\n",
"\n",
"EPOCHS = 2\n",
"\n",
"LORA_R = 8\n",
"LORA_ALPHA = 16\n",
"LORA_DROPOUT = 0.1\n",
"BATCH_SIZE = 1\n",
"\n",
"set_seed(SEED)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"'''\n",
"data source: https://github.com/tatsu-lab/stanford_alpaca\n",
"'''\n",
"\n",
"dataset = load_dataset(\"json\", data_files=\"../data/alpaca/data.json\")[\"train\"]\n",
"dataset = dataset.shuffle().train_test_split(0.2)\n",
"trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
"dataset = validset.train_test_split(0.5)\n",
"validset, testset = dataset[\"train\"], dataset[\"test\"]\n",
"\n",
"# sampling, as it is just a example file\n",
"trainset = trainset.select(range(5)).to_pandas()\n",
"validset = validset.select(range(5)).to_pandas()\n",
"testset = testset.select(range(5)).to_pandas()\n",
"\n",
"trainset[\"prompt\"] = \"<start_of_turn>user \" + trainset[\"instruction\"] + \"<end_of_turn>\\n\\n\" + \"<start_of_turn>model \" + trainset[\"output\"] + \"<end_of_turn>\\n\"\n",
"validset[\"prompt\"] = \"<start_of_turn>user \" + validset[\"instruction\"] + \"<end_of_turn>\\n\\n\" + \"<start_of_turn>model \" + validset[\"output\"] + \"<end_of_turn>\\n\"\n",
"testset[\"prompt\"] = \"<start_of_turn>user \" + testset[\"instruction\"] + \"<end_of_turn>\\n\\n\" + \"<start_of_turn>model \"\n",
"hello = testset\n",
"\n",
"trainset, validset, testset = Dataset.from_pandas(trainset[[\"prompt\"]]), Dataset.from_pandas(validset[[\"prompt\"]]), Dataset.from_pandas(testset[[\"prompt\"]])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d14f410e8838476ebc4cc027bf37f7ba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/5 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4416e436cf5f44f094c90a11243893a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/5 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
"\n",
"trainset = trainset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
"validset = validset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
"#testset = testset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
"\n",
"dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n",
"Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "54159c44cb59468a84da3a99d8fc61f1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493\n"
]
}
],
"source": [
"lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = \"CAUSAL_LM\", target_modules = \"all-linear\")\n",
"model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=\"auto\")\n",
"lora_model = get_peft_model(model, lora_config)\n",
"lora_model.print_trainable_parameters()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='2' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 2/20 : < :, Epoch 0.20/4]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# if using `Trainer`: \n",
"trainer = Trainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/instruction/Trainer/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=20,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" save_steps=5,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" tokenizer=tokenizer,\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n",
")\n",
"\n",
"trainer.train()\n",
"trainer.save_state()\n",
"trainer.save_model(output_dir=\"../experiment/instruction/Trainer/model/\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'generated_text': '<start_of_turn>user Name two major world religions.<end_of_turn>\\n\\n<start_of_turn>model inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol'}\n",
"Two major world religions are Christianity and Islam.\n"
]
}
],
"source": [
"ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n",
"output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n",
"print(output[0])\n",
"print(hello.loc[0, \"output\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='4' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 4/10 00:00 < 00:01, 3.36 it/s, Epoch 0.60/2]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# if using `SFTTrainer`\n",
"trainer = SFTTrainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/instruction/SFTTrainer/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=10,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" save_steps=5,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" dataset_text_field=\"prompt\",\n",
" tokenizer=tokenizer,\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n",
")\n",
"\n",
"trainer.train()\n",
"trainer.save_state()\n",
"trainer.save_model(output_dir=\"../experiment/instruction/SFTTrainer/model/\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'generated_text': '<start_of_turn>user Name two major world religions.<end_of_turn>\\n\\n<start_of_turn>model Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora '}\n",
"Two major world religions are Christianity and Islam.\n"
]
}
],
"source": [
"ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n",
"output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n",
"print(output[0])\n",
"print(hello.loc[0, \"output\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "languageM",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -13,4 +13,5 @@ datasets
accelerate accelerate
modelscope modelscope
bitsandbytes bitsandbytes
transformers transformers
sentencepiece

View File

@ -7,7 +7,7 @@ import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from random import randint from random import randint
from datasets import load_dataset, DatasetDict from datasets import load_dataset, DatasetDict, Dataset
import transformers import transformers
from trl import SFTTrainer from trl import SFTTrainer