diff --git a/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb b/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb new file mode 100644 index 0000000000..81762de08c --- /dev/null +++ b/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb @@ -0,0 +1,1012 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5f239612-620e-4430-8685-9fdc6b179b41", + "metadata": {}, + "source": [ + "# Training PEFT models with new tokens being added to the embedding layers and tokenizer\n", + "\n", + "In this example, we will learn how to train a LoRA model when adding new tokens to the tokenizer and model. \n", + "This is a common usecase when doing the following:\n", + "1. Instruction finetuning with new tokens beind added such as `<|user|>`, `<|assistant|>`, `<|system|>`, ``, `` to properly format the conversations\n", + "2. Finetuning on a specific language wherein language spoecific tokens are added, e.g., korean tokens being added to vocabulary for finetuning LLM on Korean datasets.\n", + "3. Instruction finetuning to return outputs in certain format to enable agent behaviour new tokens such as `<|FUNCTIONS|>`, `<|BROWSE|>`, `<|TEXT2IMAGE|>`, `<|ASR|>`, `<|TTS|>`, `<|GENERATECODE|>`, `<|RAG|>`.\n", + "\n", + "In such cases, you add the Embedding modules to the LORA `target_modules`. PEFT will take care of saving the embedding layers with the new added tokens along with the adapter weights that were trained on the specific initialization of the embeddings weights of the added tokens." + ] + }, + { + "cell_type": "markdown", + "id": "b27c55e8-edaa-4059-90bc-d6096d596902", + "metadata": {}, + "source": [ + "Let's import the necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6f864c90", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n", + "os.environ[\"WANDB_PROJECT\"] = \"PeftExamples\"\n", + "import transformers\n", + "from peft import (\n", + " LoraConfig,\n", + " PeftConfig,\n", + " PeftModel,\n", + " get_peft_model,\n", + " prepare_model_for_int8_training,\n", + ")\n", + "from transformers import (\n", + " AutoModelForCausalLM,\n", + " AutoTokenizer,\n", + " HfArgumentParser,\n", + " TrainingArguments,\n", + " Trainer,\n", + " default_data_collator,\n", + ")\n", + "import torch\n", + "from dataclasses import dataclass, field\n", + "from typing import Optional\n", + "from dataclass_csv import DataclassReader\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "from enum import Enum" + ] + }, + { + "cell_type": "markdown", + "id": "74950a3f-bb63-4ce5-9e2b-1b83f92b13a2", + "metadata": {}, + "source": [ + "## Prepare Model and Tokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "76763f5e-64b2-409b-8845-ae5589f8a4e0", + "metadata": {}, + "source": [ + "Now, we will be adding 27 new tokens as well as replace the existing pad, bos and eos tokens of the model." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fd0498ea-547e-418d-bf13-c9abafdd5476", + "metadata": {}, + "outputs": [], + "source": [ + "class SpecialTokens(str, Enum):\n", + " begin_target = \"<|begintarget|>\"\n", + " end_target = \"<|endtarget|>\"\n", + " begin_context = \"<|begincontext|>\"\n", + " end_context = \"<|endcontext|>\"\n", + " system = \"<|system|>\"\n", + " user = \"<|user|>\"\n", + " begin_last_user_utterance = \"<|beginlastuserutterance|>\"\n", + " end_last_user_utterance = \"<|endlastuserutterance|>\"\n", + " begin_dsts = \"<|begindsts|>\"\n", + " end_dsts = \"<|enddsts|>\"\n", + " begin_dst = \"<|begindst|>\"\n", + " end_dst = \"<|enddst|>\"\n", + " begin_belief = \"<|beginbelief|>\"\n", + " end_belief = \"<|endbelief|>\"\n", + " begin_response = \"<|beginresponse|>\"\n", + " end_response = \"<|endresponse|>\"\n", + " begin_action = \"<|beginaction|>\"\n", + " end_action = \"<|endaction|>\"\n", + " begin_user_action = \"<|beginuseraction|>\"\n", + " end_user_action = \"<|enduseraction|>\"\n", + " sys_actions = \"<|sysactions|>\"\n", + " begin_intent = \"<|beginintent|>\"\n", + " end_intent = \"<|endintent|>\"\n", + " begin_requested_slots = \"<|beginrequestedslots|>\"\n", + " end_requested_slots = \"<|endrequestedslots|>\"\n", + " pad_token = \"<|pad|>\"\n", + " bos_token = \"<|startoftext|>\"\n", + "\n", + " @classmethod\n", + " def list(cls):\n", + " return [c.value for c in cls]" + ] + }, + { + "cell_type": "markdown", + "id": "ae4a4255-5f13-4eef-a024-4f1de0f2173b", + "metadata": {}, + "source": [ + "We will be finetuning Mistral-7B model. Let's load the tokenizer and add the special tokens followed by loading the base model and resizzing the embedding layers to accomodate the newly added tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f0eedef9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "91c67b6377fc4dd7977bf544de784d51", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|><|begincontext|><|user|> Can you find me place to eat?<|system|> What kind of food would you like to have and where would you like me to search in?<|user|> Food kind of California will be perfect in SF.<|system|> There are 10 restaurants, Al's Place is one of the good restaurant in San Francisco.<|user|> Can you look for any other restaurant?<|system|> Alta Msp is one of the good restaurant in San Francisco.<|beginlastuserutterance|> Can you find me the address?<|endlastuserutterance|><|endcontext|><|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginrequestedslots|> Restaurants^street_address<|endrequestedslots|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->California<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^street_address~1275 Minnesota Street<|endaction|><|beginresponse|> The street address of the restaurant is 1275 Minnesota Street.<|endresponse|><|endtarget|><|endtarget|>\"" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(train_dataset[0][\"input_ids\"])" + ] + }, + { + "cell_type": "markdown", + "id": "239d1c83-196d-471e-9bf7-5f36dafa9894", + "metadata": {}, + "source": [ + "# Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ec80d6ee", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msmangrul\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.0" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /raid/sourab/temp/wandb/run-20231128_230934-edod21gq" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run ethereal-eon-1 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/smangrul/PeftExamples" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/smangrul/PeftExamples/runs/edod21gq" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [246/246 05:51, Epoch 2/2]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
105.189800
203.745500
302.371500
401.630200
501.302600
600.999400
700.704100
800.527800
900.509700
1000.382300
1100.318200
1200.323500
1300.263400
1400.290900
1500.277400
1600.232800
1700.223600
1800.229600
1900.233100
2000.210200
2100.245800
2200.197300
2300.210100
2400.209800

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=246, training_loss=0.8516577879587809, metrics={'train_runtime': 354.9013, 'train_samples_per_second': 5.556, 'train_steps_per_second': 0.693, 'total_flos': 4.318233532091597e+16, 'train_loss': 0.8516577879587809, 'epoch': 2.0})" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"mistral_lora_clm_with_added_tokens\",\n", + " num_train_epochs=2,\n", + " save_total_limit=5,\n", + " per_device_train_batch_size=8,\n", + " warmup_steps=10,\n", + " weight_decay=0.0001,\n", + " dataloader_drop_last=True,\n", + " bf16=True,\n", + " logging_steps=10,\n", + " learning_rate=1e-5,\n", + " gradient_checkpointing=True,\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False},\n", + " remove_unused_columns=False,\n", + " hub_model_id=\"smangrul/mistral_lora_clm_with_added_tokens\",\n", + " push_to_hub=True,\n", + " hub_private_repo=True,\n", + ")\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " data_collator=default_data_collator,\n", + ")\n", + "# model.config.use_cache = False\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "7bc1cbed-4eb9-4aaa-ab5f-5b91bf432307", + "metadata": {}, + "source": [ + "# Check the model output on a sample from evaluation dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "71851793", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "context=\"<|begincontext|><|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n", + "\n", + " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> Great, the phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n", + "\n", + " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "i = random.randint(0, len(dataset[\"test\"]))\n", + "context = dataset[\"test\"][i][\"context\"]\n", + "\n", + "batch = tokenizer(context, return_tensors=\"pt\")\n", + "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n", + "model.eval()\n", + "output_tokens = model.generate(\n", + " **batch,\n", + " max_new_tokens=256,\n", + " do_sample=True,\n", + " temperature=0.2,\n", + " top_p=0.95,\n", + " top_k=50,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + ")\n", + "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n", + "target = dataset[\"test\"][i][\"target\"]\n", + "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f940a660-2f7c-4a3a-b412-3f037aedb890", + "metadata": {}, + "source": [ + "# Save the Adapter model " + ] + }, + { + "cell_type": "markdown", + "id": "7ebe05e9-9b93-42f6-bba8-46b8cc3d100f", + "metadata": {}, + "source": [ + "When the lora layers are applied to embedding layers, the corresponding base model embedding layers are also saved. " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3d7459ba-caa8-4f10-aa70-89be4541cbdf", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/raid/sourab/peft/src/peft/utils/save_and_load.py:128: UserWarning: Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\n", + " warnings.warn(\"Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\")\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8d23186832014f209939ab83e79da011", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Upload 3 LFS files: 0%| | 0/3 [00:00<|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n", + "\n", + " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurant<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> The phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n", + "\n", + " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n" + ] + } + ], + "source": [ + "from peft import PeftModel\n", + "\n", + "inference_model = AutoModelForCausalLM.from_pretrained(\n", + " model_name,\n", + " low_cpu_mem_usage=True,\n", + " # use_flash_attention_2=True,\n", + ")\n", + "inference_model.resize_token_embeddings(len(tokenizer))\n", + "\n", + "inference_model = PeftModel.from_pretrained(inference_model, \"smangrul/mistral_lora_clm_with_added_tokens\")\n", + "inference_model.to(\"cuda\")\n", + "inference_model.eval()\n", + "\n", + "output_tokens = inference_model.generate(\n", + " **batch,\n", + " max_new_tokens=256,\n", + " do_sample=True,\n", + " temperature=0.2,\n", + " top_p=0.95,\n", + " top_k=50,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + ")\n", + "\n", + "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n", + "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd57f6e8-761f-4e0b-941c-f6973e13b186", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index c5c7825baa..24ef48c22e 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -159,6 +159,8 @@ def save_pretrained( save_directory: str, safe_serialization: bool = True, selected_adapters: Optional[List[str]] = None, + save_embedding_layers: Union[str, bool] = "auto", + is_main_process: bool = True, **kwargs: Any, ): r""" @@ -172,6 +174,14 @@ def save_pretrained( exist). safe_serialization (`bool`, *optional*): Whether to save the adapter files in safetensors format. + selected_adapters (`list(str)`, *optional*): + A list of adapters to be saved. If `None`, will default to all adapters. + save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`): + If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common + embedding layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. + Based on it sets the boolean flag. This only works for 🤗 transformers models. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Will default to `True`. kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the `push_to_hub` method. """ @@ -190,19 +200,23 @@ def save_pretrained( f" {list(self.peft_config.keys())} - got {selected_adapters}." ) - os.makedirs(save_directory, exist_ok=True) - self.create_or_update_model_card(save_directory) + if is_main_process: + os.makedirs(save_directory, exist_ok=True) + self.create_or_update_model_card(save_directory) for adapter_name in selected_adapters: peft_config = self.peft_config[adapter_name] # save only the trainable weights output_state_dict = get_peft_model_state_dict( - self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name + self, + state_dict=kwargs.get("state_dict", None), + adapter_name=adapter_name, + save_embedding_layers=save_embedding_layers, ) output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory os.makedirs(output_dir, exist_ok=True) - if safe_serialization: + if is_main_process and safe_serialization: # Section copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2111-L2134 # Safetensors does not allow tensor aliasing. # We're going to remove aliases before saving @@ -230,7 +244,7 @@ def save_pretrained( os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), metadata={"format": "pt"}, ) - else: + elif is_main_process: torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) # save the config and change the inference mode to `True` @@ -257,7 +271,8 @@ def save_pretrained( else: auto_mapping_dict = None - peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict) + if is_main_process: + peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict) peft_config.inference_mode = inference_mode @classmethod @@ -721,24 +736,27 @@ def create_or_update_model_card(self, output_dir: str): if hasattr(self.config, "quantization_config"): quantization_config = self.config.quantization_config.to_dict() training_config_text = "" + quantization_prefix = "The following `bitsandbytes` quantization config was used during training:" # Adds quantization information if it was used if quantization_config is not None: - training_config_text += "\nThe following `bitsandbytes` quantization config was used during training:\n" + training_config_text += f"\n{quantization_prefix}\n" training_config_text += "\n".join([f"- {name}: {value}" for name, value in quantization_config.items()]) training_config_text += "\n" - training_procedure_heading = "## Training procedure\n" - if training_procedure_heading in lines: - lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) - else: - lines.append(f"{training_procedure_heading}\n{training_config_text}") + training_procedure_heading = "## Training procedure" + if quantization_prefix not in lines and bool(training_config_text): + if training_procedure_heading in lines: + lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) + else: + lines.append(f"{training_procedure_heading}\n{training_config_text}") # Adds peft version - framework_block_heading = "### Framework versions\n" - if framework_block_heading in lines: - lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}\n") - else: - lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}\n") + framework_block_heading = "### Framework versions" + if f"- PEFT {__version__}" not in lines: + if framework_block_heading in lines: + lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}") + else: + lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}") card.text = "\n".join(lines) card.save(filename) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index e811bee5ba..1c34701739 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -583,3 +583,4 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" +EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"] diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 07e653bef1..97bde0d6fe 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from typing import Optional import torch @@ -20,11 +21,26 @@ from huggingface_hub.utils import EntryNotFoundError from safetensors.torch import load_file as safe_load_file -from .other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device +from .other import EMBEDDING_LAYER_NAMES, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device from .peft_types import PeftType -def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", unwrap_compiled=False): +def has_valid_embedding_base_layer(layer): + """Check if the layer has an embedding base layer""" + return hasattr(layer, "base_layer") and isinstance(layer.base_layer, (torch.nn.Linear, torch.nn.Embedding)) + + +def get_embedding_layer_name(model, layer, is_prompt_learning): + """Get the name of the embedding module for a given layer.""" + for name, module in model.named_modules(): + if (is_prompt_learning and module == layer) or module == layer.base_layer: + return name + return None + + +def get_peft_model_state_dict( + model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto" +): """ Get the state dict of the Peft model. @@ -37,6 +53,10 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", un The name of the adapter whose state dict should be returned. unwrap_compiled (`bool`, *optional*, defaults to `False`): Whether to unwrap the model if torch.compile was used. + save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`): + If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common embedding + layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. Based on it + sets the boolean flag. This only works for 🤗 transformers models. """ if unwrap_compiled: model = getattr(model, "_orig_mod", model) @@ -100,6 +120,27 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", un if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): to_return[key.replace("modules_to_save.", "")] = value + # check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary + if ( + save_embedding_layers == "auto" + and hasattr(config, "target_modules") + and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES) + ): + warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.") + save_embedding_layers = True + elif save_embedding_layers == "auto": + save_embedding_layers = False + + if save_embedding_layers and hasattr(model, "get_input_embeddings"): + for layer in [model.get_input_embeddings(), model.get_output_embeddings()]: + if config.is_prompt_learning or has_valid_embedding_base_layer(layer): + # support from version >= 0.6.2 + embedding_module_name = get_embedding_layer_name(model, layer, config.is_prompt_learning) + if embedding_module_name: + to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k}) + elif save_embedding_layers: + warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.") + to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} return to_return diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 347df218b2..b298388a84 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -333,6 +333,33 @@ def forward(self, X): return X +class ModelEmbWithEmbeddingUtils(nn.Module): + # Adds `get_input_embeddings` and `get_output_embeddings` methods to mimic 🤗 transformers models + def __init__(self): + super().__init__() + self.embed_tokens = nn.Embedding(100, 5) + self.conv1d = Conv1D(1, 5) + self.relu = nn.ReLU() + self.flat = nn.Flatten() + self.lin0 = nn.Linear(10, 2) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = self.embed_tokens(X) + X = self.conv1d(X) + X = self.relu(X) + X = self.flat(X) + X = self.lin0(X) + X = self.sm(X) + return X + + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return None + + class ModelConv2D(nn.Module): def __init__(self): super().__init__() @@ -750,6 +777,55 @@ def test_non_existing_model_card(self): # rough check that the model card is pre-filled self.assertGreater(len(model_card), 1000) + @parameterized.expand(["auto", True, False]) + def test_targeting_lora_to_embedding_layer(self, save_embedding_layers): + model = ModelEmbWithEmbeddingUtils() + config = LoraConfig(target_modules=["embed_tokens", "lin0"], init_lora_weights=False) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + if save_embedding_layers == "auto": + # assert warning + msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`." + with self.assertWarns(UserWarning, msg=msg_start): + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + else: + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + from safetensors.torch import load_file as safe_load_file + + state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors")) + if save_embedding_layers in ["auto", True]: + self.assertTrue("base_model.model.embed_tokens.base_layer.weight" in state_dict) + self.assertTrue( + torch.allclose( + model.base_model.model.embed_tokens.base_layer.weight, + state_dict["base_model.model.embed_tokens.base_layer.weight"], + ) + ) + else: + self.assertFalse("base_model.model.embed_tokens.base_layer.weight" in state_dict) + del state_dict + + @parameterized.expand(["auto", True, False]) + def test_targeting_lora_to_embedding_layer_non_transformers(self, save_embedding_layers): + model = ModelEmbConv1D() + config = LoraConfig(target_modules=["emb", "lin0"], init_lora_weights=False) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + if save_embedding_layers is True: + # assert warning + msg_start = "Could not identify embedding layer(s) because the model is not a 🤗 transformers model." + with self.assertWarns(UserWarning, msg=msg_start): + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + else: + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + from safetensors.torch import load_file as safe_load_file + + state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors")) + self.assertFalse("base_model.model.emb.base_layer.weight" in state_dict) + del state_dict + @parameterized.expand( [ LoraConfig(target_modules=["lin0"], init_lora_weights=False),