Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/n] Merged fine-tuning dataset: grammar + samsum #1234

Merged
merged 22 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and models.
GrammarErrorCorrectionTemplate
SummarizeTemplate
StackExchangedPairedTemplate
PromptTemplate
CustomPromptTemplate

ChatFormat
ChatMLFormat
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ Class representations for the above dataset builders.
ConcatDataset
PackedDataset
PreferenceDataset
FinetuneDataset
4 changes: 2 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,13 @@ def _setup_data(

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

sampler = DistributedSampler(
Expand Down
4 changes: 2 additions & 2 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ def _setup_data(
"""
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

sampler = DistributedSampler(
Expand Down
4 changes: 2 additions & 2 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,13 @@ def _setup_data(

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

sampler = DistributedSampler(
Expand Down
4 changes: 2 additions & 2 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,13 @@ def _setup_data(
"""
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

sampler = DistributedSampler(
Expand Down
4 changes: 2 additions & 2 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,13 @@ def _setup_data(

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

sampler = DistributedSampler(
Expand Down
2 changes: 1 addition & 1 deletion tests/recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class DummyDataset(Dataset):
def __init__(self, **kwargs):
def __init__(self, *args, **kwargs):
self._data = torch.LongTensor(
[
[0, 2, 4, 2, 5, 6, 7, 8, 9, 1, 2, 4, 3, 3, 5, 6, 8, 2, 1, 1],
Expand Down
46 changes: 36 additions & 10 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
import sys
import unittest
from contextlib import contextmanager
from functools import partial
from io import StringIO
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, TextIO, Tuple, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union

import pytest

import torch
from torch import nn
from torchtune.data import ChatFormat, Message, truncate
from torchtune.data import ChatFormat, CustomPromptTemplate, Message, truncate
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform

skip_if_cuda_not_available = unittest.skipIf(
not torch.cuda.is_available(), "CUDA is not available"
Expand All @@ -39,7 +41,10 @@
}


class DummyTokenizer(ModelTokenizer):
class DummyTokenizer(ModelTokenizer, Transform):
def __init__(self, max_seq_len: Optional[int] = None):
self.max_seq_len = max_seq_len

def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]:
words = text.split()
tokens = [len(word) for word in words]
Expand All @@ -50,7 +55,8 @@ def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]:
return tokens

def tokenize_messages(
self, messages: List[Message], max_seq_len: Optional[int] = None
self,
messages: List[Message],
) -> Tuple[List[int], List[bool]]:
"""
A simplified version of Llama2Tokenizer's ``tokenize_messages`` for testing purposes.
Expand All @@ -69,15 +75,16 @@ def tokenize_messages(
mask.append(message.masked)

# Tokenize current message, append with masks
tokens = []
for item in message.content:
if item["type"] == "text":
tokens = self.encode(
tokens = tokens + self.encode(
item["content"],
add_bos=False,
add_eos=False,
)
elif item["type"] == "image":
tokens = [self.image_id]
tokens = tokens + [self.image_id]

tokenized_messages.extend(tokens)
mask.extend([message.masked] * len(tokens))
Expand All @@ -92,16 +99,25 @@ def tokenize_messages(
start_of_turn = False

# Break out early if we reach max_seq_len
if max_seq_len and len(tokenized_messages) >= max_seq_len:
if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len:
break

# Finally, truncate if necessary
if max_seq_len:
tokenized_messages = truncate(tokenized_messages, max_seq_len, self.eos_id)
mask = truncate(mask, max_seq_len, message.masked)
if self.max_seq_len:
tokenized_messages = truncate(
tokenized_messages, self.max_seq_len, self.eos_id
)
mask = truncate(mask, self.max_seq_len, message.masked)

return tokenized_messages, mask

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
messages = sample.pop("messages")
tokens, mask = self.tokenize_messages(messages)
sample["tokens"] = tokens
sample["mask"] = mask
return sample

@property
def eos_id(self):
return -1
Expand Down Expand Up @@ -141,6 +157,16 @@ def format(
return formatted_dialogue


DummyPromptTemplate = partial(
CustomPromptTemplate,
template={
"system": ("System:\n", "\n"),
"user": ("User:\n", "\n"),
"assistant": ("Assistant:\n", "\n"),
},
)


def get_assets_path():
return Path(__file__).parent / "assets"

Expand Down
2 changes: 1 addition & 1 deletion tests/torchtune/data/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tests.test_utils import assert_dialogue_equal
from torchtune.data import get_openai_messages, get_sharegpt_messages
from torchtune.data._types import Message
from torchtune.data._messages import Message

# Taken from Open-Orca/SlimOrca-Dedup on Hugging Face:
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
Expand Down
94 changes: 1 addition & 93 deletions tests/torchtune/data/test_instruct_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.data import (
AlpacaInstructTemplate,
GrammarErrorCorrectionTemplate,
SummarizeTemplate,
)

# Taken from Open-Orca/SlimOrca-Dedup on Hugging Face:
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
CHAT_SAMPLE = {
"system": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950
"user": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950
"assistant": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950
}
from torchtune.data import AlpacaInstructTemplate


class TestAlpacaInstructTemplate:
Expand Down Expand Up @@ -72,83 +60,3 @@ def test_format_with_column_map(self):
actual = self.template.format(modified_sample, column_map=column_map)

assert actual == expected_prompt


class TestGrammarErrorCorrectionTemplate:
samples = [
{
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.",
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.",
},
{
"input": "Much many brands and sellers still in the market.",
"output": "Many brands and sellers still in the market.",
},
]
expected_prompts = [
(
"Correct this to standard English: Bitcoin is for $7,094 this morning, which CoinDesk says.\n"
"---\n"
"Corrected: "
),
(
"Correct this to standard English: Much many brands and sellers still in the market.\n"
"---\n"
"Corrected: "
),
]

template = GrammarErrorCorrectionTemplate()

def test_format(self):
for sample, expected_prompt in zip(self.samples, self.expected_prompts):
column_map = {"sentence": "input"}
actual = self.template.format(sample, column_map=column_map)
assert actual == expected_prompt


class TestSummarizeTemplate:
samples = [
{
"id": "13818513",
"dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)",
"summary": "Amanda baked cookies and will bring Jerry some tomorrow.",
},
{
"id": "13728867",
"dialogue": "Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great", # noqa: B950
"summary": "Olivia and Olivier are voting for liberals in this election.",
},
]
expected_prompts = [
(
"Summarize this dialogue:\n"
"Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)\n"
"---\n"
"Summary:\n"
),
(
"Summarize this dialogue:\n"
"Olivia: Who are you voting for in this election? Oliver: Liberals as always. Olivia: Me too!! Oliver: Great\n"
"---\n"
"Summary:\n"
),
]

template = SummarizeTemplate()

def test_format(self):
for sample, expected_prompt in zip(self.samples, self.expected_prompts):
actual = self.template.format(sample)
assert actual == expected_prompt

def test_format_with_column_map(self):
column_map = {"dialogue": "not_a_dialogue"}
for sample, expected_prompt in zip(self.samples, self.expected_prompts):
modified_sample = sample.copy()
modified_sample["not_a_dialogue"] = modified_sample["dialogue"]
del modified_sample["dialogue"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == expected_prompt
Loading
Loading