-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPMD multi core implementation #6382
Comments
Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct? This is achievable with SPMD. As an example for FSDP sharding within the groups, you can define your mesh like: # 8 devices, 4 replicas, 2-way FSDP within each replica.
mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))
# Shard the model parameters FSDP across the `fsdp` axis.
# The parameters will be replicated along all unspecified mesh axes (i.e. the `replica` axis in this case).
xs.mark_sharding(model.weight, mesh, ('fsdp', None))
# Shard the inputs' batch dimension along all mesh axes.
xs.mark_sharding(inputs, mesh, (('replica', 'fsdp'), None)) |
It would also be worth checking out the FSDPv2 wrapper if you just want to train a bigger model using all devices: #6379 |
To fully run 4 copies of model (2-way model parallel), you would need to shard your input on I am closing this issue, @mfatih7 |
Hello
Without SPMD, we can train 8 duplicates of a model on 8 TPU cores on the same v3 device and optimize model weights concurrently.
With SPMD, is it possible to place 4 copies of a bigger model on 8 cores on the same v3 device and train them concurrently?
If possible is there any example?
The text was updated successfully, but these errors were encountered: