-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch dataloader without torch formatting #5357
Support torch dataloader without torch formatting #5357
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Need some more time to fix the tests, especially with pickle |
4907aaf
to
be49a81
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
This is probably the least hacky we can get here :)
Co-authored-by: Mario Šaško <mariosasko777@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the hack :)
I just left a fix in the docs (not related to this PR).
And I actually don't quite understand the idea - what's the motivation behind making only IterableDataset
compatible with torch DataLoader without setting the format explicitly?
Setting the format to pytorch = set the output types of the dataset to be pytorch tensors. However sometimes your dataset is not made of tensors but you still want to be able to use a pytorch DataLoader |
A bit more context. The arrow-backed |
Exactly :) Btw I just took your comments into account @polinaeterna , so feel free to review again |
Yes :) |
In #5084 we make the torch formatting consistent with the map-style datasets formatting: a torch formatted iterable dataset will yield torch tensors.
The previous behavior of the torch formatting for iterable dataset was simply to make the iterable dataset inherit from
torch.utils.data.Dataset
to make it work in a torch DataLoader. However ideally an unformatted dataset should also work with a DataLoader. To fix that,datasets.IterableDataset
should inherit fromtorch.utils.data.IterableDataset
.Since we don't want to import torch on startup, I created this PR to dynamically make the
datasets.IterableDataset
class inherit form the torch one when adatasets.IterableDataset
is instantiated and if PyTorch is available.