Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch dataloader without torch formatting #5357

Merged
merged 7 commits into from
Dec 15, 2022

Conversation

lhoestq
Copy link
Member

@lhoestq lhoestq commented Dec 13, 2022

In #5084 we make the torch formatting consistent with the map-style datasets formatting: a torch formatted iterable dataset will yield torch tensors.

The previous behavior of the torch formatting for iterable dataset was simply to make the iterable dataset inherit from torch.utils.data.Dataset to make it work in a torch DataLoader. However ideally an unformatted dataset should also work with a DataLoader. To fix that, datasets.IterableDataset should inherit from torch.utils.data.IterableDataset.

Since we don't want to import torch on startup, I created this PR to dynamically make the datasets.IterableDataset class inherit form the torch one when a datasets.IterableDataset is instantiated and if PyTorch is available.

>>> from datasets import load_dataset
>>> ds = load_dataset("c4", "en", streaming=True, split="train")
>>> import torch.utils.data
>>> isinstance(ds, torch.utils.data.IterableDataset)
True
>>> dataloader = torch.utils.data.DataLoader(ds, batch_size=32, num_workers=4)
>>> for example in dataloader:
...:     ...

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 13, 2022

The documentation is not available anymore as the PR was closed or merged.

@lhoestq
Copy link
Member Author

lhoestq commented Dec 14, 2022

Need some more time to fix the tests, especially with pickle

@lhoestq lhoestq marked this pull request as draft December 14, 2022 14:13
@lhoestq lhoestq force-pushed the support-torch-dataloader-without-torch-formatting branch from 4907aaf to be49a81 Compare December 14, 2022 15:24
@lhoestq lhoestq marked this pull request as ready for review December 14, 2022 15:45
Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

This is probably the least hacky we can get here :)

src/datasets/iterable_dataset.py Outdated Show resolved Hide resolved
Co-authored-by: Mario Šaško <mariosasko777@gmail.com>
Copy link
Contributor

@polinaeterna polinaeterna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the hack :)

I just left a fix in the docs (not related to this PR).

And I actually don't quite understand the idea - what's the motivation behind making only IterableDataset compatible with torch DataLoader without setting the format explicitly?

docs/source/use_with_pytorch.mdx Outdated Show resolved Hide resolved
src/datasets/iterable_dataset.py Outdated Show resolved Hide resolved
@lhoestq
Copy link
Member Author

lhoestq commented Dec 15, 2022

And I actually don't quite understand the idea - what's the motivation behind making only IterableDataset compatible with torch DataLoader without setting the format explicitly?

Setting the format to pytorch = set the output types of the dataset to be pytorch tensors. However sometimes your dataset is not made of tensors but you still want to be able to use a pytorch DataLoader

@mariosasko
Copy link
Collaborator

A bit more context.

The arrow-backed Dataset supports DataLoader(ds) (even if the format is not "torch"), and we want to be able to do the same with IterableDataset for consistency. However, this is when the PyTorch internals come into play - an iterable dataset needs to be an instance of torch.utils.data.IterableDataset due to this check (notice there is no check for the map-style version). Hence the explicit subclassing in this PR.

@lhoestq
Copy link
Member Author

lhoestq commented Dec 15, 2022

Exactly :) Btw I just took your comments into account @polinaeterna , so feel free to review again

@lhoestq lhoestq merged commit 0bec9f3 into main Dec 15, 2022
@lhoestq lhoestq deleted the support-torch-dataloader-without-torch-formatting branch December 15, 2022 19:15
@corbyrosset
Copy link

corbyrosset commented Jan 4, 2023

@lhoestq just checking, does this change still preserve the fix to the "data duplicate when setting num_works > 1 with streaming data" issue from before?

#3423

@lhoestq
Copy link
Member Author

lhoestq commented Jan 4, 2023

Yes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants