Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterableDataset formatting in numpy/torch/tf/jax #5084

Closed
wants to merge 27 commits into from

Conversation

lhoestq
Copy link
Member

@lhoestq lhoestq commented Oct 6, 2022

This code now returns a numpy array:

from datasets import load_dataset

ds = load_dataset("imagenet-1k", split="train", streaming=True).with_format("np")
print(next(iter(ds))["image"])

It also works with "arrow", "pandas", "torch", "tf" and "jax"

Implementation details:

I'm using the existing code to format an Arrow Table to the right output format for simplicity.
Therefore it's probbaly not the most optimized approach.

For example to output PyTorch tensors it does this for every example:

python data -> arrow table -> numpy extracted data -> pytorch formatted data

Releasing this feature

Even though I consider this as a bug/inconsistency, this change is a breaking change.
And I'm sure some users were relying on the torch iterable dataset to return PIL Image and used data collators to convert to pytorch.

So I guess this is datasets 3.0 ?

TODO

Close #5083

@lhoestq lhoestq changed the base branch from main to image-formatting October 6, 2022 16:53
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Base automatically changed from image-formatting to main October 10, 2022 13:21
@lhoestq
Copy link
Member Author

lhoestq commented Dec 20, 2022

Actually I'm not happy with this implementation. It always require the iterable dataset to have definite features, which removes a lot of flexibility. So I think we need an actual formatting from python objects, not from arrow data.

@lhoestq
Copy link
Member Author

lhoestq commented Dec 20, 2022

Closing this one since it has too many conflicts and still require some work - it will be easier to open a new PR

@lhoestq lhoestq closed this Dec 20, 2022
@albertvillanova albertvillanova deleted the iterable-ds-formatting branch September 24, 2023 10:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support numpy/torch/tf/jax formatting for IterableDataset
2 participants