Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto variable sharding for all backbones/tasks #1689

Open
mattdangerw opened this issue Jul 10, 2024 · 1 comment
Open

Add auto variable sharding for all backbones/tasks #1689

mattdangerw opened this issue Jul 10, 2024 · 1 comment
Assignees
Labels
Gemma Gemma model specific issues team-created Issues created by Keras Hub team as part of development roadmap. type:feature New feature or request

Comments

@mattdangerw
Copy link
Member

mattdangerw commented Jul 10, 2024

We want model parallelism to be easy to use across the library. At a high level, a user should express their hardware, and (possibly) desired model parallel vs data parallel split for the device grid.

Currently, we have a auto layer helper for Gemma here, but it is not a salable design. The correct layout map will depend on the config of the model. E.g. you need to shard a Gemma model with multi-head-attention differently then multi-query-attention.

I think there's two main direction we can go with the API.

  1. Write our own manual sharing for a model given the config for a model. Do this for all models (most will have the same recipe, especially for our transformer models).
  2. Use some form of autosharding functionality in Jax, or add a autosharding API to Keras. In this case, we will not need to write the sharding recipes ourselves per model.

One potential high-level API would be to directly take in a device mesh when constructing the model. For both 1) and 2), we could support an API something like this...

device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), devices=devices)
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset(
    "gemma_2b_en",
    device_mesh=device_mesh,
)

For 1) we would need to enter into a LayoutMap scope after loading the config for a model but before loading the weights. For 2) it would depend on the details of the autosharding API we use.

@mattdangerw mattdangerw added the type:feature New feature or request label Jul 10, 2024
@github-actions github-actions bot added the Gemma Gemma model specific issues label Jul 10, 2024
@mattdangerw
Copy link
Member Author

We should also keep the docstring for the method on the Backbone base class. And factor out all the error checking somehow. That way the per model code here could be really minimal.

@mattdangerw mattdangerw changed the title Add get_layout_map() for all backbones Add auto variable sharding for all backbones/tasks Sep 19, 2024
@sachinprasadhs sachinprasadhs added the team-created Issues created by Keras Hub team as part of development roadmap. label Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues team-created Issues created by Keras Hub team as part of development roadmap. type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants