Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed checkpointing with mcore GPT #7116

Merged
merged 121 commits into from
Aug 28, 2023
Merged

Distributed checkpointing with mcore GPT #7116

merged 121 commits into from
Aug 28, 2023

Conversation

ericharper
Copy link
Collaborator

@ericharper ericharper commented Jul 27, 2023

This PR needs mcore dist ckpt for GPT PR to be pushed before merging.

What does this PR do ?

Adds distributed checkpointing when using mcore gpt.

Distributed checkpointing enables training runs to restart automatically with different model parallel configs.
The checkpoint is saved to disk according to the sharded_state_dict:

Below is a sample of what the checkpoint looks like on disk.

common.pt                                                     model.decoder.layers.self_attention.linear_qkv.weight                           optimizer.state.exp_avg.model.embedding.word_embeddings.weight                     optimizer.state.fp32_from_fp16.model.decoder.final_layernorm.bias
metadata.json                                                 model.embedding.position_embeddings.weight                                      optimizer.state.exp_avg.model.output_layer.weight                                  optimizer.state.fp32_from_fp16.model.decoder.final_layernorm.weight
model.decoder.final_layernorm.bias                            model.embedding.word_embeddings.weight                                          optimizer.state.exp_avg_sq.model.decoder.final_layernorm.bias                      optimizer.state.fp32_from_fp16.model.decoder.layers.input_layernorm.bias
model.decoder.final_layernorm.weight                          model.output_layer.weight                                                       optimizer.state.exp_avg_sq.model.decoder.final_layernorm.weight                    optimizer.state.fp32_from_fp16.model.decoder.layers.input_layernorm.weight
model.decoder.layers.input_layernorm.bias                     optimizer.state.exp_avg.model.decoder.final_layernorm.bias                      optimizer.state.exp_avg_sq.model.decoder.layers.input_layernorm.bias               optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc1.bias
model.decoder.layers.input_layernorm.weight                   optimizer.state.exp_avg.model.decoder.final_layernorm.weight                    optimizer.state.exp_avg_sq.model.decoder.layers.input_layernorm.weight             optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc1.weight
model.decoder.layers.mlp.linear_fc1.bias                      optimizer.state.exp_avg.model.decoder.layers.input_layernorm.bias               optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc1.bias                optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc2.bias
model.decoder.layers.mlp.linear_fc1._extra_state              optimizer.state.exp_avg.model.decoder.layers.input_layernorm.weight             optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc1.weight              optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc2.weight
model.decoder.layers.mlp.linear_fc1.weight                    optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc1.bias                optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc2.bias                optimizer.state.fp32_from_fp16.model.decoder.layers.post_self_attn_layernorm.bias
model.decoder.layers.mlp.linear_fc2.bias                      optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc1.weight              optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc2.weight              optimizer.state.fp32_from_fp16.model.decoder.layers.post_self_attn_layernorm.weight
model.decoder.layers.mlp.linear_fc2._extra_state              optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc2.bias                optimizer.state.exp_avg_sq.model.decoder.layers.post_self_attn_layernorm.bias      optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_proj.bias
model.decoder.layers.mlp.linear_fc2.weight                    optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc2.weight              optimizer.state.exp_avg_sq.model.decoder.layers.post_self_attn_layernorm.weight    optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_proj.weight
model.decoder.layers.post_self_attn_layernorm.bias            optimizer.state.exp_avg.model.decoder.layers.post_self_attn_layernorm.bias      optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_proj.bias    optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_qkv.bias
model.decoder.layers.post_self_attn_layernorm.weight          optimizer.state.exp_avg.model.decoder.layers.post_self_attn_layernorm.weight    optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_proj.weight  optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_qkv.weight
model.decoder.layers.self_attention.linear_proj.bias          optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_proj.bias    optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_qkv.bias     optimizer.state.fp32_from_fp16.model.embedding.position_embeddings.weight
model.decoder.layers.self_attention.linear_proj._extra_state  optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_proj.weight  optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_qkv.weight   optimizer.state.fp32_from_fp16.model.embedding.word_embeddings.weight
model.decoder.layers.self_attention.linear_proj.weight        optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_qkv.bias     optimizer.state.exp_avg_sq.model.embedding.position_embeddings.weight              optimizer.state.fp32_from_fp16.model.output_layer.weight
model.decoder.layers.self_attention.linear_qkv.bias           optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_qkv.weight   optimizer.state.exp_avg_sq.model.embedding.word_embeddings.weight
model.decoder.layers.self_attention.linear_qkv._extra_state   optimizer.state.exp_avg.model.embedding.position_embeddings.weight              optimizer.state.exp_avg_sq.model.output_layer.weight

Then inside a module directory we have the sharded tensor:

ls model.decoder.layers.mlp.linear_fc1.weight/
0.0.0  1.0.0  10.0.0  11.0.0  12.0.0  13.0.0  14.0.0  15.0.0  2.0.0  3.0.0  4.0.0  5.0.0  6.0.0  7.0.0  8.0.0  9.0.0

To implement distributed checkpointing for a model, the sharded_state_dict has to be defined.
This is done in megatron core so that in NeMo, if the module is from mcore, we only have to call module.sharded_state_dict().

Collection: NLP

Usage

Usage is automatic when using mcore:

model.mcore_gpt=True

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

ericharper and others added 30 commits June 7, 2023 12:00
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
@github-actions github-actions bot removed the CI label Aug 23, 2023
@github-actions github-actions bot added the CI label Aug 24, 2023
Signed-off-by: eharper <eharper@nvidia.com>
@ericharper ericharper marked this pull request as ready for review August 25, 2023 00:49
mikolajblaz and others added 3 commits August 25, 2023 17:56
* Integrate  new DistOpt state dict

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change optimizer fp32_param key

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Signed-off-by: eharper <eharper@nvidia.com>
Jenkinsfile Show resolved Hide resolved
Comment on lines +28 to +32
from megatron.core.dist_checkpointing.optimizer import (
get_param_id_to_sharded_param_map,
make_sharded_optimizer_tensor,
optim_state_to_sharding_state,
)

