diff --git a/docs/source/api_ref_data.rst b/docs/source/api_ref_data.rst index 31b56ff481..6d137a5372 100644 --- a/docs/source/api_ref_data.rst +++ b/docs/source/api_ref_data.rst @@ -22,7 +22,7 @@ and models. AlpacaInstructTemplate GrammarErrorCorrectionTemplate SummarizeTemplate - StackExchangedPairedTemplate + QuestionAnswerTemplate PromptTemplate PromptTemplateInterface ChatMLTemplate @@ -66,6 +66,7 @@ Converts data from common schema and conversation JSON formats into a list of to InputOutputToMessages ShareGPTToMessages JSONToMessages + ChosenRejectedToMessages Helper functions ---------------- diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index e682aa6e02..4ffca6e06c 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -23,7 +23,7 @@ torchtune supports several widely used datasets to help quickly bootstrap your f grammar_dataset samsum_dataset slimorca_dataset - stack_exchanged_paired_dataset + stack_exchange_paired_dataset cnn_dailymail_articles_dataset wikitext_dataset diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml index 8ae96f809b..4d8464f621 100644 --- a/recipes/configs/llama2/7B_lora_dpo.yaml +++ b/recipes/configs/llama2/7B_lora_dpo.yaml @@ -30,6 +30,7 @@ model: tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer path: /tmp/Llama-2-7b-hf/tokenizer.model + max_seq_len: 1024 checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer @@ -45,8 +46,7 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - _component_: torchtune.datasets.stack_exchanged_paired_dataset - max_seq_len: 1024 + _component_: torchtune.datasets.stack_exchange_paired_dataset seed: null shuffle: True batch_size: 4 diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml index de14613e91..6c3cb45961 100644 --- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -29,6 +29,7 @@ model: tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer path: /tmp/Llama-2-7b-hf/tokenizer.model + max_seq_len: 1024 checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer @@ -44,8 +45,7 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - _component_: torchtune.datasets.stack_exchanged_paired_dataset - max_seq_len: 1024 + _component_: torchtune.datasets.stack_exchange_paired_dataset seed: null shuffle: True batch_size: 4 diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py index f1df60895a..7b48a39fc1 100644 --- a/tests/torchtune/data/test_messages.py +++ b/tests/torchtune/data/test_messages.py @@ -12,6 +12,7 @@ MESSAGE_SAMPLE_TRAIN_ON_INPUT, ) from torchtune.data._messages import ( + ChosenRejectedToMessages, InputOutputToMessages, JSONToMessages, Message, @@ -105,6 +106,62 @@ def test_call_train_on_input(self, sample): assert_dialogue_equal(actual["messages"], expected) +class TestChosenRejectedToMessages: + @pytest.fixture + def sample(self): + return { + "maybe_chosen": [ + {"role": "user", "content": "hello world"}, + {"role": "assistant", "content": "hello world"}, + ], + "maybe_rejected": [ + {"role": "user", "content": "hello world"}, + {"role": "assistant", "content": "bye world"}, + ], + } + + def test_call(self, sample): + transform = ChosenRejectedToMessages( + column_map={ + "chosen": "maybe_chosen", + "rejected": "maybe_rejected", + }, + ) + actual = transform(sample) + expected_chosen = [ + Message(role="user", content="hello world", masked=True, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["chosen"], expected_chosen) + + expected_rejected = [ + Message(role="user", content="hello world", masked=True, eot=False), + Message(role="assistant", content="bye world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["rejected"], expected_rejected) + + def test_call_train_on_input(self, sample): + transform = ChosenRejectedToMessages( + column_map={ + "chosen": "maybe_chosen", + "rejected": "maybe_rejected", + }, + train_on_input=True, + ) + actual = transform(sample) + expected_chosen = [ + Message(role="user", content="hello world", masked=False, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["chosen"], expected_chosen) + + expected_rejected = [ + Message(role="user", content="hello world", masked=False, eot=False), + Message(role="assistant", content="bye world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["rejected"], expected_rejected) + + class TestShareGPTToMessages: samples = { "conversations": [ diff --git a/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py b/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py new file mode 100644 index 0000000000..40a7d02c8c --- /dev/null +++ b/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from collections import Counter +from unittest.mock import patch + +import pytest +from datasets import Dataset + +from tests.test_utils import DummyTokenizer +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX + +from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset + + +class TestHHRLHFHelpfulDataset: + @patch("torchtune.datasets._preference.load_dataset") + @pytest.mark.parametrize("train_on_input", [True, False]) + def test_dataset_get_item(self, mock_load_dataset, train_on_input): + # Truncated sample data from HH RLHF Helpful dataset + mock_load_dataset.return_value = Dataset.from_list( + [ + { + "chosen": [ + { + "content": "helping my granny with her mobile phone issue", + "role": "user", + }, + { + "content": "I see you are chatting with your grandmother " + "about an issue with her mobile phone. How can I help?", + "role": "assistant", + }, + {"content": "her phone is not turning on", "role": "user"}, + { + "content": "Is it on but it doesn’t power up or charge? " + "Or it’s off and does not turn on?", + "role": "assistant", + }, + ], + "rejected": [ + { + "content": "helping my granny with her mobile phone issue", + "role": "user", + }, + { + "content": "I see you are chatting with your grandmother " + "about an issue with her mobile phone. How can I help?", + "role": "assistant", + }, + {"content": "her phone is not turning on", "role": "user"}, + { + "content": "Okay, are you concerned that her phone is broken, " + "or simply that it is not turning on?", + "role": "assistant", + }, + ], + } + ] + ) + ds = hh_rlhf_helpful_dataset( + tokenizer=DummyTokenizer(), + train_on_input=train_on_input, + ) + # Generate the input and labels + sample = ds[0] + + expected_chosen_counts = { + 3: 14, + 2: 11, + 4: 7, + 5: 7, + 7: 4, + 6: 4, + 0: 2, + 1: 2, + -1: 2, + 8: 1, + 11: 1, + } + assert Counter(sample["chosen_input_ids"]) == expected_chosen_counts + if train_on_input: + assert Counter(sample["chosen_labels"]) == expected_chosen_counts + else: + # Check that the input is masked + assert sample["chosen_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 16 + + expected_rejected_counts = { + 3: 14, + 2: 8, + 5: 8, + 4: 6, + 6: 5, + 7: 4, + 0: 2, + 1: 2, + -1: 2, + 8: 1, + 11: 1, + 9: 1, + } + assert Counter(sample["rejected_input_ids"]) == expected_rejected_counts + if train_on_input: + assert Counter(sample["rejected_labels"]) == expected_rejected_counts + else: + # Check that the input is masked + assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 16 diff --git a/tests/torchtune/datasets/test_preference_dataset.py b/tests/torchtune/datasets/test_preference_dataset.py new file mode 100644 index 0000000000..d4f32acc94 --- /dev/null +++ b/tests/torchtune/datasets/test_preference_dataset.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Mapping +from unittest import mock + +import pytest +from tests.test_utils import DummyPromptTemplate, DummyTokenizer +from torchtune.data import Message +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets._preference import PreferenceDataset +from torchtune.modules.transforms import Transform + + +class ToDummyPreferenceMessages(Transform): + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + chosen_messages = [ + Message.from_dict(sample["prompt"][0]), + Message.from_dict(sample["chosen"][0]), + ] + + rejected_messages = [ + Message.from_dict(sample["prompt"][0]), + Message.from_dict(sample["rejected"][0]), + ] + + return {"chosen": chosen_messages, "rejected": rejected_messages} + + +class TestPreferenceDataset: + @pytest.fixture + def dialogue(self): + return [ + { + "prompt": [ + { + "role": "user", + "content": "What is 2+2?", + "masked": True, + }, + ], + "chosen": [ + { + "role": "assistant", + "content": "The answer is 4.", + "masked": False, + }, + ], + "rejected": [ + { + "role": "assistant", + "content": "The answer is 12.", + "masked": False, + }, + ], + }, + ] + + @pytest.fixture + def expected(self): + return { + "prompt": [ + 0, + 5, + 4, + 2, + 4, + ], + "chosen": [ + 10, + 3, + 6, + 2, + 2, + -1, + ], + "rejected": [ + 10, + 3, + 6, + 2, + 3, + -1, + ], + } + + @mock.patch("torchtune.datasets._preference.load_dataset") + def test_get_item(self, mock_load_dataset, dialogue, expected): + mock_load_dataset.return_value = dialogue + expected_chosen_tokens = expected["prompt"] + expected["chosen"] + expected_chosen_labels = [CROSS_ENTROPY_IGNORE_IDX] * len( + expected["prompt"] + ) + expected["chosen"] + expected_rejected_tokens = expected["prompt"] + expected["rejected"] + expected_rejected_labels = [CROSS_ENTROPY_IGNORE_IDX] * len( + expected["prompt"] + ) + expected["rejected"] + + ds = PreferenceDataset( + source="iam/agoofy/goober", + message_transform=ToDummyPreferenceMessages(), + tokenizer=DummyTokenizer(), + prompt_template=DummyPromptTemplate(), + ) + assert len(ds) == 1 + mock_load_dataset.assert_called_once() + + prompt, label = ds[0]["chosen_input_ids"], ds[0]["chosen_labels"] + assert prompt == expected_chosen_tokens + assert label == expected_chosen_labels + + prompt, label = ds[0]["rejected_input_ids"], ds[0]["rejected_labels"] + assert prompt == expected_rejected_tokens + assert label == expected_rejected_labels diff --git a/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py b/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py new file mode 100644 index 0000000000..889a825979 --- /dev/null +++ b/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from collections import Counter +from unittest.mock import patch + +import pytest +from datasets import Dataset + +from tests.test_utils import assert_dialogue_equal, DummyTokenizer +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.data._messages import Message + +from torchtune.datasets._stack_exchange_paired import ( + stack_exchange_paired_dataset, + StackExchangePairedToMessages, +) + + +class TestStackExchangePairedDataset: + @patch("torchtune.datasets._preference.load_dataset") + @pytest.mark.parametrize("train_on_input", [True, False]) + def test_dataset_get_item(self, mock_load_dataset, train_on_input): + # Truncated sample data from stack exchange paired dataset + mock_load_dataset.return_value = Dataset.from_list( + [ + { + "question": "I have a question about if a animation ends that it " + "will like `gotoAndStop()` to another frame ``` if (bird.hitTestObject(pipe1))" + " { bird.gotoAndStop(3); //frame 3 = animation } ``` after it ends it will need" + " to go the Game Over frame (frame 3) and I use the `Flash Timeline` not `.as` " + "thanks!", + "response_j": "Java does not provide a convenient way to list the 'files' " + "in a 'directory', when that directory is backed by a JAR file on the classpath" + " (see [How do I list the files inside a JAR file?](https://stackoverflow.com/" + "questions/1429172/how-do-i-list-the-files-inside-a-jar-file) for some work-arounds)", + "response_k": "If you are still looking for an actual answer here is [mine]" + "(https://pastebin.com/R0jMh4ui) (it is kinda hacky but its work). To use it " + "you simply have to call one of the 2 options below", + } + ] + ) + ds = stack_exchange_paired_dataset( + tokenizer=DummyTokenizer(), + train_on_input=train_on_input, + ) + # Generate the input and labels + sample = ds[0] + + expected_chosen_counts = { + 4: 20, + 2: 15, + 3: 15, + 1: 13, + 9: 6, + 5: 6, + 7: 6, + 6: 4, + 0: 1, + 8: 1, + 15: 1, + 27: 1, + 20: 1, + 10: 1, + 12: 1, + 93: 1, + 13: 1, + -1: 1, + } + assert Counter(sample["chosen_input_ids"]) == expected_chosen_counts + if train_on_input: + assert Counter(sample["chosen_labels"]) == expected_chosen_counts + else: + # Check that the input is masked + assert sample["chosen_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 54 + + expected_rejected_counts = { + 2: 17, + 3: 17, + 4: 13, + 1: 9, + 5: 9, + 7: 6, + 6: 6, + 9: 4, + 0: 1, + 8: 1, + 15: 1, + 27: 1, + 20: 1, + 37: 1, + -1: 1, + } + assert Counter(sample["rejected_input_ids"]) == expected_rejected_counts + if train_on_input: + assert Counter(sample["rejected_labels"]) == expected_rejected_counts + else: + # Check that the input is masked + assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 54 + + +class TestStackExchangePairedToMessages: + @pytest.fixture + def sample(self): + return { + "maybe_prompt": "hello world", + "maybe_chosen": "hello world", + "maybe_rejected": "bye world", + } + + def test_call(self, sample): + transform = StackExchangePairedToMessages( + column_map={ + "prompt": "maybe_prompt", + "chosen": "maybe_chosen", + "rejected": "maybe_rejected", + }, + ) + actual = transform(sample) + expected_chosen = [ + Message(role="user", content="hello world", masked=True, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["chosen"], expected_chosen) + + expected_rejected = [ + Message(role="user", content="hello world", masked=True, eot=False), + Message(role="assistant", content="bye world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["rejected"], expected_rejected) + + def test_call_train_on_input(self, sample): + transform = StackExchangePairedToMessages( + column_map={ + "prompt": "maybe_prompt", + "chosen": "maybe_chosen", + "rejected": "maybe_rejected", + }, + train_on_input=True, + ) + actual = transform(sample) + expected_chosen = [ + Message(role="user", content="hello world", masked=False, eot=False), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["chosen"], expected_chosen) + + expected_rejected = [ + Message(role="user", content="hello world", masked=False, eot=False), + Message(role="assistant", content="bye world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["rejected"], expected_rejected) diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 89dc77bb60..520399a30e 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -12,12 +12,9 @@ ) from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._converters import get_openai_messages, get_sharegpt_messages -from torchtune.data._instruct_templates import ( - AlpacaInstructTemplate, - InstructTemplate, - StackExchangedPairedTemplate, -) +from torchtune.data._instruct_templates import AlpacaInstructTemplate, InstructTemplate from torchtune.data._messages import ( + ChosenRejectedToMessages, InputOutputToMessages, JSONToMessages, Message, @@ -29,6 +26,7 @@ GrammarErrorCorrectionTemplate, PromptTemplate, PromptTemplateInterface, + QuestionAnswerTemplate, SummarizeTemplate, ) from torchtune.data._utils import truncate, validate_messages @@ -48,11 +46,12 @@ "truncate", "Message", "validate_messages", - "StackExchangedPairedTemplate", "Role", "PromptTemplateInterface", "PromptTemplate", "InputOutputToMessages", + "ChosenRejectedToMessages", + "QuestionAnswerTemplate", "ChatMLTemplate", "get_openai_messages", "get_sharegpt_messages", diff --git a/torchtune/data/_instruct_templates.py b/torchtune/data/_instruct_templates.py index 49f52c2ccc..b66ee72804 100644 --- a/torchtune/data/_instruct_templates.py +++ b/torchtune/data/_instruct_templates.py @@ -128,50 +128,3 @@ def format( instruction=sample[key_instruction] ) return prompt - - -class StackExchangedPairedTemplate(InstructTemplate): - """ - Prompt template for preference datasets similar to StackExchangedPaired. - - .. code-block:: text - - Question: - - Answer: - """ - - template = "Question: {question}\n\nAnswer: " - - @classmethod - def format( - cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None - ) -> str: - """ - Generate prompt from instruction and input. - - Args: - sample (Mapping[str, Any]): a single data sample with instruction - column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names - in the template to the column names in the sample. If None, assume these are identical. - - Examples: - >>> # Simple question - >>> StackExchangedPairedTemplate.format(sample={"question": "What is the capital of France?"}) - Question: What is the capital of France?\\n\\nAnswer: - - >>> # Question with column map where the 'question' key is actually named 'prompt' in the given sample - >>> StackExchangedPairedTemplate.format( - ... sample={"prompt": "What is the capital of France?"}, - ... column_map={"question": "prompt"} - ... ) - Question: What is the capital of France?\\n\\nAnswer: - - Returns: - The formatted prompt - """ - column_map = column_map or {} - key_prompt = column_map.get("prompt", "prompt") - prompt = cls.template.format(question=sample[key_prompt]) - - return prompt diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index 5a34fc3600..a768299c59 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -154,6 +154,65 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: return {"messages": messages} +class ChosenRejectedToMessages(Transform): + """ + Transform for converting datasets with "chosen" and "rejected" columns containing + conversations to a list of chosen and rejected messages. For example:: + + | chosen | rejected | + |----------------------------------------|----------------------------------------| + | [{"role": "user", "content": Q1}, | [{"role": "user", "content": Q1}, | + | {"role": "assistant", "content": A1}] | {"role": "assistant", "content": A2}] | + + will be converted to: + + .. code-block:: python + + chosen = [ + Message(role="user", content="Q1"), + Message(role="assistant", content="A1"), + ] + rejected = [ + Message(role="user", content="Q1"), + Message(role="assistant", content="A2"), + ] + + Args: + train_on_input (bool): Whether the model is trained on the user prompt or not. + Default is False. + column_map (Optional[Dict[str, str]]): a mapping to change the expected + "chosen" and "rejected" column names to the actual column names in the dataset. + Default is None, keeping the default column names. + """ + + def __init__( + self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self._column_map = column_map + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + column_map = self._column_map or {} + key_chosen = column_map.get("chosen", "chosen") + key_rejected = column_map.get("rejected", "rejected") + + chosen_messages = [] + for message in sample[key_chosen]: + message["masked"] = (message["role"] != "assistant") and ( + not self.train_on_input + ) + chosen_messages.append(Message.from_dict(message)) + + rejected_messages = [] + for message in sample[key_rejected]: + message["masked"] = (message["role"] != "assistant") and ( + not self.train_on_input + ) + rejected_messages.append(Message.from_dict(message)) + + return {"chosen": chosen_messages, "rejected": rejected_messages} + + class ShareGPTToMessages(Transform): """ Convert a chat sample adhering to the ShareGPT json structure to torchtune's :class:`~torchtune.data.Message` diff --git a/torchtune/data/_prompt_templates.py b/torchtune/data/_prompt_templates.py index 20a22119d4..4d88de1c9c 100644 --- a/torchtune/data/_prompt_templates.py +++ b/torchtune/data/_prompt_templates.py @@ -211,3 +211,18 @@ def __call__( Please see :class:`~torchtune.data.PromptTemplate` for full API arguments. """ +QuestionAnswerTemplate = partial( + PromptTemplate, + template={ + "user": ("Question: ", "\n\nAnswer: "), + }, +) +QuestionAnswerTemplate.__doc__ = """ +A prompt template for question answering tasks:: + + Question: {user_message} + + Answer: {assistant_message} + +Please see :class:`~torchtune.data.PromptTemplate` for full API arguments. +""" diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index 97749588db..4645d038ec 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -9,13 +9,14 @@ from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset from torchtune.datasets._concat import ConcatDataset from torchtune.datasets._grammar import grammar_dataset +from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset, InstructDataset from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import PreferenceDataset from torchtune.datasets._samsum import samsum_dataset from torchtune.datasets._sft import SFTDataset from torchtune.datasets._slimorca import slimorca_dataset -from torchtune.datasets._stack_exchanged_paired import stack_exchanged_paired_dataset +from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset from torchtune.datasets._text_completion import ( text_completion_dataset, TextCompletionDataset, @@ -27,7 +28,7 @@ "alpaca_cleaned_dataset", "grammar_dataset", "samsum_dataset", - "stack_exchanged_paired_dataset", + "stack_exchange_paired_dataset", "InstructDataset", "slimorca_dataset", "ChatDataset", @@ -41,4 +42,5 @@ "wikitext_dataset", "PreferenceDataset", "SFTDataset", + "hh_rlhf_helpful_dataset", ] diff --git a/torchtune/datasets/_hh_rlhf_helpful.py b/torchtune/datasets/_hh_rlhf_helpful.py new file mode 100644 index 0000000000..c1baa91300 --- /dev/null +++ b/torchtune/datasets/_hh_rlhf_helpful.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional + +from torchtune.data import ChosenRejectedToMessages, PromptTemplate +from torchtune.datasets._preference import PreferenceDataset +from torchtune.modules.tokenizers import ModelTokenizer + + +def hh_rlhf_helpful_dataset( + tokenizer: ModelTokenizer, + *, + source: str = "RLHFlow/HH-RLHF-Helpful-standard", + column_map: Optional[Dict[str, str]] = None, + prompt_template: Optional[PromptTemplate] = None, + train_on_input: bool = False, + split: str = "train", +) -> PreferenceDataset: + """ + Constructs preference datasets similar to `Anthropic's helpful/harmless RLHF + data + `_. This is + the processed helpful subset of the original dataset in a standardized format. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` + (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + for more details. Default is ``RLHFlow/HH-RLHF-Helpful-standard``. + column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template + to the new column names in the dataset. If None, assume these are identical. + prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default + is None. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + + Returns: + PreferenceDataset: The preference dataset built from source paired data. + """ + + message_transform = ChosenRejectedToMessages( + train_on_input=train_on_input, column_map=column_map + ) + + return PreferenceDataset( + source=source, + message_transform=message_transform, + tokenizer=tokenizer, + prompt_template=prompt_template, + split=split, + ) diff --git a/torchtune/datasets/_preference.py b/torchtune/datasets/_preference.py index 476b04323a..2acd29e85a 100644 --- a/torchtune/datasets/_preference.py +++ b/torchtune/datasets/_preference.py @@ -4,70 +4,114 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional import numpy as np from datasets import load_dataset from torch.utils.data import Dataset -from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, InstructTemplate, Message +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, PromptTemplate from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms import Transform class PreferenceDataset(Dataset): """ - Class that supports any custom dataset with instruction-based prompts and a - configurable template. - - The general flow from loading a sample to tokenized prompt is: - load sample -> apply transform -> format into template -> tokenize + Primary class for fine-tuning via preference modelling techniques (e.g. training + a preference model for RLHF, or directly optimizing a model through DPO) on a + preference dataset sourced from Hugging Face Hub, local files, or remote files. This + class requires the dataset to have "chosen" and "rejected" model responses. These are + typically either full conversations between user and assistant in separate columns:: + + | chosen | rejected | + |----------------------------------------|----------------------------------------| + | [{"role": "user", "content": Q1}, | [{"role": "user", "content": Q1}, | + | {"role": "assistant", "content": A1}] | {"role": "assistant", "content": A2}] | + + or a user prompt column with separate chosen and rejected assistant reponses:: + + | prompt | chosen | rejected | + |----------|----------|------------| + | Q1 | A1 | A2 | + + At a high level, this class will load the data from source and apply the following pre-processing steps + when a sample is retrieved: + + 1. Dataset-specific transform. This is typically unique to each dataset and extracts + the necessary prompt and chosen/rejected columns into torchtune's :class:`~torchtune.data.Message` + format, a standardized API for all model tokenizers. + 2. If specified, apply a prompt template for the task you are fine-tuning for. + 3. Tokenization + + + All datasets are formatted into a list of :class:`~torchtune.data.Message` + because preference datasets can be considered as chosen and rejected "conversations" + with the model, or AI assistant. Thus, we can standardize all text content as messages + in a conversation assigned to a role: + + - ``"user"`` messages contain the input prompt into the model + - ``"assistant"`` messages are the response of the model and what you actually want + to train for and compute loss directly against + + The :class:`~torchtune.data.Message` forms the core data unit that all tokenizer + APIs expect. The key component of this class that ensures any dataset is transformed + into this format is the ``message_transform``. This is a callable class that takes + in a sample dictionary - typically a single row from the source dataset - that + processes the sample in any configurable way to output a list of messages:: + + [ + Message( + role=, + content=, + ), + ... + ] - If the column/key names differ from the expected names in the :class:`~torchtune.data.InstructTemplate`, - then the ``column_map`` argument can be used to provide this mapping. + For any custom dataset, use the ``message_transform`` to contain all pre-processing to + return the list of messages. Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. source (str): path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. "json", "csv", "text") and pass - in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` - (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. + message_transform (Transform): callable that keys into the desired fields in the sample + and converts text content to a list of :class:`~torchtune.data.Message`. It is expected that the final list + of messages are stored in the ``"chosen"`` and ``"rejected"`` keys. + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + Since PreferenceDataset only supports text data, it requires a + :class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in + :class:`~torchtune.datasets.SFTDataset`. + prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used + to add structured text around the actual messages and is called on both chosen and rejected messages. + The structured text is used in three scenarios: + + - Task-specific templates to gear models for a particular task that it will expect after training + - Model-specific templates that are required whenever the model is prompted, such as the [INST] + tags in Llama2 and in Mistral + - Community standardized templates, such as :class:`~torchtune.data.ChatMLFormat` + + The extra text added by the template will still get tokenized as normal text, not as special tokens. + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging + Face's `API ref `_ for more details. - template (InstructTemplate): template used to format the prompt. If the placeholder variable - names in the template do not match the column/key names in the dataset, use ``column_map`` to map them. - transform (Optional[Callable]): transform to apply to the sample before formatting to the template. - Default is None. - column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template - to the column/key names in the sample. If None, assume these are identical. - max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. - Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory - and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. - **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, - such as ``data_files`` or ``split``. """ def __init__( self, - tokenizer: ModelTokenizer, + *, source: str, - template: InstructTemplate, - transform: Optional[Callable] = None, - column_map: Optional[Dict[str, str]] = None, - max_seq_len: Optional[int] = None, + message_transform: Transform, + tokenizer: ModelTokenizer, + prompt_template: Optional[PromptTemplate] = None, **load_dataset_kwargs: Dict[str, Any], ) -> None: self._tokenizer = tokenizer + self._prompt_template = prompt_template + self._message_transform = message_transform self._data = load_dataset(source, **load_dataset_kwargs) - self.template = template - self._transform = transform - self._column_map = column_map - self.max_seq_len = max_seq_len - self._data = self._data.filter( - lambda x: len(x[column_map["prompt"]]) + len(x[column_map["chosen"]]) - <= max_seq_len - and len(x[column_map["prompt"]]) + len(x[column_map["rejected"]]) - <= max_seq_len - ) def __len__(self): return len(self._data) @@ -77,47 +121,39 @@ def __getitem__(self, index: int) -> Dict[str, List[int]]: return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: - transformed_sample = self._transform(sample) if self._transform else sample - prompt = self.template.format(transformed_sample, self._column_map) - - column_map = self._column_map or {} - key_chosen = column_map.get("chosen", "chosen") - key_rejected = column_map.get("rejected", "rejected") - - chosen_message = [ - Message(role="user", content=prompt, masked=True), - Message(role="assistant", content=transformed_sample[key_chosen]), - ] - - rejected_message = [ - Message(role="user", content=prompt, masked=True), - Message(role="assistant", content=transformed_sample[key_rejected]), - ] - - # TODO: Trunction differs from original DPO repo + transformed_sample = self._message_transform(sample) + if self._prompt_template is not None: + transformed_sample["chosen"] = self._prompt_template( + transformed_sample["chosen"] + ) + transformed_sample["rejected"] = self._prompt_template( + transformed_sample["rejected"] + ) + + # TODO: Truncation differs from original DPO repo # in DPO: first truncate prompts, then responses - chosen_input_ids, c_masks = self._tokenizer.tokenize_messages( - chosen_message, + chosen_input_ids, chosen_masks = self._tokenizer.tokenize_messages( + transformed_sample["chosen"], ) chosen_labels = list( - np.where(c_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids) + np.where(chosen_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids) ) - rejected_input_ids, r_masks = self._tokenizer.tokenize_messages( - rejected_message, + rejected_input_ids, rejected_masks = self._tokenizer.tokenize_messages( + transformed_sample["rejected"], ) rejected_labels = list( - np.where(r_masks, CROSS_ENTROPY_IGNORE_IDX, rejected_input_ids) + np.where(rejected_masks, CROSS_ENTROPY_IGNORE_IDX, rejected_input_ids) ) assert len(chosen_input_ids) == len(chosen_labels) assert len(rejected_input_ids) == len(rejected_labels) - batch = dict( + tokenized_dict = dict( chosen_input_ids=chosen_input_ids, chosen_labels=chosen_labels, rejected_input_ids=rejected_input_ids, rejected_labels=rejected_labels, ) - return batch + return tokenized_dict diff --git a/torchtune/datasets/_stack_exchange_paired.py b/torchtune/datasets/_stack_exchange_paired.py new file mode 100644 index 0000000000..4f3e6ed1fe --- /dev/null +++ b/torchtune/datasets/_stack_exchange_paired.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Mapping, Optional + +from torchtune.data import Message, PromptTemplate, QuestionAnswerTemplate +from torchtune.datasets._preference import PreferenceDataset +from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms import Transform + + +class StackExchangePairedToMessages(Transform): + """ + Transform for converting datasets similar to the format in `Stack Exchange Paired dataset + `_:: + + | prompt | chosen | rejected | + |----------|----------|------------| + | Q1 | A1 | A2 | + + into a list of chosen and rejected messages: + + .. code-block:: python + + chosen = [ + Message(role="user", content="Q1"), + Message(role="assistant", content="A1"), + ] + rejected = [ + Message(role="user", content="Q1"), + Message(role="assistant", content="A2"), + ] + + Args: + train_on_input (bool): Whether the model is trained on the user prompt or not. + Default is False. + column_map (Optional[Dict[str, str]]): a mapping to change the expected "prompt", + "chosen", and "rejected" column names to the actual column names in the dataset. + Default is None, keeping the default column names. + """ + + def __init__( + self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None + ): + self.train_on_input = train_on_input + self._column_map = column_map + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + column_map = self._column_map or {} + key_prompt = column_map.get("prompt", "prompt") + key_chosen = column_map.get("chosen", "chosen") + key_rejected = column_map.get("rejected", "rejected") + + chosen_messages = [ + Message( + role="user", content=sample[key_prompt], masked=not self.train_on_input + ), + Message(role="assistant", content=sample[key_chosen]), + ] + + rejected_messages = [ + Message( + role="user", content=sample[key_prompt], masked=not self.train_on_input + ), + Message(role="assistant", content=sample[key_rejected]), + ] + + return {"chosen": chosen_messages, "rejected": rejected_messages} + + +def stack_exchange_paired_dataset( + tokenizer: ModelTokenizer, + *, + source: str = "lvwerra/stack-exchange-paired", + column_map: Optional[Dict[str, str]] = None, + prompt_template: Optional[PromptTemplate] = QuestionAnswerTemplate(), + train_on_input: bool = False, + split: str = "train", +) -> PreferenceDataset: + """ + Family of preference datasets similar to the `Stack Exchange Paired dataset + `_. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. + source (str): path to dataset repository on Hugging Face. For local datasets, + define source as the data file type (e.g. "json", "csv", "text") and pass + in the filepath in ``data_files``. See `Hugging Face's + `_ + ``load_dataset`` for more details. Default is ``lvwerra/stack-exchange-paired``. + column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template + to the new column names in the dataset. If None, assume these are identical. + prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default + is :class:`~torchtune.data.QuestionAnswerTemplate`. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset + of a given split, e.g. ``split="train[:10%]"``. Default is "train". + + Returns: + PreferenceDataset: The preference dataset built from source paired data. + """ + + column_map = column_map or { + "prompt": "question", + "chosen": "response_j", + "rejected": "response_k", + } + + message_transform = StackExchangePairedToMessages( + train_on_input=train_on_input, column_map=column_map + ) + + return PreferenceDataset( + source=source, + message_transform=message_transform, + tokenizer=tokenizer, + prompt_template=prompt_template, + split=split, + data_dir="data/rl", + ) diff --git a/torchtune/datasets/_stack_exchanged_paired.py b/torchtune/datasets/_stack_exchanged_paired.py deleted file mode 100644 index a53e5755e9..0000000000 --- a/torchtune/datasets/_stack_exchanged_paired.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchtune.data import StackExchangedPairedTemplate -from torchtune.datasets._preference import PreferenceDataset -from torchtune.modules.tokenizers import ModelTokenizer - - -def stack_exchanged_paired_dataset( - tokenizer: ModelTokenizer, - *, - source: str = "lvwerra/stack-exchange-paired", - max_seq_len: int = 1024, - split: str = "train", -) -> PreferenceDataset: - """ - Family of preference datasets similar to `StackExchangePaired data - `_. - - Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. - source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`. - max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. - Default is 1024. - split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset - of a given split, e.g. ``split="train[:10%]"``. Default is "train". - - Returns: - PreferenceDataset: The preference dataset built from source paired data. - """ - return PreferenceDataset( - tokenizer=tokenizer, - source=source, - template=StackExchangedPairedTemplate(), - column_map={ - "prompt": "question", - "chosen": "response_j", - "rejected": "response_k", - }, - max_seq_len=max_seq_len, - split=split, - data_dir="data/rl", - )