diff --git a/docs/source/api_ref_data.rst b/docs/source/api_ref_data.rst index 3f868e2048..454cf128c8 100644 --- a/docs/source/api_ref_data.rst +++ b/docs/source/api_ref_data.rst @@ -6,8 +6,6 @@ torchtune.data .. currentmodule:: torchtune.data -.. _chat_formats: - Text templates -------------- @@ -18,14 +16,12 @@ and models. :toctree: generated/ :nosignatures: - InstructTemplate GrammarErrorCorrectionTemplate SummarizeTemplate QuestionAnswerTemplate PromptTemplate PromptTemplateInterface ChatMLTemplate - ChatFormat Types ----- @@ -37,18 +33,6 @@ Types Message Role -Converters ----------- - -Converts data from common JSON formats into a torchtune :class:`Message`. - -.. autosummary:: - :toctree: generated/ - :nosignatures: - - get_sharegpt_messages - get_openai_messages - .. _message_transforms_ref: Message transforms diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index 40def346e4..98d328ee54 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -6,11 +6,11 @@ torchtune.datasets .. currentmodule:: torchtune.datasets -For a detailed general usage guide, please see our :ref:`datasets tutorial `. +For a detailed general usage guide, please see :ref:`datasets_overview`. Text datasets ------------------- +------------- torchtune supports several widely used text-only datasets to help quickly bootstrap your fine-tuning. diff --git a/docs/source/basics/packing.rst b/docs/source/basics/packing.rst new file mode 100644 index 0000000000..2673079de5 --- /dev/null +++ b/docs/source/basics/packing.rst @@ -0,0 +1,54 @@ +.. _packing_usage_label: + +============== +Sample packing +============== + +Sample packing involves concatenating multiple samples from your dataset into a single sequence, upto a maximum +sequence length. This requires some pre-processing of the dataset which may +slow down time-to-first-batch, but can introduce significant training speedups +depending on the dataset. In torchtune, sample packing is done by iterating through your dataset and performing +greedy packing upon dataset initialization. You can use sample packing with any of the single dataset builders by passing in +:code:`packed=True`. + +To set the max sequence length to pack to, make sure to define ``max_seq_len`` on your tokenizer. + +.. code-block:: python + + from torchtune.datasets import alpaca_dataset, PackedDataset + from torchtune.models.llama3 import llama3_tokenizer + + # Load in tokenizer + tokenizer = llama3_tokenizer( + path="/tmp/Llama-3.2-1B-Instruct/original/tokenizer.model", + max_seq_len=8192, + ) + dataset = alpaca_dataset( + tokenizer=tokenizer, + packed=True, + ) + print(isinstance(dataset, PackedDataset)) # True + +.. code-block:: yaml + + # YAML config + tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model + max_seq_len: 8192 + + dataset: + _component_: torchtune.datasets.alpaca_dataset + packed: True + +.. code-block:: bash + + # Command line + tune run full_finetune_single_device --config llama3_2/1B_full_single_device \ + dataset.packed=True tokenizer.max_seq_len=8192 + +torchtune will automatically handle document masking and relative position IDs when sample packing is enabled +to prevent different irrelevant samples from cross-attending. This is done via PyTorch's `Flex Attention `_, +which enables the use of flash attention with non-causal masks. If your hardware does not support Flex Attention +(for CUDA devices, it must be Turing or above), standard SDPA with memory-efficient attention will be used as a fallback, +while retaining the document masking and relative position IDs. diff --git a/docs/source/index.rst b/docs/source/index.rst index 315b7f44e0..f4bc3925ff 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -131,6 +131,7 @@ torchtune tutorials. basics/message_transforms basics/tokenizers basics/prompt_templates + basics/packing .. toctree:: :glob: @@ -144,7 +145,6 @@ torchtune tutorials. tutorials/qlora_finetune tutorials/qat_finetune tutorials/e2e_flow - tutorials/datasets tutorials/memory_optimizations tutorials/llama_kd_tutorial diff --git a/docs/source/recipes/lora_finetune_single_device.rst b/docs/source/recipes/lora_finetune_single_device.rst index be294d325a..83d7a385c0 100644 --- a/docs/source/recipes/lora_finetune_single_device.rst +++ b/docs/source/recipes/lora_finetune_single_device.rst @@ -51,7 +51,6 @@ Interested in seeing this recipe in action? Check out some of our tutorials to s * :ref:`Finetuning Llama2 with LoRA` * :ref:`Finetuning Llama2 with QLoRA` -* :ref:`End-to-End Workflow with torchtune` * :ref:`Fine-tuning Llama3 with Chat Data` * :ref:`Meta Llama3 in torchtune` * :ref:`Fine-Tune Your First LLM` diff --git a/docs/source/tutorials/chat.rst b/docs/source/tutorials/chat.rst index a5d5454d7f..ee29529007 100644 --- a/docs/source/tutorials/chat.rst +++ b/docs/source/tutorials/chat.rst @@ -18,7 +18,7 @@ custom chat dataset for fine-tuning Llama3 Instruct. .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites - * Be familiar with :ref:`configuring datasets` + * Be familiar with :ref:`configuring datasets` * Know how to :ref:`download Llama3 Instruct weights ` diff --git a/docs/source/tutorials/datasets.rst b/docs/source/tutorials/datasets.rst deleted file mode 100644 index 781573b89e..0000000000 --- a/docs/source/tutorials/datasets.rst +++ /dev/null @@ -1,562 +0,0 @@ -.. _dataset_tutorial_label: - -==================================== -Configuring Datasets for Fine-Tuning -==================================== - -This tutorial will guide you through how to set up a dataset to fine-tune on. - -.. grid:: 2 - - .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn - - * How to quickly get started with built-in datasets - * How to use any dataset from Hugging Face Hub - * How to use instruct, chat, or text completion datasets - * How to configure datasets from code, config, or command-line - * How to fully customize your own dataset - - .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites - - * Know how to :ref:`configure components from the config` - -Datasets are a core component of fine-tuning workflows that serve as a "steering -wheel" to guide LLM generation for a particular use case. Many publicly shared -open-source datasets have become popular for fine-tuning LLMs and serve as a great -starting point to train your model. torchtune gives you the tools to download external -community datasets, load in custom local datasets, or create your own datasets. - -Built-in datasets ------------------ - -To use one of the built-in datasets in the library, simply import and call the dataset builder -function. You can see a list of all supported datasets :ref:`here`. - -.. code-block:: python - - from torchtune.datasets import alpaca_dataset - - # Load in tokenizer - tokenizer = ... - dataset = alpaca_dataset(tokenizer) - -.. code-block:: yaml - - # YAML config - dataset: - _component_: torchtune.datasets.alpaca_dataset - -.. code-block:: bash - - # Command line - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset=torchtune.datasets.alpaca_dataset - -Hugging Face datasets ---------------------- - -We provide first class support for datasets on the Hugging Face hub. Under the hood, -all of our built-in datasets and dataset builders are using Hugging Face's `load_dataset() `_ -to load in your data, whether local or on the hub. - -You can pass in a Hugging Face dataset path to the ``source`` parameter in any of our builders -to specify which dataset on the hub to download or use from a local directory path (see `Local and remote datasets`_). Additionally, all builders accept -any keyword-arguments that ``load_dataset()`` supports. You can see a full list -on Hugging Face's `documentation. `_ - -.. code-block:: python - - from torchtune.datasets import text_completion_dataset - - # Load in tokenizer - tokenizer = ... - dataset = text_completion_dataset( - tokenizer, - source="allenai/c4", - # Keyword-arguments that are passed into load_dataset - split="train", - data_dir="realnewslike", - ) - -.. code-block:: yaml - - # YAML config - dataset: - _component_: torchtune.datasets.text_completion_dataset - source: allenai/c4 - split: train - data_dir: realnewslike - -.. code-block:: bash - - # Command line - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset=torchtune.datasets.text_completion_dataset dataset.source=allenai/c4 \ - dataset.split=train dataset.data_dir=realnewslike - -Setting max sequence length ---------------------------- - -The default collator, :func:`~torchtune.data.padded_collate`, used in all -our training recipes will pad samples to the max sequence length within the batch, -not globally. If you wish to set an upper limit on the max sequence length globally, -you can specify it in the dataset builder with ``max_seq_len``. Any sample in the dataset -that is longer than ``max_seq_len`` will be truncated in :func:`~torchtune.data.truncate`. -The tokenizer's EOS ids are ensured to be the last token, except in :class:`~torchtune.datasets.TextCompletionDataset`. - -Generally, you want the max sequence length returned in each data sample to match the context window -size of your model. You can also decrease this value to reduce memory usage -depending on your hardware constraints. - -.. code-block:: python - - from torchtune.datasets import alpaca_dataset - - # Load in tokenizer - tokenizer = ... - dataset = alpaca_dataset( - tokenizer=tokenizer, - max_seq_len=4096, - ) - -.. code-block:: yaml - - # YAML config - dataset: - _component_: torchtune.datasets.alpaca_dataset - max_seq_len: 4096 - -.. code-block:: bash - - # Command line - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset.max_seq_len=4096 - -Sample packing --------------- - -You can use sample packing with any of the single dataset builders by passing in -:code:`packed=True`. This requires some pre-processing of the dataset which may -slow down time-to-first-batch, but can introduce significant training speedups -depending on the dataset. - -.. code-block:: python - - from torchtune.datasets import alpaca_dataset, PackedDataset - - # Load in tokenizer - tokenizer = ... - dataset = alpaca_dataset( - tokenizer=tokenizer, - packed=True, - ) - print(isinstance(dataset, PackedDataset)) # True - -.. code-block:: yaml - - # YAML config - dataset: - _component_: torchtune.datasets.alpaca_dataset - packed: True - -.. code-block:: bash - - # Command line - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset.packed=True - - -Custom unstructured text corpus -------------------------------- - -For continued pre-training, typically a similar data setup to pre-training is used -for a simple text completion task. This means no instruct templates, chat formats, -and minimal special tokens (only BOS and, optionally, EOS). To specify an unstructured text corpus, -you can use the :func:`~torchtune.datasets.text_completion_dataset` builder with -a Hugging Face dataset or a custom local corpus. Here is how to specify it for local -files: - -.. code-block:: python - - from torchtune.datasets import text_completion_dataset - - # Load in tokenizer - tokenizer = ... - dataset = text_completion_dataset( - tokenizer, - source="text", - data_files="path/to/my_data.txt", - split="train", - ) - -.. code-block:: yaml - - # YAML config - dataset: - _component_: torchtune.datasets.text_completion_dataset - source: text - data_files: path/to/my_data.txt - split: train - -.. code-block:: bash - - # Command line - tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ - dataset=torchtune.datasets.text_completion_dataset dataset.source=text \ - dataset.data_files=path/to/my_data.txt dataset.split=train - -Custom instruct dataset and instruct templates ----------------------------------------------- - -If you have a custom instruct dataset that's not already provided in the library, -you can use the :func:`~torchtune.datasets.instruct_dataset` builder and specify -the source path. Instruct datasets typically have multiple columns with text that -are formatted into a prompt template. - -To fine-tune an LLM on a particular task, a common approach is to create a fixed instruct -template that guides the model to generate output with a specific goal. Instruct templates -are simply flavor text that structures your inputs for the model. It is model agnostic -and is tokenized normally just like any other text, but it can help condition the model -to respond better to an expected format. For example, the :class:`~torchtune.data.AlpacaInstructTemplate` -structures the data in the following way: - -.. code-block:: python - - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - -Here is an example of a sample that is formatted with :class:`~torchtune.data.AlpacaInstructTemplate`: - -.. code-block:: python - - from torchtune.data import AlpacaInstructTemplate - - sample = { - "instruction": "Classify the following into animals, plants, and minerals", - "input": "Oak tree, copper ore, elephant", - } - prompt = AlpacaInstructTemplate.format(sample) - print(prompt) - # Below is an instruction that describes a task, paired with an input that provides further context. - # Write a response that appropriately completes the request. - # - # ### Instruction: - # Classify the following into animals, plants, and minerals - # - # ### Input: - # Oak tree, copper ore, elephant - # - # ### Response: - # - -We provide :ref:`other instruct templates ` -for common tasks such summarization and grammar correction. If you need to create your own -instruct template for a custom task, you can inherit from :class:`~torchtune.data.InstructTemplate` -and create your own class. - -.. code-block:: python - - from torchtune.datasets import instruct_dataset - from torchtune.data import InstructTemplate - - class CustomTemplate(InstructTemplate): - # Define the template as string with {} as placeholders for data columns - template = ... - - # Implement this method - @classmethod - def format( - cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None - ) -> str: - ... - - # Load in tokenizer - tokenizer = ... - dataset = instruct_dataset( - tokenizer=tokenizer, - source="my/dataset/path", - template="import.path.to.CustomTemplate", - ) - -.. code-block:: yaml - - # YAML config - dataset: - _component_: torchtune.datasets.instruct_dataset - source: my/dataset/path - template: import.path.to.CustomTemplate - -.. code-block:: bash - - # Command line - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset=torchtune.datasets.instruct_dataset dataset.source=my/dataset/path \ - dataset.template=import.path.to.CustomTemplate - - -torchtune uses :code:`importlib.import_module` (see ``importlib`` `docs `_ for more details) -to locate components from their dotpaths. You can place your custom template class -in any Python file as long as the file is accessible by Python's import mechanism. -This means the module should be in a directory that is included in Python's search -paths (:code:`sys.path`). This often includes: - -- The current directory from which your Python interpreter or script is run. -- Directories where Python packages are installed (like :code:`site-packages`). -- Any directories added to :code:`sys.path` at runtime using :code:`sys.path.append` or through the :code:`PYTHONPATH` environment variable. - - -Custom chat dataset and chat formats ------------------------------------- - -If you have a custom chat/conversational dataset that's not already provided in the library, -you can use the :func:`~torchtune.datasets.chat_dataset` builder and specify -the source path. Chat datasets typically have a single column with multiple back -and forth messages between the user and assistant. - -Chat formats are similar to instruct templates, except that they format system, -user, and assistant messages into a list of messages (see :class:`~torchtune.data.ChatFormat`) -for a conversational dataset. These can be configured quite similarly to instruct -datasets. - -Here is how messages would be formatted using the :class:`~torchtune.data.Llama2ChatFormat`: - -.. code-block:: python - - from torchtune.data import Llama2ChatFormat, Message - - messages = [ - Message( - role="system", - content="You are a helpful, respectful, and honest assistant.", - ), - Message( - role="user", - content="I am going to Paris, what should I see?", - ), - Message( - role="assistant", - content="Paris, the capital of France, is known for its stunning architecture..." - ), - ] - formatted_messages = Llama2ChatFormat.format(messages) - print(formatted_messages) - # [ - # Message( - # role="user", - # content="[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\n" - # "I am going to Paris, what should I see? [/INST] ", - # ), - # Message( - # role="assistant", - # content="Paris, the capital of France, is known for its stunning architecture..." - # ), - # ] - -Note that the system message is now incorporated in the user message. If you create custom ChatFormats -you can also add more advanced behavior. - -.. code-block:: python - - from torchtune.datasets import chat_dataset - from torchtune.data import ChatFormat - - class CustomChatFormat(ChatFormat): - # Define templates for system, user, assistant messages - # as strings with {} as placeholders for message content - system = ... - user = ... - assistant = ... - - # Implement this method - @classmethod - def format( - cls, - sample: List[Message], - ) -> List[Message]: - ... - - # Load in tokenizer - tokenizer = ... - dataset = chat_dataset( - tokenizer=tokenizer, - source="my/dataset/path", - split="train", - conversation_style="openai", - chat_format="import.path.to.CustomChatFormat", - ) - -.. code-block:: yaml - - # YAML config - dataset: - _component_: torchtune.datasets.chat_dataset - source: my/dataset/path - conversation_style: openai - chat_format: import.path.to.CustomChatFormat - -.. code-block:: bash - - # Command line - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset=torchtune.datasets.chat_dataset dataset.source=my/dataset/path \ - dataset.conversation_style=openai dataset.chat_format=import.path.to.CustomChatFormat - - -Multiple in-memory datasets ---------------------------- - -It is also possible to train on multiple datasets and configure them individually using -our :class:`~torchtune.datasets.ConcatDataset` interface. You can even mix instruct and chat datasets -or other custom datasets. - -.. code-block:: yaml - - # YAML config - dataset: - - _component_: torchtune.datasets.instruct_dataset - source: vicgalle/alpaca-gpt4 - template: torchtune.data.AlpacaInstructTemplate - split: train - train_on_input: True - - _component_: torchtune.datasets.instruct_dataset - source: samsum - template: torchtune.data.SummarizeTemplate - column_map: - output: summary - split: train - train_on_input: False - - _component_: torchtune.datasets.chat_dataset - ... - - -Local and remote datasets -------------------------- - -To use a dataset saved on your local hard drive, simply specify the file type for -``source`` and pass in the ``data_files`` argument using any of the dataset -builder functions. We support all `file types `_ -supported by Hugging Face's ``load_dataset``, including csv, json, txt, and more. - -.. code-block:: python - - from torchtune.datasets import instruct_dataset - - # Load in tokenizer - tokenizer = ... - # Local files - dataset = instruct_dataset( - tokenizer=tokenizer, - source="csv", - split="train", - template="import.path.to.CustomTemplate" - data_files="path/to/my/data.csv", - ) - # Remote files - dataset = instruct_dataset( - tokenizer=tokenizer, - source="json", - split="train", - template="import.path.to.CustomTemplate" - data_files="https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", - # You can also pass in any kwarg that load_dataset accepts - field="data", - ) - -.. code-block:: yaml - - # YAML config - local files - dataset: - _component_: torchtune.datasets.instruct_dataset - source: csv - template: import.path.to.CustomTemplate - data_files: path/to/my/data.csv - - # YAML config - remote files - dataset: - _component_: torchtune.datasets.instruct_dataset - source: json - template: import.path.to.CustomTemplate - data_files: https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json - field: data - -.. code-block:: bash - - # Command line - local files - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset=torchtune.datasets.chat_dataset dataset.source=csv \ - dataset.template=import.path.to.CustomTemplate dataset.data_files=path/to/my/data.csv - -Fully customized datasets -------------------------- - -More advanced tasks and dataset formats that don't fit into the templating and processing -that :class:`~torchtune.datasets.SFTDataset` and :class:`~torchtune.datasets.TextCompletionDataset` provide may require -you to create your own dataset class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`, -which has custom functionality for RLHF preference data, as an example to understand what you'll need to do. - -.. code-block:: python - - chosen_message = [ - Message(role="user", content=prompt, masked=True), - Message(role="assistant", content=transformed_sample[key_chosen]), - ] - rejected_message = [ - Message(role="user", content=prompt, masked=True), - Message(role="assistant", content=transformed_sample[key_rejected]), - ] - - chosen_input_ids, c_masks = self._tokenizer.tokenize_messages( - chosen_message, self.max_seq_len - ) - chosen_labels = list( - np.where(c_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids) - ) - - rejected_input_ids, r_masks = self._tokenizer.tokenize_messages( - rejected_message, self.max_seq_len - ) - rejected_labels = list( - np.where(r_masks, CROSS_ENTROPY_IGNORE_IDX, rejected_input_ids) - ) - -For a specific dataset that's easy to customize from the config, you can create -a builder function. This is the builder function for the :func:`~torchtune.datasets.stack_exchanged_paired_dataset`, -which creates a :class:`~torchtune.datasets.PreferenceDataset` configured to use -a paired dataset from Hugging Face. Notice that we've also had -to add a custom instruct template as well. - -.. code-block:: python - - def stack_exchanged_paired_dataset( - tokenizer: ModelTokenizer, - max_seq_len: int = 1024, - ) -> PreferenceDataset: - return PreferenceDataset( - tokenizer=tokenizer, - source="lvwerra/stack-exchange-paired", - template=StackExchangedPairedTemplate(), - column_map={ - "prompt": "question", - "chosen": "response_j", - "rejected": "response_k", - }, - max_seq_len=max_seq_len, - split="train", - data_dir="data/rl", - ) - -Now we can easily specify our custom dataset from the config, or from command-line. - -.. code-block:: yaml - - # This is how you would configure the Alpaca dataset using the builder - dataset: - _component_: torchtune.datasets.stack_exchanged_paired_dataset - max_seq_len: 512 - -.. code-block:: bash - - # Command line - local files - tune run full_finetune_single_device --config llama3/8B_full_single_device \ - dataset=torchtune.datasets.stack_exchanged_paired_dataset dataset.max_seq_len=512 diff --git a/recipes/configs/generation.yaml b/recipes/configs/generation.yaml index efc93700a1..e9c5d0d4f5 100644 --- a/recipes/configs/generation.yaml +++ b/recipes/configs/generation.yaml @@ -27,11 +27,12 @@ tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer path: /tmp/Llama-2-7b-hf/tokenizer.model max_seq_len: null + prompt_template: null # Generation arguments; defaults taken from gpt-fast -prompt: "Tell me a joke?" -instruct_template: null -chat_format: null +prompt: + system: null + user: "Tell me a joke." max_new_tokens: 300 temperature: 0.6 # 0.8 and 0.6 are popular values to try top_k: 300 diff --git a/recipes/generate.py b/recipes/generate.py index fea44ddace..56723b04bd 100644 --- a/recipes/generate.py +++ b/recipes/generate.py @@ -6,15 +6,14 @@ import itertools import sys import time -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List import torch from omegaconf import DictConfig from torch import nn from torchtune import config, generation, training, utils -from torchtune.config._utils import _get_component_from_path -from torchtune.data import ChatFormat, InstructTemplate, Message +from torchtune.data import Message, Role from torchtune.training import FullModelTorchTuneCheckpointer logger = utils.get_logger("DEBUG") @@ -100,54 +99,28 @@ def _setup_model( def convert_prompt_to_tokens( self, - prompt: Union[DictConfig, str], - chat_format: Optional[ChatFormat], - instruct_template: Optional[InstructTemplate], - ) -> List[Message]: + prompt: Dict[Role, str], + ) -> List[int]: """ - Either: - (1) a raw string is passed as the prompt, in which case we call tokenizer.encode directly, or - (2) a DictConfig is passed as the prompt. In this case there are three possibilities: - (a) an InstructTemplate is provided. Since instruct templates output a string, we will - call tokenizer.encode on the output of the instruct template. - (b) a ChatFormat is provided. Since chat formats output a list of messages, we will - call tokenizer.tokenize_messages on the output of the chat format. - (c) neither an InstructTemplate nor a ChatFormat is provided. In this case we will - convert the DictConfig to a list of messages and call tokenizer.tokenize_messages directly. + Convert the prompt string to a user message with optional system messages + and tokenize using the prompt template defined on the tokenizer. """ - - # Should only be chat-style prompt or instruct-style prompt - if chat_format and instruct_template: - raise ValueError( - "Cannot pass both chat format and instruct template for generation" - ) - - # If instruct template is provided, assert that the prompt is a DictConfig - # and apply it - if instruct_template: - if not isinstance(prompt, DictConfig): - raise ValueError("Cannot apply instruct template to raw string") - instruct_template = _get_component_from_path(instruct_template) - prompt = instruct_template.format(prompt) - - # To hit this block, either the raw prompt is a string or an - # instruct template has been provided to convert it to a string - if isinstance(prompt, str): - return self._tokenizer.encode(prompt, add_bos=True, add_eos=False) - - # dict.items() will respect order for Python >= 3.7 - else: - messages = [Message(role=k, content=v) for k, v in prompt.items()] - messages += [Message(role="assistant", content="")] - if chat_format: - chat_format = _get_component_from_path(chat_format) - messages = chat_format.format(messages) - return self._tokenizer.tokenize_messages(messages)[0] + messages = [] + if "system" in prompt and prompt["system"] is not None: + messages.append(Message(role="system", content=prompt["system"])) + messages.extend( + [ + Message(role="user", content=prompt["user"]), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ] + ) + return self._tokenizer({"messages": messages}, inference=True)["tokens"] @torch.inference_mode() def generate(self, cfg: DictConfig): tokens = self.convert_prompt_to_tokens( - cfg.prompt, cfg.get("chat_format", None), cfg.get("instruct_template", None) + cfg.prompt, ) prompt = torch.tensor(tokens, dtype=torch.int, device=self._device) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8ba28f1bf4..bcb26285a1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -168,32 +168,6 @@ def image_id(self): return -2 -class DummyChatFormat: - - B_SYS, E_SYS = "System:\n", "\n" - B_INST, E_INST = "User:\n", "\nAssistant:\n" - B_ASST, E_ASST = "", "" - system = f"{B_SYS}{{content}}{E_SYS}" - user = f"{B_INST}{{content}}{E_INST}" - assistant = f"{B_ASST}{{content}}{E_ASST}" - - @classmethod - def format( - cls, - messages, - ): - formats = {"system": cls.system, "user": cls.user, "assistant": cls.assistant} - formatted_dialogue = [] - for message in messages: - content = formats.get(message.role).format( - content=message.content[0]["content"] - ) - formatted_dialogue.append( - Message(role=message.role, content=content, masked=message.masked), - ) - return formatted_dialogue - - DummyPromptTemplate = partial( PromptTemplate, template={ diff --git a/tests/torchtune/data/test_converters.py b/tests/torchtune/data/test_converters.py deleted file mode 100644 index 8c0265630e..0000000000 --- a/tests/torchtune/data/test_converters.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from tests.test_utils import ( - assert_dialogue_equal, - CHAT_SAMPLE, - MESSAGE_SAMPLE, - MESSAGE_SAMPLE_TRAIN_ON_INPUT, -) -from torchtune.data import get_openai_messages, get_sharegpt_messages - - -class TestShareGPTToLlama2Messages: - samples = { - "conversations": [ - { - "from": "system", - "value": CHAT_SAMPLE["system"], - }, - { - "from": "human", - "value": CHAT_SAMPLE["user"], - }, - { - "from": "gpt", - "value": CHAT_SAMPLE["assistant"], - }, - ] - } - - def test_conversion(self): - converted_messages = get_sharegpt_messages(self.samples) - assert_dialogue_equal(converted_messages, MESSAGE_SAMPLE) - - def test_conversion_train_on_input(self): - converted_messages = get_sharegpt_messages(self.samples, train_on_input=True) - assert_dialogue_equal(converted_messages, MESSAGE_SAMPLE_TRAIN_ON_INPUT) - - -class TestOpenAIToLlama2Messages: - samples_1 = { - "id": "DUMMY", - "conversations": [ - { - "role": "system", - "content": CHAT_SAMPLE["system"], - }, - { - "role": "user", - "content": CHAT_SAMPLE["user"], - }, - { - "role": "assistant", - "content": CHAT_SAMPLE["assistant"], - }, - ], - } - - samples_2 = { - "id": "DUMMY", - "messages": [ - { - "role": "system", - "content": CHAT_SAMPLE["system"], - }, - { - "role": "user", - "content": CHAT_SAMPLE["user"], - }, - { - "role": "assistant", - "content": CHAT_SAMPLE["assistant"], - }, - ], - } - - def test_conversion_conversations_key(self): - converted_messages_1 = get_openai_messages(self.samples_1) - assert_dialogue_equal(converted_messages_1, MESSAGE_SAMPLE) - - def test_conversion_messages_key(self): - converted_messages_2 = get_openai_messages(self.samples_2) - assert_dialogue_equal(converted_messages_2, MESSAGE_SAMPLE) - - def test_conversion_conversations_key_train_on_input(self): - converted_messages_1 = get_openai_messages(self.samples_1, train_on_input=True) - assert_dialogue_equal(converted_messages_1, MESSAGE_SAMPLE_TRAIN_ON_INPUT) - - def test_conversion_messages_key_train_on_input(self): - converted_messages_2 = get_openai_messages(self.samples_2, train_on_input=True) - assert_dialogue_equal(converted_messages_2, MESSAGE_SAMPLE_TRAIN_ON_INPUT) diff --git a/tests/torchtune/datasets/test_chat_dataset.py b/tests/torchtune/datasets/test_chat_dataset.py index 8e99ad92ae..4c6af92aa0 100644 --- a/tests/torchtune/datasets/test_chat_dataset.py +++ b/tests/torchtune/datasets/test_chat_dataset.py @@ -4,19 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import pytest from tests.common import ASSETS -from tests.test_utils import DummyChatFormat, DummyTokenizer +from tests.test_utils import DummyTokenizer from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.datasets import chat_dataset class TestChatDataset: - @pytest.fixture - def chat_format(self): - return DummyChatFormat - - def test_get_item(self, chat_format): + def test_get_item(self): expected_tokenized_prompts = [ [ 0, diff --git a/tests/torchtune/datasets/test_instruct_dataset.py b/tests/torchtune/datasets/test_instruct_dataset.py index 04f3c2f49c..b0dc25112a 100644 --- a/tests/torchtune/datasets/test_instruct_dataset.py +++ b/tests/torchtune/datasets/test_instruct_dataset.py @@ -7,7 +7,6 @@ import pytest from tests.common import ASSETS from tests.test_utils import DummyTokenizer -from torchtune.data import InstructTemplate from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.datasets import instruct_dataset @@ -18,14 +17,6 @@ def dummy_transform(sample): return sample -class DummyTemplate(InstructTemplate): - template = "Instruction:\n{instruction}\n\nResponse:\n" - - @classmethod - def format(cls, sample, column_map): - return cls.template.format(**sample) - - class TestInstructDataset: @pytest.mark.parametrize("train_on_input", [True, False]) def test_get_item(self, train_on_input): diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 0f7e00ae99..0d1c1f12bf 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.data._chat_formats import ChatFormat from torchtune.data._collate import ( left_pad_sequence, padded_collate, @@ -14,8 +13,6 @@ padded_collate_tiled_images_and_mask, ) from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX -from torchtune.data._converters import get_openai_messages, get_sharegpt_messages -from torchtune.data._instruct_templates import InstructTemplate from torchtune.data._messages import ( AlpacaToMessages, ChosenRejectedToMessages, @@ -37,10 +34,8 @@ from torchtune.data._utils import format_content_with_images, load_image, truncate __all__ = [ - "ChatFormat", "CROSS_ENTROPY_IGNORE_IDX", "GrammarErrorCorrectionTemplate", - "InstructTemplate", "SummarizeTemplate", "OpenAIToMessages", "ShareGPTToMessages", @@ -56,8 +51,6 @@ "ChosenRejectedToMessages", "QuestionAnswerTemplate", "ChatMLTemplate", - "get_openai_messages", - "get_sharegpt_messages", "padded_collate_sft", "padded_collate_dpo", "left_pad_sequence", diff --git a/torchtune/data/_chat_formats.py b/torchtune/data/_chat_formats.py deleted file mode 100644 index 3c08711b44..0000000000 --- a/torchtune/data/_chat_formats.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple - -from torchtune.data._messages import Message, Role - - -class ChatFormat(ABC): - """ - Warning: - This class is deprecated and will be removed in a future release. Please use - :class:`~torchtune.data.PromptTemplate` for custom chat formats. - - Interface for chat formats. Each chat format should include tags for system, - user, and assistant roles that are prepended or appended to the message - content. - """ - - # Template should map role to a tuple containing the tag to prepend to the text - # and tag to append to the text. Leave as empty strings to not prepend or append - template: Dict[Role, Tuple[str, str]] - - @classmethod - @abstractmethod - def format( - cls, - sample: List[Message], - ) -> List[Message]: - """ - Format each role's message(s) according to the chat format - - Args: - sample (List[Message]): a single conversation, structured as a list - of `Message` objects - - Returns: - The formatted list of messages - """ - pass diff --git a/torchtune/data/_converters.py b/torchtune/data/_converters.py deleted file mode 100644 index d54ba7c008..0000000000 --- a/torchtune/data/_converters.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, List, Mapping - -from torchtune.data._messages import Message -from torchtune.utils._logging import deprecated - - -@deprecated( - msg="Please use an instance of `torchtune.data.ShareGPTToMessages` as the " - "`message_transform` argument for `torchtune.datasets.SFTDataset` instead." -) -def get_sharegpt_messages( - sample: Mapping[str, Any], train_on_input: bool = False -) -> List[Message]: - """ - Warning: - This class is deprecated and will be removed in a future release. Please use - :class:`~torchtune.data.ShareGPTToMessages` instead. The following are equivalent: - - .. code-block:: python - - # Deprecated - transformed_sample = get_sharegpt_messages(sample, train_on_input=True) - - # New - transformed_sample = ShareGPTToMessages(train_on_input=True)(sample) - - Convert a chat sample adhering to the ShareGPT json structure to torchtune's :class:`~torchtune.data.Message` - structure. - - ShareGPT follows:: - - { - "conversations": [ - { - "from": , - "value": , - }, - ... - ] - } - - :class:`~torchtune.data.Message` follows:: - - [ - { - "role": , - "content": , - }, - ... - ] - - Args: - sample (Mapping[str, Any]): a single data sample with "conversations" field pointing - to a list of dict messages. - train_on_input (bool): whether the prompt should remain unmasked. Default: False - - Returns: - List[Message]: A list of messages with "role" and "content" fields. - """ - role_map = {"system": "system", "human": "user", "gpt": "assistant"} - conversations = sample["conversations"] - - messages = [] - for message in conversations: - role = role_map[message["from"]] - content = message["value"] - masked = (role != "assistant") and (not train_on_input) - messages.append( - Message( - role=role, content=[{"type": "text", "content": content}], masked=masked - ) - ) - return messages - - -@deprecated( - msg="Please use an instance of `torchtune.data.OpenAIToMessages` as the " - "`message_transform` argument for `torchtune.datasets.SFTDataset` instead." -) -def get_openai_messages( - sample: Mapping[str, Any], - train_on_input: bool = False, -) -> List[Message]: - """ - Warning: - This class is deprecated and will be removed in a future release. Please use - :class:`~torchtune.data.OpenAIToMessages` instead. The following are equivalent: - - .. code-block:: python - - # Deprecated - transformed_sample = get_openai_messages(sample, train_on_input=True) - - # New - transformed_sample = OpenAIToMessages(train_on_input=True)(sample) - - Convert a chat sample adhering to the OpenAI API json structure to torchtune's :class:`~torchtune.data.Message` - structure. - - OpenAI API `standard chat format `_ follows:: - - { - # key could be "messages" OR "conversations" - "messages": [ - { - "role": , - "content": , - }, - ... - ] - } - - :class:`~torchtune.data.Message` follows:: - - [ - { - "role": , - "content": , - }, - ... - ] - - Args: - sample (Mapping[str, Any]): a single data sample with "conversations" field pointing - to a list of dict messages. - train_on_input (bool): whether the prompt should remain unmasked. Default: False - - Raises: - ValueError: If the sample does not contain "messages" or "conversations" key. - - Returns: - List[Message]: A list of messages with "role" and "content" fields. - """ - if "messages" in sample: - messages_key = "messages" - elif "conversations" in sample: - messages_key = "conversations" - else: - raise ValueError( - f"Sample does not contain 'messages' or 'conversations' key. Existing keys: {sample.keys()}" - ) - conversations = sample[messages_key] - - messages = [] - for message in conversations: - message["masked"] = (message["role"] != "assistant") and (not train_on_input) - messages.append(Message.from_dict(message)) - return messages diff --git a/torchtune/data/_instruct_templates.py b/torchtune/data/_instruct_templates.py deleted file mode 100644 index 4e8f3ad5a5..0000000000 --- a/torchtune/data/_instruct_templates.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import Any, Dict, Mapping, Optional - - -class InstructTemplate(ABC): - """ - Warning: - This class is deprecated and will be removed in a future release. Please use - :class:`~torchtune.data.PromptTemplate` for custom instruct templates. - - Interface for instruction templates. Each template should include the template - prompt with placeholders for the data inputs. - """ - - template = "" - - @classmethod - @abstractmethod - def format( - cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None - ) -> str: - """ - Format the prompt template with the given arguments. - - Args: - sample (Mapping[str, Any]): a single data sample with various fields - column_map (Optional[Dict[str, str]]): a mapping from the expected - placeholder names in the template to the column names in the sample. - If None, assume these are identical. Note: if the sample output is not named - as "output" in the dataset, you always need to map it to "output" in column_map. - - Returns: - The formatted prompt - """ - pass diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index 304a605641..bf85bf0939 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -52,14 +52,11 @@ class ConcatDataset(Dataset): dataset: - _component_: torchtune.datasets.instruct_dataset source: vicgalle/alpaca-gpt4 - template: torchtune.data.AlpacaInstructTemplate split: train train_on_input: True - _component_: torchtune.datasets.instruct_dataset source: samsum - template: torchtune.data.SummarizeTemplate column_map: {"output": "summary"} - output: summary split: train train_on_input: False