add instruction

This commit is contained in:
TerenceLiu98 2024-04-25 12:42:03 +00:00
parent a663caaa4c
commit 7b33da1135
5 changed files with 416 additions and 28 deletions

View File

@ -4,7 +4,7 @@
## Base
* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM)
* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/)
* Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index)
## HOW-TO

View File

@ -73,7 +73,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "283eecad62ac4c468217952044f683a9",
"model_id": "d3ecaba97ac945b89d7851f2780f4faa",
"version_major": 2,
"version_minor": 0
},
@ -87,7 +87,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba35d6b0db1d492783b8b7fa41d9532f",
"model_id": "8eeaea87085c477482edbcf859cb48ab",
"version_major": 2,
"version_minor": 0
},
@ -100,7 +100,7 @@
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
"\n",
"dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
"dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n",
@ -123,7 +123,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0598ac059eb74a14aecaea95eda91ee7",
"model_id": "808a16b1185243e6bfc8c83f09f5f9ba",
"version_major": 2,
"version_minor": 0
},
@ -152,7 +152,7 @@
],
"source": [
"lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n",
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, device_map=\"auto\", num_labels=NUM_CLASSES)\n",
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_CLASSES, trust_remote_code=True)\n",
"lora_model = get_peft_model(model, lora_config)\n",
"lora_model.print_trainable_parameters()"
]
@ -188,7 +188,7 @@
" <div>\n",
" \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 00:17, Epoch 0/1]\n",
" [10/10 00:16, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
@ -202,15 +202,15 @@
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.977000</td>\n",
" <td>1.173903</td>\n",
" <td>0.560000</td>\n",
" <td>1.423400</td>\n",
" <td>1.191649</td>\n",
" <td>0.510000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>1.450300</td>\n",
" <td>1.155836</td>\n",
" <td>0.560000</td>\n",
" <td>1.164800</td>\n",
" <td>1.191105</td>\n",
" <td>0.520000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
@ -273,7 +273,7 @@
" <div>\n",
" \n",
" <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [10/10 00:17, Epoch 0/1]\n",
" [10/10 00:16, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
@ -287,15 +287,15 @@
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.827100</td>\n",
" <td>1.106255</td>\n",
" <td>0.550000</td>\n",
" <td>1.295200</td>\n",
" <td>1.138854</td>\n",
" <td>0.530000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>1.360100</td>\n",
" <td>1.089411</td>\n",
" <td>0.540000</td>\n",
" <td>1.100300</td>\n",
" <td>1.137775</td>\n",
" <td>0.530000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
@ -359,7 +359,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a6911922f574bc29d938985bb1cf2f5",
"model_id": "ef2e165bd6c74fb2b962021549e619bd",
"version_major": 2,
"version_minor": 0
},
@ -377,7 +377,7 @@
"Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n",
"100%|██████████| 100/100 [02:30<00:00, 1.50s/it]\n"
"100%|██████████| 100/100 [02:23<00:00, 1.43s/it]\n"
]
}
],
@ -410,12 +410,12 @@
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.58 0.91 0.71 56\n",
" 1 0.58 0.16 0.25 44\n",
" 0 0.60 0.90 0.72 59\n",
" 1 0.45 0.12 0.19 41\n",
"\n",
" accuracy 0.58 100\n",
" macro avg 0.58 0.53 0.48 100\n",
"weighted avg 0.58 0.58 0.51 100\n",
" macro avg 0.53 0.51 0.45 100\n",
"weighted avg 0.54 0.58 0.50 100\n",
"\n"
]
}

387
notebooks/instruction.ipynb Normal file
View File

@ -0,0 +1,387 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instruction Tuning \n",
"\n",
"In this notebook, I will give you an example of LLM's fine-tuning with instruction. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../src\")\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"MODEL_ID = \"google/gemma-2b\"\n",
"\n",
"SEED = 42\n",
"NUM_CLASSES = 2\n",
"\n",
"EPOCHS = 2\n",
"\n",
"LORA_R = 8\n",
"LORA_ALPHA = 16\n",
"LORA_DROPOUT = 0.1\n",
"BATCH_SIZE = 1\n",
"\n",
"set_seed(SEED)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"'''\n",
"data source: https://github.com/tatsu-lab/stanford_alpaca\n",
"'''\n",
"\n",
"dataset = load_dataset(\"json\", data_files=\"../data/alpaca/data.json\")[\"train\"]\n",
"dataset = dataset.shuffle().train_test_split(0.2)\n",
"trainset, validset = dataset[\"train\"], dataset[\"test\"]\n",
"dataset = validset.train_test_split(0.5)\n",
"validset, testset = dataset[\"train\"], dataset[\"test\"]\n",
"\n",
"# sampling, as it is just a example file\n",
"trainset = trainset.select(range(5)).to_pandas()\n",
"validset = validset.select(range(5)).to_pandas()\n",
"testset = testset.select(range(5)).to_pandas()\n",
"\n",
"trainset[\"prompt\"] = \"<start_of_turn>user \" + trainset[\"instruction\"] + \"<end_of_turn>\\n\\n\" + \"<start_of_turn>model \" + trainset[\"output\"] + \"<end_of_turn>\\n\"\n",
"validset[\"prompt\"] = \"<start_of_turn>user \" + validset[\"instruction\"] + \"<end_of_turn>\\n\\n\" + \"<start_of_turn>model \" + validset[\"output\"] + \"<end_of_turn>\\n\"\n",
"testset[\"prompt\"] = \"<start_of_turn>user \" + testset[\"instruction\"] + \"<end_of_turn>\\n\\n\" + \"<start_of_turn>model \"\n",
"hello = testset\n",
"\n",
"trainset, validset, testset = Dataset.from_pandas(trainset[[\"prompt\"]]), Dataset.from_pandas(validset[[\"prompt\"]]), Dataset.from_pandas(testset[[\"prompt\"]])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d14f410e8838476ebc4cc027bf37f7ba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/5 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4416e436cf5f44f094c90a11243893a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/5 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
"\n",
"trainset = trainset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
"validset = validset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
"#testset = testset.map(lambda samples: tokenizer(samples[\"prompt\"], max_length=512, truncation=True), batched=True).remove_columns(\"prompt\")\n",
"\n",
"dataset = DatasetDict({\"train\": trainset, \"valid\": validset, \"test\": testset})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n",
"Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "54159c44cb59468a84da3a99d8fc61f1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493\n"
]
}
],
"source": [
"lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = \"CAUSAL_LM\", target_modules = \"all-linear\")\n",
"model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=\"auto\")\n",
"lora_model = get_peft_model(model, lora_config)\n",
"lora_model.print_trainable_parameters()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='2' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 2/20 : < :, Epoch 0.20/4]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# if using `Trainer`: \n",
"trainer = Trainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/instruction/Trainer/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=20,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" save_steps=5,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" tokenizer=tokenizer,\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n",
")\n",
"\n",
"trainer.train()\n",
"trainer.save_state()\n",
"trainer.save_model(output_dir=\"../experiment/instruction/Trainer/model/\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'generated_text': '<start_of_turn>user Name two major world religions.<end_of_turn>\\n\\n<start_of_turn>model inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol'}\n",
"Two major world religions are Christianity and Islam.\n"
]
}
],
"source": [
"ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n",
"output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n",
"print(output[0])\n",
"print(hello.loc[0, \"output\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"max_steps is given, it will override any value given in num_train_epochs\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='4' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 4/10 00:00 < 00:01, 3.36 it/s, Epoch 0.60/2]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# if using `SFTTrainer`\n",
"trainer = SFTTrainer(\n",
" model=lora_model,\n",
" args=TrainingArguments(\n",
" output_dir=\"../experiment/instruction/SFTTrainer/\",\n",
" learning_rate=2e-5,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" logging_steps=5,\n",
" max_steps=10,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" save_steps=5,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\"\n",
" ),\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"valid\"],\n",
" dataset_text_field=\"prompt\",\n",
" tokenizer=tokenizer,\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n",
")\n",
"\n",
"trainer.train()\n",
"trainer.save_state()\n",
"trainer.save_model(output_dir=\"../experiment/instruction/SFTTrainer/model/\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'generated_text': '<start_of_turn>user Name two major world religions.<end_of_turn>\\n\\n<start_of_turn>model Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora '}\n",
"Two major world religions are Christianity and Islam.\n"
]
}
],
"source": [
"ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n",
"output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n",
"print(output[0])\n",
"print(hello.loc[0, \"output\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "languageM",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -14,3 +14,4 @@ accelerate
modelscope
bitsandbytes
transformers
sentencepiece

View File

@ -7,7 +7,7 @@ import pandas as pd
from tqdm import tqdm
from random import randint
from datasets import load_dataset, DatasetDict
from datasets import load_dataset, DatasetDict, Dataset
import transformers
from trl import SFTTrainer