-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IterableDataset formatting in numpy/torch/tf/jax #5084
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Actually I'm not happy with this implementation. It always require the iterable dataset to have definite |
Closing this one since it has too many conflicts and still require some work - it will be easier to open a new PR |
This code now returns a numpy array:
It also works with "arrow", "pandas", "torch", "tf" and "jax"
Implementation details:
I'm using the existing code to format an Arrow Table to the right output format for simplicity.
Therefore it's probbaly not the most optimized approach.
For example to output PyTorch tensors it does this for every example:
python data -> arrow table -> numpy extracted data -> pytorch formatted data
Releasing this feature
Even though I consider this as a bug/inconsistency, this change is a breaking change.
And I'm sure some users were relying on the torch iterable dataset to return PIL Image and used data collators to convert to pytorch.
So I guess this is
datasets
3.0 ?TODO
Close #5083