Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split argument to Generator #7015

Merged
merged 11 commits into from
Jul 26, 2024
4 changes: 4 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ def from_generator(
keep_in_memory: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: Optional[NamedSplit] = None,
piercus marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""Create a Dataset from a generator.
Expand All @@ -1088,6 +1089,8 @@ def from_generator(
Number of processes when downloading and generating the dataset locally.
This is helpful if the dataset is made of multiple files. Multiprocessing is disabled by default.
If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`.
split (`str`, defaults to `"train"`):
Split name to be assigned to the dataset.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring should go below <Added version="2.7.0"/>, because the version added tag corresponds to the num_proc parameter above split.

I would suggest to align its type with the rest of the code as: ([`NamedSplit`], defaults to `Split.TRAIN`).

I would also add a specific version added tag for the split parameter: . We may eventually change this depending on the next release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just used <Added version="2.21.0"/>, please cross-check


<Added version="2.7.0"/>
**kwargs (additional keyword arguments):
Expand Down Expand Up @@ -1126,6 +1129,7 @@ def from_generator(
keep_in_memory=keep_in_memory,
gen_kwargs=gen_kwargs,
num_proc=num_proc,
split=split,
**kwargs,
).read()

Expand Down
2 changes: 2 additions & 0 deletions src/datasets/io/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
keep_in_memory: bool = False,
streaming: bool = False,
num_proc: Optional[int] = None,
split: Optional[NamedSplit] = None,
piercus marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
self.features = features
Expand All @@ -47,6 +48,7 @@ def __init__(
self.streaming = streaming
self.num_proc = num_proc
self.kwargs = kwargs
self.split = split if split else "train"
piercus marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def read(self) -> Union[Dataset, IterableDataset]:
Expand Down
9 changes: 6 additions & 3 deletions src/datasets/io/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Optional

from .. import Features
from .. import Features, NamedSplit
from ..packaged_modules.generator.generator import Generator
from .abc import AbstractDatasetInputStream

Expand All @@ -15,6 +15,7 @@ def __init__(
streaming: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: Optional[NamedSplit] = None,
piercus marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
super().__init__(
Expand All @@ -23,20 +24,22 @@ def __init__(
keep_in_memory=keep_in_memory,
streaming=streaming,
num_proc=num_proc,
split=split,
piercus marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
)
self.builder = Generator(
cache_dir=cache_dir,
features=features,
generator=generator,
gen_kwargs=gen_kwargs,
split=split,
**kwargs,
)

def read(self):
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split="train")
dataset = self.builder.as_streaming_dataset(split=self.split)
piercus marked this conversation as resolved.
Show resolved Hide resolved
# Build regular (map-style) dataset
else:
download_config = None
Expand All @@ -52,6 +55,6 @@ def read(self):
num_proc=self.num_proc,
)
dataset = self.builder.as_dataset(
split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory
split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
piercus marked this conversation as resolved.
Show resolved Hide resolved
)
return dataset
2 changes: 1 addition & 1 deletion src/datasets/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def read(self):

# Build dataset for splits
dataset = self.builder.as_dataset(
split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory
split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
piercus marked this conversation as resolved.
Show resolved Hide resolved
)
return dataset

Expand Down
9 changes: 4 additions & 5 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,7 @@ def from_generator(
generator: Callable,
features: Optional[Features] = None,
gen_kwargs: Optional[dict] = None,
split: Optional[NamedSplit] = None,
piercus marked this conversation as resolved.
Show resolved Hide resolved
) -> "IterableDataset":
"""Create an Iterable Dataset from a generator.

Expand All @@ -2074,7 +2075,8 @@ def from_generator(
Keyword arguments to be passed to the `generator` callable.
You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`.
This can be used to improve shuffling and when iterating over the dataset with multiple workers.

split(`str`, default="train"):
Split name to be assigned to the dataset.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments as before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added <Added version="2.21.0"/> please cross-check

Returns:
`IterableDataset`

Expand Down Expand Up @@ -2105,10 +2107,7 @@ def from_generator(
from .io.generator import GeneratorDatasetInputStream

return GeneratorDatasetInputStream(
generator=generator,
features=features,
gen_kwargs=gen_kwargs,
streaming=True,
generator=generator, features=features, gen_kwargs=gen_kwargs, streaming=True, split=split
).read()

@staticmethod
Expand Down
12 changes: 11 additions & 1 deletion src/datasets/packaged_modules/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@ def __post_init__(self):
class Generator(datasets.GeneratorBasedBuilder):
BUILDER_CONFIG_CLASS = GeneratorConfig

def __init__(
self,
split: Optional[datasets.NamedSplit] = None,
**kwargs,
):
self.split = split if split is not None else datasets.Split.TRAIN
return super().__init__(
**kwargs,
)

piercus marked this conversation as resolved.
Show resolved Hide resolved
def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs=self.config.gen_kwargs)]
return [datasets.SplitGenerator(name=self.split, gen_kwargs=self.config.gen_kwargs)]
piercus marked this conversation as resolved.
Show resolved Hide resolved

def _generate_examples(self, **gen_kwargs):
for idx, ex in enumerate(self.config.generator(**gen_kwargs)):
Expand Down
15 changes: 10 additions & 5 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3867,10 +3867,11 @@ def _gen():
return _gen


def _check_generator_dataset(dataset, expected_features):
def _check_generator_dataset(dataset, expected_features, split):
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 4
assert dataset.num_columns == 3
assert dataset.split == split
assert dataset.column_names == ["col_1", "col_2", "col_3"]
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
Expand All @@ -3880,9 +3881,12 @@ def _check_generator_dataset(dataset, expected_features):
def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path):
cache_dir = tmp_path / "cache"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
split = NamedSplit("validation")
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
_check_generator_dataset(dataset, expected_features)
dataset = Dataset.from_generator(
data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory, split=split
)
_check_generator_dataset(dataset, expected_features, split)


@pytest.mark.parametrize(
Expand All @@ -3898,12 +3902,13 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t
def test_dataset_from_generator_features(features, data_generator, tmp_path):
cache_dir = tmp_path / "cache"
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
split = NamedSplit("validation")
expected_features = features.copy() if features else default_expected_features
features = (
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
)
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
_check_generator_dataset(dataset, expected_features)
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir, split=split)
_check_generator_dataset(dataset, expected_features, split)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a specific test_dataset_from_generator_split with a parametrized split values, such as not passing any value, passing NamedSplit("train"), passing literal "train", passing other NamedSplit, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_dataset_from_generator_split added, still i have impacted _check_generator_dataset to share the same generic check everywhere


@require_not_windows
Expand Down
Loading