Add auto variable sharding for all backbones/tasks #1689
Labels
Gemma
Gemma model specific issues
team-created
Issues created by Keras Hub team as part of development roadmap.
type:feature
New feature or request
We want model parallelism to be easy to use across the library. At a high level, a user should express their hardware, and (possibly) desired model parallel vs data parallel split for the device grid.
Currently, we have a auto layer helper for Gemma here, but it is not a salable design. The correct layout map will depend on the config of the model. E.g. you need to shard a Gemma model with multi-head-attention differently then multi-query-attention.
I think there's two main direction we can go with the API.
One potential high-level API would be to directly take in a device mesh when constructing the model. For both 1) and 2), we could support an API something like this...
For 1) we would need to enter into a LayoutMap scope after loading the config for a model but before loading the weights. For 2) it would depend on the details of the autosharding API we use.
The text was updated successfully, but these errors were encountered: