Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[4/7] Refactor preference dataset with transforms design #1276

Merged
merged 26 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a3fe457
initial commit
RdoubleA Jul 23, 2024
9da786f
Merge branch 'main' into merged_dataset_1
RdoubleA Jul 26, 2024
969909d
flesh out prompt templates
RdoubleA Jul 26, 2024
c422a01
Merge branch 'main' into merged_dataset_1
RdoubleA Jul 26, 2024
ef79507
refactor samsum
RdoubleA Jul 27, 2024
5d2e7f5
Merge branch 'main' into merged_dataset_1
RdoubleA Jul 29, 2024
7d54201
add all tests, update live docs
RdoubleA Jul 30, 2024
062ff38
Merge branch 'main' into merged_dataset_1
RdoubleA Jul 30, 2024
df00fe1
fix tests
RdoubleA Jul 30, 2024
4157dd7
change naming
RdoubleA Jul 31, 2024
604fd7b
refactor preference
RdoubleA Jul 31, 2024
ba2e2ec
fix recipe tests
RdoubleA Jul 31, 2024
a531e48
remove content.strip() in tokenizer
RdoubleA Jul 31, 2024
cb2596c
Merge branch 'merged_dataset_1' into merged_preference_dataset
RdoubleA Jul 31, 2024
0478165
Merge branch 'main' into merged_preference_dataset
RdoubleA Aug 6, 2024
99e6e27
refactor preference, stack exchange
RdoubleA Aug 6, 2024
a41479b
fix merge
RdoubleA Aug 6, 2024
3efa81e
update dpo configs
RdoubleA Aug 6, 2024
fc5bb35
Merge branch 'main' into merged_preference_dataset
RdoubleA Aug 6, 2024
e077ff3
Merge remote-tracking branch 'upstream/main' into merged_preference_d…
RdoubleA Aug 7, 2024
3099984
fix merge
RdoubleA Aug 7, 2024
4266aa9
Merge branch 'main' into merged_preference_dataset
RdoubleA Aug 8, 2024
6d0e5ac
Merge branch 'main' into merged_preference_dataset
RdoubleA Aug 10, 2024
4bb637c
add hh rlhf dataset builder
RdoubleA Aug 10, 2024
c6bdd78
Merge branch 'main' into merged_preference_dataset
RdoubleA Aug 13, 2024
8fd001d
address comments, add hh_rlhf builder, add tests
RdoubleA Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ Converts data from common JSON formats into a torchtune :class:`Message`.
get_sharegpt_messages
get_openai_messages

Message transforms
------------------

Converts data from common schema and conversation JSON formats into a list of torchtune :class:`Message`.

.. autosummary::
:toctree: generated/
:nosignatures:

InputOutputToMessages
ChosenRejectedToMessages

Helper funcs
------------

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ torchtune supports several widely used datasets to help quickly bootstrap your f
grammar_dataset
samsum_dataset
slimorca_dataset
stack_exchanged_paired_dataset
stack_exchange_paired_dataset
cnn_dailymail_articles_dataset
wikitext_dataset

Expand Down
59 changes: 58 additions & 1 deletion tests/torchtune/data/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import pytest
from tests.test_utils import assert_dialogue_equal
from torchtune.data._messages import InputOutputToMessages, Message
from torchtune.data._messages import (
ChosenRejectedToMessages,
InputOutputToMessages,
Message,
)


class TestMessage:
Expand Down Expand Up @@ -93,3 +97,56 @@ def test_call_train_on_input(self, sample):
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["messages"], expected)


class TestChosenRejectedToMessages:
@pytest.fixture
def sample(self):
return {
"maybe_prompt": "hello world",
"maybe_chosen": "hello world",
"maybe_rejected": "bye world",
}

def test_call(self, sample):
transform = ChosenRejectedToMessages(
column_map={
"prompt": "maybe_prompt",
"chosen": "maybe_chosen",
"rejected": "maybe_rejected",
},
)
actual = transform(sample)
expected_chosen = [
Message(role="user", content="hello world", masked=True, eot=False),
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["chosen"], expected_chosen)

expected_rejected = [
Message(role="user", content="hello world", masked=True, eot=False),
Message(role="assistant", content="bye world", masked=False, eot=True),
]
assert_dialogue_equal(actual["rejected"], expected_rejected)

def test_call_train_on_input(self, sample):
transform = ChosenRejectedToMessages(
column_map={
"prompt": "maybe_prompt",
"chosen": "maybe_chosen",
"rejected": "maybe_rejected",
},
train_on_input=True,
)
actual = transform(sample)
expected_chosen = [
Message(role="user", content="hello world", masked=False, eot=False),
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["chosen"], expected_chosen)

expected_rejected = [
Message(role="user", content="hello world", masked=False, eot=False),
Message(role="assistant", content="bye world", masked=False, eot=True),
]
assert_dialogue_equal(actual["rejected"], expected_rejected)
117 changes: 117 additions & 0 deletions tests/torchtune/datasets/test_preference_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Mapping
from unittest import mock

import pytest
from tests.test_utils import DummyPromptTemplate, DummyTokenizer
from torchtune.data import Message
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.transforms import Transform


class ToDummyPreferenceMessages(Transform):
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
chosen_messages = [
Message.from_dict(sample["prompt"][0]),
Message.from_dict(sample["chosen"][0]),
]

rejected_messages = [
Message.from_dict(sample["prompt"][0]),
Message.from_dict(sample["rejected"][0]),
]

return {"chosen": chosen_messages, "rejected": rejected_messages}


class TestSFTDataset:
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture
def dialogue(self):
return [
{
"prompt": [
{
"role": "user",
"content": "What is 2+2?",
"masked": True,
},
],
"chosen": [
{
"role": "assistant",
"content": "The answer is 4.",
"masked": False,
},
],
"rejected": [
{
"role": "assistant",
"content": "The answer is 12.",
"masked": False,
},
],
},
]

@pytest.fixture
def expected(self):
return {
"prompt": [
0,
5,
4,
2,
4,
],
"chosen": [
10,
3,
6,
2,
2,
-1,
],
"rejected": [
10,
3,
6,
2,
3,
-1,
],
}

@mock.patch("torchtune.datasets._preference.load_dataset")
def test_get_item(self, mock_load_dataset, dialogue, expected):
mock_load_dataset.return_value = dialogue
expected_chosen_tokens = expected["prompt"] + expected["chosen"]
expected_chosen_labels = [CROSS_ENTROPY_IGNORE_IDX] * len(
expected["prompt"]
) + expected["chosen"]
expected_rejected_tokens = expected["prompt"] + expected["rejected"]
expected_rejected_labels = [CROSS_ENTROPY_IGNORE_IDX] * len(
expected["prompt"]
) + expected["rejected"]

ds = PreferenceDataset(
source="iam/agoofy/goober",
message_transform=ToDummyPreferenceMessages(),
tokenizer=DummyTokenizer(),
prompt_template=DummyPromptTemplate(),
)
assert len(ds) == 1
mock_load_dataset.assert_called_once()

prompt, label = ds[0]["chosen_input_ids"], ds[0]["chosen_labels"]
assert prompt == expected_chosen_tokens
assert label == expected_chosen_labels

prompt, label = ds[0]["rejected_input_ids"], ds[0]["rejected_labels"]
assert prompt == expected_rejected_tokens
assert label == expected_rejected_labels
97 changes: 97 additions & 0 deletions tests/torchtune/datasets/test_stack_exchange_paired_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections import Counter
from unittest.mock import patch

import pytest
from datasets import Dataset

from tests.test_utils import DummyTokenizer
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets import stack_exchange_paired_dataset


class TestStackExchangePairedDataset:
@patch("torchtune.datasets._preference.load_dataset")
@pytest.mark.parametrize("train_on_input", [True, False])
def test_dataset_get_item(self, mock_load_dataset, train_on_input):
# Truncated sample data from stack exchange paired dataset
mock_load_dataset.return_value = Dataset.from_list(
[
{
"question": "I have a question about if a animation ends that it "
"will like `gotoAndStop()` to another frame ``` if (bird.hitTestObject(pipe1))"
" { bird.gotoAndStop(3); //frame 3 = animation } ``` after it ends it will need"
" to go the Game Over frame (frame 3) and I use the `Flash Timeline` not `.as` "
"thanks!",
"response_j": "Java does not provide a convenient way to list the 'files' "
"in a 'directory', when that directory is backed by a JAR file on the classpath"
" (see [How do I list the files inside a JAR file?](https://stackoverflow.com/"
"questions/1429172/how-do-i-list-the-files-inside-a-jar-file) for some work-arounds)",
"response_k": "If you are still looking for an actual answer here is [mine]"
"(https://pastebin.com/R0jMh4ui) (it is kinda hacky but its work). To use it "
"you simply have to call one of the 2 options below",
}
]
)
ds = stack_exchange_paired_dataset(
tokenizer=DummyTokenizer(),
train_on_input=train_on_input,
)
# Generate the input and labels
sample = ds[0]

expected_chosen_counts = {
4: 20,
2: 15,
3: 15,
1: 13,
9: 6,
5: 6,
7: 6,
6: 4,
0: 1,
8: 1,
15: 1,
27: 1,
20: 1,
10: 1,
12: 1,
93: 1,
13: 1,
-1: 1,
}
assert Counter(sample["chosen_input_ids"]) == expected_chosen_counts
if train_on_input:
assert Counter(sample["chosen_labels"]) == expected_chosen_counts
else:
# Check that the input is masked
assert sample["chosen_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 54

expected_rejected_counts = {
2: 17,
3: 17,
4: 13,
1: 9,
5: 9,
7: 6,
6: 6,
9: 4,
0: 1,
8: 1,
15: 1,
27: 1,
20: 1,
37: 1,
-1: 1,
}
assert Counter(sample["rejected_input_ids"]) == expected_rejected_counts
if train_on_input:
assert Counter(sample["rejected_labels"]) == expected_rejected_counts
else:
# Check that the input is masked
assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 54
15 changes: 9 additions & 6 deletions torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
)
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._converters import get_openai_messages, get_sharegpt_messages
from torchtune.data._instruct_templates import (
AlpacaInstructTemplate,
InstructTemplate,
StackExchangedPairedTemplate,
from torchtune.data._instruct_templates import AlpacaInstructTemplate, InstructTemplate
from torchtune.data._messages import (
ChosenRejectedToMessages,
InputOutputToMessages,
Message,
Role,
)
from torchtune.data._messages import InputOutputToMessages, Message, Role
from torchtune.data._prompt_templates import (
GrammarErrorCorrectionTemplate,
PromptTemplate,
PromptTemplateInterface,
QuestionAnswerTemplate,
SummarizeTemplate,
)
from torchtune.data._utils import truncate, validate_messages
Expand All @@ -41,9 +43,10 @@
"truncate",
"Message",
"validate_messages",
"StackExchangedPairedTemplate",
"Role",
"PromptTemplateInterface",
"PromptTemplate",
"InputOutputToMessages",
"ChosenRejectedToMessages",
"QuestionAnswerTemplate",
]
Loading
Loading