Skip to content

Commit

Permalink
add some unit tests, more templates
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Mar 25, 2024
1 parent c30d1c0 commit 89f6c2b
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 67 deletions.
31 changes: 31 additions & 0 deletions tests/torchtune/config/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_get_component_from_path,
_merge_yaml_and_cli_args,
InstantiationError,
_get_template,
)
from torchtune.utils.argparse import TuneArgumentParser

Expand Down Expand Up @@ -107,3 +108,33 @@ def test_merge_yaml_and_cli_args(self, mock_load):
ValueError, match="Command-line overrides must be in the form of key=value"
):
_ = _merge_yaml_and_cli_args(yaml_args, cli_args)

def test_get_template(self):
# Test valid template class
template = _get_template("AlpacaInstructTemplate")
assert isinstance(template, AlpacaInstructTemplate)

# Test invalid template class
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template("InvalidTemplate")

# Test valid template strings
s = [
"Instruction: {instruction}\nInput: {input}",
"Instruction: {instruction}",
"{a}",
]
for t in s:
assert _get_template(t) == t

# Test invalid template strings
s = ["hello", "{}", "a}{b"]
for t in s:
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template(t)
96 changes: 96 additions & 0 deletions tests/torchtune/datasets/test_chat_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest import mock

import pytest
from torchtune.data import AlpacaInstructTemplate
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._instruct import _get_template, InstructDataset


class DummyTokenizer:
def encode(self, text, **kwargs):
words = text.split()
return [len(word) for word in words]


class DummyTemplate:
def __init__(self, template):
self.template = template

def format(self, sample, column_map):
return self.template.format(**sample)


class TestChatDataset:
template = DummyTemplate(
"Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse: "
)
expected_tokenized_prompts = [
[12, 4, 2, 3, 2, 12, 10, 6, 4, 2, 3, 2, 6, 10, 9, 1, 5, 4, 4, 3, 6, 2, 4],
[12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4],
]

def get_samples(self):
return [
{
"instruction": "This is not an instruction.",
"input": "This is not an input.",
"output": "I never know what I'm doing, do you?",
},
{
"instruction": "This is an instruction.",
"input": "This is an input.",
"output": "I always know what I'm doing, do you?",
},
]

@mock.patch("torchtune.datasets._instruct.load_dataset")
def test_get_item_no_train_on_input(self, mock_load_dataset):
mock_load_dataset.return_value = self.get_samples()
prompt_lengths = (15, 13)
expected_labels = [
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + [1, 5, 4, 4, 3, 6, 2, 4],
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + [1, 6, 4, 4, 3, 6, 2, 4],
]

dataset = InstructDataset(
tokenizer=DummyTokenizer(),
source="iam/agoofy/goober",
template=self.template,
transform=dummy_transform,
train_on_input=False,
)
assert len(dataset) == 2
mock_load_dataset.assert_called_once()

for i in range(len(dataset)):
prompt, label = dataset[i]
print(prompt, label)
assert prompt == self.expected_tokenized_prompts[i]
assert label == expected_labels[i]

@mock.patch("torchtune.datasets._instruct.load_dataset")
def test_get_item_train_on_input(self, mock_load_dataset):
mock_load_dataset.return_value = self.get_samples()
expected_labels = self.expected_tokenized_prompts

dataset = InstructDataset(
tokenizer=DummyTokenizer(),
source="iam/agoofy/goober",
template=self.template,
transform=dummy_transform,
train_on_input=True,
)
assert len(dataset) == 2
mock_load_dataset.assert_called_once()

for i in range(len(dataset)):
prompt, label = dataset[i]
assert prompt == self.expected_tokenized_prompts[i]
assert label == expected_labels[i]
31 changes: 0 additions & 31 deletions tests/torchtune/datasets/test_instruct_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,34 +100,3 @@ def test_get_item_train_on_input(self, mock_load_dataset):
prompt, label = dataset[i]
assert prompt == self.expected_tokenized_prompts[i]
assert label == expected_labels[i]


def test_get_template():
# Test valid template class
template = _get_template("AlpacaInstructTemplate")
assert isinstance(template, AlpacaInstructTemplate)

# Test invalid template class
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template("InvalidTemplate")

# Test valid template strings
s = [
"Instruction: {instruction}\nInput: {input}",
"Instruction: {instruction}",
"{a}",
]
for t in s:
assert _get_template(t) == t

# Test invalid template strings
s = ["hello", "{}", "a}{b"]
for t in s:
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template(t)
39 changes: 39 additions & 0 deletions tests/torchtune/datasets/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest import mock

import pytest
from torchtune.data import AlpacaInstructTemplate
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._instruct import _get_template
from torchtune.datasets._utils import tokenize_prompt_and_response


class DummyTokenizer:
def encode(self, text, **kwargs):
words = text.split()
return [len(word) for word in words]


def test_tokenize_prompt_and_response():
tokenizer = DummyTokenizer()
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
response = "I always know what I'm doing, do you?"
prompt_length = 13
expected_tokenized_prompt = [
12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4,
]
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [1, 6, 4, 4, 3, 6, 2, 4]

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(tokenizer, prompt, response)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_label

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(tokenizer, prompt, response, train_on_input=True)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_prompt
33 changes: 33 additions & 0 deletions torchtune/config/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,36 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictC

# CLI takes precedence over yaml args
return OmegaConf.merge(yaml_conf, cli_conf)

def _get_template(template: str) -> PromptTemplate:
"""
Get the prompt template class from the template string.
String should either be the PromptTemplate class name directly, or a raw
string with 1 or more placeholders. If none of these apply, then raise an
error.
Args:
template (str): class name of template, or string with placeholders
Returns:
PromptTemplate: the prompt template class or the same verified string
Raises:
ValueError: if the template is not a PromptTemplate class or a proper
template string
"""
path = "torchtune.data." + template
try:
template_class = _get_component_from_path(path)
return template_class()
except InstantiationError:
# Verify that string can be used as a template, should have variable
# placeholders
pattern = r"\{.+?\}"
if not re.search(pattern, template):
raise ValueError(
f"Invalid template '{template}': "
+ "Must be a PromptTemplate class or a string with placeholders."
) from None
return template
84 changes: 83 additions & 1 deletion torchtune/data/_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class Llama2ChatTemplate(PromptTemplate):
You are a helpful, respectful and honest assistant.
<</SYS>>
I am going to Paris, what should I see? [/INST] "
I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture..."
"""

B_INST, E_INST = "[INST]", "[/INST]"
Expand Down Expand Up @@ -196,3 +196,85 @@ def format(
)
else:
return self.template["no_system"].format(user=sample["user"])

class MistralChatTemplate(PromptTemplate):
"""
Prompt template that formats according to Mistral's instruct model:
https://docs.mistral.ai/models/
It is identical to `Llama2ChatTemplate`, except it does not support system
prompts.
Example:
"[INST] I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture..."
"""

B_INST, E_INST = "[INST]", "[/INST]"
template = "{self.B_INST} {user} {self.E_INST} "

def format(
self, sample: Sample, column_map: Optional[Dict[str, str]] = None
) -> str:
"""
Generate prompt from a user message
Args:
sample (Sample): a single data sample, expects only "user" in the sample.
column_map (Optional[Dict[str, str]]): a mapping from the expected
role names in the template to the actual role names in the sample.
If None, assume these are "user".
Returns:
The formatted prompt
Raises:
ValueError: if the sample contains a "system" key
"""
if "system" in sample:
raise ValueError("System prompts are not supported in MistralChatTemplate")
return self.template.format(user=sample["user"])


class ChatMLTemplate(PromptTemplate):
"""
OpenAI's Chat Markup Language used by their chat models:
https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md
It is the default template used by HuggingFace models.
Example:
<|im_start|>system
Provide some context and/or instructions to the model.<|im_end|>
<|im_start|>user
The user’s message goes here<|im_end|>
<|im_start|>assistant
The assistant’s response goes here<|im_end|>
"""

IM_START, IM_END = "<|im_start|>", "<|im_end|>"
template = {
"system": "{self.IM_START}system\n{system}{self.IM_END}\n{self.IM_START}user\n{user}{self.IM_END}\n{self.IM_START}assistant",
"no_system": "{self.IM_START}user\n{user}{self.IM_END}\n{self.IM_START}assistant",
}

def format(
self, sample: Sample, column_map: Optional[Dict[str, str]] = None
) -> str:
"""
Generate prompt from a user message and optional system prompt.
Args:
sample (Sample): a single data sample, expects role keys "system" (optional)
and "user" in the sample.
column_map (Optional[Dict[str, str]]): a mapping from the expected
role names in the template to the actual role names in the sample.
If None, assume these are "system" and "user".
Returns:
The formatted prompt
"""
if "system" in sample:
return self.template["system"].format(
system=sample["system"], user=sample["user"]
)
else:
return self.template["no_system"].format(user=sample["user"])
1 change: 1 addition & 0 deletions torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
tokenize_prompt_and_response,
truncate_if_necessary,
)
from torchtune.config._utils import _get_template
from torchtune.modules import Tokenizer


Expand Down
3 changes: 2 additions & 1 deletion torchtune/datasets/_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from torchtune.data import PromptTemplate
from torchtune.datasets._types import Sample
from torchtune.datasets._utils import _get_template, tokenize_prompt_and_response
from torchtune.datasets._utils import tokenize_prompt_and_response
from torchtune.config._utils import _get_template
from torchtune.modules import Tokenizer


Expand Down
Loading

0 comments on commit 89f6c2b

Please sign in to comment.