-
So if I have some data and I want to shard it which should I use? It seems the docs for My big issue is that when slicing into arrays I don't seem to get the same result with both functions here's a small example
My question is why does Additionally it seems that |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
You should use |
Beta Was this translation helpful? Give feedback.
-
To add more on the indexing question: That's just the semantics for jit vs pmap. That is, when you use PmapSharding, the But slicing a Ideally, I would suggest to move away from pmap and use You can read: https://jax.readthedocs.io/en/latest/sharded-computation.html, https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html and https://jax.readthedocs.io/en/latest/notebooks/shard_map.html for more information. |
Beta Was this translation helpful? Give feedback.
-
Hi yashk2810, thanks for the response! So if I'm understanding you correctly even though the visualisation doesn't quite show that it's on device 0, that is just the semantics of With your advice I've got a working example with
My use case is that some threads are moving the data around and other threads are using it and I'd prefer certain threads to have the overhead of moving the data in order to not block others, which is why I want to do the explicit |
Beta Was this translation helpful? Give feedback.
To add more on the indexing question:
That's just the semantics for jit vs pmap. That is, when you use PmapSharding, the
__getitem__
portion goes via a different way giving you the device local array.But slicing a
NamedSharding
array goes via thejit
path and gives you a global array back. You can use.addressable_shards
to get the shards of a global Array on each device.Ideally, I would suggest to move away from pmap and use
shard_map
+jax.jit
instead since that is more in line with the new sharding semantics in JAX.You can read: https://jax.readthedocs.io/en/latest/sharded-computation.html, https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelizati…