Skip to content

Commit

Permalink
Feat: Add support of multiple datasets in config (#889)
Browse files Browse the repository at this point in the history
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
Co-authored-by: ebsmothers <ebs@meta.com>
  • Loading branch information
3 people authored May 3, 2024
1 parent 9274c89 commit d36e818
Show file tree
Hide file tree
Showing 18 changed files with 374 additions and 133 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ Class representations for the above dataset builders.

InstructDataset
ChatDataset
ConcatDataset
19 changes: 19 additions & 0 deletions docs/source/tutorials/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ You could tweak :code:`max_seq_len` to achieve that directly from the config.
# Original is 512
max_seq_len: 256
It is also possible to train on multiple datasets by combining them into a single :class:`~torchtune.datasets.ConcatDataset`. For example:

.. code-block:: yaml
dataset:
- _component_: torchtune.datasets.instruct_dataset
source: vicgalle/alpaca-gpt4
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.instruct_dataset
source: samsum
template: SummarizeTemplate
column_map: {"output": "summary"}
split: train
train_on_input: False
The preceding snippet demonstrates how you can configure each individual dataset's parameters, then combine them into a single concatenated dataset for training.

Customizing instruct templates
------------------------------

Expand Down
18 changes: 12 additions & 6 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from warnings import warn

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import init_process_group
Expand All @@ -27,7 +27,7 @@
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, utils

from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.utils.activations import apply_selective_activation_checkpointing

Expand Down Expand Up @@ -357,10 +357,16 @@ def _setup_data(
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=world_size,
Expand Down
17 changes: 11 additions & 6 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from warnings import warn

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, utils

from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface

from tqdm import tqdm
Expand Down Expand Up @@ -320,10 +320,15 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
else:
ds = config.instantiate(config=cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=1,
Expand Down
14 changes: 12 additions & 2 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from warnings import warn

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
Expand All @@ -26,6 +26,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
disable_adapter,
get_adapter_params,
Expand Down Expand Up @@ -412,7 +413,16 @@ def _setup_data(
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
Expand Down
16 changes: 11 additions & 5 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from warnings import warn

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
disable_adapter,
get_adapter_params,
Expand Down Expand Up @@ -297,10 +298,15 @@ def _setup_data(
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
"""
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=1,
Expand Down
14 changes: 12 additions & 2 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from warnings import warn

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
Expand All @@ -26,6 +26,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_merged_lora_ckpt,
Expand Down Expand Up @@ -409,7 +410,16 @@ def _setup_data(
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
Expand Down
16 changes: 11 additions & 5 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from warnings import warn

import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_merged_lora_ckpt,
Expand Down Expand Up @@ -337,10 +338,15 @@ def _setup_data(
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
"""
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

sampler = DistributedSampler(
ds,
num_replicas=1,
Expand Down
73 changes: 40 additions & 33 deletions tests/torchtune/datasets/test_alpaca_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest.mock import patch

import pytest
from datasets import Dataset

from tests.test_utils import get_assets_path
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
Expand All @@ -28,17 +29,19 @@ def test_label_no_masking(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
load_dataset.return_value = Dataset.from_list(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
)

alpaca_ds = alpaca_dataset(tokenizer=tokenizer)
input, labels = alpaca_ds[0]
Expand All @@ -55,17 +58,19 @@ def test_label_masking(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
load_dataset.return_value = Dataset.from_list(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
)

alpaca_ds = alpaca_dataset(tokenizer=tokenizer, train_on_input=False)

Expand All @@ -90,17 +95,19 @@ def test_alpaca_clean(self, load_dataset, tokenizer):
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
load_dataset.return_value = Dataset.from_list(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]
)

alpaca_ds = alpaca_cleaned_dataset(tokenizer=tokenizer)
input, labels = alpaca_ds[0]
Expand Down
Loading

0 comments on commit d36e818

Please sign in to comment.