Check notice

Code scanning / CodeQL

Unused import

Import of 'make_sharded_optimizer_tensor' is not used.
Copy link
Collaborator

@aklife97 aklife97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@ericharper ericharper merged commit d6357fd into main Aug 28, 2023
@ericharper ericharper deleted the mcore_gpt_dist_ckpt branch August 28, 2023 21:52
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* start adding gpt from megatron core path

Signed-off-by: ericharper <complex451@gmail.com>

* set model parallel config

Signed-off-by: ericharper <complex451@gmail.com>

* use model parallel config object

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* set vp size to none if it is 1

Signed-off-by: ericharper <complex451@gmail.com>

* set vp size to none if it is 1

Signed-off-by: ericharper <complex451@gmail.com>

* add TransformerConfig

Signed-off-by: ericharper <complex451@gmail.com>

* start updating to TransformerConfig

Signed-off-by: ericharper <complex451@gmail.com>

* add todo

Signed-off-by: ericharper <complex451@gmail.com>

* revert to model parallel config

Signed-off-by: ericharper <complex451@gmail.com>

* add hidden_size to model_parallel_config

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove imports

Signed-off-by: ericharper <complex451@gmail.com>

* revert

Signed-off-by: ericharper <complex451@gmail.com>

* remove import

Signed-off-by: ericharper <complex451@gmail.com>

* small clean up

Signed-off-by: ericharper <complex451@gmail.com>

* update hidden size in peft base model, add mcore commit to jenkins

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update module args

Signed-off-by: ericharper <complex451@gmail.com>

* add config obj to flash attention tests

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove args

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove sequence parallel arg

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* add config to self

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* update args

Signed-off-by: ericharper <complex451@gmail.com>

* add config to test

Signed-off-by: ericharper <complex451@gmail.com>

* get hidden_size from config

Signed-off-by: ericharper <complex451@gmail.com>

* add try except

Signed-off-by: ericharper <complex451@gmail.com>

* use default

Signed-off-by: ericharper <complex451@gmail.com>

* update config with hidden size

Signed-off-by: ericharper <complex451@gmail.com>

* remove arg

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* comment out jenkins test

Signed-off-by: ericharper <complex451@gmail.com>

* revert import

Signed-off-by: ericharper <complex451@gmail.com>

* build transformer config

Signed-off-by: ericharper <complex451@gmail.com>

* add model to provider func

Signed-off-by: ericharper <complex451@gmail.com>

* update forward and float16 wrapper

Signed-off-by: ericharper <complex451@gmail.com>

* instantiate model parallel config after init model parallel

Signed-off-by: ericharper <complex451@gmail.com>

* set virtual rank

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add GQA config to megatron gpt model (NVIDIA#7096)

* Add GQA config in gpt config file

Signed-off-by: jasonwan <jasonwan@nvidia.com>

* Verify mcore is enabled when using GQA

Signed-off-by: jasonwan <jasonwan@nvidia.com>

---------

Signed-off-by: jasonwan <jasonwan@nvidia.com>

* revert

Signed-off-by: ericharper <complex451@gmail.com>

* update strategy and exp_manager

Signed-off-by: ericharper <complex451@gmail.com>

* update model checkpoint

Signed-off-by: ericharper <complex451@gmail.com>

* update megatron gpt model

Signed-off-by: ericharper <complex451@gmail.com>

* correct var

Signed-off-by: ericharper <complex451@gmail.com>

* check for mcore gpt and use gpt model list

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove model prefix

Signed-off-by: ericharper <complex451@gmail.com>

* setup te tp groups

Signed-off-by: ericharper <complex451@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert

Signed-off-by: eharper <eharper@nvidia.com>

* revert

Signed-off-by: eharper <eharper@nvidia.com>

* add default

Signed-off-by: eharper <eharper@nvidia.com>

* add default

Signed-off-by: eharper <eharper@nvidia.com>

* revert

Signed-off-by: eharper <eharper@nvidia.com>

* update sharded state dict for interleaved

Signed-off-by: eharper <eharper@nvidia.com>

* update load for interleaved

Signed-off-by: eharper <eharper@nvidia.com>

* check sharded state dict is nonempty

Signed-off-by: eharper <eharper@nvidia.com>

* remove import

Signed-off-by: eharper <eharper@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert comment

Signed-off-by: eharper <eharper@nvidia.com>

* inject before checking legacy ckpt

Signed-off-by: eharper <eharper@nvidia.com>

* revert

Signed-off-by: eharper <eharper@nvidia.com>

* pop arg for now

Signed-off-by: eharper <eharper@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert jenkins change

Signed-off-by: eharper <eharper@nvidia.com>

* remove device state_dict

Signed-off-by: eharper <eharper@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reduce batch size for max steps

Signed-off-by: eharper <eharper@nvidia.com>

* update megatron core commit

Signed-off-by: eharper <eharper@nvidia.com>

* Integrate dist ckpt with new DistOpt state dict v2 (NVIDIA#7281)

* Integrate  new DistOpt state dict

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change optimizer fp32_param key

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>

* update apex commit

Signed-off-by: eharper <eharper@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: jasonwan <jasonwan@nvidia.com>
Signed-off-by: eharper <eharper@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jason Wang <jasonwan@nvidia.com>
Co-authored-by: mikolajblaz <mikolajblaz@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI core Changes to NeMo Core NLP
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants