You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Jax doesn't recognise tf types as jnp - this line in get_dataset prefetch your batch to the devices and shard it for you - so you shouldn't need to change the input_dtype if the prefetch_to_device function runs correctly (I just run the notebook with GPU on colab and it seems fine)
Hi, thanks for the great work!
There is an assertion error when checking the dataset, which is confusing because as far as I understand it should fail for anyone.
Possibly a version issue (maybe some version of jax recognises tf types as jnp?).
get_dataset shown below with fixing lines commented out
For anyone reading I'm using 0.3.21 CUDA (not TPU).
The text was updated successfully, but these errors were encountered: