From 7b33da1135df8e4c0232f20e4f7fb82018aa489b Mon Sep 17 00:00:00 2001 From: TerenceLiu98 Date: Thu, 25 Apr 2024 12:42:03 +0000 Subject: [PATCH] add instruction --- README.md | 2 +- ...ation-gemma.ipynb => classification.ipynb} | 50 +-- notebooks/instruction.ipynb | 387 ++++++++++++++++++ requirements.txt | 3 +- src/utils.py | 2 +- 5 files changed, 416 insertions(+), 28 deletions(-) rename notebooks/{classification-gemma.ipynb => classification.ipynb} (92%) create mode 100644 notebooks/instruction.ipynb diff --git a/README.md b/README.md index 939ef10..85f300f 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ## Base -* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) and [THUDM/GLM](https://github.com/THUDM/GLM) +* Models: [Google/Gemma](https://blog.google/technology/developers/gemma-open-models/) * Packages: [transformers](https://huggingface.co/docs/transformers/index), [Datasets](https://huggingface.co/docs/datasets/index), [trl](https://huggingface.co/docs/trl/index) ## HOW-TO diff --git a/notebooks/classification-gemma.ipynb b/notebooks/classification.ipynb similarity index 92% rename from notebooks/classification-gemma.ipynb rename to notebooks/classification.ipynb index 88a7fae..8d8d5b8 100644 --- a/notebooks/classification-gemma.ipynb +++ b/notebooks/classification.ipynb @@ -73,7 +73,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "283eecad62ac4c468217952044f683a9", + "model_id": "d3ecaba97ac945b89d7851f2780f4faa", "version_major": 2, "version_minor": 0 }, @@ -87,7 +87,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ba35d6b0db1d492783b8b7fa41d9532f", + "model_id": "8eeaea87085c477482edbcf859cb48ab", "version_major": 2, "version_minor": 0 }, @@ -100,7 +100,7 @@ } ], "source": [ - "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n", "\n", "dataset[\"test\"] = dataset[\"test\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n", "dataset[\"train\"] = dataset[\"train\"].map(lambda samples: tokenizer(samples[\"text\"], max_length=512, truncation=True), batched=True)\n", @@ -123,7 +123,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0598ac059eb74a14aecaea95eda91ee7", + "model_id": "808a16b1185243e6bfc8c83f09f5f9ba", "version_major": 2, "version_minor": 0 }, @@ -152,7 +152,7 @@ ], "source": [ "lora_config = LoraConfig(r = LORA_R, lora_alpha = LORA_ALPHA, lora_dropout = LORA_DROPOUT, task_type = TaskType.SEQ_CLS, target_modules = \"all-linear\")\n", - "model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, device_map=\"auto\", num_labels=NUM_CLASSES)\n", + "model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=NUM_CLASSES, trust_remote_code=True)\n", "lora_model = get_peft_model(model, lora_config)\n", "lora_model.print_trainable_parameters()" ] @@ -188,7 +188,7 @@ "
\n", " \n", " \n", - " [10/10 00:17, Epoch 0/1]\n", + " [10/10 00:16, Epoch 0/1]\n", "
\n", " \n", " \n", @@ -202,15 +202,15 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", "
50.9770001.1739030.5600001.4234001.1916490.510000
101.4503001.1558360.5600001.1648001.1911050.520000

" @@ -273,7 +273,7 @@ "

\n", " \n", " \n", - " [10/10 00:17, Epoch 0/1]\n", + " [10/10 00:16, Epoch 0/1]\n", "
\n", " \n", " \n", @@ -287,15 +287,15 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", "
50.8271001.1062550.5500001.2952001.1388540.530000
101.3601001.0894110.5400001.1003001.1377750.530000

" @@ -359,7 +359,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3a6911922f574bc29d938985bb1cf2f5", + "model_id": "ef2e165bd6c74fb2b962021549e619bd", "version_major": 2, "version_minor": 0 }, @@ -377,7 +377,7 @@ "Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at google/gemma-2b and are newly initialized: ['score.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "The model 'PeftModelForSequenceClassification' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTBigCodeForSequenceClassification', 'GPTNeoForSequenceClassification', 'GPTNeoXForSequenceClassification', 'GPTJForSequenceClassification', 'IBertForSequenceClassification', 'JambaForSequenceClassification', 'LayoutLMForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'LayoutLMv3ForSequenceClassification', 'LEDForSequenceClassification', 'LiltForSequenceClassification', 'LlamaForSequenceClassification', 'LongformerForSequenceClassification', 'LukeForSequenceClassification', 'MarkupLMForSequenceClassification', 'MBartForSequenceClassification', 'MegaForSequenceClassification', 'MegatronBertForSequenceClassification', 'MistralForSequenceClassification', 'MixtralForSequenceClassification', 'MobileBertForSequenceClassification', 'MPNetForSequenceClassification', 'MptForSequenceClassification', 'MraForSequenceClassification', 'MT5ForSequenceClassification', 'MvpForSequenceClassification', 'NezhaForSequenceClassification', 'NystromformerForSequenceClassification', 'OpenLlamaForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'OPTForSequenceClassification', 'PerceiverForSequenceClassification', 'PersimmonForSequenceClassification', 'PhiForSequenceClassification', 'PLBartForSequenceClassification', 'QDQBertForSequenceClassification', 'Qwen2ForSequenceClassification', 'Qwen2MoeForSequenceClassification', 'ReformerForSequenceClassification', 'RemBertForSequenceClassification', 'RobertaForSequenceClassification', 'RobertaPreLayerNormForSequenceClassification', 'RoCBertForSequenceClassification', 'RoFormerForSequenceClassification', 'SqueezeBertForSequenceClassification', 'StableLmForSequenceClassification', 'Starcoder2ForSequenceClassification', 'T5ForSequenceClassification', 'TapasForSequenceClassification', 'TransfoXLForSequenceClassification', 'UMT5ForSequenceClassification', 'XLMForSequenceClassification', 'XLMRobertaForSequenceClassification', 'XLMRobertaXLForSequenceClassification', 'XLNetForSequenceClassification', 'XmodForSequenceClassification', 'YosoForSequenceClassification'].\n", - "100%|██████████| 100/100 [02:30<00:00, 1.50s/it]\n" + "100%|██████████| 100/100 [02:23<00:00, 1.43s/it]\n" ] } ], @@ -410,12 +410,12 @@ "text": [ " precision recall f1-score support\n", "\n", - " 0 0.58 0.91 0.71 56\n", - " 1 0.58 0.16 0.25 44\n", + " 0 0.60 0.90 0.72 59\n", + " 1 0.45 0.12 0.19 41\n", "\n", " accuracy 0.58 100\n", - " macro avg 0.58 0.53 0.48 100\n", - "weighted avg 0.58 0.58 0.51 100\n", + " macro avg 0.53 0.51 0.45 100\n", + "weighted avg 0.54 0.58 0.50 100\n", "\n" ] } diff --git a/notebooks/instruction.ipynb b/notebooks/instruction.ipynb new file mode 100644 index 0000000..dff7752 --- /dev/null +++ b/notebooks/instruction.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instruction Tuning \n", + "\n", + "In this notebook, I will give you an example of LLM's fine-tuning with instruction. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../src\")\n", + "from utils import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_ID = \"google/gemma-2b\"\n", + "\n", + "SEED = 42\n", + "NUM_CLASSES = 2\n", + "\n", + "EPOCHS = 2\n", + "\n", + "LORA_R = 8\n", + "LORA_ALPHA = 16\n", + "LORA_DROPOUT = 0.1\n", + "BATCH_SIZE = 1\n", + "\n", + "set_seed(SEED)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "'''\n", + "data source: https://github.com/tatsu-lab/stanford_alpaca\n", + "'''\n", + "\n", + "dataset = load_dataset(\"json\", data_files=\"../data/alpaca/data.json\")[\"train\"]\n", + "dataset = dataset.shuffle().train_test_split(0.2)\n", + "trainset, validset = dataset[\"train\"], dataset[\"test\"]\n", + "dataset = validset.train_test_split(0.5)\n", + "validset, testset = dataset[\"train\"], dataset[\"test\"]\n", + "\n", + "# sampling, as it is just a example file\n", + "trainset = trainset.select(range(5)).to_pandas()\n", + "validset = validset.select(range(5)).to_pandas()\n", + "testset = testset.select(range(5)).to_pandas()\n", + "\n", + "trainset[\"prompt\"] = \"user \" + trainset[\"instruction\"] + \"\\n\\n\" + \"model \" + trainset[\"output\"] + \"\\n\"\n", + "validset[\"prompt\"] = \"user \" + validset[\"instruction\"] + \"\\n\\n\" + \"model \" + validset[\"output\"] + \"\\n\"\n", + "testset[\"prompt\"] = \"user \" + testset[\"instruction\"] + \"\\n\\n\" + \"model \"\n", + "hello = testset\n", + "\n", + "trainset, validset, testset = Dataset.from_pandas(trainset[[\"prompt\"]]), Dataset.from_pandas(validset[[\"prompt\"]]), Dataset.from_pandas(testset[[\"prompt\"]])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d14f410e8838476ebc4cc027bf37f7ba", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/5 [00:00\n", + " \n", + " \n", + " [ 2/20 : < :, Epoch 0.20/4]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation Loss

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# if using `Trainer`: \n", + "trainer = Trainer(\n", + " model=lora_model,\n", + " args=TrainingArguments(\n", + " output_dir=\"../experiment/instruction/Trainer/\",\n", + " learning_rate=2e-5,\n", + " num_train_epochs=1,\n", + " weight_decay=0.01,\n", + " logging_steps=5,\n", + " max_steps=20,\n", + " per_device_train_batch_size=BATCH_SIZE,\n", + " per_device_eval_batch_size=BATCH_SIZE,\n", + " evaluation_strategy=\"steps\",\n", + " save_strategy=\"steps\",\n", + " save_steps=5,\n", + " load_best_model_at_end=True,\n", + " report_to=\"none\"\n", + " ),\n", + " train_dataset=dataset[\"train\"],\n", + " eval_dataset=dataset[\"valid\"],\n", + " tokenizer=tokenizer,\n", + " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n", + ")\n", + "\n", + "trainer.train()\n", + "trainer.save_state()\n", + "trainer.save_model(output_dir=\"../experiment/instruction/Trainer/model/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'generated_text': 'user Name two major world religions.\\n\\nmodel inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol inol'}\n", + "Two major world religions are Christianity and Islam.\n" + ] + } + ], + "source": [ + "ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n", + "output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n", + "print(output[0])\n", + "print(hello.loc[0, \"output\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [ 4/10 00:00 < 00:01, 3.36 it/s, Epoch 0.60/2]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation Loss

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# if using `SFTTrainer`\n", + "trainer = SFTTrainer(\n", + " model=lora_model,\n", + " args=TrainingArguments(\n", + " output_dir=\"../experiment/instruction/SFTTrainer/\",\n", + " learning_rate=2e-5,\n", + " num_train_epochs=1,\n", + " weight_decay=0.01,\n", + " logging_steps=5,\n", + " max_steps=10,\n", + " per_device_train_batch_size=BATCH_SIZE,\n", + " per_device_eval_batch_size=BATCH_SIZE,\n", + " evaluation_strategy=\"steps\",\n", + " save_strategy=\"steps\",\n", + " save_steps=5,\n", + " load_best_model_at_end=True,\n", + " report_to=\"none\"\n", + " ),\n", + " train_dataset=dataset[\"train\"],\n", + " eval_dataset=dataset[\"valid\"],\n", + " dataset_text_field=\"prompt\",\n", + " tokenizer=tokenizer,\n", + " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),\n", + ")\n", + "\n", + "trainer.train()\n", + "trainer.save_state()\n", + "trainer.save_model(output_dir=\"../experiment/instruction/SFTTrainer/model/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'generated_text': 'user Name two major world religions.\\n\\nmodel Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Abbiamo Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora Allora '}\n", + "Two major world religions are Christianity and Islam.\n" + ] + } + ], + "source": [ + "ins_pipeline = pipeline(\"text-generation\", model=trainer.model, tokenizer=trainer.tokenizer)\n", + "output = ins_pipeline(testset[\"prompt\"][0], max_new_tokens=256, do_sample=True)\n", + "print(output[0])\n", + "print(hello.loc[0, \"output\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "languageM", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 94bdb32..a4d1213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ datasets accelerate modelscope bitsandbytes -transformers \ No newline at end of file +transformers +sentencepiece \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 6e88fe1..9d007c2 100644 --- a/src/utils.py +++ b/src/utils.py @@ -7,7 +7,7 @@ import pandas as pd from tqdm import tqdm from random import randint -from datasets import load_dataset, DatasetDict +from datasets import load_dataset, DatasetDict, Dataset import transformers from trl import SFTTrainer