-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Auto convert moe param groups (#5354)
When using frameworks like HF Accelerate with MoE models in HF there's an issue when DeepSpeed is creating the optimizer where we have no way to automatically create the compatible MoE param groups. This PR detects if no client optimizer is set and model_parameters are passed to DeepSpeed that they are either MoE compatible or makes them MoE compatible automatically. This was never an issue previously since (1) MoE hasn't really been tested outside MDS and (2) MDS manually converts the weight-decay param groups into being MoE compatible before deepspeed.initialize. The error that is triggered if the param groups are not MoE compatible is triggered here: https://github.com/microsoft/DeepSpeed/blob/cc897ecf15fdac5437fa4a2743154dc6c1749da4/deepspeed/runtime/zero/stage_1_and_2.py#L610-L612 Tagging @tohtana and @ykim362 to help review --------- Co-authored-by: Jeff Rasley <jeff.rasley@snowflake.com>
- Loading branch information
1 parent
731fd68
commit 42a8eaa
Showing
3 changed files
with
69 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters