Skip to content

Commit

Permalink
all unit tests, big update
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Mar 25, 2024
1 parent 89f6c2b commit b33c3c9
Show file tree
Hide file tree
Showing 24 changed files with 497 additions and 330 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ torchtune.datasets
alpaca_dataset
grammar_dataset
samsum_dataset
SlimOrcaDataset
slimorca_dataset
4 changes: 2 additions & 2 deletions docs/source/examples/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,5 @@ name directly. Any nested fields in the components can be overridden with dot no
.. code-block:: bash
# Change to SlimOrcaDataset and set train_on_input to False
tune full_finetune --config my_config.yaml dataset=torchtune.datasets.SlimOrcaDataset dataset.train_on_input=False
# Change to slimorca_dataset and set train_on_input to False
tune full_finetune --config my_config.yaml dataset=torchtune.datasets.slimorca_dataset dataset.train_on_input=False
3 changes: 2 additions & 1 deletion tests/torchtune/config/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import pytest
from torchtune.config._utils import (
_get_component_from_path,
_get_template,
_merge_yaml_and_cli_args,
InstantiationError,
_get_template,
)
from torchtune.data import AlpacaInstructTemplate
from torchtune.utils.argparse import TuneArgumentParser

_CONFIG = {
Expand Down
106 changes: 106 additions & 0 deletions tests/torchtune/data/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from torchtune.data import (
AlpacaInstructTemplate,
ChatMLTemplate,
GrammarErrorCorrectionTemplate,
Llama2ChatTemplate,
MistralChatTemplate,
SummarizeTemplate,
)

# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace
CHAT_SAMPLE = {
"system": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950
"user": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950
"assistant": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950
}


class TestAlpacaInstructTemplate:
samples = [
Expand Down Expand Up @@ -144,3 +155,98 @@ def test_format_with_column_map(self):
actual = self.template.format(modified_sample, column_map=column_map)

assert actual == expected_prompt


class TestLlama2ChatTemplate:
expected_prompt = (
"[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.\n<</SYS>>\n\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] "
)

template = Llama2ChatTemplate()

def test_format(self):
actual = self.template.format(CHAT_SAMPLE)
assert actual == self.expected_prompt

def test_format_with_column_map(self):
column_map = {"system": "not_system"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_system"] = modified_sample["system"]
del modified_sample["system"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt


class TestMistralChatTemplate:
expected_prompt = (
"[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] "
)

template = MistralChatTemplate()

def test_format(self):
no_system_sample = CHAT_SAMPLE.copy()
del no_system_sample["system"]
actual = self.template.format(no_system_sample)
assert actual == self.expected_prompt

with pytest.raises(
ValueError, match="System prompts are not supported in MistralChatTemplate"
):
_ = self.template.format(CHAT_SAMPLE)

def test_format_with_column_map(self):
column_map = {"user": "not_user"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_user"] = modified_sample["user"]
del modified_sample["system"]
del modified_sample["user"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt


class TestChatMLTemplate:
expected_prompt = (
"<|im_start|>system\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.<|im_end|>\n<|im_start|>user\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary:<|im_end|>\n<|im_start|>assistant\n"
)

template = ChatMLTemplate()

def test_format(self):
actual = self.template.format(CHAT_SAMPLE)
assert actual == self.expected_prompt

def test_format_with_column_map(self):
column_map = {"system": "not_system"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_system"] = modified_sample["system"]
del modified_sample["system"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt
95 changes: 95 additions & 0 deletions tests/torchtune/data/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.data import tokenize_prompt_and_response, truncate_if_necessary
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX


class DummyTokenizer:
def encode(self, text, **kwargs):
words = text.split()
return [len(word) for word in words]

@property
def eos_id(self):
return -1


def test_tokenize_prompt_and_response():
tokenizer = DummyTokenizer()
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
response = "I always know what I'm doing, do you?"
prompt_length = 11
expected_tokenized_prompt = [
12,
4,
2,
2,
12,
6,
4,
2,
2,
6,
9,
1,
6,
4,
4,
3,
6,
2,
4,
]
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [
1,
6,
4,
4,
3,
6,
2,
4,
]

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_label

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response, train_on_input=True
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_prompt


def test_truncate_if_necessary():
prompt_tokens = [1, 2, 3, 4, -1]
label_tokens = [1, 2, 3, 4, -1]
max_seq_len = 5

# Test no truncation
truncated_prompt_tokens, truncated_label_tokens = truncate_if_necessary(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=max_seq_len,
)
assert truncated_prompt_tokens == [1, 2, 3, 4, -1]
assert truncated_label_tokens == [1, 2, 3, 4, -1]

# Test truncated
max_seq_len = 4
truncated_prompt_tokens, truncated_label_tokens = truncate_if_necessary(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=max_seq_len,
)
assert truncated_prompt_tokens == [1, 2, 3, -1]
assert truncated_label_tokens == [1, 2, 3, -1]
4 changes: 2 additions & 2 deletions tests/torchtune/datasets/test_alpaca_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import pytest

from tests.test_utils import get_assets_path
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._alpaca import alpaca_dataset
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import alpaca_dataset
from torchtune.modules.tokenizer import Tokenizer


Expand Down
Loading

0 comments on commit b33c3c9

Please sign in to comment.