-
Hi. I'm trying to do what I believe is rather trivial parallelization. I have an fn that accepts four arguments:
I'm trying to transform the fn, so it knows to shard configs to individual devices, replicate rng and the 2d dataset, and shard the network params, I also need vmap configs and network params across the batch dimension. For the life of me, I can't figure out how. I think part of the problem is the sheer number of available parallelization APIs, with overlapping but slightly different syntax, and my difficulty of understanding the docs (not sure if it's me, or the docs can be improved, particularly the experimental sections). Anyone providing a basic example I can follow would be a hero! I should note that the actual fn is somewhat complicated with mutiple nested |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 12 replies
-
Have you checked this doc out: https://jax.readthedocs.io/en/latest/sharded-computation.html? |
Beta Was this translation helpful? Give feedback.
Mesh defines an nd grid of devices that you can shard your array on. With pmap, you could only shard 1 dimension at a time and to do multi-dimension sharding, you would need to nest pmaps (which became very complicated).
With LLMs, you often shard multiple dimensions different ways which is why you need an abstraction that can define such a sharding in a very straightforward way. This is where
Mesh
comes into play.PartitionSpec
defines how an array should be sharded given the axis of the meshes. For example: