Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPMD multi core implementation #6382

Closed
mfatih7 opened this issue Jan 25, 2024 · 3 comments
Closed

SPMD multi core implementation #6382

mfatih7 opened this issue Jan 25, 2024 · 3 comments
Assignees

Comments

@mfatih7
Copy link

mfatih7 commented Jan 25, 2024

Hello

Without SPMD, we can train 8 duplicates of a model on 8 TPU cores on the same v3 device and optimize model weights concurrently.

With SPMD, is it possible to place 4 copies of a bigger model on 8 cores on the same v3 device and train them concurrently?

If possible is there any example?

@jonb377
Copy link
Collaborator

jonb377 commented Feb 7, 2024

Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct?

This is achievable with SPMD. As an example for FSDP sharding within the groups, you can define your mesh like: mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp')) and shard your model parameters across the fsdp axis and your inputs across all devices, e.g.

# 8 devices, 4 replicas, 2-way FSDP within each replica.
mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))

# Shard the model parameters FSDP across the `fsdp` axis.
# The parameters will be replicated along all unspecified mesh axes (i.e. the `replica` axis in this case).
xs.mark_sharding(model.weight, mesh, ('fsdp', None))

# Shard the inputs' batch dimension along all mesh axes.
xs.mark_sharding(inputs, mesh, (('replica', 'fsdp'), None))

@jonb377
Copy link
Collaborator

jonb377 commented Feb 7, 2024

It would also be worth checking out the FSDPv2 wrapper if you just want to train a bigger model using all devices: #6379

@yeounoh yeounoh self-assigned this Mar 16, 2024
@yeounoh
Copy link
Contributor

yeounoh commented Mar 16, 2024

Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct?

This is achievable with SPMD. As an example for FSDP sharding within the groups, you can define your mesh like: mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp')) and shard your model parameters across the fsdp axis and your inputs across all devices, e.g.

# 8 devices, 4 replicas, 2-way FSDP within each replica.
mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))

# Shard the model parameters FSDP across the `fsdp` axis.
# The parameters will be replicated along all unspecified mesh axes (i.e. the `replica` axis in this case).
xs.mark_sharding(model.weight, mesh, ('fsdp', None))

# Shard the inputs' batch dimension along all mesh axes.
xs.mark_sharding(inputs, mesh, (('replica', 'fsdp'), None))

To fully run 4 copies of model (2-way model parallel), you would need to shard your input on replica axis (4-way data parallel). Otherwise, the model (or the data) will be all-gathered during the computation. Note that, though, "With SPMD, is it possible to place 4 copies of a bigger model on 8 cores on the same v3 device and train them concurrently?" you will require an extra all-reduce within the group (pair of devices) -- SPMD does it for you.
cc @jonb377

I am closing this issue, @mfatih7

@yeounoh yeounoh closed this as completed Mar 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants