You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there, nice package but I think that the reason that torch dataloader was slow due to the fact that you are transforming to tensor and normalizing for torch but not for jax.
The text was updated successfully, but these errors were encountered:
You're right in that I should probably have created a more in depth comparison between both data loaders.
I should remove the transformations and let it run again. I haven't checked the PyTorch dataloader's source code whether the transformations are applied lazily (i.e. only during iteration time which would make it indeed slower) or if they are applied directly as I generate the datasets. I'm thinking it's the latter because I can index the dataset and would already get the transformed data (I'll double check this later -- I'm currently on my phone as I'm going home from work). Or perhaps the transformations are applied during the indexing. In any case, I'll double check this and post the results :)
Hi again. I double checked it and you were right! The PyTorch dataloader is not slow and my tests were wrong because I used the transformations. Now, Jaxonloader is just a bit faster but it's negligible almost. Thanks for pointing this out! I removed the misleading statements.
As a side note, I noticed another issue when testing this on a GPU (in which case it's actually a bit slower!) which is that it's trying fit the entire dataset into GPU memory! That's not good and needs to be fixed. I'll get on that soon.
Hi there, nice package but I think that the reason that torch dataloader was slow due to the fact that you are transforming to tensor and normalizing for torch but not for jax.
The text was updated successfully, but these errors were encountered: