diff --git a/demo2.ipynb b/demo2.ipynb new file mode 100644 index 0000000000..009164652d --- /dev/null +++ b/demo2.ipynb @@ -0,0 +1,2258 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to Fine-Tune LLMs with TRL\n", + "\n", + "_Authored by Philipp Schmid_\n", + "Original post: [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl)\n", + "\n", + "_Adapted for by Quentin Gallouédec_\n", + "\n", + "Large Language Models or LLMs have seen a lot of progress in the last year. We went from now ChatGPT competitor to a whole zoo of LLMs, including Meta AI's [Llama 3](https://huggingface.co/blog/llama31), Mistrals [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) & [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), and many more. \n", + "Those LLMs can be used for a variety of tasks, including chatbots, question answering, summarization without any additional training. However, if you want to customize a model for your application. You may need to fine-tune the model on your data to achieve higher quality results than prompting or saving cost by training smaller models more efficient model.\n", + "\n", + " \n", + "This blog post walks you thorugh how to fine-tune open LLMs using Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [Datasets](https://huggingface.co/docs/datasets/index). In the blog, we are going to:\n", + "\n", + "1. Define our use case \n", + "2. Setup development environment\n", + "3. Create and prepare the dataset\n", + "4. Fine-tune LLM using `trl` and the `SFTTrainer` \n", + "5. Test and evaluate the LLM\n", + "6. Deploy the LLM for Production\n", + "\n", + "_Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs._\n", + "\n", + "\n", + "## 1. Define our use case \n", + "\n", + "When fine-tuning LLMs, it is important you know your use case and the task you want to solve. This will help you to choose the right model or help you to create a dataset to fine-tune your model. If you haven't defined your use case yet. You might want to go back to the drawing board.\n", + "I want to mention that not all use cases require fine-tuning and it is always recommended to evaluate and try out already fine-tuned models or API-based models before fine-tuning your own model. \n", + "\n", + "As an example, we are going to use the following use case:\n", + "\n", + "> We want to fine-tune a model, which can generate SQL queries based on a natural language instruction, which can then be integrated into our BI tool. The goal is to reduce the time it takes to create a SQL query and make it easier for non-technical users to create SQL queries.\n", + "\n", + "Text to SQL can be a good use case for fine-tuning LLMs, as it is a complex task that requires a lot of (internal) knowledge about the data and the SQL language. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Setup development environment\n", + "\n", + "Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: trl in ./nb/lib/python3.11/site-packages (0.12.1)\n", + "Requirement already satisfied: hf_transfer in ./nb/lib/python3.11/site-packages (0.1.8)\n", + "Requirement already satisfied: accelerate>=0.34.0 in ./nb/lib/python3.11/site-packages (from trl) (1.1.1)\n", + "Requirement already satisfied: datasets>=2.21.0 in ./nb/lib/python3.11/site-packages (from trl) (3.1.0)\n", + "Requirement already satisfied: rich in ./nb/lib/python3.11/site-packages (from trl) (13.9.4)\n", + "Requirement already satisfied: transformers>=4.46.0 in ./nb/lib/python3.11/site-packages (from trl) (4.46.3)\n", + "Requirement already satisfied: huggingface-hub>=0.21.0 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (0.26.2)\n", + "Requirement already satisfied: numpy<3.0.0,>=1.17 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (2.1.3)\n", + "Requirement already satisfied: packaging>=20.0 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (24.2)\n", + "Requirement already satisfied: psutil in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (6.1.0)\n", + "Requirement already satisfied: pyyaml in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (6.0.2)\n", + "Requirement already satisfied: safetensors>=0.4.3 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (0.4.5)\n", + "Requirement already satisfied: torch>=1.10.0 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (2.5.1)\n", + "Requirement already satisfied: filelock in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (3.16.1)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (18.1.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (0.3.8)\n", + "Requirement already satisfied: pandas in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (2.2.3)\n", + "Requirement already satisfied: requests>=2.32.2 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (4.67.1)\n", + "Requirement already satisfied: xxhash in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (3.5.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in ./nb/lib/python3.11/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets>=2.21.0->trl) (2024.9.0)\n", + "Requirement already satisfied: aiohttp in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (3.11.7)\n", + "Requirement already satisfied: regex!=2019.12.17 in ./nb/lib/python3.11/site-packages (from transformers>=4.46.0->trl) (2024.11.6)\n", + "Requirement already satisfied: tokenizers<0.21,>=0.20 in ./nb/lib/python3.11/site-packages (from transformers>=4.46.0->trl) (0.20.3)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in ./nb/lib/python3.11/site-packages (from rich->trl) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in ./nb/lib/python3.11/site-packages (from rich->trl) (2.18.0)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (2.4.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (6.1.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (0.2.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (1.18.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./nb/lib/python3.11/site-packages (from huggingface-hub>=0.21.0->accelerate>=0.34.0->trl) (4.12.2)\n", + "Requirement already satisfied: mdurl~=0.1 in ./nb/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich->trl) (0.1.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (2024.8.30)\n", + "Requirement already satisfied: networkx in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (3.4.2)\n", + "Requirement already satisfied: jinja2 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (3.1.4)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (9.1.0.70)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.5.8)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (11.2.1.3)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (10.3.5.147)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (11.6.1.9)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.3.1.170)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: triton==3.1.0 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (3.1.0)\n", + "Requirement already satisfied: sympy==1.13.1 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./nb/lib/python3.11/site-packages (from sympy==1.13.1->torch>=1.10.0->accelerate>=0.34.0->trl) (1.3.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in ./nb/lib/python3.11/site-packages (from pandas->datasets>=2.21.0->trl) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in ./nb/lib/python3.11/site-packages (from pandas->datasets>=2.21.0->trl) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in ./nb/lib/python3.11/site-packages (from pandas->datasets>=2.21.0->trl) (2024.2)\n", + "Requirement already satisfied: six>=1.5 in ./nb/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets>=2.21.0->trl) (1.16.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in ./nb/lib/python3.11/site-packages (from jinja2->torch>=1.10.0->accelerate>=0.34.0->trl) (3.0.2)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "# Install TRL\n", + "%pip install trl hf_transfer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main).\n", + "\n", + "_Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of `MAX_JOBS`. On the `g6.2xlarge` we used `4`._" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "# import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'\n", + "# # install flash-attn\n", + "# !pip install ninja packaging\n", + "# !MAX_JOBS=4 pip install flash-attn --no-build-isolation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_Installing flash attention can take quite a bit of time (10-45 minutes)._" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training.\n", + "\n", + "1. **Sign up**: Create an account at [Hugging Face](https://huggingface.co/join) if you don’t already have one. \n", + "2. **Generate a token**: Go to [Token Settings](https://huggingface.co/settings/tokens) and create a new token: \n", + " - Select *Fine-grained*. \n", + " - Assign any name. \n", + " - Enable\n", + " - \"Write access to contents/settings of all repos under your personal namespace.\" and\n", + " - \"Read access to contents of all public gated repos you can access.\"\n", + " - Create the token and copy-paste it un the next cell.\n", + "\n", + "⚠️ **Keep your token secret**: Don't push this notebook with your token in it. ⚠️ " + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "\n", + "login(\n", + " token=\"...\", # ADD YOUR TOKEN HERE\n", + " add_to_git_credential=True,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Create and prepare the dataset\n", + "\n", + "Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. The dataset should be a diverse set of demonstrations of the task you want to solve. There are several ways to create such a dataset, including:\n", + "* Using existing open-source datasets, e.g., [Spider](https://huggingface.co/datasets/spider)\n", + "* Using LLMs to create synthetically datasets, e.g., [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)\n", + "* Using Humans to create datasets, e.g., [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k).\n", + "* Using a combination of the above methods, e.g., [Orca](https://huggingface.co/datasets/Open-Orca/OpenOrca)\n", + "\n", + "Each of the methods has its own advantages and disadvantages and depends on the budget, time, and quality requirements. For example, using an existing dataset is the easiest but might not be tailored to your specific use case, while using humans might be the most accurate but can be time-consuming and expensive. It is also possible to combine several methods to create an instruction dataset, as shown in [Orca: Progressive Learning from Complex Explanation Traces of GPT-4.](https://arxiv.org/abs/2306.02707)\n", + "\n", + "In our example we will use an already existing dataset called [sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context), which contains samples of natural language instructions, schema definitions and the corresponding SQL query." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'answer': 'SELECT COUNT(*) FROM head WHERE age > 56', 'question': 'How many heads of the departments are older than 56 ?', 'context': 'CREATE TABLE head (age INTEGER)'}\n" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"b-mc2/sql-create-context\", split=\"train\")\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the latest release of `trl` we now support popular [standard and conversational dataset formats](https://huggingface.co/docs/trl/en/dataset_formats). This means we only need to convert our dataset to one of the supported formats and `trl` will take care of the rest. Those formats include:\n", + "\n", + "* Standard format\n", + "\n", + "```python\n", + "{\"text\": \"The sky is blue.\"}\n", + "```\n", + "\n", + "* Conversational format\n", + "\n", + "```python\n", + "{\"messages\": [{\"role\": \"user\", \"content\": \"What color is the sky?\"},\n", + " {\"role\": \"assistant\", \"content\": \"It is blue.\"}]}\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In our example we are going to load our open-source dataset using the 🤗 Datasets library and then convert it into the the conversational format, where we include the schema definition in the system message for our assistant.\n", + "\n", + "_Note: This step can be different for your use case. For example, if you have already a dataset from, e.g. working with OpenAI, you can skip this step and go directly to the fine-tuning step._" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "# Convert dataset to OAI messages\n", + "system_message = \"\"\"You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\n", + "SCHEMA:\n", + "{schema}\"\"\"\n", + "\n", + "\n", + "def create_conversation(example):\n", + " return {\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system_message.format(schema=example[\"context\"])},\n", + " {\"role\": \"user\", \"content\": example[\"question\"]},\n", + " {\"role\": \"assistant\", \"content\": example[\"answer\"]}],\n", + " }\n", + "\n", + "\n", + "# Load dataset from the hub\n", + "dataset = load_dataset(\"b-mc2/sql-create-context\", split=\"train[:50000]\")\n", + "\n", + "# Convert dataset to conversational format\n", + "dataset = dataset.map(create_conversation, remove_columns=dataset.column_names)\n", + "\n", + "# Split dataset into train and test (95% train, 5% test)\n", + "dataset = dataset.train_test_split(test_size=0.05)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's see if we can load it, and how it looks like." + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'content': 'You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\\nSCHEMA:\\nCREATE TABLE table_name_62 (attendance VARCHAR, result VARCHAR)', 'role': 'system'}, {'content': 'What is Attendance, when Result is \"2-4\"?', 'role': 'user'}, {'content': 'SELECT attendance FROM table_name_62 WHERE result = \"2-4\"', 'role': 'assistant'}]\n" + ] + } + ], + "source": [ + "print(dataset[\"train\"][345][\"messages\"])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Fine-tune LLM using `trl` and the `SFTTrainer` \n", + "\n", + "We are now ready to fine-tune our model. We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:\n", + "* Dataset formatting, including standard and conversational format\n", + "* Training on completions only, ignoring prompts\n", + "* Packing datasets for more efficient training\n", + "* PEFT (parameter-efficient fine-tuning) support including Q-LoRA\n", + "* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)\n", + "\n", + "We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use [QLoRA](https://huggingface.co/paper/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization. If you want to learn more about QLoRA and how it works, check out [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.\n", + "\n", + "Now, lets get started! 🚀" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we will load our LLM. For our use case we are going to use [Llama 3.1 8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B). \n", + "But we can easily swap out the model for another model, e.g. [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) or [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), or any other LLMs by changing our `model_id` variable. We will use `bitsandbytes` to quantize our model to 4-bit.\n", + "\n", + "_Note: Be aware the bigger the model the more memory it will require. In our example we will use the 8B version, which can be tuned on 24GB GPUs._" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM#, BitsAndBytesConfig\n", + "\n", + "# Hugging Face model id\n", + "model_id = \"Qwen/Qwen2.5-0.5B\" # or `mistralai/Mistral-7B-v0.1`\n", + "\n", + "# Load model and tokenizer\n", + "model = AutoModelForCausalLM.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Correctly, preparing the LLM and Tokenizer for training chat/conversational models is crucial. We need to add new special tokens to the tokenizer and model and teach to understand the different roles in a conversation. In `trl` we have a convinient method called [`setup_chat_format`](https://huggingface.co/docs/trl/main/en/sft_trainer#add-special-tokens-for-chat-format), which:\n", + "* Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.\n", + "* Resizes the model’s embedding layer to accommodate the new tokens.\n", + "* Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI." + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "from trl import setup_chat_format\n", + "\n", + "# Load tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", + "# tokenizer.padding_side = 'right' # to prevent warnings\n", + "\n", + "# Set chat template to OAI chatml\n", + "# model, tokenizer = setup_chat_format(model, tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that our tokenzer have a proper chat template, it can convert our data into a formatted conversation, based on its chat template. Let's see how it looks like." + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|im_start|>system\n", + "You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\n", + "SCHEMA:\n", + "CREATE TABLE table_name_62 (attendance VARCHAR, result VARCHAR)<|im_end|>\n", + "<|im_start|>user\n", + "What is Attendance, when Result is \"2-4\"?<|im_end|>\n", + "<|im_start|>assistant\n", + "SELECT attendance FROM table_name_62 WHERE result = \"2-4\"<|im_end|>\n", + "\n" + ] + } + ], + "source": [ + "from trl import apply_chat_template\n", + "\n", + "example = dataset[\"train\"][345]\n", + "formatted_example = apply_chat_template(example, tokenizer)\n", + "print(formatted_example[\"text\"])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `SFTTrainer`  supports a native integration with `peft`, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our `LoraConfig` and provide it to the trainer. Our `LoraConfig` parameters are defined based on the [qlora paper](https://arxiv.org/pdf/2305.14314.pdf) and sebastian's [blog post](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms)." + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "# from peft import LoraConfig\n", + "\n", + "# # LoRA config based on QLoRA paper & Sebastian Raschka experiment\n", + "# peft_config = LoraConfig(\n", + "# lora_alpha=128,\n", + "# lora_dropout=0.05,\n", + "# r=256,\n", + "# bias=\"none\",\n", + "# target_modules=\"all-linear\",\n", + "# task_type=\"CAUSAL_LM\", \n", + "# )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before we can start our training we need to define the hyperparameters (`TrainingArguments`) we want to use." + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " output_dir=\"Qwen2.5-0.5B-SQL\", # directory to save and repository id\n", + " save_strategy=\"epoch\", # save checkpoint every epoch\n", + " push_to_hub=True, # push model to hub\n", + " report_to=\"none\", # don't report to wandb\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now have every building block we need to create our `SFTTrainer` to start then training our model." + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fsx/qgallouedec/trl/trl/trainer/sft_trainer.py:248: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "82c6a0e2e52d42a793cfd0be4f17441c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/47500 [00:00\n", + " \n", + " \n", + " [ 2999/17814 06:17 < 31:06, 7.94 it/s, Epoch 0.50/3]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
101.578700
201.099800
300.935500
400.816000
500.825600
600.729100
700.714400
800.680500
900.682500
1000.692700
1100.693800
1200.695100
1300.723100
1400.674500
1500.721600
1600.715700
1700.640500
1800.693800
1900.651500
2000.943500
2101.108900
2200.791800
2300.739100
2400.676400
2500.688700
2600.665100
2700.670400
2800.686600
2900.679300
3000.672500
3100.680900
3200.604300
3300.643200
3400.649500
3500.661400
3600.687200
3700.646700
3800.688400
3900.690000
4000.649300
4100.628200
4200.642100
4300.625900
4400.678500
4500.620500
4600.632800
4700.611900
4800.625800
4900.632900
5000.607400
5100.638600
5200.639700
5300.655200
5400.611800
5500.617600
5600.589500
5700.631400
5800.649300
5900.650300
6000.619900
6100.620300
6200.638200
6300.653700
6400.669100
6500.649000
6600.667900
6700.633000
6800.635800
6900.645200
7000.641000
7100.613200
7200.644100
7300.614900
7400.604600
7500.613800
7600.583300
7700.648700
7800.594600
7900.584100
8000.630900
8100.580700
8200.628100
8300.617400
8400.632200
8500.661200
8600.636200
8700.617100
8800.587000
8900.595000
9000.613700
9100.650500
9200.635700
9300.585400
9400.640100
9500.602500
9600.602300
9700.628400
9800.591400
9900.613800
10000.630400
10100.635500
10200.603000
10300.638600
10400.610000
10500.617700
10600.602900
10700.642400
10800.574200
10900.610700
11000.602400
11100.565800
11200.611800
11300.598400
11400.572400
11500.609400
11600.623400
11700.643900
11800.603100
11900.582800
12000.625300
12100.621200
12200.602500
12300.603600
12400.605700
12500.621700
12600.634400
12700.600800
12800.598900
12900.608400
13000.603900
13100.631000
13200.604400
13300.565300
13400.612100
13500.595600
13600.546300
13700.616500
13800.633100
13900.601200
14000.571400
14100.573000
14200.576800
14300.624200
14400.611300
14500.592600
14600.621700
14700.602400
14800.631600
14900.628400
15000.590400
15100.576200
15200.559700
15300.558600
15400.555300
15500.586700
15600.620300
15700.604200
15800.565900
15900.605600
16000.576200
16100.586900
16200.538900
16300.630800
16400.573800
16500.590000
16600.554300
16700.561400
16800.608600
16900.624900
17000.574500
17100.585700
17200.584100
17300.568800
17400.604800
17500.557200
17600.565600
17700.598000
17800.565200
17900.604600
18000.588700
18100.552400
18200.627300
18300.601200
18400.558900
18500.541700
18600.610900
18700.558100
18800.584000
18900.607400
19000.561400
19100.571900
19200.634600
19300.589600
19400.583700
19500.581000
19600.563500
19700.555800
19800.568500
19900.574000
20000.597900
20100.601300
20200.586100
20300.551000
20400.581400
20500.580100
20600.588300
20700.558400
20800.561900
20900.622200
21000.577300
21100.618500
21200.585500
21300.553000
21400.569300
21500.559400
21600.552000
21700.572300
21800.541100
21900.571200
22000.530100
22100.560900
22200.604000
22300.581000
22400.550900
22500.590800
22600.603700
22700.581600
22800.583100
22900.608700
23000.574100
23100.567600
23200.593300
23300.573500
23400.557900
23500.587300
23600.550300
23700.572700
23800.533900
23900.564300
24000.552300
24100.548400
24200.605300
24300.640300
24400.565600
24500.541200
24600.557400
24700.534200
24800.541800
24900.554900
25000.531100
25100.616900
25200.549000
25300.607900
25400.568100
25500.557300
25600.582100
25700.555500
25800.553700
25900.559900
26000.512100
26100.556200
26200.559800
26300.563400
26400.557700
26500.624200
26600.553400
26700.582400
26800.553800
26900.527100
27000.557600
27100.565900
27200.609700
27300.600900
27400.547100
27500.569500
27600.587900
27700.545300
27800.551800
27900.546400
28000.615400
28100.574700
28200.609600
28300.547100
28400.598900
28500.574900
28600.565000
28700.530400
28800.558300
28900.563300
29000.555800
29100.536700
29200.565800
29300.568500
29400.547200
29500.571600
29600.579100
29700.533600
29800.545400
29900.550600

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[65], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# start training, the model will be automatically saved to the hub and the output directory\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# save model \u001b[39;00m\n\u001b[1;32m 5\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model()\n", + "File \u001b[0;32m/fsx/qgallouedec/trl/nb/lib/python3.11/site-packages/transformers/trainer.py:2114\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2111\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 2112\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 2113\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 2114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2115\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2116\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2117\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2118\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2119\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2120\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 2121\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n", + "File \u001b[0;32m/fsx/qgallouedec/trl/nb/lib/python3.11/site-packages/transformers/trainer.py:2481\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2475\u001b[0m context \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2476\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mno_sync, model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(batch_samples) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 2478\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext\n\u001b[1;32m 2479\u001b[0m )\n\u001b[1;32m 2480\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m-> 2481\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2483\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2484\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2485\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2486\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2487\u001b[0m ):\n\u001b[1;32m 2488\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2489\u001b[0m tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n", + "File \u001b[0;32m/fsx/qgallouedec/trl/nb/lib/python3.11/site-packages/transformers/trainer.py:3612\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 3610\u001b[0m scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 3611\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3612\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3613\u001b[0m \u001b[38;5;66;03m# Finally we need to normalize the loss for reporting\u001b[39;00m\n\u001b[1;32m 3614\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_items_in_batch \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m/fsx/qgallouedec/trl/nb/lib/python3.11/site-packages/accelerate/accelerator.py:2241\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2239\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n\u001b[1;32m 2240\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2241\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/fsx/qgallouedec/trl/nb/lib/python3.11/site-packages/torch/_tensor.py:581\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 572\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 573\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 574\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 579\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 580\u001b[0m )\n\u001b[0;32m--> 581\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/fsx/qgallouedec/trl/nb/lib/python3.11/site-packages/torch/autograd/__init__.py:347\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 342\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 344\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 347\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 350\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 351\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/fsx/qgallouedec/trl/nb/lib/python3.11/site-packages/torch/autograd/graph.py:825\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 823\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 824\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 825\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 828\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 829\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# start training, the model will be automatically saved to the hub and the output directory\n", + "trainer.train()\n", + "\n", + "# save model \n", + "trainer.save_model()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training with Flash Attention for 3 epochs with a dataset of 10k samples took 02:05:58 on a `g6.2xlarge`. The instance costs `1,212$/h` which brings us to a total cost of only `1.8$`." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "# free the memory again\n", + "del model\n", + "del trainer\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Merge LoRA adapter in to the original model\n", + "\n", + "When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This will save a default model, which can be used for inference.\n", + "\n", + "_Note: This requires > 30GB CPU Memory._" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "07e8fc0add14407c8c32c88abe998e78", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00