Skip to content

device_put vs device_put_sharded #24279

Answered by yashk2810
sash-a asked this question in Q&A
Discussion options

You must be logged in to vote

To add more on the indexing question:

That's just the semantics for jit vs pmap. That is, when you use PmapSharding, the __getitem__ portion goes via a different way giving you the device local array.

But slicing a NamedSharding array goes via the jit path and gives you a global array back. You can use .addressable_shards to get the shards of a global Array on each device.

Ideally, I would suggest to move away from pmap and use shard_map + jax.jit instead since that is more in line with the new sharding semantics in JAX.

You can read: https://jax.readthedocs.io/en/latest/sharded-computation.html, https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelizati…

Replies: 3 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by sash-a
Comment options

You must be logged in to vote
1 reply
@yashk2810
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants