44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , Dict , List , Mapping , Optional
7+ from typing import Any , Callable , Dict , List , Mapping , Optional
88
99import numpy as np
1010from datasets import load_dataset
@@ -86,6 +86,9 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
8686 Since PreferenceDataset only supports text data, it requires a
8787 :class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
8888 :class:`~torchtune.datasets.SFTDataset`.
89+ filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
90+ the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
91+ details.
8992 **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
9093 Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_
9194 for more details.
@@ -97,12 +100,16 @@ def __init__(
97100 source : str ,
98101 message_transform : Transform ,
99102 tokenizer : ModelTokenizer ,
103+ filter_fn : Optional [Callable ] = None ,
100104 ** load_dataset_kwargs : Dict [str , Any ],
101105 ) -> None :
102106 self ._tokenizer = tokenizer
103107 self ._message_transform = message_transform
104108 self ._data = load_dataset (source , ** load_dataset_kwargs )
105109
110+ if filter_fn is not None :
111+ self ._data = self ._data .filter (filter_fn )
112+
106113 def __len__ (self ):
107114 return len (self ._data )
108115
@@ -149,6 +156,7 @@ def preference_dataset(
149156 column_map : Optional [Dict [str , str ]] = None ,
150157 train_on_input : bool = False ,
151158 new_system_prompt : Optional [str ] = None ,
159+ filter_fn : Optional [Callable ] = None ,
152160 split : str = "train" ,
153161 ** load_dataset_kwargs : Dict [str , Any ],
154162) -> PreferenceDataset :
@@ -214,6 +222,9 @@ def preference_dataset(
214222 new_system_prompt (Optional[str]): if specified, prepend a system message to every sample for both chosen
215223 and rejected. This can serve as instructions to guide the model response. Setting this will OVERRIDE
216224 any system messages already present in the dataset. Default is None.
225+ filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
226+ the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
227+ details.
217228 split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
218229 of a given split, e.g. ``split="train[:10%]"``. Default is "train".
219230 **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
@@ -291,6 +302,7 @@ def preference_dataset(
291302 source = source ,
292303 message_transform = message_transform ,
293304 tokenizer = tokenizer ,
305+ filter_fn = filter_fn ,
294306 split = split ,
295307 ** load_dataset_kwargs ,
296308 )
0 commit comments