Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add support of multiple datasets in config #889

Merged
merged 25 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c46dd6c
Support of list of datasets added to chat and instruct dataset constr…
EvilFreelancer Apr 27, 2024
83eff03
Tests added, docstrings updated
EvilFreelancer Apr 27, 2024
c7fcdab
All other tests fixed to a new format
EvilFreelancer Apr 27, 2024
da9df43
MultiDataset class added to utils
EvilFreelancer Apr 28, 2024
d6913c9
Multi-source logic removed from chat and instuct dataset builders
EvilFreelancer Apr 28, 2024
46ea2bc
MultiDataset logic added to some recipes
EvilFreelancer Apr 28, 2024
90a6294
isinstance(cfg_dataset, ListConfig) instead of check for DictConfig
EvilFreelancer Apr 28, 2024
abed501
MultiDataset refactoring
EvilFreelancer Apr 28, 2024
09078c4
All recipes switched to a new format of MultiConfig
EvilFreelancer Apr 28, 2024
071456d
Update torchtune/utils/multi_dataset.py
EvilFreelancer Apr 28, 2024
c4af9ea
len of MultiDataset will be calculated in constructor
EvilFreelancer Apr 28, 2024
c1307bf
Extra check removed from _getitem_ fo MultiDataset
EvilFreelancer Apr 28, 2024
78d871d
Datasets cumulative indexes calculation moved to contructor
EvilFreelancer Apr 29, 2024
ccd99cf
Tests of MultiDataset until class added
EvilFreelancer Apr 29, 2024
61c6d22
MultiDataset constructor type fixed
EvilFreelancer Apr 30, 2024
1ba70ac
noqa comment removed from MultiDataset
EvilFreelancer Apr 30, 2024
c0746d5
MultiDataset class moved to datasets namespace
EvilFreelancer Apr 30, 2024
913b352
Path to MultiDataset class fixed
EvilFreelancer Apr 30, 2024
c44badf
Merge branch 'pytorch:main' into feat-concatenate-datasets
EvilFreelancer Apr 30, 2024
26b2997
New dataset format added to tutorials
EvilFreelancer Apr 30, 2024
4faa4d4
MultiDataset renamed to ConcatDataset
EvilFreelancer May 3, 2024
0d11d88
A comprehensive docstring aboutr ConcatDataset added
EvilFreelancer May 3, 2024
90731a3
Update docs/source/tutorials/datasets.rst
EvilFreelancer May 3, 2024
854eb05
ConcatDataset added to api_ref_dataset.rst
EvilFreelancer May 3, 2024
f7a3f95
Extra EOL added to datasets.rst
EvilFreelancer May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,12 @@ def _setup_data(
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)

