-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support DataLoader with num_workers > 0 in streaming mode #4375
Conversation
- make it picklable - paralellize over the shards when num_workers is passed
The documentation is not available anymore as the PR was closed or merged. |
Alright this is finally ready for review ! It's quite long I'm sorry, but it's not easy to disentangle everything ^^' The main additions are in
|
|
||
|
||
def xpathrglob(path, pattern, **kwargs): | ||
"""Rglob function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. | ||
class xPath(type(Path())): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
many changes in this file are just about moving functions inside this class.
For example I moved xpathrglob to xPath.rglob
if worker_info.id == 0 and self.n_shards < worker_info.num_workers: | ||
logger.warning( | ||
f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={self.n_shards}). " | ||
f"Stopping dataloader workers [{self.n_shards}...{worker_info.num_workers -1}]." | ||
) | ||
logger.warning( | ||
f"To parallelize data loading, we give each process some shards (or data sources) to process. " | ||
f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={self.n_shards}." | ||
f"To enable more parallelism, please split the dataset in more files than {self.n_shards}." | ||
) | ||
# split workload | ||
shards_indices = list(range(worker_info.id, self.n_shards, worker_info.num_workers)) | ||
if shards_indices: | ||
logger.debug( | ||
f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards." | ||
) | ||
for shard_idx in shards_indices: | ||
for key, example in self._iter_shard(shard_idx): | ||
yield self._apply_feature_types(example) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is where we shard the iterable dataset when it's passed to a DataLoader worker
ex_iterable = self._ex_iterable.shuffle_data_sources(self._effective_generator()) | ||
else: | ||
ex_iterable = self._ex_iterable | ||
yield from ex_iterable.shard_data_sources(shard_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what is called when iterating in a DataLoader worker. The idea is to iterate only on one shard out of the self.n_shards
available
# xml.etree.ElementTree | ||
for submodule in ["ElementTree", "ET"]: | ||
patch_submodule(module, f"{submodule}.parse", wrap_auth(xet_parse)).start() | ||
patch_submodule(module, "pathlib.Path", xPath).start() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we can just pass the source object to patch in the module, and it will patch it even if the attribute alone is imported, or even if a parent module has been imported and renamed (see test_patching.py for a list of all supported cases - I probably have to add a docstring to patch_submodule
as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is 🔥. Just two comments:
(Also thanks for the the review instructions/code clarifications)
Added some comments and an error when lists have different lengths for sharding :) |
Let's resolve the merge conflict and the CI error (if it's related to the changes), and I can review the PR again. |
Feel free to review again :) The CI fail is unrelated to this PR and will be fixed by #4472 (the hub now returns 401 instead of 404 for unauthenticated requests to non-existing repos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All looks good now! Thanks!
Co-authored-by: Mario Šaško <mariosasko777@gmail.com>
CI failures are unrelated to this PR - merging :) (CI fails are a mix of pip install fails and Hub fails) |
@lhoestq you're our hero :) |
Issue
It's currently not possible to properly stream a dataset using multiple
torch.utils.data.DataLoader
workers:TorchIterableDataset
can't be pickled and passed to the subprocesses: Streaming Datasets don't work with Transformers Trainer when dataloader_num_workers>1 #3950open
data urls rather than use network #3951fsspec
doesn't work out of the box in subprocessesSolution in this PR
I fixed these to enable passing an
IterableDataset
to atorch.utils.data.DataLoader
withnum_workers > 0
.I also had to shard the
IterableDataset
to give each worker a shard, otherwise data would be duplicated. This is implemented inTorchIterableDataset.__iter__
and uses the newIterableDataset._iter_shard(shard_idx)
methodI also had to do a few changes the patching that enable streaming in dataset scripts:
xPath
, so thatPath
outside of dataset scripts stay unchanged - otherwise I didn't change the content of the extended Path methods for streamingpd.read_csv
patch, opening the file in "rb" mode was missing and causing some datasets to not work in streaming mode, and compression inference was missingA few details regarding
fsspec
in multiprocessingFrom fsspec/filesystem_spec#963 (comment) :
Therefore in a DataLoader's worker, I clear the reference to the loop and thread (1). We should be fine for 2 and 3 already since we don't use fsspec class instances from the parent process.
Fix #3950
Fix #3951
TODO: