diff --git a/demo2.ipynb b/demo2.ipynb new file mode 100644 index 0000000000..009164652d --- /dev/null +++ b/demo2.ipynb @@ -0,0 +1,2258 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to Fine-Tune LLMs with TRL\n", + "\n", + "_Authored by Philipp Schmid_\n", + "Original post: [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl)\n", + "\n", + "_Adapted for by Quentin Gallouédec_\n", + "\n", + "Large Language Models or LLMs have seen a lot of progress in the last year. We went from now ChatGPT competitor to a whole zoo of LLMs, including Meta AI's [Llama 3](https://huggingface.co/blog/llama31), Mistrals [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) & [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), and many more. \n", + "Those LLMs can be used for a variety of tasks, including chatbots, question answering, summarization without any additional training. However, if you want to customize a model for your application. You may need to fine-tune the model on your data to achieve higher quality results than prompting or saving cost by training smaller models more efficient model.\n", + "\n", + " \n", + "This blog post walks you thorugh how to fine-tune open LLMs using Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [Datasets](https://huggingface.co/docs/datasets/index). In the blog, we are going to:\n", + "\n", + "1. Define our use case \n", + "2. Setup development environment\n", + "3. Create and prepare the dataset\n", + "4. Fine-tune LLM using `trl` and the `SFTTrainer` \n", + "5. Test and evaluate the LLM\n", + "6. Deploy the LLM for Production\n", + "\n", + "_Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs._\n", + "\n", + "\n", + "## 1. Define our use case \n", + "\n", + "When fine-tuning LLMs, it is important you know your use case and the task you want to solve. This will help you to choose the right model or help you to create a dataset to fine-tune your model. If you haven't defined your use case yet. You might want to go back to the drawing board.\n", + "I want to mention that not all use cases require fine-tuning and it is always recommended to evaluate and try out already fine-tuned models or API-based models before fine-tuning your own model. \n", + "\n", + "As an example, we are going to use the following use case:\n", + "\n", + "> We want to fine-tune a model, which can generate SQL queries based on a natural language instruction, which can then be integrated into our BI tool. The goal is to reduce the time it takes to create a SQL query and make it easier for non-technical users to create SQL queries.\n", + "\n", + "Text to SQL can be a good use case for fine-tuning LLMs, as it is a complex task that requires a lot of (internal) knowledge about the data and the SQL language. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Setup development environment\n", + "\n", + "Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: trl in ./nb/lib/python3.11/site-packages (0.12.1)\n", + "Requirement already satisfied: hf_transfer in ./nb/lib/python3.11/site-packages (0.1.8)\n", + "Requirement already satisfied: accelerate>=0.34.0 in ./nb/lib/python3.11/site-packages (from trl) (1.1.1)\n", + "Requirement already satisfied: datasets>=2.21.0 in ./nb/lib/python3.11/site-packages (from trl) (3.1.0)\n", + "Requirement already satisfied: rich in ./nb/lib/python3.11/site-packages (from trl) (13.9.4)\n", + "Requirement already satisfied: transformers>=4.46.0 in ./nb/lib/python3.11/site-packages (from trl) (4.46.3)\n", + "Requirement already satisfied: huggingface-hub>=0.21.0 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (0.26.2)\n", + "Requirement already satisfied: numpy<3.0.0,>=1.17 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (2.1.3)\n", + "Requirement already satisfied: packaging>=20.0 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (24.2)\n", + "Requirement already satisfied: psutil in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (6.1.0)\n", + "Requirement already satisfied: pyyaml in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (6.0.2)\n", + "Requirement already satisfied: safetensors>=0.4.3 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (0.4.5)\n", + "Requirement already satisfied: torch>=1.10.0 in ./nb/lib/python3.11/site-packages (from accelerate>=0.34.0->trl) (2.5.1)\n", + "Requirement already satisfied: filelock in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (3.16.1)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (18.1.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (0.3.8)\n", + "Requirement already satisfied: pandas in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (2.2.3)\n", + "Requirement already satisfied: requests>=2.32.2 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (4.67.1)\n", + "Requirement already satisfied: xxhash in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (3.5.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in ./nb/lib/python3.11/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets>=2.21.0->trl) (2024.9.0)\n", + "Requirement already satisfied: aiohttp in ./nb/lib/python3.11/site-packages (from datasets>=2.21.0->trl) (3.11.7)\n", + "Requirement already satisfied: regex!=2019.12.17 in ./nb/lib/python3.11/site-packages (from transformers>=4.46.0->trl) (2024.11.6)\n", + "Requirement already satisfied: tokenizers<0.21,>=0.20 in ./nb/lib/python3.11/site-packages (from transformers>=4.46.0->trl) (0.20.3)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in ./nb/lib/python3.11/site-packages (from rich->trl) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in ./nb/lib/python3.11/site-packages (from rich->trl) (2.18.0)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (2.4.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (6.1.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (0.2.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in ./nb/lib/python3.11/site-packages (from aiohttp->datasets>=2.21.0->trl) (1.18.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./nb/lib/python3.11/site-packages (from huggingface-hub>=0.21.0->accelerate>=0.34.0->trl) (4.12.2)\n", + "Requirement already satisfied: mdurl~=0.1 in ./nb/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich->trl) (0.1.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./nb/lib/python3.11/site-packages (from requests>=2.32.2->datasets>=2.21.0->trl) (2024.8.30)\n", + "Requirement already satisfied: networkx in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (3.4.2)\n", + "Requirement already satisfied: jinja2 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (3.1.4)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (9.1.0.70)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.5.8)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (11.2.1.3)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (10.3.5.147)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (11.6.1.9)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.3.1.170)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (12.4.127)\n", + "Requirement already satisfied: triton==3.1.0 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (3.1.0)\n", + "Requirement already satisfied: sympy==1.13.1 in ./nb/lib/python3.11/site-packages (from torch>=1.10.0->accelerate>=0.34.0->trl) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./nb/lib/python3.11/site-packages (from sympy==1.13.1->torch>=1.10.0->accelerate>=0.34.0->trl) (1.3.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in ./nb/lib/python3.11/site-packages (from pandas->datasets>=2.21.0->trl) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in ./nb/lib/python3.11/site-packages (from pandas->datasets>=2.21.0->trl) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in ./nb/lib/python3.11/site-packages (from pandas->datasets>=2.21.0->trl) (2024.2)\n", + "Requirement already satisfied: six>=1.5 in ./nb/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets>=2.21.0->trl) (1.16.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in ./nb/lib/python3.11/site-packages (from jinja2->torch>=1.10.0->accelerate>=0.34.0->trl) (3.0.2)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "# Install TRL\n", + "%pip install trl hf_transfer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main).\n", + "\n", + "_Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of `MAX_JOBS`. On the `g6.2xlarge` we used `4`._" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "# import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'\n", + "# # install flash-attn\n", + "# !pip install ninja packaging\n", + "# !MAX_JOBS=4 pip install flash-attn --no-build-isolation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_Installing flash attention can take quite a bit of time (10-45 minutes)._" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training.\n", + "\n", + "1. **Sign up**: Create an account at [Hugging Face](https://huggingface.co/join) if you don’t already have one. \n", + "2. **Generate a token**: Go to [Token Settings](https://huggingface.co/settings/tokens) and create a new token: \n", + " - Select *Fine-grained*. \n", + " - Assign any name. \n", + " - Enable\n", + " - \"Write access to contents/settings of all repos under your personal namespace.\" and\n", + " - \"Read access to contents of all public gated repos you can access.\"\n", + " - Create the token and copy-paste it un the next cell.\n", + "\n", + "⚠️ **Keep your token secret**: Don't push this notebook with your token in it. ⚠️ " + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "\n", + "login(\n", + " token=\"...\", # ADD YOUR TOKEN HERE\n", + " add_to_git_credential=True,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Create and prepare the dataset\n", + "\n", + "Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. The dataset should be a diverse set of demonstrations of the task you want to solve. There are several ways to create such a dataset, including:\n", + "* Using existing open-source datasets, e.g., [Spider](https://huggingface.co/datasets/spider)\n", + "* Using LLMs to create synthetically datasets, e.g., [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)\n", + "* Using Humans to create datasets, e.g., [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k).\n", + "* Using a combination of the above methods, e.g., [Orca](https://huggingface.co/datasets/Open-Orca/OpenOrca)\n", + "\n", + "Each of the methods has its own advantages and disadvantages and depends on the budget, time, and quality requirements. For example, using an existing dataset is the easiest but might not be tailored to your specific use case, while using humans might be the most accurate but can be time-consuming and expensive. It is also possible to combine several methods to create an instruction dataset, as shown in [Orca: Progressive Learning from Complex Explanation Traces of GPT-4.](https://arxiv.org/abs/2306.02707)\n", + "\n", + "In our example we will use an already existing dataset called [sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context), which contains samples of natural language instructions, schema definitions and the corresponding SQL query." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'answer': 'SELECT COUNT(*) FROM head WHERE age > 56', 'question': 'How many heads of the departments are older than 56 ?', 'context': 'CREATE TABLE head (age INTEGER)'}\n" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"b-mc2/sql-create-context\", split=\"train\")\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the latest release of `trl` we now support popular [standard and conversational dataset formats](https://huggingface.co/docs/trl/en/dataset_formats). This means we only need to convert our dataset to one of the supported formats and `trl` will take care of the rest. Those formats include:\n", + "\n", + "* Standard format\n", + "\n", + "```python\n", + "{\"text\": \"The sky is blue.\"}\n", + "```\n", + "\n", + "* Conversational format\n", + "\n", + "```python\n", + "{\"messages\": [{\"role\": \"user\", \"content\": \"What color is the sky?\"},\n", + " {\"role\": \"assistant\", \"content\": \"It is blue.\"}]}\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In our example we are going to load our open-source dataset using the 🤗 Datasets library and then convert it into the the conversational format, where we include the schema definition in the system message for our assistant.\n", + "\n", + "_Note: This step can be different for your use case. For example, if you have already a dataset from, e.g. working with OpenAI, you can skip this step and go directly to the fine-tuning step._" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "# Convert dataset to OAI messages\n", + "system_message = \"\"\"You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\n", + "SCHEMA:\n", + "{schema}\"\"\"\n", + "\n", + "\n", + "def create_conversation(example):\n", + " return {\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system_message.format(schema=example[\"context\"])},\n", + " {\"role\": \"user\", \"content\": example[\"question\"]},\n", + " {\"role\": \"assistant\", \"content\": example[\"answer\"]}],\n", + " }\n", + "\n", + "\n", + "# Load dataset from the hub\n", + "dataset = load_dataset(\"b-mc2/sql-create-context\", split=\"train[:50000]\")\n", + "\n", + "# Convert dataset to conversational format\n", + "dataset = dataset.map(create_conversation, remove_columns=dataset.column_names)\n", + "\n", + "# Split dataset into train and test (95% train, 5% test)\n", + "dataset = dataset.train_test_split(test_size=0.05)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's see if we can load it, and how it looks like." + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'content': 'You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\\nSCHEMA:\\nCREATE TABLE table_name_62 (attendance VARCHAR, result VARCHAR)', 'role': 'system'}, {'content': 'What is Attendance, when Result is \"2-4\"?', 'role': 'user'}, {'content': 'SELECT attendance FROM table_name_62 WHERE result = \"2-4\"', 'role': 'assistant'}]\n" + ] + } + ], + "source": [ + "print(dataset[\"train\"][345][\"messages\"])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Fine-tune LLM using `trl` and the `SFTTrainer` \n", + "\n", + "We are now ready to fine-tune our model. We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:\n", + "* Dataset formatting, including standard and conversational format\n", + "* Training on completions only, ignoring prompts\n", + "* Packing datasets for more efficient training\n", + "* PEFT (parameter-efficient fine-tuning) support including Q-LoRA\n", + "* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)\n", + "\n", + "We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use [QLoRA](https://huggingface.co/paper/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization. If you want to learn more about QLoRA and how it works, check out [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.\n", + "\n", + "Now, lets get started! 🚀" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we will load our LLM. For our use case we are going to use [Llama 3.1 8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B). \n", + "But we can easily swap out the model for another model, e.g. [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) or [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) models, TII [Falcon](https://huggingface.co/tiiuae/falcon-40b), or any other LLMs by changing our `model_id` variable. We will use `bitsandbytes` to quantize our model to 4-bit.\n", + "\n", + "_Note: Be aware the bigger the model the more memory it will require. In our example we will use the 8B version, which can be tuned on 24GB GPUs._" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM#, BitsAndBytesConfig\n", + "\n", + "# Hugging Face model id\n", + "model_id = \"Qwen/Qwen2.5-0.5B\" # or `mistralai/Mistral-7B-v0.1`\n", + "\n", + "# Load model and tokenizer\n", + "model = AutoModelForCausalLM.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Correctly, preparing the LLM and Tokenizer for training chat/conversational models is crucial. We need to add new special tokens to the tokenizer and model and teach to understand the different roles in a conversation. In `trl` we have a convinient method called [`setup_chat_format`](https://huggingface.co/docs/trl/main/en/sft_trainer#add-special-tokens-for-chat-format), which:\n", + "* Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.\n", + "* Resizes the model’s embedding layer to accommodate the new tokens.\n", + "* Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI." + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "from trl import setup_chat_format\n", + "\n", + "# Load tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", + "# tokenizer.padding_side = 'right' # to prevent warnings\n", + "\n", + "# Set chat template to OAI chatml\n", + "# model, tokenizer = setup_chat_format(model, tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that our tokenzer have a proper chat template, it can convert our data into a formatted conversation, based on its chat template. Let's see how it looks like." + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|im_start|>system\n", + "You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\n", + "SCHEMA:\n", + "CREATE TABLE table_name_62 (attendance VARCHAR, result VARCHAR)<|im_end|>\n", + "<|im_start|>user\n", + "What is Attendance, when Result is \"2-4\"?<|im_end|>\n", + "<|im_start|>assistant\n", + "SELECT attendance FROM table_name_62 WHERE result = \"2-4\"<|im_end|>\n", + "\n" + ] + } + ], + "source": [ + "from trl import apply_chat_template\n", + "\n", + "example = dataset[\"train\"][345]\n", + "formatted_example = apply_chat_template(example, tokenizer)\n", + "print(formatted_example[\"text\"])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `SFTTrainer` supports a native integration with `peft`, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our `LoraConfig` and provide it to the trainer. Our `LoraConfig` parameters are defined based on the [qlora paper](https://arxiv.org/pdf/2305.14314.pdf) and sebastian's [blog post](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms)." + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "# from peft import LoraConfig\n", + "\n", + "# # LoRA config based on QLoRA paper & Sebastian Raschka experiment\n", + "# peft_config = LoraConfig(\n", + "# lora_alpha=128,\n", + "# lora_dropout=0.05,\n", + "# r=256,\n", + "# bias=\"none\",\n", + "# target_modules=\"all-linear\",\n", + "# task_type=\"CAUSAL_LM\", \n", + "# )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before we can start our training we need to define the hyperparameters (`TrainingArguments`) we want to use." + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " output_dir=\"Qwen2.5-0.5B-SQL\", # directory to save and repository id\n", + " save_strategy=\"epoch\", # save checkpoint every epoch\n", + " push_to_hub=True, # push model to hub\n", + " report_to=\"none\", # don't report to wandb\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now have every building block we need to create our `SFTTrainer` to start then training our model." + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fsx/qgallouedec/trl/trl/trainer/sft_trainer.py:248: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "82c6a0e2e52d42a793cfd0be4f17441c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/47500 [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=dataset[\"train\"],\n", + " processing_class=tokenizer,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start training our model by calling the `train()` method on our `Trainer` instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model." + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
Step | \n", + "Training Loss | \n", + "
---|---|
10 | \n", + "1.578700 | \n", + "
20 | \n", + "1.099800 | \n", + "
30 | \n", + "0.935500 | \n", + "
40 | \n", + "0.816000 | \n", + "
50 | \n", + "0.825600 | \n", + "
60 | \n", + "0.729100 | \n", + "
70 | \n", + "0.714400 | \n", + "
80 | \n", + "0.680500 | \n", + "
90 | \n", + "0.682500 | \n", + "
100 | \n", + "0.692700 | \n", + "
110 | \n", + "0.693800 | \n", + "
120 | \n", + "0.695100 | \n", + "
130 | \n", + "0.723100 | \n", + "
140 | \n", + "0.674500 | \n", + "
150 | \n", + "0.721600 | \n", + "
160 | \n", + "0.715700 | \n", + "
170 | \n", + "0.640500 | \n", + "
180 | \n", + "0.693800 | \n", + "
190 | \n", + "0.651500 | \n", + "
200 | \n", + "0.943500 | \n", + "
210 | \n", + "1.108900 | \n", + "
220 | \n", + "0.791800 | \n", + "
230 | \n", + "0.739100 | \n", + "
240 | \n", + "0.676400 | \n", + "
250 | \n", + "0.688700 | \n", + "
260 | \n", + "0.665100 | \n", + "
270 | \n", + "0.670400 | \n", + "
280 | \n", + "0.686600 | \n", + "
290 | \n", + "0.679300 | \n", + "
300 | \n", + "0.672500 | \n", + "
310 | \n", + "0.680900 | \n", + "
320 | \n", + "0.604300 | \n", + "
330 | \n", + "0.643200 | \n", + "
340 | \n", + "0.649500 | \n", + "
350 | \n", + "0.661400 | \n", + "
360 | \n", + "0.687200 | \n", + "
370 | \n", + "0.646700 | \n", + "
380 | \n", + "0.688400 | \n", + "
390 | \n", + "0.690000 | \n", + "
400 | \n", + "0.649300 | \n", + "
410 | \n", + "0.628200 | \n", + "
420 | \n", + "0.642100 | \n", + "
430 | \n", + "0.625900 | \n", + "
440 | \n", + "0.678500 | \n", + "
450 | \n", + "0.620500 | \n", + "
460 | \n", + "0.632800 | \n", + "
470 | \n", + "0.611900 | \n", + "
480 | \n", + "0.625800 | \n", + "
490 | \n", + "0.632900 | \n", + "
500 | \n", + "0.607400 | \n", + "
510 | \n", + "0.638600 | \n", + "
520 | \n", + "0.639700 | \n", + "
530 | \n", + "0.655200 | \n", + "
540 | \n", + "0.611800 | \n", + "
550 | \n", + "0.617600 | \n", + "
560 | \n", + "0.589500 | \n", + "
570 | \n", + "0.631400 | \n", + "
580 | \n", + "0.649300 | \n", + "
590 | \n", + "0.650300 | \n", + "
600 | \n", + "0.619900 | \n", + "
610 | \n", + "0.620300 | \n", + "
620 | \n", + "0.638200 | \n", + "
630 | \n", + "0.653700 | \n", + "
640 | \n", + "0.669100 | \n", + "
650 | \n", + "0.649000 | \n", + "
660 | \n", + "0.667900 | \n", + "
670 | \n", + "0.633000 | \n", + "
680 | \n", + "0.635800 | \n", + "
690 | \n", + "0.645200 | \n", + "
700 | \n", + "0.641000 | \n", + "
710 | \n", + "0.613200 | \n", + "
720 | \n", + "0.644100 | \n", + "
730 | \n", + "0.614900 | \n", + "
740 | \n", + "0.604600 | \n", + "
750 | \n", + "0.613800 | \n", + "
760 | \n", + "0.583300 | \n", + "
770 | \n", + "0.648700 | \n", + "
780 | \n", + "0.594600 | \n", + "
790 | \n", + "0.584100 | \n", + "
800 | \n", + "0.630900 | \n", + "
810 | \n", + "0.580700 | \n", + "
820 | \n", + "0.628100 | \n", + "
830 | \n", + "0.617400 | \n", + "
840 | \n", + "0.632200 | \n", + "
850 | \n", + "0.661200 | \n", + "
860 | \n", + "0.636200 | \n", + "
870 | \n", + "0.617100 | \n", + "
880 | \n", + "0.587000 | \n", + "
890 | \n", + "0.595000 | \n", + "
900 | \n", + "0.613700 | \n", + "
910 | \n", + "0.650500 | \n", + "
920 | \n", + "0.635700 | \n", + "
930 | \n", + "0.585400 | \n", + "
940 | \n", + "0.640100 | \n", + "
950 | \n", + "0.602500 | \n", + "
960 | \n", + "0.602300 | \n", + "
970 | \n", + "0.628400 | \n", + "
980 | \n", + "0.591400 | \n", + "
990 | \n", + "0.613800 | \n", + "
1000 | \n", + "0.630400 | \n", + "
1010 | \n", + "0.635500 | \n", + "
1020 | \n", + "0.603000 | \n", + "
1030 | \n", + "0.638600 | \n", + "
1040 | \n", + "0.610000 | \n", + "
1050 | \n", + "0.617700 | \n", + "
1060 | \n", + "0.602900 | \n", + "
1070 | \n", + "0.642400 | \n", + "
1080 | \n", + "0.574200 | \n", + "
1090 | \n", + "0.610700 | \n", + "
1100 | \n", + "0.602400 | \n", + "
1110 | \n", + "0.565800 | \n", + "
1120 | \n", + "0.611800 | \n", + "
1130 | \n", + "0.598400 | \n", + "
1140 | \n", + "0.572400 | \n", + "
1150 | \n", + "0.609400 | \n", + "
1160 | \n", + "0.623400 | \n", + "
1170 | \n", + "0.643900 | \n", + "
1180 | \n", + "0.603100 | \n", + "
1190 | \n", + "0.582800 | \n", + "
1200 | \n", + "0.625300 | \n", + "
1210 | \n", + "0.621200 | \n", + "
1220 | \n", + "0.602500 | \n", + "
1230 | \n", + "0.603600 | \n", + "
1240 | \n", + "0.605700 | \n", + "
1250 | \n", + "0.621700 | \n", + "
1260 | \n", + "0.634400 | \n", + "
1270 | \n", + "0.600800 | \n", + "
1280 | \n", + "0.598900 | \n", + "
1290 | \n", + "0.608400 | \n", + "
1300 | \n", + "0.603900 | \n", + "
1310 | \n", + "0.631000 | \n", + "
1320 | \n", + "0.604400 | \n", + "
1330 | \n", + "0.565300 | \n", + "
1340 | \n", + "0.612100 | \n", + "
1350 | \n", + "0.595600 | \n", + "
1360 | \n", + "0.546300 | \n", + "
1370 | \n", + "0.616500 | \n", + "
1380 | \n", + "0.633100 | \n", + "
1390 | \n", + "0.601200 | \n", + "
1400 | \n", + "0.571400 | \n", + "
1410 | \n", + "0.573000 | \n", + "
1420 | \n", + "0.576800 | \n", + "
1430 | \n", + "0.624200 | \n", + "
1440 | \n", + "0.611300 | \n", + "
1450 | \n", + "0.592600 | \n", + "
1460 | \n", + "0.621700 | \n", + "
1470 | \n", + "0.602400 | \n", + "
1480 | \n", + "0.631600 | \n", + "
1490 | \n", + "0.628400 | \n", + "
1500 | \n", + "0.590400 | \n", + "
1510 | \n", + "0.576200 | \n", + "
1520 | \n", + "0.559700 | \n", + "
1530 | \n", + "0.558600 | \n", + "
1540 | \n", + "0.555300 | \n", + "
1550 | \n", + "0.586700 | \n", + "
1560 | \n", + "0.620300 | \n", + "
1570 | \n", + "0.604200 | \n", + "
1580 | \n", + "0.565900 | \n", + "
1590 | \n", + "0.605600 | \n", + "
1600 | \n", + "0.576200 | \n", + "
1610 | \n", + "0.586900 | \n", + "
1620 | \n", + "0.538900 | \n", + "
1630 | \n", + "0.630800 | \n", + "
1640 | \n", + "0.573800 | \n", + "
1650 | \n", + "0.590000 | \n", + "
1660 | \n", + "0.554300 | \n", + "
1670 | \n", + "0.561400 | \n", + "
1680 | \n", + "0.608600 | \n", + "
1690 | \n", + "0.624900 | \n", + "
1700 | \n", + "0.574500 | \n", + "
1710 | \n", + "0.585700 | \n", + "
1720 | \n", + "0.584100 | \n", + "
1730 | \n", + "0.568800 | \n", + "
1740 | \n", + "0.604800 | \n", + "
1750 | \n", + "0.557200 | \n", + "
1760 | \n", + "0.565600 | \n", + "
1770 | \n", + "0.598000 | \n", + "
1780 | \n", + "0.565200 | \n", + "
1790 | \n", + "0.604600 | \n", + "
1800 | \n", + "0.588700 | \n", + "
1810 | \n", + "0.552400 | \n", + "
1820 | \n", + "0.627300 | \n", + "
1830 | \n", + "0.601200 | \n", + "
1840 | \n", + "0.558900 | \n", + "
1850 | \n", + "0.541700 | \n", + "
1860 | \n", + "0.610900 | \n", + "
1870 | \n", + "0.558100 | \n", + "
1880 | \n", + "0.584000 | \n", + "
1890 | \n", + "0.607400 | \n", + "
1900 | \n", + "0.561400 | \n", + "
1910 | \n", + "0.571900 | \n", + "
1920 | \n", + "0.634600 | \n", + "
1930 | \n", + "0.589600 | \n", + "
1940 | \n", + "0.583700 | \n", + "
1950 | \n", + "0.581000 | \n", + "
1960 | \n", + "0.563500 | \n", + "
1970 | \n", + "0.555800 | \n", + "
1980 | \n", + "0.568500 | \n", + "
1990 | \n", + "0.574000 | \n", + "
2000 | \n", + "0.597900 | \n", + "
2010 | \n", + "0.601300 | \n", + "
2020 | \n", + "0.586100 | \n", + "
2030 | \n", + "0.551000 | \n", + "
2040 | \n", + "0.581400 | \n", + "
2050 | \n", + "0.580100 | \n", + "
2060 | \n", + "0.588300 | \n", + "
2070 | \n", + "0.558400 | \n", + "
2080 | \n", + "0.561900 | \n", + "
2090 | \n", + "0.622200 | \n", + "
2100 | \n", + "0.577300 | \n", + "
2110 | \n", + "0.618500 | \n", + "
2120 | \n", + "0.585500 | \n", + "
2130 | \n", + "0.553000 | \n", + "
2140 | \n", + "0.569300 | \n", + "
2150 | \n", + "0.559400 | \n", + "
2160 | \n", + "0.552000 | \n", + "
2170 | \n", + "0.572300 | \n", + "
2180 | \n", + "0.541100 | \n", + "
2190 | \n", + "0.571200 | \n", + "
2200 | \n", + "0.530100 | \n", + "
2210 | \n", + "0.560900 | \n", + "
2220 | \n", + "0.604000 | \n", + "
2230 | \n", + "0.581000 | \n", + "
2240 | \n", + "0.550900 | \n", + "
2250 | \n", + "0.590800 | \n", + "
2260 | \n", + "0.603700 | \n", + "
2270 | \n", + "0.581600 | \n", + "
2280 | \n", + "0.583100 | \n", + "
2290 | \n", + "0.608700 | \n", + "
2300 | \n", + "0.574100 | \n", + "
2310 | \n", + "0.567600 | \n", + "
2320 | \n", + "0.593300 | \n", + "
2330 | \n", + "0.573500 | \n", + "
2340 | \n", + "0.557900 | \n", + "
2350 | \n", + "0.587300 | \n", + "
2360 | \n", + "0.550300 | \n", + "
2370 | \n", + "0.572700 | \n", + "
2380 | \n", + "0.533900 | \n", + "
2390 | \n", + "0.564300 | \n", + "
2400 | \n", + "0.552300 | \n", + "
2410 | \n", + "0.548400 | \n", + "
2420 | \n", + "0.605300 | \n", + "
2430 | \n", + "0.640300 | \n", + "
2440 | \n", + "0.565600 | \n", + "
2450 | \n", + "0.541200 | \n", + "
2460 | \n", + "0.557400 | \n", + "
2470 | \n", + "0.534200 | \n", + "
2480 | \n", + "0.541800 | \n", + "
2490 | \n", + "0.554900 | \n", + "
2500 | \n", + "0.531100 | \n", + "
2510 | \n", + "0.616900 | \n", + "
2520 | \n", + "0.549000 | \n", + "
2530 | \n", + "0.607900 | \n", + "
2540 | \n", + "0.568100 | \n", + "
2550 | \n", + "0.557300 | \n", + "
2560 | \n", + "0.582100 | \n", + "
2570 | \n", + "0.555500 | \n", + "
2580 | \n", + "0.553700 | \n", + "
2590 | \n", + "0.559900 | \n", + "
2600 | \n", + "0.512100 | \n", + "
2610 | \n", + "0.556200 | \n", + "
2620 | \n", + "0.559800 | \n", + "
2630 | \n", + "0.563400 | \n", + "
2640 | \n", + "0.557700 | \n", + "
2650 | \n", + "0.624200 | \n", + "
2660 | \n", + "0.553400 | \n", + "
2670 | \n", + "0.582400 | \n", + "
2680 | \n", + "0.553800 | \n", + "
2690 | \n", + "0.527100 | \n", + "
2700 | \n", + "0.557600 | \n", + "
2710 | \n", + "0.565900 | \n", + "
2720 | \n", + "0.609700 | \n", + "
2730 | \n", + "0.600900 | \n", + "
2740 | \n", + "0.547100 | \n", + "
2750 | \n", + "0.569500 | \n", + "
2760 | \n", + "0.587900 | \n", + "
2770 | \n", + "0.545300 | \n", + "
2780 | \n", + "0.551800 | \n", + "
2790 | \n", + "0.546400 | \n", + "
2800 | \n", + "0.615400 | \n", + "
2810 | \n", + "0.574700 | \n", + "
2820 | \n", + "0.609600 | \n", + "
2830 | \n", + "0.547100 | \n", + "
2840 | \n", + "0.598900 | \n", + "
2850 | \n", + "0.574900 | \n", + "
2860 | \n", + "0.565000 | \n", + "
2870 | \n", + "0.530400 | \n", + "
2880 | \n", + "0.558300 | \n", + "
2890 | \n", + "0.563300 | \n", + "
2900 | \n", + "0.555800 | \n", + "
2910 | \n", + "0.536700 | \n", + "
2920 | \n", + "0.565800 | \n", + "
2930 | \n", + "0.568500 | \n", + "
2940 | \n", + "0.547200 | \n", + "
2950 | \n", + "0.571600 | \n", + "
2960 | \n", + "0.579100 | \n", + "
2970 | \n", + "0.533600 | \n", + "
2980 | \n", + "0.545400 | \n", + "
2990 | \n", + "0.550600 | \n", + "
"
+ ],
+ "text/plain": [
+ "