Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchtune/datasets/_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from functools import partial

from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from torchtune.data._messages import AlpacaToMessages

Expand All @@ -22,6 +22,7 @@ def alpaca_dataset(
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = True,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Expand Down Expand Up @@ -52,6 +53,9 @@ def alpaca_dataset(
the default column names ``"instruction``, ``"input"``, and ``"output"`` in ``tatsu-lab/alpaca``.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
Expand All @@ -78,6 +82,7 @@ def alpaca_dataset(
source=source,
message_transform=message_transform,
model_transform=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)
Expand Down
7 changes: 6 additions & 1 deletion torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
from torchtune.datasets._packed import PackedDataset
Expand All @@ -21,6 +21,7 @@ def chat_dataset(
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Expand Down Expand Up @@ -82,6 +83,9 @@ def chat_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
Expand Down Expand Up @@ -171,6 +175,7 @@ def chat_dataset(
message_transform=message_transform,
model_transform=tokenizer,
split=split,
filter_fn=filter_fn,
**load_dataset_kwargs,
)
if packed:
Expand Down
7 changes: 6 additions & 1 deletion torchtune/datasets/_cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional

from torchtune.datasets._text_completion import TextCompletionDataset

Expand All @@ -15,6 +15,7 @@ def cnn_dailymail_articles_dataset(
tokenizer: ModelTokenizer,
source: str = "ccdv/cnn_dailymail",
max_seq_len: Optional[int] = None,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
Expand All @@ -30,6 +31,9 @@ def cnn_dailymail_articles_dataset(
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Expand All @@ -43,6 +47,7 @@ def cnn_dailymail_articles_dataset(
source=source,
column="article",
max_seq_len=max_seq_len,
filter_fn=filter_fn,
split=split,
# This is used to specify the version of the dataset, a required argument
# by the cnn_dailymail dataset builder:
Expand Down
7 changes: 6 additions & 1 deletion torchtune/datasets/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
Expand All @@ -21,6 +21,7 @@ def grammar_dataset(
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Expand Down Expand Up @@ -53,6 +54,9 @@ def grammar_dataset(
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Expand All @@ -79,6 +83,7 @@ def grammar_dataset(
source=source,
message_transform=message_transform,
model_transform=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)
Expand Down
7 changes: 6 additions & 1 deletion torchtune/datasets/_hh_rlhf_helpful.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional

from torchtune.data import ChosenRejectedToMessages
from torchtune.datasets._preference import PreferenceDataset
Expand All @@ -18,6 +18,7 @@ def hh_rlhf_helpful_dataset(
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> PreferenceDataset:
Expand All @@ -42,6 +43,9 @@ def hh_rlhf_helpful_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message to every sample for both chosen
and rejected. This can serve as instructions to guide the model response. Setting this will OVERRIDE
any system messages already present in the dataset. Default is None.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Expand All @@ -60,6 +64,7 @@ def hh_rlhf_helpful_dataset(
source=source,
message_transform=message_transform,
tokenizer=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)
8 changes: 7 additions & 1 deletion torchtune/datasets/_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
Expand All @@ -20,6 +20,7 @@ def instruct_dataset(
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Expand Down Expand Up @@ -65,6 +66,9 @@ def instruct_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
Expand Down Expand Up @@ -137,9 +141,11 @@ def instruct_dataset(
source=source,
message_transform=message_transform,
model_transform=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)

if packed:
if tokenizer.max_seq_len is None:
raise ValueError(
Expand Down
14 changes: 13 additions & 1 deletion torchtune/datasets/_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Callable, Dict, List, Mapping, Optional

import numpy as np
from datasets import load_dataset
Expand Down Expand Up @@ -86,6 +86,9 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
Since PreferenceDataset only supports text data, it requires a
:class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
:class:`~torchtune.datasets.SFTDataset`.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_
for more details.
Expand All @@ -97,12 +100,16 @@ def __init__(
source: str,
message_transform: Transform,
tokenizer: ModelTokenizer,
filter_fn: Optional[Callable] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
self._tokenizer = tokenizer
self._message_transform = message_transform
self._data = load_dataset(source, **load_dataset_kwargs)

if filter_fn is not None:
self._data = self._data.filter(filter_fn)

def __len__(self):
return len(self._data)

Expand Down Expand Up @@ -149,6 +156,7 @@ def preference_dataset(
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> PreferenceDataset:
Expand Down Expand Up @@ -214,6 +222,9 @@ def preference_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message to every sample for both chosen
and rejected. This can serve as instructions to guide the model response. Setting this will OVERRIDE
any system messages already present in the dataset. Default is None.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Expand Down Expand Up @@ -291,6 +302,7 @@ def preference_dataset(
source=source,
message_transform=message_transform,
tokenizer=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)
7 changes: 6 additions & 1 deletion torchtune/datasets/_samsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
Expand All @@ -21,6 +21,7 @@ def samsum_dataset(
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Expand Down Expand Up @@ -53,6 +54,9 @@ def samsum_dataset(
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Expand Down Expand Up @@ -81,6 +85,7 @@ def samsum_dataset(
message_transform=message_transform,
model_transform=tokenizer,
split=split,
filter_fn=filter_fn,
**load_dataset_kwargs,
)
if packed:
Expand Down
7 changes: 6 additions & 1 deletion torchtune/datasets/_slimorca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from torchtune.data import ShareGPTToMessages
from torchtune.datasets._packed import PackedDataset
Expand All @@ -21,6 +21,7 @@ def slimorca_dataset(
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Expand Down Expand Up @@ -50,6 +51,9 @@ def slimorca_dataset(
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Expand Down Expand Up @@ -79,6 +83,7 @@ def slimorca_dataset(
source=source,
message_transform=message_transform,
model_transform=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)
Expand Down
7 changes: 6 additions & 1 deletion torchtune/datasets/_stack_exchange_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Mapping, Optional
from typing import Any, Callable, Dict, Mapping, Optional

from torchtune.data import Message
from torchtune.datasets._preference import PreferenceDataset
Expand Down Expand Up @@ -78,6 +78,7 @@ def stack_exchange_paired_dataset(
source: str = "lvwerra/stack-exchange-paired",
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> PreferenceDataset:
Expand All @@ -100,6 +101,9 @@ def stack_exchange_paired_dataset(
Keys should be "prompt", "chosen", and "rejected" and values should be the actual column names.
Default is None, keeping the default column names.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Expand All @@ -122,6 +126,7 @@ def stack_exchange_paired_dataset(
source=source,
message_transform=message_transform,
tokenizer=tokenizer,
filter_fn=filter_fn,
split=split,
data_dir="data/rl",
**load_dataset_kwargs,
Expand Down
Loading
Loading