From a3fe457e89bfb9ca86dbea398ab00ebbc56fb81c Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 22 Jul 2024 18:44:19 -0700 Subject: [PATCH 01/16] initial commit --- torchtune/models/gemma/_tokenizer.py | 23 +++++++++++++++-- torchtune/models/llama2/_tokenizer.py | 23 +++++++++++++++-- torchtune/models/llama3/_tokenizer.py | 34 ++++++++++++++++++++------ torchtune/models/mistral/_tokenizer.py | 23 +++++++++++++++-- torchtune/models/phi3/_tokenizer.py | 23 +++++++++++++++-- 5 files changed, 111 insertions(+), 15 deletions(-) diff --git a/torchtune/models/gemma/_tokenizer.py b/torchtune/models/gemma/_tokenizer.py index bed4f8606c..eb7fea4c1d 100644 --- a/torchtune/models/gemma/_tokenizer.py +++ b/torchtune/models/gemma/_tokenizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple +from typing import Any, List, Mapping, Optional, Tuple from torchtune.data import Message from torchtune.modules.tokenizers import ( @@ -12,11 +12,12 @@ SentencePieceBaseTokenizer, tokenize_messages_no_special_tokens, ) +from torchtune.modules.transforms import Transform WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] -class GemmaTokenizer(ModelTokenizer): +class GemmaTokenizer(ModelTokenizer, Transform): """ Gemma's implementation of the SentencePiece tokenizer @@ -119,3 +120,21 @@ def tokenize_messages( eos_id=self.eos_id, max_seq_len=max_seq_len, ) + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/torchtune/models/llama2/_tokenizer.py b/torchtune/models/llama2/_tokenizer.py index 4358a48566..d8c822d4c2 100644 --- a/torchtune/models/llama2/_tokenizer.py +++ b/torchtune/models/llama2/_tokenizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple +from typing import Any, List, Mapping, Optional, Tuple from torchtune.data import Message from torchtune.modules.tokenizers import ( @@ -12,11 +12,12 @@ SentencePieceBaseTokenizer, tokenize_messages_no_special_tokens, ) +from torchtune.modules.transforms import Transform WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] -class Llama2Tokenizer(ModelTokenizer): +class Llama2Tokenizer(ModelTokenizer, Transform): """ Llama2's implementation of the SentencePiece tokenizer. Llama2Tokenizer does not include any additional special tokens. The prompt template described in @@ -131,3 +132,21 @@ def tokenize_messages( eos_id=self.eos_id, max_seq_len=max_seq_len, ) + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/torchtune/models/llama3/_tokenizer.py b/torchtune/models/llama3/_tokenizer.py index 0f1509b4ca..fa61d7b66c 100644 --- a/torchtune/models/llama3/_tokenizer.py +++ b/torchtune/models/llama3/_tokenizer.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Mapping, Optional, Tuple from torchtune.data import Message, truncate from torchtune.modules.tokenizers import ModelTokenizer, TikTokenBaseTokenizer +from torchtune.modules.transforms import Transform CL100K_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa @@ -38,7 +39,7 @@ LLAMA3_SPECIAL_TOKENS = {**SPECIAL_TOKENS, **RESERVED_TOKENS} -class Llama3Tokenizer(ModelTokenizer): +class Llama3Tokenizer(ModelTokenizer, Transform): """ tiktoken tokenizer configured with Llama3 Instruct's special tokens, as described in https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 @@ -60,6 +61,7 @@ def __init__( self, path: str, special_tokens: Optional[Dict[str, int]] = None, + max_seq_len: Optional[int] = None, ): self.special_tokens = ( special_tokens if special_tokens is not None else LLAMA3_SPECIAL_TOKENS @@ -95,6 +97,7 @@ def __init__( eos_id=self.eos_id, special_tokens=self.special_tokens, ) + self.max_seq_len = max_seq_len def _validate_special_tokens( self, @@ -216,7 +219,6 @@ def tokenize_message( def tokenize_messages( self, messages: List[Message], - max_seq_len: Optional[int] = None, add_eos: bool = True, ) -> Tuple[List[int], List[bool]]: """ @@ -237,14 +239,32 @@ def tokenize_messages( tokens = tokens + tokenized_message mask = mask + ([message.masked] * len(tokenized_message)) - if max_seq_len and len(tokens) >= max_seq_len: + if self.max_seq_len and len(tokens) >= self.max_seq_len: break if add_eos: tokens = tokens + [self.eos_id] mask = mask + [True] - if max_seq_len: - tokens = truncate(tokens, max_seq_len, self.eos_id) - mask = truncate(mask, max_seq_len, True) + if self.max_seq_len: + tokens = truncate(tokens, self.max_seq_len, self.eos_id) + mask = truncate(mask, self.max_seq_len, True) return tokens, mask + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/torchtune/models/mistral/_tokenizer.py b/torchtune/models/mistral/_tokenizer.py index f8a2f4b645..ac35f53465 100644 --- a/torchtune/models/mistral/_tokenizer.py +++ b/torchtune/models/mistral/_tokenizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple +from typing import Any, List, Mapping, Optional, Tuple from torchtune.data import Message from torchtune.modules.tokenizers import ( @@ -12,11 +12,12 @@ SentencePieceBaseTokenizer, tokenize_messages_no_special_tokens, ) +from torchtune.modules.transforms import Transform WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] -class MistralTokenizer(ModelTokenizer): +class MistralTokenizer(ModelTokenizer, Transform): """ Mistral's implementation of the SentencePiece tokenizer @@ -147,3 +148,21 @@ def tokenize_messages( eos_id=self.eos_id, max_seq_len=max_seq_len, ) + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample diff --git a/torchtune/models/phi3/_tokenizer.py b/torchtune/models/phi3/_tokenizer.py index 888be82c1c..958b5ac64e 100644 --- a/torchtune/models/phi3/_tokenizer.py +++ b/torchtune/models/phi3/_tokenizer.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Mapping, Optional, Tuple from torchtune.data import Message, truncate from torchtune.modules.tokenizers import ModelTokenizer, SentencePieceBaseTokenizer +from torchtune.modules.transforms import Transform PHI3_SPECIAL_TOKENS = { "<|endoftext|>": 32000, @@ -24,7 +25,7 @@ } -class Phi3MiniTokenizer(ModelTokenizer): +class Phi3MiniTokenizer(ModelTokenizer, Transform): """ SentencePiece tokenizer configured with Phi3 Mini's special tokens. @@ -219,3 +220,21 @@ def tokenize_messages( mask = truncate(mask, max_seq_len, message.masked) return tokenized_messages, mask + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply ``tokenize_messages`` to the "messages" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with added "tokens" and "mask" fields + and the "messages" field removed. + """ + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample From 969909def631250b9122781b12ecf56036509fd7 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Fri, 26 Jul 2024 16:49:02 -0700 Subject: [PATCH 02/16] flesh out prompt templates --- recipes/full_finetune_distributed.py | 4 +- tests/test_utils.py | 17 ++- .../datasets/test_grammar_dataset.py | 19 +-- torchtune/data/_prompt_templates.py | 137 ++++++++++++++++++ torchtune/datasets/__init__.py | 2 + torchtune/datasets/_finetune.py | 122 ++++++++++++++++ torchtune/datasets/_grammar.py | 61 ++++++-- 7 files changed, 329 insertions(+), 33 deletions(-) create mode 100644 torchtune/data/_prompt_templates.py create mode 100644 torchtune/datasets/_finetune.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b7cfa040a3..b9cdde574c 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -106,7 +106,7 @@ def __init__(self, cfg: DictConfig) -> None: if ( cfg.get("fsdp_cpu_offload", False) - and cfg.optimizer.get("fused", False) + and cfg.get("fused", False) and not utils.torch_version_ge("2.4.0") ): raise RuntimeError( @@ -397,7 +397,7 @@ def _setup_data( ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( diff --git a/tests/test_utils.py b/tests/test_utils.py index 35ea7cef22..45ef7f6189 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,7 +12,7 @@ from contextlib import contextmanager from io import StringIO from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, TextIO, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union import pytest @@ -20,6 +20,7 @@ from torch import nn from torchtune.data import ChatFormat, Message, truncate from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms import Transform skip_if_cuda_not_available = unittest.skipIf( not torch.cuda.is_available(), "CUDA is not available" @@ -39,7 +40,7 @@ } -class DummyTokenizer(ModelTokenizer): +class DummyTokenizer(ModelTokenizer, Transform): def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]: words = text.split() tokens = [len(word) for word in words] @@ -69,15 +70,16 @@ def tokenize_messages( mask.append(message.masked) # Tokenize current message, append with masks + tokens = [] for item in message.content: if item["type"] == "text": - tokens = self.encode( + tokens = tokens + self.encode( item["content"], add_bos=False, add_eos=False, ) elif item["type"] == "image": - tokens = [self.image_id] + tokens = tokens + [self.image_id] tokenized_messages.extend(tokens) mask.extend([message.masked] * len(tokens)) @@ -102,6 +104,13 @@ def tokenize_messages( return tokenized_messages, mask + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + messages = sample.pop("messages") + tokens, mask = self.tokenize_messages(messages) + sample["tokens"] = tokens + sample["mask"] = mask + return sample + @property def eos_id(self): return -1 diff --git a/tests/torchtune/datasets/test_grammar_dataset.py b/tests/torchtune/datasets/test_grammar_dataset.py index 832c289b5a..524ee95592 100644 --- a/tests/torchtune/datasets/test_grammar_dataset.py +++ b/tests/torchtune/datasets/test_grammar_dataset.py @@ -20,7 +20,7 @@ class TestGrammarDataset: def tokenizer(self): return DummyTokenizer() - @patch("torchtune.datasets._instruct.load_dataset") + @patch("torchtune.datasets._finetune.load_dataset") def test_label_no_masking(self, load_dataset, tokenizer): """ Test whether the input and the labels are correctly created when the input is not masked. @@ -36,7 +36,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): ] ) - grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=True) + grammar_ds = grammar_dataset(model_transform=tokenizer, train_on_input=True) input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"] assert len(input) == len(labels) @@ -44,7 +44,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): assert input[0] == tokenizer.bos_id assert CROSS_ENTROPY_IGNORE_IDX not in labels - @patch("torchtune.datasets._instruct.load_dataset") + @patch("torchtune.datasets._finetune.load_dataset") def test_label_masking(self, load_dataset, tokenizer): """ Test whether the input and the labels are correctly created when the input is masked. @@ -60,15 +60,7 @@ def test_label_masking(self, load_dataset, tokenizer): ] ) - grammar_ds = grammar_dataset(tokenizer=tokenizer) - - # Extract the prompt and tokenize it; we'll need this to test whether we're masking the - # input correctly - sample = grammar_ds._data[0] - prompt = grammar_ds.template.format( - sample=sample, column_map={"sentence": "input"} - ) - encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) + grammar_ds = grammar_dataset(model_transform=tokenizer) # Generate the input and labels input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"] @@ -76,4 +68,5 @@ def test_label_masking(self, load_dataset, tokenizer): assert len(input) == len(labels) assert labels[-1] == tokenizer.eos_id assert input[0] == tokenizer.bos_id - assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) + # Check that the input is masked + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 16 diff --git a/torchtune/data/_prompt_templates.py b/torchtune/data/_prompt_templates.py new file mode 100644 index 0000000000..45b7dbd99c --- /dev/null +++ b/torchtune/data/_prompt_templates.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from functools import partial +from typing import Dict, List, Protocol, Tuple + +from torchtune.data import Message, Role + + +class PromptTemplate(Protocol): + """ + Interface for prompt templates. Each prompt template can include structured + text for system, user, and assistant roles that are prepended or appended to + the message content. + """ + + # Template should map role to a tuple containing the tag to prepend to the text + # and tag to append to the text. Leave as empty strings to not prepend or append + template: Dict[Role, Tuple[str, str]] + + def __call__( + self, + messages: List[Message], + ) -> List[Message]: + """ + Format each role's message(s) according to the prompt template + + Args: + messages (List[Message]): a single conversation, structured as a list + of :class:`~torchtune.data.Message` objects + + Returns: + The formatted list of messages + """ + pass + + +class CustomPromptTemplate(PromptTemplate): + """ + Define a custom prompt template by passing in a dictionary mapping role to + the prepend and append tags. For example, to achieve the following prompt + template:: + + System: {content}\n + User: {content}\n + Assistant: {content}\n + Tool: {content}\n + + You can define the template as follows:: + + template = { + "system": ("System: ", "\n"), + "user": ("User: ", "\n"), + "assistant": ("Assistant: ", "\n"), + "ipython": ("Tool: ", "\n"), + } + + Once instantiated, you must call the prompt template on a list of messages. It + will return the same list of messages updated with the template. + + Args: + template (Dict[Role, Tuple[str, str]]): a dictionary mapping role to the + prepend and append tags + """ + + def __init__( + self, + template: Dict[Role, Tuple[str, str]], + ): + self.template = template + + def __call__(self, messages: List[Message]) -> List[Message]: + """ + Format each role's message(s) according to the prompt template by prepending + and appending the defined tags. + + Args: + messages (List[Message]): list of messages to apply the template to + + Returns: + List[Message]: The formatted list of messages + """ + formatted_dialogue = [] + for message in messages: + prepend_tag = self.template[message.role][0] + append_tag = self.template[message.role][1] + content = ( + [{"type": "text", "content": prepend_tag}] + + message.content + + [{"type": "text", "content": append_tag}] + ) + formatted_dialogue.append( + Message( + role=message.role, + content=content, + masked=message.masked, + ipython=message.ipython, + eot=message.eot, + ), + ) + return formatted_dialogue + + +GrammarErrorCorrectionTemplate = partial( + CustomPromptTemplate, + template={ + "user": ("Correct this to standard English: ", "\n---\n"), + "assistant": ("Corrected: ", ""), + }, +) +GrammarErrorCorrectionTemplate.__doc__ = """ +A prompt template for grammar error correction tasks:: + + Correct this to standard English: {user_message} + --- + Corrected: {assistant_message} + +""" +SummarizeTemplate = partial( + CustomPromptTemplate, + template={ + "user": ("Summarize this dialogue:\n", "\n---\n"), + "assistant": ("Summary:\n", ""), + }, +) +SummarizeTemplate.__doc__ = """ +A prompt template for summarization tasks:: + + Summarize this dialogue: + {user_message} + --- + Summary: + {assistant_message} + +""" diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index dddf411e27..f1305bb719 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -8,6 +8,7 @@ from torchtune.datasets._chat import chat_dataset, ChatDataset from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset from torchtune.datasets._concat import ConcatDataset +from torchtune.datasets._finetune import FinetuneDataset from torchtune.datasets._grammar import grammar_dataset from torchtune.datasets._instruct import instruct_dataset, InstructDataset from torchtune.datasets._packed import PackedDataset @@ -39,4 +40,5 @@ "ConcatDataset", "wikitext_dataset", "PreferenceDataset", + "FinetuneDataset", ] diff --git a/torchtune/datasets/_finetune.py b/torchtune/datasets/_finetune.py new file mode 100644 index 0000000000..4643ce2b13 --- /dev/null +++ b/torchtune/datasets/_finetune.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, Mapping, Optional + +import numpy as np + +from datasets import load_dataset +from torch.utils.data import Dataset +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, PromptTemplate +from torchtune.modules.transforms import Transform + + +class FinetuneDataset(Dataset): + """ + Dataset class for creating instruct, chat, tool, or multimodal datasets for fine-tuning. + + All datasets can be considered "conversations" with the model, or AI assistant. + Thus, we can format all text content as messages in a conversation assigned to + a :class:`~torchtune.data.Role`: + - system messages contain the system prompt + - user messages contain the input prompt into the model + - assistant messages are the response of the model and what you actually want + to train for and compute loss directly against + - ipython messages are the return from a tool call + + Chat datasets are multiple rounds of user-assistant messages. Instruct datasets + are typically a single round involving a specific instruction and the model's response. + + The :class:`~torchtune.data.Message` forms the core data unit that all tokenizer + APIs expect. The key component of this class that ensures any dataset is transformed + into thie format is the ``message_transform``. This is a callable class that takes + in a sample dictionary - typically a single row from a Hugging Face dataset or a single + json - that processes the sample in any configurable way to output a list of messages:: + + [ + Message( + role=, + content=, + ), + ... + ] + + For any custom dataset, use the ``message_transform`` to contain all pre-processing to + return the list of messages. + + Any model specific pre-processing that needs to happen can be configured with the ``model_transform`` + parameter. This is another callable class that contains any custom logic tied to the + model you are fine-tuning. For example, text + image multimodal datasets requires processing + the images in a way specific to the vision encoder being used by the model and is agnostic + to the specific dataset. + + Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`s + can be treated as a ``model_transform`` since it uses the model-specific tokenizer to + transform the list of messages outputted from the ``message_transform`` into tokens + used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer` + into ``model_transform``. + + The general pipeline is then: raw sample -> optional filter -> apply dataset-specific message transform -> apply + optional prompt template -> apply model-specific transform -> tokens used for training + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample + and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys + chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual + messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not + as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed, + unless you want to structure messages in a particular way for inference. + max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + """ + + def __init__( + self, + *, + source: str, + message_transform: Transform, + model_transform: Transform, + prompt_template: Optional[PromptTemplate] = None, + filter_fn: Optional[Callable] = None, + **load_dataset_kwargs: Dict[str, Any], + ) -> None: + self._message_transform = message_transform + self._prompt_template = prompt_template + self._model_transform = model_transform + + self._data = load_dataset(source, **load_dataset_kwargs) + self._data = self._data.filter(filter_fn) + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int) -> Dict[str, Any]: + sample = self._data[index] + return self._prepare_sample(sample) + + def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + transformed_sample = self._message_transform(sample) + if self._prompt_template is not None: + transformed_sample["messages"] = self._prompt_template( + transformed_sample["messages"] + ) + tokenized_dict = self._model_transform(transformed_sample) + + # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens + tokenized_dict["labels"] = list( + np.where( + tokenized_dict["mask"], + CROSS_ENTROPY_IGNORE_IDX, + tokenized_dict["tokens"], + ) + ) + assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + + return tokenized_dict diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index 157c4ddc6e..401bf74c49 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -4,18 +4,54 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.datasets._instruct import instruct_dataset, InstructDataset -from torchtune.modules.tokenizers import ModelTokenizer + +from typing import Any, Dict, Mapping, Optional + +from torch.utils.data import Dataset +from torchtune.data import Message +from torchtune.data._templates import GrammarErrorCorrectionTemplate +from torchtune.datasets._finetune import FinetuneDataset +from torchtune.datasets._packed import PackedDataset +from torchtune.modules.transforms import Transform + + +class GrammarMessages(Transform): + def __init__( + self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self.column_map = column_map + self.template = GrammarErrorCorrectionTemplate() + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + column_map = self.column_map or {} + key_input = column_map.get("input", "input") + key_output = column_map.get("output", "output") + messages = [ + Message( + role="user", + content=sample[key_input], + masked=not self.train_on_input, + eot=False, + ), + Message( + role="assistant", + content=sample[key_output], + masked=False, + eot=True, + ), + ] + sample["messages"] = self.template(messages=messages) + return sample def grammar_dataset( - tokenizer: ModelTokenizer, + model_transform: Transform, *, source: str = "liweili/c4_200m", train_on_input: bool = False, packed: bool = False, - split: str = "train", -) -> InstructDataset: +) -> Dataset: """ Support for grammar correction datasets and their variants from Hugging Face Datasets. Here is an `example `_ of a grammar correction dataset. @@ -36,8 +72,6 @@ def grammar_dataset( source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. - split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset - of a given split, e.g. ``split="train[:10%]"``. Default is "train". Returns: InstructDataset: dataset configured with source data and template @@ -50,12 +84,11 @@ def grammar_dataset( >>> Batch size: 8 """ - return instruct_dataset( - tokenizer=tokenizer, + message_transform = GrammarMessages(train_on_input=train_on_input) + ds = FinetuneDataset( source=source, - template="torchtune.data.GrammarErrorCorrectionTemplate", - column_map={"sentence": "input"}, - train_on_input=train_on_input, - packed=packed, - split=split, + message_transform=message_transform, + model_transform=model_transform, + split="train", ) + return PackedDataset(ds) if packed else ds From ef79507b1850ba3c946a5e99caf26debf123a415 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Fri, 26 Jul 2024 17:41:13 -0700 Subject: [PATCH 03/16] refactor samsum --- recipes/full_finetune_distributed.py | 2 +- recipes/full_finetune_single_device.py | 4 +- recipes/lora_finetune_distributed.py | 4 +- recipes/lora_finetune_single_device.py | 4 +- recipes/qat_distributed.py | 4 +- tests/torchtune/data/test_converters.py | 2 +- .../torchtune/datasets/test_samsum_dataset.py | 8 +- .../models/llama3/test_llama3_tokenizer.py | 2 +- torchtune/data/__init__.py | 10 +- torchtune/data/_chat_formats.py | 2 +- torchtune/data/_converters.py | 2 +- torchtune/data/_instruct_templates.py | 107 ------------------ torchtune/data/{_types.py => _messages.py} | 33 +++++- torchtune/data/_utils.py | 2 +- torchtune/datasets/_finetune.py | 30 +++-- torchtune/datasets/_grammar.py | 51 +++------ torchtune/datasets/_samsum.py | 34 ++++-- torchtune/modules/tokenizers/_utils.py | 2 +- 18 files changed, 117 insertions(+), 186 deletions(-) rename torchtune/data/{_types.py => _messages.py} (78%) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b9cdde574c..07b84b37fa 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -391,7 +391,7 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index eed942b4c8..b76982cab0 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -338,13 +338,13 @@ def _setup_data( """ if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 12fd7c5eb9..a493e795ba 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -496,13 +496,13 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 4ac2d6876e..8331fb842b 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -428,13 +428,13 @@ def _setup_data( """ if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 5f5bb7b81a..ba990b8e11 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -419,13 +419,13 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( diff --git a/tests/torchtune/data/test_converters.py b/tests/torchtune/data/test_converters.py index 88d1d1a424..f86e60af88 100644 --- a/tests/torchtune/data/test_converters.py +++ b/tests/torchtune/data/test_converters.py @@ -6,7 +6,7 @@ from tests.test_utils import assert_dialogue_equal from torchtune.data import get_openai_messages, get_sharegpt_messages -from torchtune.data._types import Message +from torchtune.data._messages import Message # Taken from Open-Orca/SlimOrca-Dedup on Hugging Face: # https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup diff --git a/tests/torchtune/datasets/test_samsum_dataset.py b/tests/torchtune/datasets/test_samsum_dataset.py index eb275dea4f..a41903ab4c 100644 --- a/tests/torchtune/datasets/test_samsum_dataset.py +++ b/tests/torchtune/datasets/test_samsum_dataset.py @@ -20,7 +20,7 @@ class TestSamsumDataset: def tokenizer(self): return DummyTokenizer() - @patch("torchtune.datasets._instruct.load_dataset") + @patch("torchtune.datasets._finetune.load_dataset") def test_label_no_masking(self, load_dataset, tokenizer): """ Test whether the input and the labels are correctly created when the input is not masked. @@ -37,7 +37,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): ] ) - samsum_ds = samsum_dataset(tokenizer=tokenizer, train_on_input=True) + samsum_ds = samsum_dataset(model_transform=tokenizer, train_on_input=True) input, labels = samsum_ds[0]["tokens"], samsum_ds[0]["labels"] assert len(input) == len(labels) @@ -45,7 +45,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): assert input[0] == tokenizer.bos_id assert CROSS_ENTROPY_IGNORE_IDX not in labels - @patch("torchtune.datasets._instruct.load_dataset") + @patch("torchtune.datasets._finetune.load_dataset") def test_label_masking(self, load_dataset, tokenizer): """ Test whether the input and the labels are correctly created when the input is masked. @@ -62,7 +62,7 @@ def test_label_masking(self, load_dataset, tokenizer): ] ) - samsum_ds = samsum_dataset(tokenizer=tokenizer) + samsum_ds = samsum_dataset(model_transform=tokenizer) # Extract the prompt and tokenize it; we'll need this to test whether we're masking the # input correctly diff --git a/tests/torchtune/models/llama3/test_llama3_tokenizer.py b/tests/torchtune/models/llama3/test_llama3_tokenizer.py index 3553eb36b5..79569e0e29 100644 --- a/tests/torchtune/models/llama3/test_llama3_tokenizer.py +++ b/tests/torchtune/models/llama3/test_llama3_tokenizer.py @@ -7,7 +7,7 @@ from pathlib import Path import pytest -from torchtune.data._types import Message +from torchtune.data._messages import Message from torchtune.models.llama3 import llama3_tokenizer, Llama3Tokenizer ASSETS = Path(__file__).parent.parent.parent.parent / "assets" diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 37cf80bc5b..dbd250388c 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -14,12 +14,16 @@ from torchtune.data._converters import get_openai_messages, get_sharegpt_messages from torchtune.data._instruct_templates import ( AlpacaInstructTemplate, - GrammarErrorCorrectionTemplate, InstructTemplate, StackExchangedPairedTemplate, +) +from torchtune.data._messages import Message, Role +from torchtune.data._prompt_templates import ( + CustomPromptTemplate, + GrammarErrorCorrectionTemplate, + PromptTemplate, SummarizeTemplate, ) -from torchtune.data._types import Message, Role from torchtune.data._utils import truncate, validate_messages __all__ = [ @@ -39,4 +43,6 @@ "validate_messages", "StackExchangedPairedTemplate", "Role", + "CustomPromptTemplate", + "PromptTemplate", ] diff --git a/torchtune/data/_chat_formats.py b/torchtune/data/_chat_formats.py index 4bd87472da..9d1751bca8 100644 --- a/torchtune/data/_chat_formats.py +++ b/torchtune/data/_chat_formats.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Tuple -from torchtune.data._types import Message, Role +from torchtune.data._messages import Message, Role class ChatFormat(ABC): diff --git a/torchtune/data/_converters.py b/torchtune/data/_converters.py index e25910f76b..160aa2ac70 100644 --- a/torchtune/data/_converters.py +++ b/torchtune/data/_converters.py @@ -6,7 +6,7 @@ from typing import Any, List, Mapping -from torchtune.data._types import Message +from torchtune.data._messages import Message def get_sharegpt_messages( diff --git a/torchtune/data/_instruct_templates.py b/torchtune/data/_instruct_templates.py index e7acafe933..49f52c2ccc 100644 --- a/torchtune/data/_instruct_templates.py +++ b/torchtune/data/_instruct_templates.py @@ -130,113 +130,6 @@ def format( return prompt -class GrammarErrorCorrectionTemplate(InstructTemplate): - """ - Prompt template for grammar correction datasets. - - .. code-block:: text - - Correct this to standard English: - --- - Corrected: - - """ - - template = "Correct this to standard English: {sentence}\n---\nCorrected: " - - @classmethod - def format( - cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None - ) -> str: - """ - Generate prompt from sentence. - - Args: - sample (Mapping[str, Any]): a single data sample with sentence - column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names - in the template to the column names in the sample. If None, assume these are identical. - - Examples: - >>> # Simple sentence - >>> GrammarErrorCorrectionTemplate.format(sample={"sentence": "The quik brown fox jumps the lazy dog"}) - Correct this to standard English: The quik brown fox jumps the lazy dog - --- - Corrected: - - >>> # Sentence with column map where the 'sentence' key is actually named 'input' in the given sample - >>> GrammarErrorCorrectionTemplate.format( - ... sample={"input": "The quik brown fox jumps the lazy dog"}, - ... column_map={"sentence": "input"} - ... ) - Correct this to standard English: The quik brown fox jumps the lazy dog - --- - Corrected: - - Returns: - The formatted prompt - """ - column_map = column_map or {} - key_sentence = column_map.get("sentence", "sentence") - - prompt = cls.template.format(sentence=sample[key_sentence]) - return prompt - - -class SummarizeTemplate(InstructTemplate): - """ - Prompt template to format datasets for summarization tasks. - - .. code-block:: text - - Summarize this dialogue: - - --- - Summary: - - """ - - template = "Summarize this dialogue:\n{dialogue}\n---\nSummary:\n" - - @classmethod - def format( - cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None - ) -> str: - """ - Generate prompt from dialogue. - - Args: - sample (Mapping[str, Any]): a single data sample with dialog - column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names - in the template to the column names in the sample. If None, assume these are identical. - - Examples: - >>> # Simple dialogue - >>> SummarizeTemplate.format(sample={"dialogue": "Hello, how are you? Did you know the capital of France is Paris?"}) - Summarize this dialogue: - Hello, how are you? Did you know the capital of France is Paris? - --- - Summary: - - >>> # Dialogue with column map where the 'dialogue' key is actually named 'prompt' in the given sample - >>> SummarizeTemplate.format( - ... sample={"prompt": "Hello, how are you? Did you know the capital of France is Paris?"}, - ... column_map={"dialogue": "prompt"} - ... ) - Summarize this dialogue: - Hello, how are you? Did you know the capital of France is Paris? - --- - Summary: - - Returns: - The formatted prompt - """ - column_map = column_map or {} - key_dialogue = column_map.get("dialogue", "dialogue") - - prompt = cls.template.format(dialogue=sample[key_dialogue]) - return prompt - - class StackExchangedPairedTemplate(InstructTemplate): """ Prompt template for preference datasets similar to StackExchangedPaired. diff --git a/torchtune/data/_types.py b/torchtune/data/_messages.py similarity index 78% rename from torchtune/data/_types.py rename to torchtune/data/_messages.py index 75f43e7e26..1a77de959f 100644 --- a/torchtune/data/_types.py +++ b/torchtune/data/_messages.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Literal, Union +from typing import Any, Dict, List, Literal, Mapping, Optional, Union + +from torchtune.modules.transforms import Transform Role = Literal[ "system", # Origin is system prompt @@ -106,3 +108,32 @@ def _validate_message(self) -> None: raise RuntimeError( f"Only assistant messages can be tool calls. Found role {self.role} in message: {self.text_content}" ) + + +class ToInputOutputMessages(Transform): + def __init__( + self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self.column_map = column_map + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + column_map = self.column_map or {} + key_input = column_map.get("input", "input") + key_output = column_map.get("output", "output") + messages = [ + Message( + role="user", + content=sample[key_input], + masked=not self.train_on_input, + eot=False, + ), + Message( + role="assistant", + content=sample[key_output], + masked=False, + eot=True, + ), + ] + sample["messages"] = messages + return sample diff --git a/torchtune/data/_utils.py b/torchtune/data/_utils.py index 83cac657c2..36dd187775 100644 --- a/torchtune/data/_utils.py +++ b/torchtune/data/_utils.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional -from torchtune.data._types import Message +from torchtune.data._messages import Message def truncate( diff --git a/torchtune/datasets/_finetune.py b/torchtune/datasets/_finetune.py index 4643ce2b13..e746938a49 100644 --- a/torchtune/datasets/_finetune.py +++ b/torchtune/datasets/_finetune.py @@ -63,17 +63,27 @@ class FinetuneDataset(Dataset): optional prompt template -> apply model-specific transform -> tokens used for training Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. - source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) - convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample - and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys - chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual - messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not - as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed, - unless you want to structure messages in a particular way for inference. - max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. - train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + for more details. + message_transform (Transform): callable that keys into the desired fields in the sample + and converts text content to a list of :class:`~torchtune.data.Message`. It is expected that the final list + of messages are stored in the ``"messages"`` key. + model_transform (Transform): callable that applies model-specific pre-processing to the sample after the list of + messages is created from ``message_transform``. This includes tokenization and any modality-specific + transforms. + prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used + to add structured text around the actual messages. The structured text is used in three scenarios: + - Task-specific templates to gear models for a particular task that it will expect after training + - Model-specific templates that are required whenever the model is prompted, such as the [INST] + tags in Llama2 and in Mistral + - Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate` + The extra text will still get tokenized as normal text, not as special tokens. + filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See + the Hugging Face `docs `_ for more + details. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. """ diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index 401bf74c49..2250a40775 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -5,52 +5,28 @@ # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Mapping, Optional +from typing import Dict, Optional from torch.utils.data import Dataset -from torchtune.data import Message -from torchtune.data._templates import GrammarErrorCorrectionTemplate +from torchtune.data import ToInputOutputMessages +from torchtune.data._prompt_templates import ( + GrammarErrorCorrectionTemplate, + PromptTemplate, +) from torchtune.datasets._finetune import FinetuneDataset from torchtune.datasets._packed import PackedDataset from torchtune.modules.transforms import Transform -class GrammarMessages(Transform): - def __init__( - self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None - ): - self.train_on_input = train_on_input - self.column_map = column_map - self.template = GrammarErrorCorrectionTemplate() - - def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: - column_map = self.column_map or {} - key_input = column_map.get("input", "input") - key_output = column_map.get("output", "output") - messages = [ - Message( - role="user", - content=sample[key_input], - masked=not self.train_on_input, - eot=False, - ), - Message( - role="assistant", - content=sample[key_output], - masked=False, - eot=True, - ), - ] - sample["messages"] = self.template(messages=messages) - return sample - - def grammar_dataset( model_transform: Transform, *, source: str = "liweili/c4_200m", + column_map: Optional[Dict[str, str]] = None, + prompt_template: Optional[PromptTemplate] = GrammarErrorCorrectionTemplate(), train_on_input: bool = False, packed: bool = False, + split: str = "train", ) -> Dataset: """ Support for grammar correction datasets and their variants from Hugging Face Datasets. @@ -70,6 +46,8 @@ def grammar_dataset( Args: tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``. + prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default + is :class:`~torchtune.data.GrammarErrorCorrectionTemplate`. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. @@ -84,11 +62,14 @@ def grammar_dataset( >>> Batch size: 8 """ - message_transform = GrammarMessages(train_on_input=train_on_input) + message_transform = ToInputOutputMessages( + train_on_input=train_on_input, column_map=column_map + ) ds = FinetuneDataset( source=source, message_transform=message_transform, model_transform=model_transform, - split="train", + prompt_template=prompt_template, + split=split, ) return PackedDataset(ds) if packed else ds diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py index f1254eb6e3..d40fe4d9d8 100644 --- a/torchtune/datasets/_samsum.py +++ b/torchtune/datasets/_samsum.py @@ -4,18 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.datasets import instruct_dataset, InstructDataset -from torchtune.modules.tokenizers import ModelTokenizer + +from typing import Dict, Optional + +from torchtune.data import ToInputOutputMessages +from torchtune.data._prompt_templates import PromptTemplate, SummarizeTemplate +from torchtune.datasets._finetune import FinetuneDataset +from torchtune.datasets._packed import PackedDataset +from torchtune.modules.transforms import Transform def samsum_dataset( - tokenizer: ModelTokenizer, + model_transform: Transform, *, - source: str = "samsum", + source: str = "Samsung/samsum", + column_map: Optional[Dict[str, str]] = None, + prompt_template: Optional[PromptTemplate] = SummarizeTemplate(), train_on_input: bool = False, packed: bool = False, split: str = "train", -) -> InstructDataset: +) -> FinetuneDataset: """ Support for summarization datasets and their variants from Hugging Face Datasets. An example is the `SAMsum dataset `_. @@ -49,13 +57,15 @@ def samsum_dataset( >>> print(f"Batch size: {len(batch)}") >>> Batch size: 8 """ - - return instruct_dataset( - tokenizer=tokenizer, + column_map = column_map or {"input": "dialogue", "output": "summary"} + message_transform = ToInputOutputMessages( + train_on_input=train_on_input, column_map=column_map + ) + ds = FinetuneDataset( source=source, - template="torchtune.data.SummarizeTemplate", - column_map={"output": "summary"}, - train_on_input=train_on_input, - packed=packed, + message_transform=message_transform, + model_transform=model_transform, + prompt_template=prompt_template, split=split, ) + return PackedDataset(ds) if packed else ds diff --git a/torchtune/modules/tokenizers/_utils.py b/torchtune/modules/tokenizers/_utils.py index b3e04e3c01..c76fcc6634 100644 --- a/torchtune/modules/tokenizers/_utils.py +++ b/torchtune/modules/tokenizers/_utils.py @@ -7,7 +7,7 @@ import json from typing import Dict, List, Optional, Protocol, Tuple -from torchtune.data._types import Message +from torchtune.data._messages import Message from torchtune.data._utils import truncate From 7d542017f2d50a9bdc0ec7a1beace72f8927ee9d Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 30 Jul 2024 14:45:42 -0700 Subject: [PATCH 04/16] add all tests, update live docs --- docs/source/api_ref_data.rst | 2 + docs/source/api_ref_datasets.rst | 1 + recipes/full_finetune_distributed.py | 2 +- tests/test_utils.py | 13 +- .../torchtune/data/test_instruct_templates.py | 94 +------------ tests/torchtune/data/test_messages.py | 95 +++++++++++++ tests/torchtune/data/test_prompt_templates.py | 132 ++++++++++++++++++ .../datasets/test_finetune_dataset.py | 110 +++++++++++++++ .../datasets/test_grammar_dataset.py | 66 ++++++++- .../torchtune/datasets/test_samsum_dataset.py | 84 +++++++++-- torchtune/data/__init__.py | 3 +- torchtune/data/_messages.py | 11 +- torchtune/data/_prompt_templates.py | 33 +++-- torchtune/datasets/_finetune.py | 6 +- torchtune/datasets/_grammar.py | 13 +- torchtune/datasets/_samsum.py | 13 +- torchtune/models/llama3/_tokenizer.py | 3 +- 17 files changed, 540 insertions(+), 141 deletions(-) create mode 100644 tests/torchtune/data/test_messages.py create mode 100644 tests/torchtune/data/test_prompt_templates.py create mode 100644 tests/torchtune/datasets/test_finetune_dataset.py diff --git a/docs/source/api_ref_data.rst b/docs/source/api_ref_data.rst index 518216c4c8..a1b7584a86 100644 --- a/docs/source/api_ref_data.rst +++ b/docs/source/api_ref_data.rst @@ -23,6 +23,8 @@ and models. GrammarErrorCorrectionTemplate SummarizeTemplate StackExchangedPairedTemplate + PromptTemplate + CustomPromptTemplate ChatFormat ChatMLFormat diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index 09012b3abf..6aea567fae 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -56,3 +56,4 @@ Class representations for the above dataset builders. ConcatDataset PackedDataset PreferenceDataset + FinetuneDataset diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 889f40ff69..0b8a4bb6b6 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -106,7 +106,7 @@ def __init__(self, cfg: DictConfig) -> None: if ( cfg.get("fsdp_cpu_offload", False) - and cfg.get("fused", False) + and cfg.optimizer.get("fused", False) and not utils.torch_version_ge("2.4.0") ): raise RuntimeError( diff --git a/tests/test_utils.py b/tests/test_utils.py index 45ef7f6189..9182fff530 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,6 +10,7 @@ import sys import unittest from contextlib import contextmanager +from functools import partial from io import StringIO from pathlib import Path from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union @@ -18,7 +19,7 @@ import torch from torch import nn -from torchtune.data import ChatFormat, Message, truncate +from torchtune.data import ChatFormat, CustomPromptTemplate, Message, truncate from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform @@ -150,6 +151,16 @@ def format( return formatted_dialogue +DummyPromptTemplate = partial( + CustomPromptTemplate, + template={ + "system": ("System:\n", "\n"), + "user": ("User:\n", "\n"), + "assistant": ("Assistant:\n", "\n"), + }, +) + + def get_assets_path(): return Path(__file__).parent / "assets" diff --git a/tests/torchtune/data/test_instruct_templates.py b/tests/torchtune/data/test_instruct_templates.py index be8f34ebd8..f30a70a169 100644 --- a/tests/torchtune/data/test_instruct_templates.py +++ b/tests/torchtune/data/test_instruct_templates.py @@ -4,19 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.data import ( - AlpacaInstructTemplate, - GrammarErrorCorrectionTemplate, - SummarizeTemplate, -) - -# Taken from Open-Orca/SlimOrca-Dedup on Hugging Face: -# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup -CHAT_SAMPLE = { - "system": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950 - "user": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950 - "assistant": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950 -} +from torchtune.data import AlpacaInstructTemplate class TestAlpacaInstructTemplate: @@ -72,83 +60,3 @@ def test_format_with_column_map(self): actual = self.template.format(modified_sample, column_map=column_map) assert actual == expected_prompt - - -class TestGrammarErrorCorrectionTemplate: - samples = [ - { - "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", - "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", - }, - { - "input": "Much many brands and sellers still in the market.", - "output": "Many brands and sellers still in the market.", - }, - ] - expected_prompts = [ - ( - "Correct this to standard English: Bitcoin is for $7,094 this morning, which CoinDesk says.\n" - "---\n" - "Corrected: " - ), - ( - "Correct this to standard English: Much many brands and sellers still in the market.\n" - "---\n" - "Corrected: " - ), - ] - - template = GrammarErrorCorrectionTemplate() - - def test_format(self): - for sample, expected_prompt in zip(self.samples, self.expected_prompts): - column_map = {"sentence": "input"} - actual = self.template.format(sample, column_map=column_map) - assert actual == expected_prompt - - -class TestSummarizeTemplate: - samples = [ - { - "id": "13818513", - "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", - "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", - }, - { - "id": "13728867", - "dialogue": "Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great", # noqa: B950 - "summary": "Olivia and Olivier are voting for liberals in this election.", - }, - ] - expected_prompts = [ - ( - "Summarize this dialogue:\n" - "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)\n" - "---\n" - "Summary:\n" - ), - ( - "Summarize this dialogue:\n" - "Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great\n" - "---\n" - "Summary:\n" - ), - ] - - template = SummarizeTemplate() - - def test_format(self): - for sample, expected_prompt in zip(self.samples, self.expected_prompts): - actual = self.template.format(sample) - assert actual == expected_prompt - - def test_format_with_column_map(self): - column_map = {"dialogue": "not_a_dialogue"} - for sample, expected_prompt in zip(self.samples, self.expected_prompts): - modified_sample = sample.copy() - modified_sample["not_a_dialogue"] = modified_sample["dialogue"] - del modified_sample["dialogue"] - - actual = self.template.format(modified_sample, column_map=column_map) - - assert actual == expected_prompt diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py new file mode 100644 index 0000000000..243e4573a8 --- /dev/null +++ b/tests/torchtune/data/test_messages.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from tests.test_utils import assert_dialogue_equal +from torchtune.data._messages import Message, ToInputOutputMessages + + +class TestMessage: + @pytest.fixture + def text_message(self): + return Message(role="user", content="hello world") + + @pytest.fixture + def image_message(self): + return Message( + role="user", + content=[ + {"type": "text", "content": "hello"}, + {"type": "image"}, + {"type": "text", "content": "world"}, + ], + ) + + def test_message_validation(self, text_message): + message = text_message + assert message.role == "user" + assert message.content == [{"type": "text", "content": "hello world"}] + + with pytest.raises( + ValueError, + match="Only assistant messages can be tool calls. Found role user in message: hello world", + ): + message = Message(role="user", content="hello world", ipython=True) + + with pytest.raises( + ValueError, + match="Media tokens in tool calls are not supported. Both are set in message: hello world", + ): + message = Message( + role="user", + content=[{"type": "image"}], + ipython=True, + ) + + def test_from_dict(self): + message = Message.from_dict({"role": "user", "content": "hello world"}) + assert message.role == "user" + assert message.content[0]["content"] == "hello world" + assert not message.masked + assert not message.ipython + assert message.eot + + def test_contains_media(self, text_message, image_message): + assert not text_message.contains_media + assert image_message.contains_media + + def test_text_content(self, text_message, image_message): + assert text_message.text_content == "hello world" + assert image_message.text_content == "hello world" + + +class TestToInputOutputMessages: + @pytest.fixture + def sample(self): + return { + "maybe_input": "hello world", + "maybe_output": "hello world", + } + + def test_call(self, sample): + transform = ToInputOutputMessages( + column_map={"input": "maybe_input", "output": "maybe_output"} + ) + actual = transform(sample) + expected = [ + Message(role="user", content="hello world", masked=True, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual, expected) + + def test_call_train_on_input(self, sample): + transform = ToInputOutputMessages( + column_map={"input": "maybe_input", "output": "maybe_output"}, + train_on_input=True, + ) + actual = transform(sample) + expected = [ + Message(role="user", content="hello world", masked=False, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual, expected) diff --git a/tests/torchtune/data/test_prompt_templates.py b/tests/torchtune/data/test_prompt_templates.py new file mode 100644 index 0000000000..363a338eec --- /dev/null +++ b/tests/torchtune/data/test_prompt_templates.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from tests.test_utils import assert_dialogue_equal +from torchtune.data import GrammarErrorCorrectionTemplate, Message, SummarizeTemplate + + +class TestGrammarErrorCorrectionTemplate: + samples = [ + { + "messages": [ + Message( + role="user", + content="Bitcoin is for $7,094 this morning, which CoinDesk says.", + ), + Message( + role="assistant", + content="Bitcoin goes for $7,094 this morning, according to CoinDesk.", + ), + ] + }, + { + "messages": [ + Message( + role="user", + content="Much many brands and sellers still in the market.", + ), + Message( + role="assistant", + content="Many brands and sellers still in the market.", + ), + ], + }, + ] + expected_prompts = [ + [ + Message( + role="user", + content="Correct this to standard English: Bitcoin is for $7,094 this morning, which CoinDesk says.\n" + "---\n", + ), + Message( + role="assistant", + content="Corrected: Bitcoin goes for $7,094 this morning, according to CoinDesk.", + ), + ], + [ + Message( + role="user", + content="Correct this to standard English: Much many brands and sellers still in the market.\n" + "---\n", + ), + Message( + role="assistant", + content="Corrected: Many brands and sellers still in the market.", + ), + ], + ] + + template = GrammarErrorCorrectionTemplate() + + def test_call(self): + for sample, expected_prompt in zip(self.samples, self.expected_prompts): + actual = self.template(sample["messages"]) + assert_dialogue_equal(actual, expected_prompt) + + +class TestSummarizeTemplate: + samples = [ + { + "messages": [ + Message( + role="user", + content="Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + ), + Message( + role="assistant", + content="Amanda baked cookies and will bring Jerry some tomorrow.", + ), + ], + }, + { + "messages": [ + Message( + role="user", + content="Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great", # noqa: B950 + ), + Message( + role="assistant", + content="Olivia and Olivier are voting for liberals in this election.", + ), + ], + }, + ] + expected_prompts = [ + [ + Message( + role="user", + content="Summarize this dialogue:\n" + "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)\n" + "---\n", + ), + Message( + role="assistant", + content="Summary:\n" + "Amanda baked cookies and will bring Jerry some tomorrow.", + ), + ], + [ + Message( + role="user", + content="Summarize this dialogue:\n" + "Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great\n" + "---\n", + ), + Message( + role="assistant", + content="Summary:\n" + "Olivia and Olivier are voting for liberals in this election.", + ), + ], + ] + + template = SummarizeTemplate() + + def test_call(self): + for sample, expected_prompt in zip(self.samples, self.expected_prompts): + actual = self.template(sample["messages"]) + assert_dialogue_equal(actual, expected_prompt) diff --git a/tests/torchtune/datasets/test_finetune_dataset.py b/tests/torchtune/datasets/test_finetune_dataset.py new file mode 100644 index 0000000000..2c8dccb2b2 --- /dev/null +++ b/tests/torchtune/datasets/test_finetune_dataset.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Mapping +from unittest import mock + +import pytest +from tests.test_utils import DummyPromptTemplate, DummyTokenizer +from torchtune.data import Message +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets._finetune import FinetuneDataset +from torchtune.modules.transforms import Transform + + +class ToDummyMessages(Transform): + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + dialogue = sample["dialogue"] + messages = [Message.from_dict(d) for d in dialogue] + return {"messages": messages} + + +class TestFinetuneDataset: + @pytest.fixture + def dialogue(self): + return [ + { + "dialogue": [ + { + "role": "system", + "content": "You are an AI assistant.", + "masked": True, + }, + { + "role": "user", + "content": "What is the meaning of life?", + "masked": True, + }, + { + "role": "assistant", + "content": "The meaning of life is 42.", + "masked": False, + }, + { + "role": "user", + "content": "That's ridiculous.", + "masked": True, + }, + {"role": "assistant", "content": "I agree.", "masked": False}, + ], + }, + ] + + @mock.patch("torchtune.datasets._finetune.load_dataset") + def test_get_item(self, mock_load_dataset, dialogue): + mock_load_dataset.return_value = dialogue + expected_tokenized_prompts = [ + [ + 0, + 7, + 3, + 3, + 2, + 2, + 10, + 5, + 4, + 2, + 3, + 7, + 2, + 5, + 10, + 3, + 7, + 2, + 4, + 2, + 3, + -1, + 0, + 5, + 6, + 11, + 10, + 1, + 6, + -1, + ] + ] + prompt_lengths = (14, 4) + expected_labels = [ + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + + [10, 3, 7, 2, 4, 2, 3, -1] + + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + + [10, 1, 6, -1] + ] + ds = FinetuneDataset( + source="iam/agoofy/goober", + message_transform=ToDummyMessages(), + model_transform=DummyTokenizer(), + prompt_template=DummyPromptTemplate(), + ) + assert len(ds) == 1 + mock_load_dataset.assert_called_once() + prompt, label = ds[0]["tokens"], ds[0]["labels"] + assert prompt == expected_tokenized_prompts[0] + assert label == expected_labels[0] diff --git a/tests/torchtune/datasets/test_grammar_dataset.py b/tests/torchtune/datasets/test_grammar_dataset.py index 524ee95592..10b2b2b425 100644 --- a/tests/torchtune/datasets/test_grammar_dataset.py +++ b/tests/torchtune/datasets/test_grammar_dataset.py @@ -39,10 +39,36 @@ def test_label_no_masking(self, load_dataset, tokenizer): grammar_ds = grammar_dataset(model_transform=tokenizer, train_on_input=True) input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"] - assert len(input) == len(labels) - assert labels[-1] == tokenizer.eos_id - assert input[0] == tokenizer.bos_id - assert CROSS_ENTROPY_IGNORE_IDX not in labels + assert input == [ + 0, + 7, + 4, + 2, + 8, + 8, + 7, + 2, + 3, + 6, + 4, + 8, + 5, + 8, + 5, + 3, + 10, + 7, + 4, + 3, + 6, + 4, + 8, + 9, + 2, + 9, + -1, + ] + assert labels == input @patch("torchtune.datasets._finetune.load_dataset") def test_label_masking(self, load_dataset, tokenizer): @@ -65,8 +91,34 @@ def test_label_masking(self, load_dataset, tokenizer): # Generate the input and labels input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"] - assert len(input) == len(labels) - assert labels[-1] == tokenizer.eos_id - assert input[0] == tokenizer.bos_id + assert input == [ + 0, + 7, + 4, + 2, + 8, + 8, + 7, + 2, + 3, + 6, + 4, + 8, + 5, + 8, + 5, + 3, + 10, + 7, + 4, + 3, + 6, + 4, + 8, + 9, + 2, + 9, + -1, + ] # Check that the input is masked assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 16 diff --git a/tests/torchtune/datasets/test_samsum_dataset.py b/tests/torchtune/datasets/test_samsum_dataset.py index a41903ab4c..9b18e43455 100644 --- a/tests/torchtune/datasets/test_samsum_dataset.py +++ b/tests/torchtune/datasets/test_samsum_dataset.py @@ -40,10 +40,41 @@ def test_label_no_masking(self, load_dataset, tokenizer): samsum_ds = samsum_dataset(model_transform=tokenizer, train_on_input=True) input, labels = samsum_ds[0]["tokens"], samsum_ds[0]["labels"] - assert len(input) == len(labels) - assert labels[-1] == tokenizer.eos_id - assert input[0] == tokenizer.bos_id - assert CROSS_ENTROPY_IGNORE_IDX not in labels + assert input == [ + 0, + 9, + 4, + 9, + 7, + 1, + 5, + 8, + 2, + 3, + 4, + 5, + 6, + 5, + 7, + 4, + 5, + 3, + 8, + 3, + 3, + 8, + 6, + 5, + 7, + 3, + 4, + 5, + 5, + 4, + 9, + -1, + ] + assert labels == input @patch("torchtune.datasets._finetune.load_dataset") def test_label_masking(self, load_dataset, tokenizer): @@ -64,16 +95,41 @@ def test_label_masking(self, load_dataset, tokenizer): samsum_ds = samsum_dataset(model_transform=tokenizer) - # Extract the prompt and tokenize it; we'll need this to test whether we're masking the - # input correctly - sample = samsum_ds._data[0] - prompt = samsum_ds.template.format(sample=sample) - encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) - # Generate the input and labels input, labels = samsum_ds[0]["tokens"], samsum_ds[0]["labels"] - assert len(input) == len(labels) - assert labels[-1] == tokenizer.eos_id - assert input[0] == tokenizer.bos_id - assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) + assert input == [ + 0, + 9, + 4, + 9, + 7, + 1, + 5, + 8, + 2, + 3, + 4, + 5, + 6, + 5, + 7, + 4, + 5, + 3, + 8, + 3, + 3, + 8, + 6, + 5, + 7, + 3, + 4, + 5, + 5, + 4, + 9, + -1, + ] + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 21 diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index dbd250388c..9c21b3bdad 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -17,7 +17,7 @@ InstructTemplate, StackExchangedPairedTemplate, ) -from torchtune.data._messages import Message, Role +from torchtune.data._messages import Message, Role, ToInputOutputMessages from torchtune.data._prompt_templates import ( CustomPromptTemplate, GrammarErrorCorrectionTemplate, @@ -45,4 +45,5 @@ "Role", "CustomPromptTemplate", "PromptTemplate", + "ToInputOutputMessages", ] diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index dc04d20805..85a0cd4816 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -24,7 +24,9 @@ class Message: the appropriate special tokens based on the flags set in this class. Args: - role (Role): role of the message writer. Can be "system", "user", "assistant", or "ipython". + role (Role): role of the message writer. Can be "system" for system prompts, + "user" for human prompts, "assistant" for model responses, or "ipython" + for tool call returns. content (Union[str, List[Dict[str, str]]]): content of the message. If it is text only content, you can pass in a string. If it is multimodal content, pass in a list of dictionaries formatted as follows:: @@ -101,11 +103,11 @@ def text_content(self) -> str: def _validate_message(self) -> None: if self.ipython and self.contains_media: - raise RuntimeError( + raise ValueError( f"Media tokens in tool calls are not supported. Both are set in message: {self.text_content}" ) if self.ipython and self.role != "assistant": - raise RuntimeError( + raise ValueError( f"Only assistant messages can be tool calls. Found role {self.role} in message: {self.text_content}" ) @@ -135,5 +137,4 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: eot=True, ), ] - sample["messages"] = messages - return sample + return {"messages": messages} diff --git a/torchtune/data/_prompt_templates.py b/torchtune/data/_prompt_templates.py index 45b7dbd99c..8e825a3a8d 100644 --- a/torchtune/data/_prompt_templates.py +++ b/torchtune/data/_prompt_templates.py @@ -39,27 +39,34 @@ def __call__( class CustomPromptTemplate(PromptTemplate): """ - Define a custom prompt template by passing in a dictionary mapping role to + Define a quick custom prompt template by passing in a dictionary mapping role to the prepend and append tags. For example, to achieve the following prompt template:: - System: {content}\n - User: {content}\n - Assistant: {content}\n - Tool: {content}\n + System: {content}\\n + User: {content}\\n + Assistant: {content}\\n + Tool: {content}\\n You can define the template as follows:: template = { - "system": ("System: ", "\n"), - "user": ("User: ", "\n"), - "assistant": ("Assistant: ", "\n"), - "ipython": ("Tool: ", "\n"), + "system": ("System: ", "\\n"), + "user": ("User: ", "\\n"), + "assistant": ("Assistant: ", "\\n"), + "ipython": ("Tool: ", "\\n"), } Once instantiated, you must call the prompt template on a list of messages. It will return the same list of messages updated with the template. + Note: + Any tags prepended/appended to the assistant message will be included + in the loss calculation. Consider using the append tags for user messages for + tags that need to come before the assistant message but should not be included in + loss. For more custom masking and prompt templating, you can create your own + class based off the :class:`~torchtune.data.PromptTemplate` interface. + Args: template (Dict[Role, Tuple[str, str]]): a dictionary mapping role to the prepend and append tags @@ -106,8 +113,7 @@ def __call__(self, messages: List[Message]) -> List[Message]: GrammarErrorCorrectionTemplate = partial( CustomPromptTemplate, template={ - "user": ("Correct this to standard English: ", "\n---\n"), - "assistant": ("Corrected: ", ""), + "user": ("Correct this to standard English: ", "\n---\nCorrected: "), }, ) GrammarErrorCorrectionTemplate.__doc__ = """ @@ -117,12 +123,12 @@ def __call__(self, messages: List[Message]) -> List[Message]: --- Corrected: {assistant_message} +Please see :class:`~torchtune.data.CustomPromptTemplate` for full API arguments. """ SummarizeTemplate = partial( CustomPromptTemplate, template={ - "user": ("Summarize this dialogue:\n", "\n---\n"), - "assistant": ("Summary:\n", ""), + "user": ("Summarize this dialogue:\n", "\n---\nSummary:\n"), }, ) SummarizeTemplate.__doc__ = """ @@ -134,4 +140,5 @@ def __call__(self, messages: List[Message]) -> List[Message]: Summary: {assistant_message} +Please see :class:`~torchtune.data.CustomPromptTemplate` for full API arguments. """ diff --git a/torchtune/datasets/_finetune.py b/torchtune/datasets/_finetune.py index e746938a49..c38ac0a1e7 100644 --- a/torchtune/datasets/_finetune.py +++ b/torchtune/datasets/_finetune.py @@ -21,6 +21,7 @@ class FinetuneDataset(Dataset): All datasets can be considered "conversations" with the model, or AI assistant. Thus, we can format all text content as messages in a conversation assigned to a :class:`~torchtune.data.Role`: + - system messages contain the system prompt - user messages contain the input prompt into the model - assistant messages are the response of the model and what you actually want @@ -76,10 +77,12 @@ class FinetuneDataset(Dataset): transforms. prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used to add structured text around the actual messages. The structured text is used in three scenarios: + - Task-specific templates to gear models for a particular task that it will expect after training - Model-specific templates that are required whenever the model is prompted, such as the [INST] tags in Llama2 and in Mistral - Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate` + The extra text will still get tokenized as normal text, not as special tokens. filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See the Hugging Face `docs `_ for more @@ -102,7 +105,8 @@ def __init__( self._model_transform = model_transform self._data = load_dataset(source, **load_dataset_kwargs) - self._data = self._data.filter(filter_fn) + if filter_fn is not None: + self._data = self._data.filter(filter_fn) def __len__(self): return len(self._data) diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index 2250a40775..72b7eeb836 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -44,12 +44,21 @@ def grammar_dataset( - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. - source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``. + model_transform (Transform): model specific transform to convert a list of messages + output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. Default is ``liweili/c4_200m``. + column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template + to the new column names in the dataset. If None, assume these are identical. prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default is :class:`~torchtune.data.GrammarErrorCorrectionTemplate`. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". Returns: InstructDataset: dataset configured with source data and template diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py index d40fe4d9d8..b6ba76340c 100644 --- a/torchtune/datasets/_samsum.py +++ b/torchtune/datasets/_samsum.py @@ -40,8 +40,17 @@ def samsum_dataset( - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. - source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`. + model_transform (Transform): model specific transform to convert a list of messages + output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. Default is ``Samsung/samsum``. + column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template + to the new column names in the dataset. If None, assume these are identical. + prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default + is :class:`~torchtune.data.GrammarErrorCorrectionTemplate`. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset diff --git a/torchtune/models/llama3/_tokenizer.py b/torchtune/models/llama3/_tokenizer.py index 107cd9a86d..191c944425 100644 --- a/torchtune/models/llama3/_tokenizer.py +++ b/torchtune/models/llama3/_tokenizer.py @@ -49,6 +49,8 @@ class Llama3Tokenizer(ModelTokenizer, Transform): special_tokens (Optional[Dict[str, int]]): mapping containing special text tokens and their registered token IDs. If left as None, this will be set to the canonical Llama3 special tokens. + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. Examples: >>> tokenizer = Llama3Tokenizer("/path/to/tt_model") @@ -226,7 +228,6 @@ def tokenize_messages( Args: messages (List[Message]): The list of messages to tokenize. - max_seq_len (Optional[int]): The maximum sequence length. add_eos (bool): Wether to add the tokenizer's eos_id. Default True. Returns: From df00fe10c1f3ef830baff992f50cb0660645b4ae Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 30 Jul 2024 16:22:48 -0700 Subject: [PATCH 05/16] fix tests --- tests/torchtune/data/test_messages.py | 8 +++---- tests/torchtune/data/test_prompt_templates.py | 22 ++++++++++--------- .../datasets/test_grammar_dataset.py | 2 +- .../torchtune/datasets/test_samsum_dataset.py | 2 +- torchtune/data/_prompt_templates.py | 17 ++++++++------ torchtune/datasets/_chat.py | 2 +- torchtune/datasets/_instruct.py | 2 +- 7 files changed, 30 insertions(+), 25 deletions(-) diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py index 243e4573a8..c3813f23a6 100644 --- a/tests/torchtune/data/test_messages.py +++ b/tests/torchtune/data/test_messages.py @@ -21,7 +21,7 @@ def image_message(self): content=[ {"type": "text", "content": "hello"}, {"type": "image"}, - {"type": "text", "content": "world"}, + {"type": "text", "content": " world"}, ], ) @@ -38,7 +38,7 @@ def test_message_validation(self, text_message): with pytest.raises( ValueError, - match="Media tokens in tool calls are not supported. Both are set in message: hello world", + match="Media tokens in tool calls are not supported. Both are set in message: ", ): message = Message( role="user", @@ -80,7 +80,7 @@ def test_call(self, sample): Message(role="user", content="hello world", masked=True, eot=False), Message(role="assistant", content="hello world", masked=False, eot=True), ] - assert_dialogue_equal(actual, expected) + assert_dialogue_equal(actual["messages"], expected) def test_call_train_on_input(self, sample): transform = ToInputOutputMessages( @@ -92,4 +92,4 @@ def test_call_train_on_input(self, sample): Message(role="user", content="hello world", masked=False, eot=False), Message(role="assistant", content="hello world", masked=False, eot=True), ] - assert_dialogue_equal(actual, expected) + assert_dialogue_equal(actual["messages"], expected) diff --git a/tests/torchtune/data/test_prompt_templates.py b/tests/torchtune/data/test_prompt_templates.py index 363a338eec..6bcd3b8d6f 100644 --- a/tests/torchtune/data/test_prompt_templates.py +++ b/tests/torchtune/data/test_prompt_templates.py @@ -40,22 +40,24 @@ class TestGrammarErrorCorrectionTemplate: Message( role="user", content="Correct this to standard English: Bitcoin is for $7,094 this morning, which CoinDesk says.\n" - "---\n", + "---\n" + "Corrected: ", ), Message( role="assistant", - content="Corrected: Bitcoin goes for $7,094 this morning, according to CoinDesk.", + content="Bitcoin goes for $7,094 this morning, according to CoinDesk.", ), ], [ Message( role="user", content="Correct this to standard English: Much many brands and sellers still in the market.\n" - "---\n", + "---\n" + "Corrected: ", ), Message( role="assistant", - content="Corrected: Many brands and sellers still in the market.", + content="Many brands and sellers still in the market.", ), ], ] @@ -101,12 +103,12 @@ class TestSummarizeTemplate: role="user", content="Summarize this dialogue:\n" "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)\n" - "---\n", + "---\n" + "Summary:\n", ), Message( role="assistant", - content="Summary:\n" - "Amanda baked cookies and will bring Jerry some tomorrow.", + content="Amanda baked cookies and will bring Jerry some tomorrow.", ), ], [ @@ -114,12 +116,12 @@ class TestSummarizeTemplate: role="user", content="Summarize this dialogue:\n" "Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great\n" - "---\n", + "---\n" + "Summary:\n", ), Message( role="assistant", - content="Summary:\n" - "Olivia and Olivier are voting for liberals in this election.", + content="Olivia and Olivier are voting for liberals in this election.", ), ], ] diff --git a/tests/torchtune/datasets/test_grammar_dataset.py b/tests/torchtune/datasets/test_grammar_dataset.py index 10b2b2b425..dd94e530b2 100644 --- a/tests/torchtune/datasets/test_grammar_dataset.py +++ b/tests/torchtune/datasets/test_grammar_dataset.py @@ -121,4 +121,4 @@ def test_label_masking(self, load_dataset, tokenizer): -1, ] # Check that the input is masked - assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 16 + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 17 diff --git a/tests/torchtune/datasets/test_samsum_dataset.py b/tests/torchtune/datasets/test_samsum_dataset.py index 9b18e43455..e31d4bd310 100644 --- a/tests/torchtune/datasets/test_samsum_dataset.py +++ b/tests/torchtune/datasets/test_samsum_dataset.py @@ -132,4 +132,4 @@ def test_label_masking(self, load_dataset, tokenizer): 9, -1, ] - assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 21 + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 22 diff --git a/torchtune/data/_prompt_templates.py b/torchtune/data/_prompt_templates.py index 8e825a3a8d..b07112885e 100644 --- a/torchtune/data/_prompt_templates.py +++ b/torchtune/data/_prompt_templates.py @@ -91,13 +91,16 @@ def __call__(self, messages: List[Message]) -> List[Message]: """ formatted_dialogue = [] for message in messages: - prepend_tag = self.template[message.role][0] - append_tag = self.template[message.role][1] - content = ( - [{"type": "text", "content": prepend_tag}] - + message.content - + [{"type": "text", "content": append_tag}] - ) + if message.role in self.template: + prepend_tag = self.template[message.role][0] + append_tag = self.template[message.role][1] + content = ( + [{"type": "text", "content": prepend_tag}] + + message.content + + [{"type": "text", "content": append_tag}] + ) + else: + content = message.content formatted_dialogue.append( Message( role=message.role, diff --git a/torchtune/datasets/_chat.py b/torchtune/datasets/_chat.py index 9b9bca9b48..36597fc7b2 100644 --- a/torchtune/datasets/_chat.py +++ b/torchtune/datasets/_chat.py @@ -112,8 +112,8 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: def chat_dataset( - *, tokenizer: ModelTokenizer, + *, source: str, conversation_style: str, chat_format: Optional[str] = None, diff --git a/torchtune/datasets/_instruct.py b/torchtune/datasets/_instruct.py index 9918f6a5ea..ba0a66f667 100644 --- a/torchtune/datasets/_instruct.py +++ b/torchtune/datasets/_instruct.py @@ -121,8 +121,8 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: def instruct_dataset( - *, tokenizer: ModelTokenizer, + *, source: str, template: str, column_map: Optional[Dict[str, str]] = None, From 4157dd7b1e527f2961da335b513968838848ceb4 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 30 Jul 2024 21:48:45 -0700 Subject: [PATCH 06/16] change naming --- tests/torchtune/data/test_messages.py | 8 ++++---- torchtune/data/__init__.py | 4 ++-- torchtune/data/_messages.py | 2 +- torchtune/datasets/_grammar.py | 4 ++-- torchtune/datasets/_samsum.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py index c3813f23a6..6890b78c9c 100644 --- a/tests/torchtune/data/test_messages.py +++ b/tests/torchtune/data/test_messages.py @@ -6,7 +6,7 @@ import pytest from tests.test_utils import assert_dialogue_equal -from torchtune.data._messages import Message, ToInputOutputMessages +from torchtune.data._messages import InputOutputToMessages, Message class TestMessage: @@ -63,7 +63,7 @@ def test_text_content(self, text_message, image_message): assert image_message.text_content == "hello world" -class TestToInputOutputMessages: +class TestInputOutputToMessages: @pytest.fixture def sample(self): return { @@ -72,7 +72,7 @@ def sample(self): } def test_call(self, sample): - transform = ToInputOutputMessages( + transform = InputOutputToMessages( column_map={"input": "maybe_input", "output": "maybe_output"} ) actual = transform(sample) @@ -83,7 +83,7 @@ def test_call(self, sample): assert_dialogue_equal(actual["messages"], expected) def test_call_train_on_input(self, sample): - transform = ToInputOutputMessages( + transform = InputOutputToMessages( column_map={"input": "maybe_input", "output": "maybe_output"}, train_on_input=True, ) diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 9c21b3bdad..fafa7d2ae1 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -17,7 +17,7 @@ InstructTemplate, StackExchangedPairedTemplate, ) -from torchtune.data._messages import Message, Role, ToInputOutputMessages +from torchtune.data._messages import InputOutputToMessages, Message, Role from torchtune.data._prompt_templates import ( CustomPromptTemplate, GrammarErrorCorrectionTemplate, @@ -45,5 +45,5 @@ "Role", "CustomPromptTemplate", "PromptTemplate", - "ToInputOutputMessages", + "InputOutputToMessages", ] diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index 85a0cd4816..2c75ece6d2 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -112,7 +112,7 @@ def _validate_message(self) -> None: ) -class ToInputOutputMessages(Transform): +class InputOutputToMessages(Transform): def __init__( self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None ): diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index 72b7eeb836..faf8f2b1e8 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -8,7 +8,7 @@ from typing import Dict, Optional from torch.utils.data import Dataset -from torchtune.data import ToInputOutputMessages +from torchtune.data import InputOutputToMessages from torchtune.data._prompt_templates import ( GrammarErrorCorrectionTemplate, PromptTemplate, @@ -71,7 +71,7 @@ def grammar_dataset( >>> Batch size: 8 """ - message_transform = ToInputOutputMessages( + message_transform = InputOutputToMessages( train_on_input=train_on_input, column_map=column_map ) ds = FinetuneDataset( diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py index b6ba76340c..2d106b2dc6 100644 --- a/torchtune/datasets/_samsum.py +++ b/torchtune/datasets/_samsum.py @@ -7,7 +7,7 @@ from typing import Dict, Optional -from torchtune.data import ToInputOutputMessages +from torchtune.data import InputOutputToMessages from torchtune.data._prompt_templates import PromptTemplate, SummarizeTemplate from torchtune.datasets._finetune import FinetuneDataset from torchtune.datasets._packed import PackedDataset @@ -67,7 +67,7 @@ def samsum_dataset( >>> Batch size: 8 """ column_map = column_map or {"input": "dialogue", "output": "summary"} - message_transform = ToInputOutputMessages( + message_transform = InputOutputToMessages( train_on_input=train_on_input, column_map=column_map ) ds = FinetuneDataset( From ba2e2ecc4f819a59d4ed8ee1ea3faf0747020cfa Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Wed, 31 Jul 2024 11:55:19 -0700 Subject: [PATCH 07/16] fix recipe tests --- tests/recipes/utils.py | 2 +- torchtune/datasets/_chat.py | 2 +- torchtune/datasets/_instruct.py | 2 +- torchtune/datasets/_preference.py | 4 ++-- torchtune/models/gemma/_tokenizer.py | 16 ++++++++++------ torchtune/models/llama2/_tokenizer.py | 16 ++++++++++------ torchtune/models/mistral/_tokenizer.py | 15 +++++++++------ torchtune/models/phi3/_tokenizer.py | 22 +++++++++++++--------- torchtune/models/qwen2/_tokenizer.py | 15 +++++++++------ torchtune/modules/tokenizers/_utils.py | 1 + 10 files changed, 57 insertions(+), 38 deletions(-) diff --git a/tests/recipes/utils.py b/tests/recipes/utils.py index 66297984fb..0fc73a7c53 100644 --- a/tests/recipes/utils.py +++ b/tests/recipes/utils.py @@ -20,7 +20,7 @@ class DummyDataset(Dataset): - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): self._data = torch.LongTensor( [ [0, 2, 4, 2, 5, 6, 7, 8, 9, 1, 2, 4, 3, 3, 5, 6, 8, 2, 1, 1], diff --git a/torchtune/datasets/_chat.py b/torchtune/datasets/_chat.py index 36597fc7b2..323f1dfa12 100644 --- a/torchtune/datasets/_chat.py +++ b/torchtune/datasets/_chat.py @@ -102,7 +102,7 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: messages = self.chat_format.format(messages) validate_messages(messages) tokens, mask = self._tokenizer.tokenize_messages( - messages, max_seq_len=self.max_seq_len + messages, ) # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens)) diff --git a/torchtune/datasets/_instruct.py b/torchtune/datasets/_instruct.py index ba0a66f667..cb0006663a 100644 --- a/torchtune/datasets/_instruct.py +++ b/torchtune/datasets/_instruct.py @@ -110,7 +110,7 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: validate_messages(messages) tokens, mask = self._tokenizer.tokenize_messages( - messages, max_seq_len=self.max_seq_len + messages, ) # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens diff --git a/torchtune/datasets/_preference.py b/torchtune/datasets/_preference.py index 790cf2af22..476b04323a 100644 --- a/torchtune/datasets/_preference.py +++ b/torchtune/datasets/_preference.py @@ -97,14 +97,14 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: # TODO: Trunction differs from original DPO repo # in DPO: first truncate prompts, then responses chosen_input_ids, c_masks = self._tokenizer.tokenize_messages( - chosen_message, self.max_seq_len + chosen_message, ) chosen_labels = list( np.where(c_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids) ) rejected_input_ids, r_masks = self._tokenizer.tokenize_messages( - rejected_message, self.max_seq_len + rejected_message, ) rejected_labels = list( np.where(r_masks, CROSS_ENTROPY_IGNORE_IDX, rejected_input_ids) diff --git a/torchtune/models/gemma/_tokenizer.py b/torchtune/models/gemma/_tokenizer.py index eb7fea4c1d..b2b3cd7dbb 100644 --- a/torchtune/models/gemma/_tokenizer.py +++ b/torchtune/models/gemma/_tokenizer.py @@ -23,6 +23,8 @@ class GemmaTokenizer(ModelTokenizer, Transform): Args: path (str): Path to pretrained tokenizer file. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None Examples: >>> tokenizer = GemmaTokenizer("/path/to/spm_model") @@ -34,6 +36,7 @@ class GemmaTokenizer(ModelTokenizer, Transform): def __init__( self, path: str, + max_seq_len: Optional[int] = None, ): self._spm_model = SentencePieceBaseTokenizer(path) @@ -43,6 +46,8 @@ def __init__( # During generation, stop when eos_id is encountered self.stop_tokens = [self.eos_id] + self.max_seq_len = max_seq_len + @property def eos_id(self): return self._spm_model.eos_id @@ -80,14 +85,15 @@ def decode( return self._spm_model.decode(token_ids) def tokenize_messages( - self, messages: List[Message], max_seq_len: Optional[int] = None + self, + messages: List[Message], ) -> Tuple[List[int], List[bool]]: r"""Tokenize a list of messages one at a time then concatenate them, returning a list of tokens and a list of masks. Example: - >>> tokenizer = GemmaTokenizer(tokenizer_path) + >>> tokenizer = GemmaTokenizer(tokenizer_path, max_seq_len) >>> messages = [ Message(role="system", content="system message\n", masked=True), Message(role="user", content="user prompt\n", masked=True), @@ -95,7 +101,7 @@ def tokenize_messages( ] >>> # tokenize_messages encodes messages separately and concats - >>> tokenizer.tokenize_messages(messages, max_seq_len)[0] + >>> tokenizer.tokenize_messages(messages)[0] [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] @@ -107,8 +113,6 @@ def tokenize_messages( Args: messages (List[Message]): A list of messages, each containing role, content, and masked attributes. - max_seq_len (Optional[int]): A max sequence length to truncate tokens to. - Default: None Returns: Tuple[List[int], List[bool]]: The tokenized messages @@ -118,7 +122,7 @@ def tokenize_messages( messages=messages, bos_id=self.bos_id, eos_id=self.eos_id, - max_seq_len=max_seq_len, + max_seq_len=self.max_seq_len, ) def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: diff --git a/torchtune/models/llama2/_tokenizer.py b/torchtune/models/llama2/_tokenizer.py index d8c822d4c2..71f4ac2aef 100644 --- a/torchtune/models/llama2/_tokenizer.py +++ b/torchtune/models/llama2/_tokenizer.py @@ -31,6 +31,8 @@ class Llama2Tokenizer(ModelTokenizer, Transform): Args: path (str): Path to pretrained SentencePiece tokenizer file. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None Examples: >>> tokenizer = Llama2Tokenizer("/path/to/spm_model") @@ -42,6 +44,7 @@ class Llama2Tokenizer(ModelTokenizer, Transform): def __init__( self, path: str, + max_seq_len: Optional[int] = None, ): self._spm_model = SentencePieceBaseTokenizer(path) @@ -51,6 +54,8 @@ def __init__( # During generation, stop when eos_id is encountered self.stop_tokens = [self.eos_id] + self.max_seq_len = max_seq_len + @property def eos_id(self): return self._spm_model.eos_id @@ -88,7 +93,8 @@ def decode( return self._spm_model.decode(token_ids) def tokenize_messages( - self, messages: List[Message], max_seq_len: Optional[int] = None + self, + messages: List[Message], ) -> Tuple[List[int], List[bool]]: r"""Tokenize a list of messages one at a time then concatenate them, returning a list of tokens and a list of masks. @@ -100,7 +106,7 @@ def tokenize_messages( beginning off the tokenized s2. Example: - >>> tokenizer = Llama2Tokenizer(tokenizer_path) + >>> tokenizer = Llama2Tokenizer(tokenizer_path, max_seq_len) >>> messages = [ Message(role="system", content="system message\n", masked=True), Message(role="user", content="user prompt\n", masked=True), @@ -108,7 +114,7 @@ def tokenize_messages( ] >>> # tokenize_messages encodes messages separately and concats - >>> tokenizer.tokenize_messages(messages, max_seq_len)[0] + >>> tokenizer.tokenize_messages(messages)[0] [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] >>> # Same result as encoding the full string in one go @@ -119,8 +125,6 @@ def tokenize_messages( Args: messages (List[Message]): A list of messages, each containing role, content, and masked attributes. - max_seq_len (Optional[int]): A max sequence length to truncate tokens to. - Default: None Returns: Tuple[List[int], List[bool]]: The tokenized messages @@ -130,7 +134,7 @@ def tokenize_messages( messages=messages, bos_id=self.bos_id, eos_id=self.eos_id, - max_seq_len=max_seq_len, + max_seq_len=self.max_seq_len, ) def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: diff --git a/torchtune/models/mistral/_tokenizer.py b/torchtune/models/mistral/_tokenizer.py index 4d25a52143..7c63c17ca8 100644 --- a/torchtune/models/mistral/_tokenizer.py +++ b/torchtune/models/mistral/_tokenizer.py @@ -23,6 +23,8 @@ class MistralTokenizer(ModelTokenizer, Transform): Args: path (str): Path to pretrained tokenizer file. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None Examples: >>> tokenizer = MistralTokenizer("/path/to/spm_model") @@ -34,6 +36,7 @@ class MistralTokenizer(ModelTokenizer, Transform): def __init__( self, path: str, + max_seq_len: Optional[int] = None, ): self._spm_model = SentencePieceBaseTokenizer(path) @@ -43,6 +46,8 @@ def __init__( # During generation, stop when eos_id is encountered self.stop_tokens = [self.eos_id] + self.max_seq_len = max_seq_len + @property def eos_id(self): return self._spm_model.eos_id @@ -103,7 +108,7 @@ def decode( return self._spm_model.decode(token_ids) def tokenize_messages( - self, messages: List[Message], max_seq_len: Optional[int] = None + self, messages: List[Message] ) -> Tuple[List[int], List[bool]]: r"""Tokenize a list of messages one at a time then concatenate them, returning a list of tokens and a list of masks. @@ -115,7 +120,7 @@ def tokenize_messages( beginning off the tokenized s2. Example: - >>> tokenizer = MistralTokenizer(tokenizer_path) + >>> tokenizer = MistralTokenizer(tokenizer_path, max_seq_len) >>> messages = [ Message(role="system", content="system message\n", masked=True), Message(role="user", content="user prompt\n", masked=True), @@ -123,7 +128,7 @@ def tokenize_messages( ] >>> # tokenize_messages encodes messages separately and concats - >>> tokenizer.tokenize_messages(messages, max_seq_len)[0] + >>> tokenizer.tokenize_messages(messages)[0] [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] @@ -135,8 +140,6 @@ def tokenize_messages( Args: messages (List[Message]): A list of messages, each containing role, content, and masked attributes. - max_seq_len (Optional[int]): A max sequence length to truncate tokens to. - Default: None Returns: Tuple[List[int], List[bool]]: The tokenized messages @@ -146,7 +149,7 @@ def tokenize_messages( messages=messages, bos_id=self.bos_id, eos_id=self.eos_id, - max_seq_len=max_seq_len, + max_seq_len=self.max_seq_len, ) def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: diff --git a/torchtune/models/phi3/_tokenizer.py b/torchtune/models/phi3/_tokenizer.py index 958b5ac64e..221db90ed0 100644 --- a/torchtune/models/phi3/_tokenizer.py +++ b/torchtune/models/phi3/_tokenizer.py @@ -34,6 +34,8 @@ class Phi3MiniTokenizer(ModelTokenizer, Transform): special_tokens (Optional[Dict[str, int]]): mapping containing special text tokens and their registered token IDs. If left as None, this will be set to the canonical Phi3 special tokens. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None Examples: >>> tokenizer = Phi3MiniTokenizer("/path/to/spm_model") @@ -46,6 +48,7 @@ def __init__( self, path: str, special_tokens: Optional[Dict[str, int]] = None, + max_seq_len: Optional[int] = None, ): self._spm_model = SentencePieceBaseTokenizer(path) @@ -60,6 +63,8 @@ def __init__( # During generation, stop when eos_id is encountered self.stop_tokens = [self.eos_id] + self.max_seq_len = max_seq_len + @property def vocab_size(self): return self._spm_model.vocab_size @@ -104,7 +109,6 @@ def decode(self, ids: List[int]) -> str: def tokenize_messages( self, messages: List[Message], - max_seq_len: Optional[int] = None, *, add_eos: bool = False, ignore_system_prompts: bool = True, @@ -113,7 +117,7 @@ def tokenize_messages( returning a list of tokens and a list of masks. Example: - >>> tokenizer = Phi3MiniTokenizer(tokenizer_path) + >>> tokenizer = Phi3MiniTokenizer(tokenizer_path, max_seq_len) >>> messages = [ Message(role="system", content="system message\n", masked=True), Message(role="user", content="user prompt\n", masked=True), @@ -121,7 +125,7 @@ def tokenize_messages( ] >>> # tokenize_messages encodes messages separately and concats - >>> tokenizer.tokenize_messages(messages, max_seq_len)[0] + >>> tokenizer.tokenize_messages(messages)[0] [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2] >>> # Same result as encoding the full string in one go @@ -132,8 +136,6 @@ def tokenize_messages( Args: messages (List[Message]): A list of messages, each containing role, content, and masked attributes. - max_seq_len (Optional[int]): A max sequence length to truncate tokens to. - Default: None add_eos (bool): Whether to append EOS after assistant message, default to False ignore_system_prompts (bool): Whether to ignore system prompts. This matches the HF implementation, default to True. @@ -211,13 +213,15 @@ def tokenize_messages( start_of_turn = False # Break out early if we reach max_seq_len - if max_seq_len and len(tokenized_messages) >= max_seq_len: + if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: break # Finally, truncate if necessary - if max_seq_len and len(tokenized_messages) >= max_seq_len: - tokenized_messages = truncate(tokenized_messages, max_seq_len, self.eos_id) - mask = truncate(mask, max_seq_len, message.masked) + if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: + tokenized_messages = truncate( + tokenized_messages, self.max_seq_len, self.eos_id + ) + mask = truncate(mask, self.max_seq_len, message.masked) return tokenized_messages, mask diff --git a/torchtune/models/qwen2/_tokenizer.py b/torchtune/models/qwen2/_tokenizer.py index d5a8f3fa5d..8020564a0c 100644 --- a/torchtune/models/qwen2/_tokenizer.py +++ b/torchtune/models/qwen2/_tokenizer.py @@ -84,6 +84,8 @@ class Qwen2Tokenizer(ModelTokenizer): merges.txt contains all BPE merge operations, and this file is required to split a single word into byte-level BPE tokens. special_tokens (Optional[Dict[str, int]]): Special tokens to add to the tokenizer. Default is None. + max_seq_len (Optional[int]): A max sequence length to truncate tokens to. + Default: None errors (str): Paradigm to follow when decoding bytes to UTF-8. Defaults to "replace". See [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. unk_token (Optional[str]): The unknown token. A token that is not in the vocabulary cannot be converted @@ -120,6 +122,7 @@ def __init__( path: str, merges_file: str, special_tokens: Optional[Dict[str, int]] = None, + max_seq_len: Optional[int] = None, *, errors: str = "replace", unk_token: Optional[str] = ENDOFTEXT, @@ -166,6 +169,8 @@ def __init__( r"(\L)", options=self.special_tokens.keys() ) + self.max_seq_len = max_seq_len + def _bpe_without_cache(self, token): word = tuple(token) pairs = get_pairs(word) @@ -320,7 +325,6 @@ def decode( def tokenize_messages( self, messages: List[Message], - max_seq_len: Optional[int] = None, apply_chat_template: bool = True, ) -> Tuple[List[int], List[bool]]: """ @@ -329,7 +333,6 @@ def tokenize_messages( Args: messages (List[Message]): The message list to tokenize. - max_seq_len (Optional[int]): The maximum sequence length. apply_chat_template (bool): Whether to apply Qwen2 chat template. Returns: @@ -354,7 +357,7 @@ def tokenize_messages( tokens.extend(tokenized_message) mask.extend([message.masked] * len(tokenized_message)) - if max_seq_len and len(tokens) >= max_seq_len: + if self.max_seq_len and len(tokens) >= self.max_seq_len: break if not is_generation: @@ -363,7 +366,7 @@ def tokenize_messages( if messages: last_message_masked = messages[-1].masked mask = mask + [last_message_masked] - if max_seq_len: - tokens = truncate(tokens, max_seq_len, self.eos_id) - mask = truncate(mask, max_seq_len, True) + if self.max_seq_len: + tokens = truncate(tokens, self.max_seq_len, self.eos_id) + mask = truncate(mask, self.max_seq_len, True) return tokens, mask diff --git a/torchtune/modules/tokenizers/_utils.py b/torchtune/modules/tokenizers/_utils.py index f34af05165..a658008b1e 100644 --- a/torchtune/modules/tokenizers/_utils.py +++ b/torchtune/modules/tokenizers/_utils.py @@ -50,6 +50,7 @@ class ModelTokenizer(Protocol): """ special_tokens: Dict[str, int] + max_seq_len: Optional[int] def tokenize_messages( self, messages: List[Message], **kwargs: Dict[str, Any] From a531e481123c2729e37d94ed5d3580f466c58465 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Wed, 31 Jul 2024 12:53:17 -0700 Subject: [PATCH 08/16] remove content.strip() in tokenizer --- tests/torchtune/models/llama3/test_llama3_tokenizer.py | 1 + torchtune/models/llama3/_tokenizer.py | 2 +- torchtune/models/phi3/_tokenizer.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/torchtune/models/llama3/test_llama3_tokenizer.py b/tests/torchtune/models/llama3/test_llama3_tokenizer.py index 79569e0e29..2a0f26302d 100644 --- a/tests/torchtune/models/llama3/test_llama3_tokenizer.py +++ b/tests/torchtune/models/llama3/test_llama3_tokenizer.py @@ -194,6 +194,7 @@ def user_interleaved_image_text_message(self, user_text_a, user_text_b): 376, 110, 46, + 32, 128011, 1542, 720, diff --git a/torchtune/models/llama3/_tokenizer.py b/torchtune/models/llama3/_tokenizer.py index 191c944425..dea6225df3 100644 --- a/torchtune/models/llama3/_tokenizer.py +++ b/torchtune/models/llama3/_tokenizer.py @@ -178,7 +178,7 @@ def _tokenize_body(self, message: Message) -> List[int]: for item in message.content: if item["type"] == "text": tokenized_body += self.encode( - item["content"].strip(), add_bos=False, add_eos=False + item["content"], add_bos=False, add_eos=False ) elif item["type"] == "image": tokenized_body += [self.image_id] diff --git a/torchtune/models/phi3/_tokenizer.py b/torchtune/models/phi3/_tokenizer.py index 221db90ed0..8124ca1453 100644 --- a/torchtune/models/phi3/_tokenizer.py +++ b/torchtune/models/phi3/_tokenizer.py @@ -189,7 +189,7 @@ def tokenize_messages( for item in message.content: if item["type"] == "text": tokens = tokens + self.encode( - item["content"].rstrip(" "), + item["content"], add_bos=False, add_eos=False, trim_leading_whitespace=True, # Always trim whitespace (just to match HF tokenizer implementation) From c57a26a8812eacd6ec3a5a978152ab4bf1551125 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Wed, 31 Jul 2024 14:44:55 -0700 Subject: [PATCH 09/16] fix test --- tests/test_utils.py | 16 +++++++++++----- .../torchtune/datasets/test_slimorca_dataset.py | 12 ++++++------ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9182fff530..bd0dbf4988 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -42,6 +42,9 @@ class DummyTokenizer(ModelTokenizer, Transform): + def __init__(self, max_seq_len: Optional[int] = None): + self.max_seq_len = max_seq_len + def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]: words = text.split() tokens = [len(word) for word in words] @@ -52,7 +55,8 @@ def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]: return tokens def tokenize_messages( - self, messages: List[Message], max_seq_len: Optional[int] = None + self, + messages: List[Message], ) -> Tuple[List[int], List[bool]]: """ A simplified version of Llama2Tokenizer's ``tokenize_messages`` for testing purposes. @@ -95,13 +99,15 @@ def tokenize_messages( start_of_turn = False # Break out early if we reach max_seq_len - if max_seq_len and len(tokenized_messages) >= max_seq_len: + if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: break # Finally, truncate if necessary - if max_seq_len: - tokenized_messages = truncate(tokenized_messages, max_seq_len, self.eos_id) - mask = truncate(mask, max_seq_len, message.masked) + if self.max_seq_len: + tokenized_messages = truncate( + tokenized_messages, self.max_seq_len, self.eos_id + ) + mask = truncate(mask, self.max_seq_len, message.masked) return tokenized_messages, mask diff --git a/tests/torchtune/datasets/test_slimorca_dataset.py b/tests/torchtune/datasets/test_slimorca_dataset.py index 508f65ce11..0b8de933e4 100644 --- a/tests/torchtune/datasets/test_slimorca_dataset.py +++ b/tests/torchtune/datasets/test_slimorca_dataset.py @@ -14,19 +14,19 @@ class TestSlimOrcaDataset: - @pytest.fixture - def tokenizer(self): - return DummyTokenizer() + def tokenizer(self, max_seq_len=None): + return DummyTokenizer(max_seq_len=max_seq_len) @patch("torchtune.datasets._chat.load_dataset") - def test_value_error(self, load_dataset, tokenizer): + def test_value_error(self, load_dataset): load_dataset.return_value = [] with pytest.raises(ValueError): - slimorca_dataset(tokenizer=tokenizer, max_seq_len=3) + slimorca_dataset(tokenizer=self.tokenizer, max_seq_len=3) @patch("torchtune.datasets._chat.load_dataset") @pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096]) - def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len): + def test_dataset_get_item(self, load_dataset, max_seq_len): + tokenizer = self.tokenizer(max_seq_len=max_seq_len) # Sample data from slimorca dataset load_dataset.return_value = Dataset.from_list( [ From 6dd1794afea364ae907c22ed77c9777a7e4ee718 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Thu, 1 Aug 2024 13:57:06 -0700 Subject: [PATCH 10/16] readd strip --- torchtune/models/llama3/_tokenizer.py | 2 +- torchtune/models/phi3/_tokenizer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/models/llama3/_tokenizer.py b/torchtune/models/llama3/_tokenizer.py index dea6225df3..191c944425 100644 --- a/torchtune/models/llama3/_tokenizer.py +++ b/torchtune/models/llama3/_tokenizer.py @@ -178,7 +178,7 @@ def _tokenize_body(self, message: Message) -> List[int]: for item in message.content: if item["type"] == "text": tokenized_body += self.encode( - item["content"], add_bos=False, add_eos=False + item["content"].strip(), add_bos=False, add_eos=False ) elif item["type"] == "image": tokenized_body += [self.image_id] diff --git a/torchtune/models/phi3/_tokenizer.py b/torchtune/models/phi3/_tokenizer.py index 8124ca1453..221db90ed0 100644 --- a/torchtune/models/phi3/_tokenizer.py +++ b/torchtune/models/phi3/_tokenizer.py @@ -189,7 +189,7 @@ def tokenize_messages( for item in message.content: if item["type"] == "text": tokens = tokens + self.encode( - item["content"], + item["content"].rstrip(" "), add_bos=False, add_eos=False, trim_leading_whitespace=True, # Always trim whitespace (just to match HF tokenizer implementation) From fed7dae256a17e6f70833fb3ce0d4c0e8e904562 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Thu, 1 Aug 2024 14:20:32 -0700 Subject: [PATCH 11/16] fix test --- tests/torchtune/models/llama3/test_llama3_tokenizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/torchtune/models/llama3/test_llama3_tokenizer.py b/tests/torchtune/models/llama3/test_llama3_tokenizer.py index 2a0f26302d..f1d2ad9317 100644 --- a/tests/torchtune/models/llama3/test_llama3_tokenizer.py +++ b/tests/torchtune/models/llama3/test_llama3_tokenizer.py @@ -194,7 +194,6 @@ def user_interleaved_image_text_message(self, user_text_a, user_text_b): 376, 110, 46, - 32, 128011, 1542, 720, @@ -401,7 +400,7 @@ def test_validate_special_tokens(self): with pytest.raises( ValueError, match="<|begin_of_text|> missing from special_tokens" ): - tokenizer = Llama3Tokenizer( + _ = Llama3Tokenizer( path=str(ASSETS / "tiktoken_small.model"), # Same as LLAMA3_SPECIAL_TOKENS but one missing special_tokens={ From 9822cdd8fc8fbcf90abd700f0b39ff20cba15efa Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Fri, 2 Aug 2024 16:18:02 -0700 Subject: [PATCH 12/16] test torchvision instal; --- .github/workflows/build_linux_wheels.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build_linux_wheels.yaml b/.github/workflows/build_linux_wheels.yaml index e45c95859d..46266bd043 100644 --- a/.github/workflows/build_linux_wheels.yaml +++ b/.github/workflows/build_linux_wheels.yaml @@ -42,3 +42,4 @@ jobs: pre-script: .github/scripts/pre_build_script.sh trigger-event: ${{ github.event_name }} build-platform: 'python-build-package' + pip-install-torch-extra-args: torchvision From a67ce9b9a3291029995017f75a79522b5b0db7a4 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 5 Aug 2024 10:01:22 -0700 Subject: [PATCH 13/16] buff docstrings --- torchtune/data/_messages.py | 14 +++++++++++ torchtune/data/_prompt_templates.py | 19 ++++++++++----- torchtune/datasets/_finetune.py | 38 ++++++++++++++++++----------- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index 2c75ece6d2..b42df602d5 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -113,6 +113,20 @@ def _validate_message(self) -> None: class InputOutputToMessages(Transform): + """ + Message transform class that converts a sample with "input" and "output" fields, + (or equivalent fields specified in column_map) to user and assistant messages, + respectively. This is useful for datasets that have two columns, one containing + the user prompt and the other containing the model response. + + Args: + train_on_input (bool): Whether the model is trained on the user prompt or not. + Default is False. + column_map (Optional[Dict[str, str]]): a mapping to change the expected "input" + and "output" column names to the actual column names in the dataset. Default is None, + keeping the default "input" and "output" column names. + """ + def __init__( self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None ): diff --git a/torchtune/data/_prompt_templates.py b/torchtune/data/_prompt_templates.py index b07112885e..24375e9bdd 100644 --- a/torchtune/data/_prompt_templates.py +++ b/torchtune/data/_prompt_templates.py @@ -39,7 +39,7 @@ def __call__( class CustomPromptTemplate(PromptTemplate): """ - Define a quick custom prompt template by passing in a dictionary mapping role to + Quickly define a custom prompt template by passing in a dictionary mapping role to the prepend and append tags. For example, to achieve the following prompt template:: @@ -48,7 +48,12 @@ class CustomPromptTemplate(PromptTemplate): Assistant: {content}\\n Tool: {content}\\n - You can define the template as follows:: + You need to pass in a tuple for each role, where ``PREPEND_TAG`` is the string + added before the text content and ``APPEND_TAG`` is the string added after:: + + template = {role: (PREPEND_TAG, APPEND_TAG)} + + Thus, the template would be defined as follows:: template = { "system": ("System: ", "\\n"), @@ -62,10 +67,12 @@ class CustomPromptTemplate(PromptTemplate): Note: Any tags prepended/appended to the assistant message will be included - in the loss calculation. Consider using the append tags for user messages for - tags that need to come before the assistant message but should not be included in - loss. For more custom masking and prompt templating, you can create your own - class based off the :class:`~torchtune.data.PromptTemplate` interface. + in the loss calculation. All other prepend/append tags for other roles + (system, user, ipython) are, in most cases, not included in loss. Consider using + the append tags for user messages for tags that need to come before the + assistant message but should not be included in loss. For more custom masking + and prompt templating, you can create your own class based off the + :class:`~torchtune.data.PromptTemplate` interface. Args: template (Dict[Role, Tuple[str, str]]): a dictionary mapping role to the diff --git a/torchtune/datasets/_finetune.py b/torchtune/datasets/_finetune.py index c38ac0a1e7..bed0cff890 100644 --- a/torchtune/datasets/_finetune.py +++ b/torchtune/datasets/_finetune.py @@ -16,10 +16,21 @@ class FinetuneDataset(Dataset): """ - Dataset class for creating instruct, chat, tool, or multimodal datasets for fine-tuning. - - All datasets can be considered "conversations" with the model, or AI assistant. - Thus, we can format all text content as messages in a conversation assigned to + Primary class for creating any dataset for supervised fine-tuning either from + Hugging Face Hub, local files, or remote files. This class supports instruct, + chat, tool, or multimodal data for fine-tuning. At a high level, this class + will load the data from source and apply the following pre-processing steps + when a sample is retrieved: + + 1. Dataset-specific transform. This is typically unique to each dataset and extracts + the necessary columns into torchtune's :class:`~torchtune.data.Message` format, + a standardized API for all model tokenizers. + 2. If specified, apply a prompt template for the task you are fine-tuning for. + 3. Model-specific transform or tokenization + + All datasets are formatted into :class:`~torchtune.data.Message`s because for + fine-tuning, datasets can be considered as "conversations" with the model, + or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to a :class:`~torchtune.data.Role`: - system messages contain the system prompt @@ -30,12 +41,14 @@ class FinetuneDataset(Dataset): Chat datasets are multiple rounds of user-assistant messages. Instruct datasets are typically a single round involving a specific instruction and the model's response. + Tool datasets are a type of chat dataset that includes ipython messages. Multimodal + datasets are a type of chat dataset that incorporates media into the user messages. The :class:`~torchtune.data.Message` forms the core data unit that all tokenizer APIs expect. The key component of this class that ensures any dataset is transformed - into thie format is the ``message_transform``. This is a callable class that takes - in a sample dictionary - typically a single row from a Hugging Face dataset or a single - json - that processes the sample in any configurable way to output a list of messages:: + into this format is the ``message_transform``. This is a callable class that takes + in a sample dictionary - typically a single row from the source dataset - that + processes the sample in any configurable way to output a list of messages:: [ Message( @@ -48,11 +61,11 @@ class FinetuneDataset(Dataset): For any custom dataset, use the ``message_transform`` to contain all pre-processing to return the list of messages. - Any model specific pre-processing that needs to happen can be configured with the ``model_transform`` + Any model-specific pre-processing that needs to happen can be configured with the ``model_transform`` parameter. This is another callable class that contains any custom logic tied to the - model you are fine-tuning. For example, text + image multimodal datasets requires processing - the images in a way specific to the vision encoder being used by the model and is agnostic - to the specific dataset. + model you are fine-tuning and will carry over to inference. For example, text + image + multimodal datasets requires processing the images in a way specific to the vision + encoder being used by the model and is agnostic to the specific dataset. Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`s can be treated as a ``model_transform`` since it uses the model-specific tokenizer to @@ -60,9 +73,6 @@ class FinetuneDataset(Dataset): used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer` into ``model_transform``. - The general pipeline is then: raw sample -> optional filter -> apply dataset-specific message transform -> apply - optional prompt template -> apply model-specific transform -> tokens used for training - Args: source (str): path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. "json", "csv", "text") and pass From 07cc0d11ce712e6632cd616fa42587019d10fc74 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 5 Aug 2024 10:52:05 -0700 Subject: [PATCH 14/16] rename prompt template --- docs/source/api_ref_data.rst | 2 +- docs/source/api_ref_datasets.rst | 2 +- tests/test_utils.py | 4 ++-- tests/torchtune/datasets/test_finetune_dataset.py | 6 +++--- torchtune/data/__init__.py | 4 ++-- torchtune/data/_prompt_templates.py | 12 ++++++------ torchtune/datasets/__init__.py | 4 ++-- torchtune/datasets/_finetune.py | 4 ++-- torchtune/datasets/_grammar.py | 4 ++-- torchtune/datasets/_samsum.py | 6 +++--- 10 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/source/api_ref_data.rst b/docs/source/api_ref_data.rst index a1b7584a86..adac6289ca 100644 --- a/docs/source/api_ref_data.rst +++ b/docs/source/api_ref_data.rst @@ -24,7 +24,7 @@ and models. SummarizeTemplate StackExchangedPairedTemplate PromptTemplate - CustomPromptTemplate + PromptTemplateInterface ChatFormat ChatMLFormat diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index 6aea567fae..e682aa6e02 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -56,4 +56,4 @@ Class representations for the above dataset builders. ConcatDataset PackedDataset PreferenceDataset - FinetuneDataset + SFTDataset diff --git a/tests/test_utils.py b/tests/test_utils.py index bd0dbf4988..cfbc6dece7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,7 +19,7 @@ import torch from torch import nn -from torchtune.data import ChatFormat, CustomPromptTemplate, Message, truncate +from torchtune.data import ChatFormat, Message, PromptTemplate, truncate from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform @@ -158,7 +158,7 @@ def format( DummyPromptTemplate = partial( - CustomPromptTemplate, + PromptTemplate, template={ "system": ("System:\n", "\n"), "user": ("User:\n", "\n"), diff --git a/tests/torchtune/datasets/test_finetune_dataset.py b/tests/torchtune/datasets/test_finetune_dataset.py index 2c8dccb2b2..9fd5b20245 100644 --- a/tests/torchtune/datasets/test_finetune_dataset.py +++ b/tests/torchtune/datasets/test_finetune_dataset.py @@ -11,7 +11,7 @@ from tests.test_utils import DummyPromptTemplate, DummyTokenizer from torchtune.data import Message from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX -from torchtune.datasets._finetune import FinetuneDataset +from torchtune.datasets._finetune import SFTDataset from torchtune.modules.transforms import Transform @@ -22,7 +22,7 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: return {"messages": messages} -class TestFinetuneDataset: +class TestSFTDataset: @pytest.fixture def dialogue(self): return [ @@ -97,7 +97,7 @@ def test_get_item(self, mock_load_dataset, dialogue): + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + [10, 1, 6, -1] ] - ds = FinetuneDataset( + ds = SFTDataset( source="iam/agoofy/goober", message_transform=ToDummyMessages(), model_transform=DummyTokenizer(), diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index fafa7d2ae1..51380d4c0d 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -19,9 +19,9 @@ ) from torchtune.data._messages import InputOutputToMessages, Message, Role from torchtune.data._prompt_templates import ( - CustomPromptTemplate, GrammarErrorCorrectionTemplate, PromptTemplate, + PromptTemplateInterface, SummarizeTemplate, ) from torchtune.data._utils import truncate, validate_messages @@ -43,7 +43,7 @@ "validate_messages", "StackExchangedPairedTemplate", "Role", - "CustomPromptTemplate", + "PromptTemplateInterface", "PromptTemplate", "InputOutputToMessages", ] diff --git a/torchtune/data/_prompt_templates.py b/torchtune/data/_prompt_templates.py index 24375e9bdd..47a969bf52 100644 --- a/torchtune/data/_prompt_templates.py +++ b/torchtune/data/_prompt_templates.py @@ -9,7 +9,7 @@ from torchtune.data import Message, Role -class PromptTemplate(Protocol): +class PromptTemplateInterface(Protocol): """ Interface for prompt templates. Each prompt template can include structured text for system, user, and assistant roles that are prepended or appended to @@ -37,7 +37,7 @@ def __call__( pass -class CustomPromptTemplate(PromptTemplate): +class PromptTemplate(PromptTemplateInterface): """ Quickly define a custom prompt template by passing in a dictionary mapping role to the prepend and append tags. For example, to achieve the following prompt @@ -121,7 +121,7 @@ def __call__(self, messages: List[Message]) -> List[Message]: GrammarErrorCorrectionTemplate = partial( - CustomPromptTemplate, + PromptTemplate, template={ "user": ("Correct this to standard English: ", "\n---\nCorrected: "), }, @@ -133,10 +133,10 @@ def __call__(self, messages: List[Message]) -> List[Message]: --- Corrected: {assistant_message} -Please see :class:`~torchtune.data.CustomPromptTemplate` for full API arguments. +Please see :class:`~torchtune.data.PromptTemplate` for full API arguments. """ SummarizeTemplate = partial( - CustomPromptTemplate, + PromptTemplate, template={ "user": ("Summarize this dialogue:\n", "\n---\nSummary:\n"), }, @@ -150,5 +150,5 @@ def __call__(self, messages: List[Message]) -> List[Message]: Summary: {assistant_message} -Please see :class:`~torchtune.data.CustomPromptTemplate` for full API arguments. +Please see :class:`~torchtune.data.PromptTemplate` for full API arguments. """ diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index f1305bb719..3094689862 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -8,7 +8,7 @@ from torchtune.datasets._chat import chat_dataset, ChatDataset from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset from torchtune.datasets._concat import ConcatDataset -from torchtune.datasets._finetune import FinetuneDataset +from torchtune.datasets._finetune import SFTDataset from torchtune.datasets._grammar import grammar_dataset from torchtune.datasets._instruct import instruct_dataset, InstructDataset from torchtune.datasets._packed import PackedDataset @@ -40,5 +40,5 @@ "ConcatDataset", "wikitext_dataset", "PreferenceDataset", - "FinetuneDataset", + "SFTDataset", ] diff --git a/torchtune/datasets/_finetune.py b/torchtune/datasets/_finetune.py index bed0cff890..07aa7a0fae 100644 --- a/torchtune/datasets/_finetune.py +++ b/torchtune/datasets/_finetune.py @@ -14,7 +14,7 @@ from torchtune.modules.transforms import Transform -class FinetuneDataset(Dataset): +class SFTDataset(Dataset): """ Primary class for creating any dataset for supervised fine-tuning either from Hugging Face Hub, local files, or remote files. This class supports instruct, @@ -84,7 +84,7 @@ class FinetuneDataset(Dataset): of messages are stored in the ``"messages"`` key. model_transform (Transform): callable that applies model-specific pre-processing to the sample after the list of messages is created from ``message_transform``. This includes tokenization and any modality-specific - transforms. + transforms. It is expected to return at minimum ``"tokens"`` and ``"mask"`` keys. prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used to add structured text around the actual messages. The structured text is used in three scenarios: diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index faf8f2b1e8..f4b5dc38e0 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -13,7 +13,7 @@ GrammarErrorCorrectionTemplate, PromptTemplate, ) -from torchtune.datasets._finetune import FinetuneDataset +from torchtune.datasets._finetune import SFTDataset from torchtune.datasets._packed import PackedDataset from torchtune.modules.transforms import Transform @@ -74,7 +74,7 @@ def grammar_dataset( message_transform = InputOutputToMessages( train_on_input=train_on_input, column_map=column_map ) - ds = FinetuneDataset( + ds = SFTDataset( source=source, message_transform=message_transform, model_transform=model_transform, diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py index 2d106b2dc6..187198b352 100644 --- a/torchtune/datasets/_samsum.py +++ b/torchtune/datasets/_samsum.py @@ -9,7 +9,7 @@ from torchtune.data import InputOutputToMessages from torchtune.data._prompt_templates import PromptTemplate, SummarizeTemplate -from torchtune.datasets._finetune import FinetuneDataset +from torchtune.datasets._finetune import SFTDataset from torchtune.datasets._packed import PackedDataset from torchtune.modules.transforms import Transform @@ -23,7 +23,7 @@ def samsum_dataset( train_on_input: bool = False, packed: bool = False, split: str = "train", -) -> FinetuneDataset: +) -> SFTDataset: """ Support for summarization datasets and their variants from Hugging Face Datasets. An example is the `SAMsum dataset `_. @@ -70,7 +70,7 @@ def samsum_dataset( message_transform = InputOutputToMessages( train_on_input=train_on_input, column_map=column_map ) - ds = FinetuneDataset( + ds = SFTDataset( source=source, message_transform=message_transform, model_transform=model_transform, From b2a7139930e3e6f453d32ed985caf296d8e154b3 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 5 Aug 2024 15:45:13 -0700 Subject: [PATCH 15/16] fix docstrings --- docs/source/api_ref_modules.rst | 3 +++ torchtune/datasets/_finetune.py | 29 ++++++++++++++++------------- torchtune/datasets/_grammar.py | 13 ++++++------- torchtune/datasets/_samsum.py | 12 ++++++------ 4 files changed, 31 insertions(+), 26 deletions(-) diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 4cf8dd0140..d290c2a977 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -35,6 +35,8 @@ model specific tokenizers. tokenizers.SentencePieceBaseTokenizer tokenizers.TikTokenBaseTokenizer + tokenizers.ModelTokenizer + tokenizers.BaseTokenizer Tokenizer Utilities ------------------- @@ -93,6 +95,7 @@ Functions used for preprocessing images. :toctree: generated/ :nosignatures: + transforms.Transform transforms.get_canvas_best_fit transforms.resize_with_pad transforms.tile_crop diff --git a/torchtune/datasets/_finetune.py b/torchtune/datasets/_finetune.py index 07aa7a0fae..9e0df23ef2 100644 --- a/torchtune/datasets/_finetune.py +++ b/torchtune/datasets/_finetune.py @@ -28,16 +28,17 @@ class SFTDataset(Dataset): 2. If specified, apply a prompt template for the task you are fine-tuning for. 3. Model-specific transform or tokenization - All datasets are formatted into :class:`~torchtune.data.Message`s because for - fine-tuning, datasets can be considered as "conversations" with the model, + + All datasets are formatted into a list of :class:`~torchtune.data.Message` + because for fine-tuning, datasets can be considered as "conversations" with the model, or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to - a :class:`~torchtune.data.Role`: + a role: - - system messages contain the system prompt - - user messages contain the input prompt into the model - - assistant messages are the response of the model and what you actually want + - ``"system"`` messages contain the system prompt + - ``"user"`` messages contain the input prompt into the model + - ``"assistant"`` messages are the response of the model and what you actually want to train for and compute loss directly against - - ipython messages are the return from a tool call + - ``"ipython"`` messages are the return from a tool call Chat datasets are multiple rounds of user-assistant messages. Instruct datasets are typically a single round involving a specific instruction and the model's response. @@ -67,7 +68,7 @@ class SFTDataset(Dataset): multimodal datasets requires processing the images in a way specific to the vision encoder being used by the model and is agnostic to the specific dataset. - Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`s + Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer` can be treated as a ``model_transform`` since it uses the model-specific tokenizer to transform the list of messages outputted from the ``message_transform`` into tokens used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer` @@ -76,9 +77,9 @@ class SFTDataset(Dataset): Args: source (str): path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. "json", "csv", "text") and pass - in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` - (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) - for more details. + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. message_transform (Transform): callable that keys into the desired fields in the sample and converts text content to a list of :class:`~torchtune.data.Message`. It is expected that the final list of messages are stored in the ``"messages"`` key. @@ -91,13 +92,15 @@ class SFTDataset(Dataset): - Task-specific templates to gear models for a particular task that it will expect after training - Model-specific templates that are required whenever the model is prompted, such as the [INST] tags in Llama2 and in Mistral - - Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate` + - Community standardized templates, such as :class:`~torchtune.data.ChatMLFormat` The extra text will still get tokenized as normal text, not as special tokens. filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See the Hugging Face `docs `_ for more details. - **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging + Face's `docs `_ + for more details. """ def __init__( diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index f4b5dc38e0..2505330aa3 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -5,9 +5,8 @@ # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, Optional, Union -from torch.utils.data import Dataset from torchtune.data import InputOutputToMessages from torchtune.data._prompt_templates import ( GrammarErrorCorrectionTemplate, @@ -27,7 +26,7 @@ def grammar_dataset( train_on_input: bool = False, packed: bool = False, split: str = "train", -) -> Dataset: +) -> Union[SFTDataset, PackedDataset]: """ Support for grammar correction datasets and their variants from Hugging Face Datasets. Here is an `example `_ of a grammar correction dataset. @@ -48,9 +47,9 @@ def grammar_dataset( output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. source (str): path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. "json", "csv", "text") and pass - in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` - (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) - for more details. Default is ``liweili/c4_200m``. + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. Default is ``liweili/c4_200m``. column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template to the new column names in the dataset. If None, assume these are identical. prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default @@ -61,7 +60,7 @@ def grammar_dataset( of a given split, e.g. ``split="train[:10%]"``. Default is "train". Returns: - InstructDataset: dataset configured with source data and template + Union[SFTDataset, PackedDataset]: dataset configured with source data and template Example: diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py index 187198b352..ec0e819108 100644 --- a/torchtune/datasets/_samsum.py +++ b/torchtune/datasets/_samsum.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, Optional, Union from torchtune.data import InputOutputToMessages from torchtune.data._prompt_templates import PromptTemplate, SummarizeTemplate @@ -23,7 +23,7 @@ def samsum_dataset( train_on_input: bool = False, packed: bool = False, split: str = "train", -) -> SFTDataset: +) -> Union[SFTDataset, PackedDataset]: """ Support for summarization datasets and their variants from Hugging Face Datasets. An example is the `SAMsum dataset `_. @@ -44,9 +44,9 @@ def samsum_dataset( output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. source (str): path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. "json", "csv", "text") and pass - in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` - (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) - for more details. Default is ``Samsung/samsum``. + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. Default is ``Samsung/samsum``. column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template to the new column names in the dataset. If None, assume these are identical. prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default @@ -57,7 +57,7 @@ def samsum_dataset( of a given split, e.g. ``split="train[:10%]"``. Default is "train". Returns: - InstructDataset: dataset configured with source data and template + Union[SFTDataset, PackedDataset]: dataset configured with source data and template Example: From 75db6228276f829732001b89360253deb943f40f Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 5 Aug 2024 16:02:08 -0700 Subject: [PATCH 16/16] fix doc build --- torchtune/datasets/_finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/datasets/_finetune.py b/torchtune/datasets/_finetune.py index 9e0df23ef2..3fdf5dcc41 100644 --- a/torchtune/datasets/_finetune.py +++ b/torchtune/datasets/_finetune.py @@ -99,7 +99,7 @@ class SFTDataset(Dataset): the Hugging Face `docs `_ for more details. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging - Face's `docs `_ + Face's `API ref `_ for more details. """