Skip to content

Commit

Permalink
comments + raise error if sharding is ambiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Jun 9, 2022
1 parent 54e9f39 commit 8f5579e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
13 changes: 13 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ def _shuffle_kwargs(rng: np.random.Generator, kwargs: dict) -> dict:

def _shard_kwargs(shard_idx: int, kwargs: dict) -> dict:
"""Return a copy of the input kwargs but with only one shard"""
# Having lists of different sizes makes sharding ambigious, raise an error in this case
# until we decide how to define sharding without ambiguity for users
lists_lengths = {key: len(value) for key, value in kwargs.items() if isinstance(value, list)}
if len(set(lists_lengths.values())) > 1:
raise RuntimeError(
(
"Sharding is ambiguous for this dataset: "
+ "we found several data sources lists of different lengths, and we don't know over which list we should parallelize:\n"
+ "\n".join(f"\t- key {key} has length {length}" for key, length in lists_lengths.items())
+ "\nTo fix this, check the dataset script 'gen_kwargs' and make sure to use lists only for data sources, "
+ "and use tuples otherwise. In the end there should only one single list, or several lists with the same length."
)
)
return {key: [value[shard_idx]] if isinstance(value, list) else value for key, value in kwargs.items()}


Expand Down
14 changes: 12 additions & 2 deletions src/datasets/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@ def __init__(self, obj, target: str, new, attrs=None):

def __enter__(self):
*submodules, target_attr = self.target.split(".")

# Patch modules:
# it's used to patch attributes of submodules like "os.path.join";
# in this case we need to patch "os" and "os.path"

for i in range(len(submodules)):
submodule = import_module(".".join(submodules[: i + 1]))
# We iterate over all the globals in self.obj in case we find "os" or "os.path"
for attr in self.obj.__dir__():
obj_attr = getattr(self.obj, attr)
# We don't check for the name of the global, but rather if its value *is* "os" or "os.path".
# This allows to patch renamed modules like "from os import path as ospath".
if obj_attr is submodule or (
(isinstance(obj_attr, _PatchedModuleObj) and obj_attr._original_module is submodule)
):
Expand All @@ -67,17 +72,22 @@ def __enter__(self):
patched = getattr(patched, key)
# finally set the target attribute
setattr(patched, target_attr, self.new)

# Patch attribute itself:
# it's used for builtins like "open",
# and also to patch "os.path.join" we may also need to patch "join"
# itself if it was imported as "from os.path import join".
if submodules: # if it's an attribute of a submodule

if submodules: # if it's an attribute of a submodule like "os.path.join"
attr_value = getattr(import_module(".".join(submodules)), target_attr)
# We iterate over all the globals in self.obj in case we find "os.path.join"
for attr in self.obj.__dir__():
# We don't check for the name of the global, but rather if its value *is* "os.path.join".
# This allows to patch renamed attributes like "from os.path import join as pjoin".
if getattr(self.obj, attr) is attr_value:
self.original[attr] = getattr(self.obj, attr)
setattr(self.obj, attr, self.new)
elif target_attr in globals()["__builtins__"]: # if it'a s builtin
elif target_attr in globals()["__builtins__"]: # if it'a s builtin like "open"
self.original[target_attr] = globals()["__builtins__"][target_attr]
setattr(self.obj, target_attr, self.new)
else:
Expand Down

1 comment on commit 8f5579e

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008841 / 0.011353 (-0.002512) 0.004394 / 0.011008 (-0.006614) 0.030241 / 0.038508 (-0.008268) 0.038195 / 0.023109 (0.015086) 0.317927 / 0.275898 (0.042029) 0.335620 / 0.323480 (0.012140) 0.007048 / 0.007986 (-0.000937) 0.004022 / 0.004328 (-0.000307) 0.007926 / 0.004250 (0.003675) 0.046433 / 0.037052 (0.009380) 0.301609 / 0.258489 (0.043120) 0.333866 / 0.293841 (0.040025) 0.032276 / 0.128546 (-0.096270) 0.010072 / 0.075646 (-0.065574) 0.255439 / 0.419271 (-0.163832) 0.052335 / 0.043533 (0.008802) 0.313149 / 0.255139 (0.058010) 0.319257 / 0.283200 (0.036057) 0.101400 / 0.141683 (-0.040283) 1.816260 / 1.452155 (0.364105) 1.903185 / 1.492716 (0.410468)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.341608 / 0.018006 (0.323602) 0.561422 / 0.000490 (0.560932) 0.020254 / 0.000200 (0.020054) 0.000382 / 0.000054 (0.000327)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.028965 / 0.037411 (-0.008446) 0.113329 / 0.014526 (0.098803) 0.118505 / 0.176557 (-0.058052) 0.168717 / 0.737135 (-0.568418) 0.119576 / 0.296338 (-0.176762)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.437199 / 0.215209 (0.221990) 4.356905 / 2.077655 (2.279250) 1.942159 / 1.504120 (0.438039) 1.739902 / 1.541195 (0.198708) 1.894144 / 1.468490 (0.425654) 0.452778 / 4.584777 (-4.131999) 4.854245 / 3.745712 (1.108532) 4.109992 / 5.269862 (-1.159869) 0.957141 / 4.565676 (-3.608535) 0.053757 / 0.424275 (-0.370518) 0.012372 / 0.007607 (0.004765) 0.544744 / 0.226044 (0.318700) 5.509171 / 2.268929 (3.240243) 2.402038 / 55.444624 (-53.042587) 2.066981 / 6.876477 (-4.809495) 2.228226 / 2.142072 (0.086154) 0.569526 / 4.805227 (-4.235701) 0.126627 / 6.500664 (-6.374037) 0.065661 / 0.075469 (-0.009808)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.628698 / 1.841788 (-0.213089) 15.369940 / 8.074308 (7.295632) 26.864462 / 10.191392 (16.673070) 0.882322 / 0.680424 (0.201898) 0.544910 / 0.534201 (0.010709) 0.498911 / 0.579283 (-0.080372) 0.507665 / 0.434364 (0.073301) 0.329500 / 0.540337 (-0.210837) 0.335282 / 1.386936 (-1.051654)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008915 / 0.011353 (-0.002438) 0.004334 / 0.011008 (-0.006674) 0.030189 / 0.038508 (-0.008319) 0.037055 / 0.023109 (0.013946) 0.344273 / 0.275898 (0.068375) 0.355564 / 0.323480 (0.032084) 0.006927 / 0.007986 (-0.001059) 0.005291 / 0.004328 (0.000963) 0.007686 / 0.004250 (0.003436) 0.046618 / 0.037052 (0.009565) 0.320852 / 0.258489 (0.062363) 0.359540 / 0.293841 (0.065699) 0.032123 / 0.128546 (-0.096423) 0.010017 / 0.075646 (-0.065629) 0.253773 / 0.419271 (-0.165498) 0.052609 / 0.043533 (0.009076) 0.332853 / 0.255139 (0.077714) 0.343424 / 0.283200 (0.060225) 0.104532 / 0.141683 (-0.037151) 1.850045 / 1.452155 (0.397891) 1.935952 / 1.492716 (0.443236)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.425327 / 0.018006 (0.407321) 0.557215 / 0.000490 (0.556726) 0.032740 / 0.000200 (0.032540) 0.000522 / 0.000054 (0.000468)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.027173 / 0.037411 (-0.010239) 0.106898 / 0.014526 (0.092372) 0.120610 / 0.176557 (-0.055946) 0.181720 / 0.737135 (-0.555415) 0.119280 / 0.296338 (-0.177058)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.423998 / 0.215209 (0.208789) 4.234794 / 2.077655 (2.157139) 1.883358 / 1.504120 (0.379238) 1.709793 / 1.541195 (0.168598) 1.880241 / 1.468490 (0.411751) 0.440232 / 4.584777 (-4.144545) 4.706728 / 3.745712 (0.961016) 3.465078 / 5.269862 (-1.804784) 0.966368 / 4.565676 (-3.599308) 0.053135 / 0.424275 (-0.371140) 0.012086 / 0.007607 (0.004479) 0.534382 / 0.226044 (0.308337) 5.331280 / 2.268929 (3.062352) 2.296390 / 55.444624 (-53.148235) 1.989872 / 6.876477 (-4.886604) 2.155626 / 2.142072 (0.013553) 0.562478 / 4.805227 (-4.242750) 0.123916 / 6.500664 (-6.376748) 0.062047 / 0.075469 (-0.013422)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.630135 / 1.841788 (-0.211653) 15.363766 / 8.074308 (7.289458) 26.770961 / 10.191392 (16.579569) 0.884467 / 0.680424 (0.204043) 0.540846 / 0.534201 (0.006645) 0.493756 / 0.579283 (-0.085527) 0.498028 / 0.434364 (0.063665) 0.317786 / 0.540337 (-0.222552) 0.329176 / 1.386936 (-1.057760)

CML watermark

Please sign in to comment.