diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index a1fd140028..741782687f 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -46,3 +46,4 @@ Class representations for the above dataset builders. InstructDataset ChatDataset + ConcatDataset diff --git a/docs/source/tutorials/datasets.rst b/docs/source/tutorials/datasets.rst index 9be7ddaa5e..dfa9a9a5f6 100644 --- a/docs/source/tutorials/datasets.rst +++ b/docs/source/tutorials/datasets.rst @@ -47,6 +47,25 @@ You could tweak :code:`max_seq_len` to achieve that directly from the config. # Original is 512 max_seq_len: 256 +It is also possible to train on multiple datasets by combining them into a single :class:`~torchtune.datasets.ConcatDataset`. For example: + +.. code-block:: yaml + + dataset: + - _component_: torchtune.datasets.instruct_dataset + source: vicgalle/alpaca-gpt4 + template: AlpacaInstructTemplate + split: train + train_on_input: True + - _component_: torchtune.datasets.instruct_dataset + source: samsum + template: SummarizeTemplate + column_map: {"output": "summary"} + split: train + train_on_input: False + +The preceding snippet demonstrates how you can configure each individual dataset's parameters, then combine them into a single concatenated dataset for training. + Customizing instruct templates ------------------------------ diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 096d66db3d..b1d371ba2c 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -12,7 +12,7 @@ from warnings import warn import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.distributed import init_process_group @@ -27,7 +27,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils - +from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.utils.activations import apply_selective_activation_checkpointing @@ -357,10 +357,16 @@ def _setup_data( iterable datasets and streaming datasets are not supported. """ world_size, rank = utils.get_world_size_and_rank() - ds = config.instantiate( - cfg_dataset, - tokenizer=self._tokenizer, - ) + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + sampler = DistributedSampler( ds, num_replicas=world_size, diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index e13902c0c9..e738e8647f 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -11,14 +11,14 @@ from warnings import warn import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils - +from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from tqdm import tqdm @@ -320,10 +320,15 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - ds = config.instantiate( - cfg_dataset, - tokenizer=self._tokenizer, - ) + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(config=cfg_dataset, tokenizer=self._tokenizer) + sampler = DistributedSampler( ds, num_replicas=1, diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 8c535a087a..30dab1615b 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -12,7 +12,7 @@ from warnings import warn import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.distributed import destroy_process_group, init_process_group @@ -26,6 +26,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets import ConcatDataset from torchtune.modules.peft.peft_utils import ( disable_adapter, get_adapter_params, @@ -412,7 +413,16 @@ def _setup_data( iterable datasets and streaming datasets are not supported. """ world_size, rank = utils.get_world_size_and_rank() - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 2283725a05..baa0053bc5 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -11,13 +11,14 @@ from warnings import warn import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets import ConcatDataset from torchtune.modules.peft.peft_utils import ( disable_adapter, get_adapter_params, @@ -297,10 +298,15 @@ def _setup_data( Map-style Datasets which fit into memory and an option for random shuffling. Samplers, iterable datasets, and streaming datasets are not supported. """ - ds = config.instantiate( - cfg_dataset, - tokenizer=self._tokenizer, - ) + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + sampler = DistributedSampler( ds, num_replicas=1, diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 16c08a4aa1..64605a697b 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -13,7 +13,7 @@ from warnings import warn import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.distributed import destroy_process_group, init_process_group @@ -26,6 +26,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset from torchtune.modules.peft.peft_utils import ( get_adapter_params, get_merged_lora_ckpt, @@ -409,7 +410,16 @@ def _setup_data( iterable datasets and streaming datasets are not supported. """ world_size, rank = utils.get_world_size_and_rank() - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index bef234e55f..e481cc167d 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -12,12 +12,13 @@ from warnings import warn import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset from torchtune.modules.peft.peft_utils import ( get_adapter_params, get_merged_lora_ckpt, @@ -337,10 +338,15 @@ def _setup_data( Map-style Datasets which fit into memory and an option for random shuffling. Samplers, iterable datasets, and streaming datasets are not supported. """ - ds = config.instantiate( - cfg_dataset, - tokenizer=self._tokenizer, - ) + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + sampler = DistributedSampler( ds, num_replicas=1, diff --git a/tests/torchtune/datasets/test_alpaca_dataset.py b/tests/torchtune/datasets/test_alpaca_dataset.py index 2a05cefd06..59bcec737f 100644 --- a/tests/torchtune/datasets/test_alpaca_dataset.py +++ b/tests/torchtune/datasets/test_alpaca_dataset.py @@ -7,6 +7,7 @@ from unittest.mock import patch import pytest +from datasets import Dataset from tests.test_utils import get_assets_path from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX @@ -28,17 +29,19 @@ def test_label_no_masking(self, load_dataset, tokenizer): """ # mock the call to HF datasets - load_dataset.return_value = [ - { - "instruction": "Give three tips for staying healthy.", - "input": "", - "output": ( - "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." - "2. Exercise regularly to keep your body active and strong." - "3. Get enough sleep and maintain a consistent sleep schedule." - ), - } - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": ( + "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." + "2. Exercise regularly to keep your body active and strong." + "3. Get enough sleep and maintain a consistent sleep schedule." + ), + } + ] + ) alpaca_ds = alpaca_dataset(tokenizer=tokenizer) input, labels = alpaca_ds[0] @@ -55,17 +58,19 @@ def test_label_masking(self, load_dataset, tokenizer): """ # mock the call to HF datasets - load_dataset.return_value = [ - { - "instruction": "Give three tips for staying healthy.", - "input": "", - "output": ( - "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." - "2. Exercise regularly to keep your body active and strong." - "3. Get enough sleep and maintain a consistent sleep schedule." - ), - } - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": ( + "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." + "2. Exercise regularly to keep your body active and strong." + "3. Get enough sleep and maintain a consistent sleep schedule." + ), + } + ] + ) alpaca_ds = alpaca_dataset(tokenizer=tokenizer, train_on_input=False) @@ -90,17 +95,19 @@ def test_alpaca_clean(self, load_dataset, tokenizer): """ # mock the call to HF datasets - load_dataset.return_value = [ - { - "instruction": "Give three tips for staying healthy.", - "input": "", - "output": ( - "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." - "2. Exercise regularly to keep your body active and strong." - "3. Get enough sleep and maintain a consistent sleep schedule." - ), - } - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": ( + "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables." + "2. Exercise regularly to keep your body active and strong." + "3. Get enough sleep and maintain a consistent sleep schedule." + ), + } + ] + ) alpaca_ds = alpaca_cleaned_dataset(tokenizer=tokenizer) input, labels = alpaca_ds[0] diff --git a/tests/torchtune/datasets/test_chat_dataset.py b/tests/torchtune/datasets/test_chat_dataset.py index 3397a8e244..2c24045c67 100644 --- a/tests/torchtune/datasets/test_chat_dataset.py +++ b/tests/torchtune/datasets/test_chat_dataset.py @@ -7,10 +7,10 @@ from unittest import mock import pytest +from datasets import Dataset from tests.test_utils import DummyTokenizer from torchtune.data import Message from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX - from torchtune.datasets import ChatDataset @@ -31,9 +31,11 @@ def format( formats = {"system": cls.system, "user": cls.user, "assistant": cls.assistant} formatted_dialogue = [] for message in messages: - content = formats.get(message.role).format(content=message.content) + content = formats.get(message["role"]).format(content=message["content"]) formatted_dialogue.append( - Message(role=message.role, content=content, masked=message.masked), + Message( + role=message["role"], content=content, masked=message["masked"] + ), ) return formatted_dialogue @@ -57,26 +59,30 @@ def dialogue(self): return [ { "dialogue": [ - Message( - role="system", content="You are an AI assistant.", masked=True - ), - Message( - role="user", content="What is the meaning of life?", masked=True - ), - Message( - role="assistant", - content="The meaning of life is 42.", - masked=False, - ), - Message(role="user", content="That's ridiculous.", masked=True), - Message(role="assistant", content="I agree.", masked=False), + { + "role": "system", + "content": "You are an AI assistant.", + "masked": True, + }, + { + "role": "user", + "content": "What is the meaning of life?", + "masked": True, + }, + { + "role": "assistant", + "content": "The meaning of life is 42.", + "masked": False, + }, + {"role": "user", "content": "That's ridiculous.", "masked": True}, + {"role": "assistant", "content": "I agree.", "masked": False}, ], }, ] @mock.patch("torchtune.datasets._chat.load_dataset") def test_get_item(self, mock_load_dataset, chat_format, dialogue): - mock_load_dataset.return_value = dialogue + mock_load_dataset.return_value = Dataset.from_list(dialogue) expected_tokenized_prompts = [ [ 0, @@ -114,15 +120,7 @@ def test_get_item(self, mock_load_dataset, chat_format, dialogue): prompt_lengths = (15, 5) expected_labels = [ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] - + [ - 3, - 7, - 2, - 4, - 2, - 3, - -1, - ] + + [3, 7, 2, 4, 2, 3, -1] + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] + [1, 6, -1] ] diff --git a/tests/torchtune/datasets/test_concat_dataset.py b/tests/torchtune/datasets/test_concat_dataset.py new file mode 100644 index 0000000000..32ecc3bc5a --- /dev/null +++ b/tests/torchtune/datasets/test_concat_dataset.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from datasets import Dataset +from torchtune.datasets._concat import ConcatDataset + + +class TestConcatDataset: + @pytest.fixture + def datasets(self): + ds1 = Dataset.from_list([{"data": f"ds1_{i}"} for i in range(4)]) + ds2 = Dataset.from_list([{"data": f"ds2_{i}"} for i in range(8)]) + ds3 = Dataset.from_list([{"data": f"ds3_{i}"} for i in range(15)]) + ds4 = Dataset.from_list([{"data": f"ds4_{i}"} for i in range(16)]) + ds5 = Dataset.from_list([{"data": f"ds5_{i}"} for i in range(23)]) + ds6 = Dataset.from_list([{"data": f"ds6_{i}"} for i in range(42)]) + return [ds1, ds2, ds3, ds4, ds5, ds6] + + def test_length(self, datasets): + """Test the correct computation of total length""" + multi_dataset = ConcatDataset(datasets) + + # sum of individual datasets lengths + expected_length = 4 + 8 + 15 + 16 + 23 + 42 # 108 + assert len(multi_dataset) == expected_length + + def test_getitem(self, datasets): + """Test item retrieval across dataset boundaries""" + multi_dataset = ConcatDataset(datasets) + + # Testing indices across different datasets + assert multi_dataset[-1] is None # Index out of range + assert multi_dataset[0] == {"data": "ds1_0"} + assert multi_dataset[3] == {"data": "ds1_3"} + assert multi_dataset[4] == {"data": "ds2_0"} + assert multi_dataset[10] == {"data": "ds2_6"} + assert multi_dataset[20] == {"data": "ds3_8"} + assert multi_dataset[35] == {"data": "ds4_8"} + assert multi_dataset[50] == {"data": "ds5_7"} + assert multi_dataset[70] == {"data": "ds6_4"} + assert multi_dataset[90] == {"data": "ds6_24"} + assert multi_dataset[108] is None # Index out of range + + def test_invalid_index_type(self, datasets): + """Test handling of invalid index types""" + multi_dataset = ConcatDataset(datasets) + + with pytest.raises(TypeError): + multi_dataset["invalid_type"] # Non-integer index diff --git a/tests/torchtune/datasets/test_grammar_dataset.py b/tests/torchtune/datasets/test_grammar_dataset.py index 20c209f004..11957e4cb6 100644 --- a/tests/torchtune/datasets/test_grammar_dataset.py +++ b/tests/torchtune/datasets/test_grammar_dataset.py @@ -7,6 +7,7 @@ from unittest.mock import patch import pytest +from datasets import Dataset from tests.test_utils import get_assets_path from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX @@ -29,12 +30,14 @@ def test_label_no_masking(self, load_dataset, tokenizer): """ # mock the call to HF datasets - load_dataset.return_value = [ - { - "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", - "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", - } - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", + "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", + } + ] + ) grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=True) input, labels = grammar_ds[0] @@ -51,12 +54,14 @@ def test_label_masking(self, load_dataset, tokenizer): """ # mock the call to HF datasets - load_dataset.return_value = [ - { - "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", - "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", - } - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", + "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", + } + ] + ) grammar_ds = grammar_dataset(tokenizer=tokenizer) diff --git a/tests/torchtune/datasets/test_instruct_dataset.py b/tests/torchtune/datasets/test_instruct_dataset.py index b31e8a2aec..2fcbb0021f 100644 --- a/tests/torchtune/datasets/test_instruct_dataset.py +++ b/tests/torchtune/datasets/test_instruct_dataset.py @@ -6,10 +6,9 @@ from unittest import mock +from datasets import Dataset from tests.test_utils import DummyTokenizer - from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX - from torchtune.datasets import InstructDataset @@ -60,6 +59,34 @@ class TestInstructDataset: -1, ], [0, 12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4, -1], + [ + 0, + 12, + 4, + 2, + 3, + 2, + 12, + 10, + 6, + 4, + 2, + 3, + 2, + 6, + 10, + 9, + 1, + 5, + 4, + 4, + 3, + 6, + 2, + 4, + -1, + ], + [0, 12, 4, 2, 2, 12, 10, 6, 4, 2, 2, 6, 10, 9, 1, 6, 4, 4, 3, 6, 2, 4, -1], ] def get_samples(self): @@ -78,7 +105,7 @@ def get_samples(self): @mock.patch("torchtune.datasets._instruct.load_dataset") def test_get_item_no_train_on_input(self, mock_load_dataset): - mock_load_dataset.return_value = self.get_samples() + mock_load_dataset.return_value = Dataset.from_list(self.get_samples()) prompt_lengths = (16, 14) expected_labels = [ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] @@ -104,7 +131,7 @@ def test_get_item_no_train_on_input(self, mock_load_dataset): @mock.patch("torchtune.datasets._instruct.load_dataset") def test_get_item_train_on_input(self, mock_load_dataset): - mock_load_dataset.return_value = self.get_samples() + mock_load_dataset.return_value = Dataset.from_list(self.get_samples()) expected_labels = self.expected_tokenized_prompts dataset = InstructDataset( diff --git a/tests/torchtune/datasets/test_samsum_dataset.py b/tests/torchtune/datasets/test_samsum_dataset.py index 6ec6a52679..cae43e04d1 100644 --- a/tests/torchtune/datasets/test_samsum_dataset.py +++ b/tests/torchtune/datasets/test_samsum_dataset.py @@ -7,6 +7,7 @@ from unittest.mock import patch import pytest +from datasets import Dataset from tests.test_utils import get_assets_path from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX @@ -29,13 +30,15 @@ def test_label_no_masking(self, load_dataset, tokenizer): """ # mock the call to HF datasets - load_dataset.return_value = [ - { - "id": "13818513", - "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", - "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", - }, - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "id": "13818513", + "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", + }, + ] + ) samsum_ds = samsum_dataset(tokenizer=tokenizer, train_on_input=True) input, labels = samsum_ds[0] @@ -52,13 +55,15 @@ def test_label_masking(self, load_dataset, tokenizer): """ # mock the call to HF datasets - load_dataset.return_value = [ - { - "id": "13818513", - "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", - "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", - }, - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "id": "13818513", + "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", + }, + ] + ) samsum_ds = samsum_dataset(tokenizer=tokenizer) diff --git a/tests/torchtune/datasets/test_slimorca_dataset.py b/tests/torchtune/datasets/test_slimorca_dataset.py index 03a8396271..4f3f46e718 100644 --- a/tests/torchtune/datasets/test_slimorca_dataset.py +++ b/tests/torchtune/datasets/test_slimorca_dataset.py @@ -6,6 +6,7 @@ from unittest.mock import patch import pytest +from datasets import Dataset from tests.test_utils import get_assets_path @@ -30,24 +31,26 @@ def test_value_error(self, load_dataset, tokenizer): @pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096]) def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len): # Sample data from slimorca dataset - load_dataset.return_value = [ - { - "conversations": [ - { - "from": "system", - "value": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950 - }, - { - "from": "human", - "value": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950 - }, - { - "from": "gpt", - "value": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950 - }, - ] - } - ] + load_dataset.return_value = Dataset.from_list( + [ + { + "conversations": [ + { + "from": "system", + "value": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950 + }, + { + "from": "human", + "value": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950 + }, + { + "from": "gpt", + "value": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950 + }, + ] + } + ] + ) ds = slimorca_dataset( tokenizer=tokenizer, max_seq_len=max_seq_len, diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index 23a939890c..aee09de0ff 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -6,6 +6,7 @@ from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset from torchtune.datasets._chat import chat_dataset, ChatDataset +from torchtune.datasets._concat import ConcatDataset from torchtune.datasets._grammar import grammar_dataset from torchtune.datasets._instruct import instruct_dataset, InstructDataset from torchtune.datasets._samsum import samsum_dataset @@ -23,4 +24,5 @@ "ChatDataset", "instruct_dataset", "chat_dataset", + "ConcatDataset", ] diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py new file mode 100644 index 0000000000..c650e54a9a --- /dev/null +++ b/torchtune/datasets/_concat.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +from torch.utils.data import Dataset + +from torchtune import utils + +log = utils.get_logger("DEBUG") + + +class ConcatDataset(Dataset): + """ + A dataset class for concatenating multiple sub-datasets into a single dataset. This class enables the + unified handling of different datasets as if they were a single dataset, simplifying tasks such as + training models on multiple sources of data simultaneously. + + The class internally manages the aggregation of different datasets and allows transparent indexing across them. + However, it requires all constituent datasets to be fully loaded into memory, which might not be optimal for + very large datasets. + + Upon initialization, this class computes the cumulative length of all datasets and maintains an internal mapping + of indices to the respective datasets. This approach allows the `ConcatDataset` to delegate data retrieval to + the appropriate sub-dataset transparently when a particular index is accessed. + + Note: + Using this class with very large datasets can lead to high memory consumption, as it requires all datasets to + be loaded into memory. For large-scale scenarios, consider other strategies that might stream data on demand. + + Args: + datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class + derived from `torch.utils.data.Dataset`. + + Attributes: + _datasets (List[Dataset]): Stores the list of datasets passed during initialization. + _len (int): The total combined length of all datasets. + _indexes (List[Tuple[int, int, int]]): A list of tuples where each tuple contains the starting index, the + ending index, and the dataset index for quick lookup and access during indexing operations. + + Example: + >>> dataset1 = MyCustomDataset(params1) + >>> dataset2 = MyCustomDataset(params2) + >>> concat_dataset = ConcatDataset([dataset1, dataset2]) + >>> print(len(concat_dataset)) # Total length of both datasets + >>> data_point = concat_dataset[1500] # Accesses an element from the appropriate dataset + + This class primarily focuses on providing a unified interface to access elements from multiple datasets, + enhancing the flexibility in handling diverse data sources for training machine learning models. + """ + + def __init__(self, datasets: List[Dataset]): + self._datasets = datasets + self._len = sum(len(dataset) for dataset in datasets) + self._indexes = [] + + # Calculate distribution of indexes in all datasets + cumulative_index = 0 + for idx, dataset in enumerate(datasets): + next_cumulative_index = cumulative_index + len(dataset) + self._indexes.append((cumulative_index, next_cumulative_index, idx)) + cumulative_index = next_cumulative_index + + log.debug(f"Datasets summary length: {self._len}") + log.debug(f"Datasets indexes: {self._indexes}") + + def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: + for start, stop, dataset_index in self._indexes: + if start <= index < stop: + dataset = self._datasets[dataset_index] + return dataset[index - start] + + def __len__(self) -> int: + return self._len diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index da6cb53d74..cb3e52d747 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -46,6 +46,7 @@ register_optim_in_bwd_hooks, set_activation_checkpointing, ) + from .precision import ( get_dtype, list_dtypes,