Skip to content

Commit 5de5001

Browse files
authored
Add filter_fn to all generic dataset classes and builders API (#1789)
1 parent 8488725 commit 5de5001

File tree

14 files changed

+92
-14
lines changed

14 files changed

+92
-14
lines changed

torchtune/datasets/_alpaca.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from functools import partial
88

9-
from typing import Any, Dict, Optional, Union
9+
from typing import Any, Callable, Dict, Optional, Union
1010

1111
from torchtune.data._messages import AlpacaToMessages
1212

@@ -22,6 +22,7 @@ def alpaca_dataset(
2222
column_map: Optional[Dict[str, str]] = None,
2323
train_on_input: bool = True,
2424
packed: bool = False,
25+
filter_fn: Optional[Callable] = None,
2526
split: str = "train",
2627
**load_dataset_kwargs: Dict[str, Any],
2728
) -> Union[SFTDataset, PackedDataset]:
@@ -52,6 +53,9 @@ def alpaca_dataset(
5253
the default column names ``"instruction``, ``"input"``, and ``"output"`` in ``tatsu-lab/alpaca``.
5354
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
5455
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
56+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
57+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
58+
details.
5559
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
5660
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
5761
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
@@ -78,6 +82,7 @@ def alpaca_dataset(
7882
source=source,
7983
message_transform=message_transform,
8084
model_transform=tokenizer,
85+
filter_fn=filter_fn,
8186
split=split,
8287
**load_dataset_kwargs,
8388
)

torchtune/datasets/_chat.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Dict, Optional, Union
7+
from typing import Any, Callable, Dict, Optional, Union
88

99
from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
1010
from torchtune.datasets._packed import PackedDataset
@@ -21,6 +21,7 @@ def chat_dataset(
2121
train_on_input: bool = False,
2222
new_system_prompt: Optional[str] = None,
2323
packed: bool = False,
24+
filter_fn: Optional[Callable] = None,
2425
split: str = "train",
2526
**load_dataset_kwargs: Dict[str, Any],
2627
) -> Union[SFTDataset, PackedDataset]:
@@ -82,6 +83,9 @@ def chat_dataset(
8283
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
8384
serve as instructions to guide the model response. Default is None.
8485
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
86+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
87+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
88+
details.
8589
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
8690
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
8791
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
@@ -171,6 +175,7 @@ def chat_dataset(
171175
message_transform=message_transform,
172176
model_transform=tokenizer,
173177
split=split,
178+
filter_fn=filter_fn,
174179
**load_dataset_kwargs,
175180
)
176181
if packed:

torchtune/datasets/_cnn_dailymail.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Dict, Optional
7+
from typing import Any, Callable, Dict, Optional
88

99
from torchtune.datasets._text_completion import TextCompletionDataset
1010

@@ -15,6 +15,7 @@ def cnn_dailymail_articles_dataset(
1515
tokenizer: ModelTokenizer,
1616
source: str = "ccdv/cnn_dailymail",
1717
max_seq_len: Optional[int] = None,
18+
filter_fn: Optional[Callable] = None,
1819
split: str = "train",
1920
**load_dataset_kwargs: Dict[str, Any],
2021
) -> TextCompletionDataset:
@@ -30,6 +31,9 @@ def cnn_dailymail_articles_dataset(
3031
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
3132
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
3233
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
34+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
35+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
36+
details.
3337
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
3438
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
3539
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -43,6 +47,7 @@ def cnn_dailymail_articles_dataset(
4347
source=source,
4448
column="article",
4549
max_seq_len=max_seq_len,
50+
filter_fn=filter_fn,
4651
split=split,
4752
# This is used to specify the version of the dataset, a required argument
4853
# by the cnn_dailymail dataset builder:

torchtune/datasets/_grammar.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from typing import Any, Dict, Optional, Union
8+
from typing import Any, Callable, Dict, Optional, Union
99

1010
from torchtune.data import InputOutputToMessages
1111
from torchtune.datasets._packed import PackedDataset
@@ -21,6 +21,7 @@ def grammar_dataset(
2121
train_on_input: bool = False,
2222
new_system_prompt: Optional[str] = None,
2323
packed: bool = False,
24+
filter_fn: Optional[Callable] = None,
2425
split: str = "train",
2526
**load_dataset_kwargs: Dict[str, Any],
2627
) -> Union[SFTDataset, PackedDataset]:
@@ -53,6 +54,9 @@ def grammar_dataset(
5354
serve as instructions to guide the model response. Setting this will OVERRIDE any system
5455
messages already present in the dataset. Default is None.
5556
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
57+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
58+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
59+
details.
5660
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
5761
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
5862
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -79,6 +83,7 @@ def grammar_dataset(
7983
source=source,
8084
message_transform=message_transform,
8185
model_transform=tokenizer,
86+
filter_fn=filter_fn,
8287
split=split,
8388
**load_dataset_kwargs,
8489
)

torchtune/datasets/_hh_rlhf_helpful.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Dict, Optional
7+
from typing import Any, Callable, Dict, Optional
88

99
from torchtune.data import ChosenRejectedToMessages
1010
from torchtune.datasets._preference import PreferenceDataset
@@ -18,6 +18,7 @@ def hh_rlhf_helpful_dataset(
1818
column_map: Optional[Dict[str, str]] = None,
1919
train_on_input: bool = False,
2020
new_system_prompt: Optional[str] = None,
21+
filter_fn: Optional[Callable] = None,
2122
split: str = "train",
2223
**load_dataset_kwargs: Dict[str, Any],
2324
) -> PreferenceDataset:
@@ -42,6 +43,9 @@ def hh_rlhf_helpful_dataset(
4243
new_system_prompt (Optional[str]): if specified, prepend a system message to every sample for both chosen
4344
and rejected. This can serve as instructions to guide the model response. Setting this will OVERRIDE
4445
any system messages already present in the dataset. Default is None.
46+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
47+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
48+
details.
4549
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
4650
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
4751
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -60,6 +64,7 @@ def hh_rlhf_helpful_dataset(
6064
source=source,
6165
message_transform=message_transform,
6266
tokenizer=tokenizer,
67+
filter_fn=filter_fn,
6368
split=split,
6469
**load_dataset_kwargs,
6570
)

torchtune/datasets/_instruct.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Dict, Optional, Union
7+
from typing import Any, Callable, Dict, Optional, Union
88

99
from torchtune.data import InputOutputToMessages
1010
from torchtune.datasets._packed import PackedDataset
@@ -20,6 +20,7 @@ def instruct_dataset(
2020
train_on_input: bool = False,
2121
new_system_prompt: Optional[str] = None,
2222
packed: bool = False,
23+
filter_fn: Optional[Callable] = None,
2324
split: str = "train",
2425
**load_dataset_kwargs: Dict[str, Any],
2526
) -> Union[SFTDataset, PackedDataset]:
@@ -65,6 +66,9 @@ def instruct_dataset(
6566
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
6667
serve as instructions to guide the model response. Default is None.
6768
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
69+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
70+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
71+
details.
6872
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
6973
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
7074
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
@@ -137,9 +141,11 @@ def instruct_dataset(
137141
source=source,
138142
message_transform=message_transform,
139143
model_transform=tokenizer,
144+
filter_fn=filter_fn,
140145
split=split,
141146
**load_dataset_kwargs,
142147
)
148+
143149
if packed:
144150
if tokenizer.max_seq_len is None:
145151
raise ValueError(

torchtune/datasets/_preference.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Dict, List, Mapping, Optional
7+
from typing import Any, Callable, Dict, List, Mapping, Optional
88

99
import numpy as np
1010
from datasets import load_dataset
@@ -86,6 +86,9 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
8686
Since PreferenceDataset only supports text data, it requires a
8787
:class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
8888
:class:`~torchtune.datasets.SFTDataset`.
89+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
90+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
91+
details.
8992
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
9093
Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_
9194
for more details.
@@ -97,12 +100,16 @@ def __init__(
97100
source: str,
98101
message_transform: Transform,
99102
tokenizer: ModelTokenizer,
103+
filter_fn: Optional[Callable] = None,
100104
**load_dataset_kwargs: Dict[str, Any],
101105
) -> None:
102106
self._tokenizer = tokenizer
103107
self._message_transform = message_transform
104108
self._data = load_dataset(source, **load_dataset_kwargs)
105109

110+
if filter_fn is not None:
111+
self._data = self._data.filter(filter_fn)
112+
106113
def __len__(self):
107114
return len(self._data)
108115

@@ -149,6 +156,7 @@ def preference_dataset(
149156
column_map: Optional[Dict[str, str]] = None,
150157
train_on_input: bool = False,
151158
new_system_prompt: Optional[str] = None,
159+
filter_fn: Optional[Callable] = None,
152160
split: str = "train",
153161
**load_dataset_kwargs: Dict[str, Any],
154162
) -> PreferenceDataset:
@@ -214,6 +222,9 @@ def preference_dataset(
214222
new_system_prompt (Optional[str]): if specified, prepend a system message to every sample for both chosen
215223
and rejected. This can serve as instructions to guide the model response. Setting this will OVERRIDE
216224
any system messages already present in the dataset. Default is None.
225+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
226+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
227+
details.
217228
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
218229
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
219230
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -291,6 +302,7 @@ def preference_dataset(
291302
source=source,
292303
message_transform=message_transform,
293304
tokenizer=tokenizer,
305+
filter_fn=filter_fn,
294306
split=split,
295307
**load_dataset_kwargs,
296308
)

torchtune/datasets/_samsum.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from typing import Any, Dict, Optional, Union
8+
from typing import Any, Callable, Dict, Optional, Union
99

1010
from torchtune.data import InputOutputToMessages
1111
from torchtune.datasets._packed import PackedDataset
@@ -21,6 +21,7 @@ def samsum_dataset(
2121
train_on_input: bool = False,
2222
new_system_prompt: Optional[str] = None,
2323
packed: bool = False,
24+
filter_fn: Optional[Callable] = None,
2425
split: str = "train",
2526
**load_dataset_kwargs: Dict[str, Any],
2627
) -> Union[SFTDataset, PackedDataset]:
@@ -53,6 +54,9 @@ def samsum_dataset(
5354
serve as instructions to guide the model response. Setting this will OVERRIDE any system
5455
messages already present in the dataset. Default is None.
5556
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
57+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
58+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
59+
details.
5660
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
5761
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
5862
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -81,6 +85,7 @@ def samsum_dataset(
8185
message_transform=message_transform,
8286
model_transform=tokenizer,
8387
split=split,
88+
filter_fn=filter_fn,
8489
**load_dataset_kwargs,
8590
)
8691
if packed:

torchtune/datasets/_slimorca.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Dict, Optional, Union
7+
from typing import Any, Callable, Dict, Optional, Union
88

99
from torchtune.data import ShareGPTToMessages
1010
from torchtune.datasets._packed import PackedDataset
@@ -21,6 +21,7 @@ def slimorca_dataset(
2121
train_on_input: bool = False,
2222
new_system_prompt: Optional[str] = None,
2323
packed: bool = False,
24+
filter_fn: Optional[Callable] = None,
2425
split: str = "train",
2526
**load_dataset_kwargs: Dict[str, Any],
2627
) -> Union[SFTDataset, PackedDataset]:
@@ -50,6 +51,9 @@ def slimorca_dataset(
5051
serve as instructions to guide the model response. Setting this will OVERRIDE any system
5152
messages already present in the dataset. Default is None.
5253
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
54+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
55+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
56+
details.
5357
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
5458
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
5559
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -79,6 +83,7 @@ def slimorca_dataset(
7983
source=source,
8084
message_transform=message_transform,
8185
model_transform=tokenizer,
86+
filter_fn=filter_fn,
8287
split=split,
8388
**load_dataset_kwargs,
8489
)

torchtune/datasets/_stack_exchange_paired.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Dict, Mapping, Optional
7+
from typing import Any, Callable, Dict, Mapping, Optional
88

99
from torchtune.data import Message
1010
from torchtune.datasets._preference import PreferenceDataset
@@ -78,6 +78,7 @@ def stack_exchange_paired_dataset(
7878
source: str = "lvwerra/stack-exchange-paired",
7979
column_map: Optional[Dict[str, str]] = None,
8080
train_on_input: bool = False,
81+
filter_fn: Optional[Callable] = None,
8182
split: str = "train",
8283
**load_dataset_kwargs: Dict[str, Any],
8384
) -> PreferenceDataset:
@@ -100,6 +101,9 @@ def stack_exchange_paired_dataset(
100101
Keys should be "prompt", "chosen", and "rejected" and values should be the actual column names.
101102
Default is None, keeping the default column names.
102103
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
104+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
105+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
106+
details.
103107
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
104108
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
105109
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -122,6 +126,7 @@ def stack_exchange_paired_dataset(
122126
source=source,
123127
message_transform=message_transform,
124128
tokenizer=tokenizer,
129+
filter_fn=filter_fn,
125130
split=split,
126131
data_dir="data/rl",
127132
**load_dataset_kwargs,

0 commit comments

Comments
 (0)