Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception running inference with MCore Distributed Checkpoint with different TP setting than training #8460

Closed
ryxli opened this issue Feb 20, 2024 · 7 comments
Assignees
Labels
bug Something isn't working stale

Comments

@ryxli
Copy link
Contributor

ryxli commented Feb 20, 2024

Describe the bug

A clear and concise description of what the bug is.

I have an mcore distributed checkpoint trained with PP=1, TP=1. When running inference with this distributed checkpoint, when I set the TP to higher than 1, it results in exceptions and inconsistent hangs.

When running inference with mcore distributed checkpoint with a tp > 1, there is an exception raised for:

Traceback (most recent call last):
  File "/workspace/src/3rdparty/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py", line 271, in main
    model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer, **kwargs)
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/models/nlp_model.py", line 397, in load_from_checkpoint
    checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
  File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 92, in load
    validate_sharding_integrity(nested_values(sharded_state_dict))
  File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 306, in validate_sharding_integrity
    _validate_sharding_for_key(shardings)
  File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 344, in _validate_sharding_for_key
    raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
megatron.core.dist_checkpointing.core.CheckpointingException: Invalid access pattern for ShardedTensor(key='model.embedding.word_embeddings.weight')



Error executing job with overrides: ['inference.greedy=True', 'inference.add_BOS=True', 'trainer.devices=8', 'trainer.num_nodes=1', 'tensor_model_parallel_size=8', 'pipeline_model_parallel_size=1']
Traceback (most recent call last):
  File "/workspace/src/3rdparty/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py", line 308, in main
    response = model.generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1414, in generate
    return megatron_gpt_generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 127, in megatron_gpt_generate
    output = generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 645, in generate
    output = synced_generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 510, in synced_generate
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 888, in sample_sequence_batch
    torch.distributed.broadcast(done, src, group)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1947, in broadcast
    work = group.broadcast([tensor], opts)
torch.distributed.DistNetworkError: Broken pipe

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.

  1. Have an existing distributed mcore gpt checkpoint saved in a directory trained with TP=1, PP=1
  2. Pass checkpoint_dir, checkpoint_name into megatron_gpt_eval.py
  3. [Optional] In my case, I enabled activation check pointing so I also had to pass kwargs["activations_checkpoint_granularity"] = None and kwargs["activations_checkpoint_method"] = None into MegatronGPTModel.load_from_checkpoint
  4. Run inference with TP > 1 on a device with multiple gpus

This results in an exception everytime, and a percentage of runs are able to complete, but most of the time the process ends up hanging.

Expected behavior

A clear and concise description of what you expected to happen.

With mcore distributed checkpointing, I expect to be able to load an mcore model with different model parallel configs without any error using the example scripts for inference.

Environment overview (please complete the following information)

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)] AWS p4de.24xlarge (A100)
  • Method of NeMo install: [pip install or from source]. From source, using NeMo r1.23.0 branch and dependencies
  • If method of install is [Docker], provide docker pull & docker run commands used

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version - Ubuntu 22.04.3 LTS (pytorch 23.12 container)
  • PyTorch version - 2.2.0a0+81ea7a4
  • Python version - Python 3.10.12

Additional context

Attaching full logs in files.

Add any other context about the problem here.
Example: GPU model
tp8_error.log

@ryxli ryxli added the bug Something isn't working label Feb 20, 2024
@ryxli
Copy link
Contributor Author

ryxli commented Mar 20, 2024

@dimapihtar @ericharper
Same issue occurs when trying to load the distributed checkpoint for continued training / sft.

Loading distributed checkpoint with a single A100 works fine, with gbs=1,tp=1,pp=1,mbs=1.
When scaling to 8 gpus, and changing gbs to 8, loading the checkpoint fails.

Has the team been able to reproduce this internally?
Currently using the following commits:
Megatron-LM ad53b1e38689a0ceed75ade7821f4e6c7554abb4
NeMo 9b64e39
TransformerEngine: da30634a6c9ccdbb6c587b6c93b1860e4b038204

[0]:    model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name)
[0]:  File "/workspace/src/3rdparty/NeMo/build/__editable__.nemo_toolkit-1.23.0rc0-py3-none-any/nemo/collections/nlp/models/nlp_model.py", line 397, in load_from_checkpoint
[0]:    checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
[0]:  File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 99, in load
[0]:    validate_sharding_integrity(nested_values(sharded_state_dict))
[0]:  File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 346, in validate_sharding_integrity
[0]:    _validate_sharding_for_key(shardings)
[0]:  File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 384, in _validate_sharding_for_key
[0]:    raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
[0]:megatron.core.dist_checkpointing.core.CheckpointingException: Invalid access pattern for ShardedTensor(key='model.embedding.word_embeddings.weight')

@ryxli
Copy link
Contributor Author

ryxli commented Mar 20, 2024

Some debug logs as well, let me know if anything else could be useful to include:

> rank_sharing[0][1]

[(0, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (1, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (2, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (3, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (4, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (5, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (6, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (7, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None))]
> torch.zeros(rank_sharding[0][1].axis_fragmentations)
tensor([[0.]])

The assertion is checking that the shard is access once on all ranks:

        if not torch.all(shard_access_cnt == 1):
            logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
            raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')

But this is the behavior I observe:

> shard_access_cnt
tensor([[8]], dtype=torch.int32)

>for rank, sharding in rank_sharding:
    print(sharding.replica_id)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)

Notice that there is a TODO listed to check the shard_access_cnt of replicas as well: https://github.com/NVIDIA/Megatron-LM/blob/0fecd76e995c136021d478c6c52caa57c2f9aa25/megatron/core/dist_checkpointing/serialization.py#L444C1-L447C59

But there should only be one replica per rank, which is the one above.

Since I'm setting tensor parallel to 1 here, and gbs to 8, I think that the embedding_weights should not be expected to be sharded since there is only a single TP group with worldsize=8. I believe that during training, this dist ckpt also only used tp=1. So it should be fine if the embedding weights are fully replicated across all of the ranks?

Off topic:
I also notice that when loading from the distributed checkpoint, it loads the saved optimizer states as well, which is probably expected behavior. Is there any reference or guidance on how to load without these saved optimizer states? Or does it not matter if it gets overwritten later on?

@ryxli
Copy link
Contributor Author

ryxli commented Mar 20, 2024

Seem to have figured out the root cause.

Background

When loading from distributed checkpoint, we call NLPModel.load_from_checkpoint(..) ref
For distributed checkpoint, loading the state_dict gets deferred until the class is initialized

The current logic is as follows:

            if 'cfg' in kwargs:
                model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
            else:
                model = ptl_load_state(cls, checkpoint, strict=strict, cfg=cfg, **kwargs)
                # cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg

            if checkpoint_dir is not None:
                sharded_state_dict = model.sharded_state_dict()
                checkpoint['state_dict'] = sharded_state_dict
                # dist checkpointing needs torch.distributed to load the checkpoint
                if parallel_state.is_unitialized():

                    def dummy():
                        return

                    if model.trainer.strategy.launcher is not None:
                        model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
                    model.trainer.strategy.setup_environment()

For MegatronGPTModel(MCoreGPTModel), this instantiates the LanguageModelEmbedding() if cfg.pre_process is true.

On this line sharded_state_dict = model.sharded_state_dict(), we call MegatronGPTModel.sharded_state_dict(...) which in turn calls VocabParallelEmbedding.sharded_state_dict(...):

    def sharded_state_dict(
        self, prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = ()
    ) -> ShardedStateDict:
        """ Non-default implementation for embeddings due to `allow_shape_mismatch` param """
        state_dict = self.state_dict(prefix='', keep_vars=True)

        weight_prefix = f'{prefix}weight'
        return {
            weight_prefix: make_tp_sharded_tensor_for_checkpoint(
                tensor=state_dict['weight'],
                key=weight_prefix,
                allow_shape_mismatch=True,
                prepend_offsets=sharded_offsets,
            )
        }

make_tp_sharded_tensor_for_checkpoint:

def make_tp_sharded_tensor_for_checkpoint(
    tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
    prepend_axis_num = len(prepend_offsets)
    if replica_id is None:
        replica_id = (0, 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))
    ...

The Issue

At this point, parallel_state is not yet available since torch.distributed.is_initialized() == False, so any data_parallel ranks return as 0 and incorrectly sets all replica_ids as (0,0,0), which in turn causes the assertion mentioned above to fail when validating the sharded tensors.

Potential Resolution

It seems like an easy fix would be to move the initialization logic before the model.sharded_state_dict() call, but I'm not informed enough to know what the downstream impact would be.

            if checkpoint_dir is not None:
                sharded_state_dict = model.sharded_state_dict()
                checkpoint['state_dict'] = sharded_state_dict
                # dist checkpointing needs torch.distributed to load the checkpoint
                if parallel_state.is_unitialized():

                    def dummy():
                        return

                    if model.trainer.strategy.launcher is not None:
                        model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
                    model.trainer.strategy.setup_environment()

Can you please let me know if this is a correct understanding?

@dimapihtar @ericharper

Copy link
Contributor

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Apr 20, 2024
@ryxli
Copy link
Contributor Author

ryxli commented Apr 20, 2024

any updates on this issue?

@github-actions github-actions bot removed the stale label Apr 21, 2024
Copy link
Contributor

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

@github-actions github-actions bot added the stale label May 21, 2024
Copy link
Contributor

This issue was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale
Projects
None yet
Development

No branches or pull requests

2 participants