Skip to content

Commit

Permalink
[4/7] Refactor preference dataset with transforms design (#1276)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Aug 13, 2024
1 parent 7a6a9b0 commit 6a7951f
Show file tree
Hide file tree
Showing 17 changed files with 808 additions and 171 deletions.
3 changes: 2 additions & 1 deletion docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ and models.
AlpacaInstructTemplate
GrammarErrorCorrectionTemplate
SummarizeTemplate
StackExchangedPairedTemplate
QuestionAnswerTemplate
PromptTemplate
PromptTemplateInterface
ChatMLTemplate
Expand Down Expand Up @@ -66,6 +66,7 @@ Converts data from common schema and conversation JSON formats into a list of to
InputOutputToMessages
ShareGPTToMessages
JSONToMessages
ChosenRejectedToMessages

Helper functions
----------------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ torchtune supports several widely used datasets to help quickly bootstrap your f
grammar_dataset
samsum_dataset
slimorca_dataset
stack_exchanged_paired_dataset
stack_exchange_paired_dataset
cnn_dailymail_articles_dataset
wikitext_dataset

Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/llama2/7B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ model:
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: 1024

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
Expand All @@ -45,8 +46,7 @@ save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchanged_paired_dataset
max_seq_len: 1024
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/llama2/7B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ model:
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: 1024

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
Expand All @@ -44,8 +45,7 @@ save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchanged_paired_dataset
max_seq_len: 1024
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4
Expand Down
57 changes: 57 additions & 0 deletions tests/torchtune/data/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
MESSAGE_SAMPLE_TRAIN_ON_INPUT,
)
from torchtune.data._messages import (
ChosenRejectedToMessages,
InputOutputToMessages,
JSONToMessages,
Message,
Expand Down Expand Up @@ -105,6 +106,62 @@ def test_call_train_on_input(self, sample):
assert_dialogue_equal(actual["messages"], expected)


class TestChosenRejectedToMessages:
@pytest.fixture
def sample(self):
return {
"maybe_chosen": [
{"role": "user", "content": "hello world"},
{"role": "assistant", "content": "hello world"},
],
"maybe_rejected": [
{"role": "user", "content": "hello world"},
{"role": "assistant", "content": "bye world"},
],
}

def test_call(self, sample):
transform = ChosenRejectedToMessages(
column_map={
"chosen": "maybe_chosen",
"rejected": "maybe_rejected",
},
)
actual = transform(sample)
expected_chosen = [
Message(role="user", content="hello world", masked=True, eot=False),
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["chosen"], expected_chosen)

expected_rejected = [
Message(role="user", content="hello world", masked=True, eot=False),
Message(role="assistant", content="bye world", masked=False, eot=True),
]
assert_dialogue_equal(actual["rejected"], expected_rejected)

def test_call_train_on_input(self, sample):
transform = ChosenRejectedToMessages(
column_map={
"chosen": "maybe_chosen",
"rejected": "maybe_rejected",
},
train_on_input=True,
)
actual = transform(sample)
expected_chosen = [
Message(role="user", content="hello world", masked=False, eot=False),
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["chosen"], expected_chosen)

expected_rejected = [
Message(role="user", content="hello world", masked=False, eot=False),
Message(role="assistant", content="bye world", masked=False, eot=True),
]
assert_dialogue_equal(actual["rejected"], expected_rejected)


class TestShareGPTToMessages:
samples = {
"conversations": [
Expand Down
109 changes: 109 additions & 0 deletions tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections import Counter
from unittest.mock import patch

import pytest
from datasets import Dataset

from tests.test_utils import DummyTokenizer
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset


class TestHHRLHFHelpfulDataset:
@patch("torchtune.datasets._preference.load_dataset")
@pytest.mark.parametrize("train_on_input", [True, False])
def test_dataset_get_item(self, mock_load_dataset, train_on_input):
# Truncated sample data from HH RLHF Helpful dataset
mock_load_dataset.return_value = Dataset.from_list(
[
{
"chosen": [
{
"content": "helping my granny with her mobile phone issue",
"role": "user",
},
{
"content": "I see you are chatting with your grandmother "
"about an issue with her mobile phone. How can I help?",
"role": "assistant",
},
{"content": "her phone is not turning on", "role": "user"},
{
"content": "Is it on but it doesn’t power up or charge? "
"Or it’s off and does not turn on?",
"role": "assistant",
},
],
"rejected": [
{
"content": "helping my granny with her mobile phone issue",
"role": "user",
},
{
"content": "I see you are chatting with your grandmother "
"about an issue with her mobile phone. How can I help?",
"role": "assistant",
},
{"content": "her phone is not turning on", "role": "user"},
{
"content": "Okay, are you concerned that her phone is broken, "
"or simply that it is not turning on?",
"role": "assistant",
},
],
}
]
)
ds = hh_rlhf_helpful_dataset(
tokenizer=DummyTokenizer(),
train_on_input=train_on_input,
)
# Generate the input and labels
sample = ds[0]

expected_chosen_counts = {
3: 14,
2: 11,
4: 7,
5: 7,
7: 4,
6: 4,
0: 2,
1: 2,
-1: 2,
8: 1,
11: 1,
}
assert Counter(sample["chosen_input_ids"]) == expected_chosen_counts
if train_on_input:
assert Counter(sample["chosen_labels"]) == expected_chosen_counts
else:
# Check that the input is masked
assert sample["chosen_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 16

expected_rejected_counts = {
3: 14,
2: 8,
5: 8,
4: 6,
6: 5,
7: 4,
0: 2,
1: 2,
-1: 2,
8: 1,
11: 1,
9: 1,
}
assert Counter(sample["rejected_input_ids"]) == expected_rejected_counts
if train_on_input:
assert Counter(sample["rejected_labels"]) == expected_rejected_counts
else:
# Check that the input is masked
assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 16
117 changes: 117 additions & 0 deletions tests/torchtune/datasets/test_preference_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Mapping
from unittest import mock

import pytest
from tests.test_utils import DummyPromptTemplate, DummyTokenizer
from torchtune.data import Message
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.transforms import Transform


class ToDummyPreferenceMessages(Transform):
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
chosen_messages = [
Message.from_dict(sample["prompt"][0]),
Message.from_dict(sample["chosen"][0]),
]

rejected_messages = [
Message.from_dict(sample["prompt"][0]),
Message.from_dict(sample["rejected"][0]),
]

return {"chosen": chosen_messages, "rejected": rejected_messages}


class TestPreferenceDataset:
@pytest.fixture
def dialogue(self):
return [
{
"prompt": [
{
"role": "user",
"content": "What is 2+2?",
"masked": True,
},
],
"chosen": [
{
"role": "assistant",
"content": "The answer is 4.",
"masked": False,
},
],
"rejected": [
{
"role": "assistant",
"content": "The answer is 12.",
"masked": False,
},
],
},
]

@pytest.fixture
def expected(self):
return {
"prompt": [
0,
5,
4,
2,
4,
],
"chosen": [
10,
3,
6,
2,
2,
-1,
],
"rejected": [
10,
3,
6,
2,
3,
-1,
],
}

@mock.patch("torchtune.datasets._preference.load_dataset")
def test_get_item(self, mock_load_dataset, dialogue, expected):
mock_load_dataset.return_value = dialogue
expected_chosen_tokens = expected["prompt"] + expected["chosen"]
expected_chosen_labels = [CROSS_ENTROPY_IGNORE_IDX] * len(
expected["prompt"]
) + expected["chosen"]
expected_rejected_tokens = expected["prompt"] + expected["rejected"]
expected_rejected_labels = [CROSS_ENTROPY_IGNORE_IDX] * len(
expected["prompt"]
) + expected["rejected"]

ds = PreferenceDataset(
source="iam/agoofy/goober",
message_transform=ToDummyPreferenceMessages(),
tokenizer=DummyTokenizer(),
prompt_template=DummyPromptTemplate(),
)
assert len(ds) == 1
mock_load_dataset.assert_called_once()

prompt, label = ds[0]["chosen_input_ids"], ds[0]["chosen_labels"]
assert prompt == expected_chosen_tokens
assert label == expected_chosen_labels

prompt, label = ds[0]["rejected_input_ids"], ds[0]["rejected_labels"]
assert prompt == expected_rejected_tokens
assert label == expected_rejected_labels
Loading

0 comments on commit 6a7951f

Please sign in to comment.