if isinstance(cfg_dataset.get(0), DictConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also just directly check if it's a ListConfig. If it's a single dataset then this might fail

Suggested change
if isinstance(cfg_dataset.get(0), DictConfig):
if isinstance(cfg_dataset, ListConfig):

I also wonder if there's a better way to handle this so we don't have to repeat this if-else check across all recipes...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, i've replaced it with ListConfig check.

Copy link
Contributor Author

@EvilFreelancer EvilFreelancer Apr 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if there's a better way to handle this so we don't have to repeat this if-else check across all recipes...

Yeah, this logic can be moved (for example) to config.instantiate method, but I guess it will break single responsibility principle. So I suggest leaving it as is.

ds = utils.MultiDataset(datasets=cfg_dataset, tokenizer=self._tokenizer)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=world_size,
Expand Down
9 changes: 5 additions & 4 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,11 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)
if isinstance(cfg_dataset.get(0), DictConfig):
ds = utils.MultiDataset(datasets=cfg_dataset, tokenizer=self._tokenizer)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=1,
Expand Down
7 changes: 6 additions & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,12 @@ def _setup_data(
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

if isinstance(cfg_dataset.get(0), DictConfig):
ds = utils.MultiDataset(datasets=cfg_dataset, tokenizer=self._tokenizer)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
Expand Down
9 changes: 5 additions & 4 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,11 @@ def _setup_data(
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
"""
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)
if isinstance(cfg_dataset.get(0), DictConfig):
ds = utils.MultiDataset(datasets=cfg_dataset, tokenizer=self._tokenizer)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=1,
Expand Down
7 changes: 6 additions & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,12 @@ def _setup_data(
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

if isinstance(cfg_dataset.get(0), DictConfig):
ds = utils.MultiDataset(datasets=cfg_dataset, tokenizer=self._tokenizer)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
Expand Down
9 changes: 5 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,11 @@ def _setup_data(
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
"""
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)
if isinstance(cfg_dataset.get(0), DictConfig):
ds = utils.MultiDataset(datasets=cfg_dataset, tokenizer=self._tokenizer)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=1,
Expand Down
73 changes: 40 additions & 33 deletions tests/torchtune/datasets/test_alpaca_dataset.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to make these changes to the test samples to use Dataset across all the individual dataset test files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to keep the datasets.Dataset objects, since this is a common format and anyone can see that the data should come in this format, plus there is no need to perform any extra transformations, as was the case, for example in _chat tests.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest.mock import patch

import pytest
from datasets import Dataset

from tests.test_utils import get_assets_path
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
Expand All @@ -28,17 +29,19 @@ def test_label_no_masking(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
load_dataset.return_value = Dataset.from_list(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the story with these Dataset.from_list changes? I know it will give us the right return type (Dataset instead of raw List), anything besides that motivating the change? (I am fine with keeping them in, mainly asking out of curiosity)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it's a more accurate return type of load dataset, I think it's ok to just leave these out and stick to primitives for simplicity, but I don't have a strong opinion here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've invested considerable time into understanding how to test my new dataset class. It was initially unclear regarding the required format and content of its elements. Hence, I suggest providing clear guidelines to save fellow programmers time, elucidating the expected format for dataset elements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is very good feedback. I agree that the contracts of various dataset components are not always obvious and take some time to sort through. Aside from improving live docs and better code comments, I'm open to any suggestions you have on how to make this clearer based on your experience.

[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
)

alpaca_ds = alpaca_dataset(tokenizer=tokenizer)
input, labels = alpaca_ds[0]
Expand All @@ -55,17 +58,19 @@ def test_label_masking(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
load_dataset.return_value = Dataset.from_list(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
)

alpaca_ds = alpaca_dataset(tokenizer=tokenizer, train_on_input=False)

Expand All @@ -90,17 +95,19 @@ def test_alpaca_clean(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
load_dataset.return_value = Dataset.from_list(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
)

alpaca_ds = alpaca_cleaned_dataset(tokenizer=tokenizer)
input, labels = alpaca_ds[0]
Expand Down
50 changes: 24 additions & 26 deletions tests/torchtune/datasets/test_chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from unittest import mock

import pytest
from datasets import Dataset
from tests.test_utils import DummyTokenizer
from torchtune.data import Message
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets import ChatDataset


Expand All @@ -31,9 +31,11 @@ def format(
formats = {"system": cls.system, "user": cls.user, "assistant": cls.assistant}
formatted_dialogue = []
for message in messages:
content = formats.get(message.role).format(content=message.content)
content = formats.get(message["role"]).format(content=message["content"])
formatted_dialogue.append(
Message(role=message.role, content=content, masked=message.masked),
Message(
role=message["role"], content=content, masked=message["masked"]
),
)
return formatted_dialogue

Expand All @@ -57,26 +59,30 @@ def dialogue(self):
return [
{
"dialogue": [
Message(
role="system", content="You are an AI assistant.", masked=True
),
Message(
role="user", content="What is the meaning of life?", masked=True
),
Message(
role="assistant",
content="The meaning of life is 42.",
masked=False,
),
Message(role="user", content="That's ridiculous.", masked=True),
Message(role="assistant", content="I agree.", masked=False),
{
"role": "system",
"content": "You are an AI assistant.",
"masked": True,
},
{
"role": "user",
"content": "What is the meaning of life?",
"masked": True,
},
{
"role": "assistant",
"content": "The meaning of life is 42.",
"masked": False,
},
{"role": "user", "content": "That's ridiculous.", "masked": True},
{"role": "assistant", "content": "I agree.", "masked": False},
],
},
]

@mock.patch("torchtune.datasets._chat.load_dataset")
def test_get_item(self, mock_load_dataset, chat_format, dialogue):
mock_load_dataset.return_value = dialogue
mock_load_dataset.return_value = Dataset.from_list(dialogue)
expected_tokenized_prompts = [
[
0,
Expand Down Expand Up @@ -114,15 +120,7 @@ def test_get_item(self, mock_load_dataset, chat_format, dialogue):
prompt_lengths = (15, 5)
expected_labels = [
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0]
+ [
3,
7,
2,
4,
2,
3,
-1,
]
+ [3, 7, 2, 4, 2, 3, -1]
+ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
+ [1, 6, -1]
]
Expand Down
29 changes: 17 additions & 12 deletions tests/torchtune/datasets/test_grammar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest.mock import patch

import pytest
from datasets import Dataset

from tests.test_utils import get_assets_path
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
Expand All @@ -29,12 +30,14 @@ def test_label_no_masking(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.",
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.",
}
]
load_dataset.return_value = Dataset.from_list(
[
{
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.",
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.",
}
]
)

grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=True)
input, labels = grammar_ds[0]
Expand All @@ -51,12 +54,14 @@ def test_label_masking(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.",
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.",
}
]
load_dataset.return_value = Dataset.from_list(
[
{
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.",
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.",
}
]
)

grammar_ds = grammar_dataset(tokenizer=tokenizer)

Expand Down
35 changes: 31 additions & 4 deletions tests/torchtune/datasets/test_instruct_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

from unittest import mock

from datasets import Dataset
from tests.test_utils import DummyTokenizer

from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets import InstructDataset


Expand Down Expand Up @@ -60,6 +59,34 @@ class TestInstructDataset:
-1,
],
[0, 12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4, -1],
[
0,
12,
4,
2,
3,
2,
12,
10,
6,
4,
2,
3,
2,
6,
10,
9,
1,
5,
4,
4,
3,
6,
2,
4,
-1,
],
[0, 12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4, -1],
]

def get_samples(self):
Expand All @@ -78,7 +105,7 @@ def get_samples(self):

@mock.patch("torchtune.datasets._instruct.load_dataset")
def test_get_item_no_train_on_input(self, mock_load_dataset):
mock_load_dataset.return_value = self.get_samples()
mock_load_dataset.return_value = Dataset.from_list(self.get_samples())
prompt_lengths = (16, 14)
expected_labels = [
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0]
Expand All @@ -104,7 +131,7 @@ def test_get_item_no_train_on_input(self, mock_load_dataset):

@mock.patch("torchtune.datasets._instruct.load_dataset")
def test_get_item_train_on_input(self, mock_load_dataset):
mock_load_dataset.return_value = self.get_samples()
mock_load_dataset.return_value = Dataset.from_list(self.get_samples())
expected_labels = self.expected_tokenized_prompts

dataset = InstructDataset(
Expand Down
Loading