Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed optimizer support for multiple dtypes #1721

Merged

Conversation

timmoon10
Copy link
Contributor

This PR adds logic so that the parameters can be configured with different dtypes for the grad reduce-scatters and param all-gathers. I have two NeMo use-cases in mind:

  • For GPT, most grads can be reduced in BF16 but embedding grads need to be reduced in FP32 to avoid learning issues.
  • For FP8 support, weight matrices can be stored in FP8 while most other parameters (e.g. biases, layernorm params, embeddings) are in BF16. We would like to handle FP8 and BF16 param all-gathers in the same optimizer.

This also includes changes from #1719, which returns the state dict on all ranks and not just rank 0. We can either merge that first and rebase, or merge this and close #1719.

Rough draft.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
…checkpoint-allgather

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Handle case where we load old checkpoints without multi-dtype support

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

apex/contrib/test/optimizers/test_dist_adam.py Outdated Show resolved Hide resolved
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@crcrpar crcrpar merged commit 52e18c8 into NVIDIA:master Sep 6, 2023
@timmoon10 timmoon10 deleted the distopt-multi-dtype-checkpoint-allgather branch September 11, 2023 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants