Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows non-strict load with distributed checkpoints #9715

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=de1b7c223303f6ba21e0540f27361334116efcbc
ARG MCORE_TAG=c0164bcfd4f8213a10a6b1e47ef80721a68b4fb6
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand Down Expand Up @@ -69,14 +69,14 @@ git clone https://github.com/state-spaces/mamba.git && \
git checkout v2.0.3 && \
python setup.py install && \
cd .. && \
rm -rf mamba
rm -rf mamba

git clone https://github.com/Dao-AILab/causal-conv1d && \
cd causal-conv1d && \
git checkout v1.2.2.post1 && \
python setup.py install && \
cd .. && \
rm -rf causal-conv1d
rm -rf causal-conv1d

EOF

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ model:
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files.
dist_ckpt_load_strictness: null # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (default - try loading without any check), log_all (log mismatches), raise_all (raise mismatches)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
50 changes: 22 additions & 28 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from time import time
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO
Expand All @@ -44,6 +44,7 @@
FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.parallel_state import get_data_parallel_group

HAVE_MEGATRON_CORE = True
Expand Down Expand Up @@ -188,6 +189,9 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
load_directly_on_device (bool, optional): if True, loads the weights directly
on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed
always loads on device). Defaults to True.
load_strictness (StrictHandling, optional): defines loading strictness.
If not None, overwrites the `strict` flag passed to `load_checkpoint`.
Defaults to None.
async_save (bool): whether to save asynchronously. Should be set to True if
this class will be wrapped with AsyncFinalizableCheckpointIO.
torch_dist_multiproc (int, optional): number of extra processes per rank
Expand All @@ -202,6 +206,7 @@ def __init__(
self,
save_ckpt_format: str,
load_directly_on_device: bool = True,
load_strictness: Optional['StrictHandling'] = None,
async_save: bool = False,
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
Expand All @@ -215,6 +220,7 @@ def __init__(

self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device
self.load_strictness = load_strictness
self.async_save = async_save
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
Expand All @@ -238,6 +244,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
Expand Down Expand Up @@ -275,7 +282,7 @@ def load_checkpoint(
path: _PATH,
map_location: Optional[Any] = None,
sharded_state_dict: Dict[str, Any] = None,
strict: Optional[bool] = True,
strict: Union[None, bool, 'StrictHandling'] = None,
validate_access_integrity: Optional[bool] = True,
) -> Dict[str, Any]:
"""Loads a distributed checkpoint.
Expand All @@ -287,6 +294,10 @@ def load_checkpoint(
defines the loading procedure for the distributed checkpoint.
Defaults to None to comply with the CheckpointIO interface,
but it's a required argument.
strict (bool, StrictHandling, optional): adjust load strictness. bool value
is translated to StrictHandling instance. Gets overwritten by
`self.load_strictness`. Defaults to None. If `self.load_strictness`
is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED.

Returns:
Dist[str, Any]: loaded checkpoint.
Expand All @@ -311,40 +322,23 @@ def load_checkpoint(
if sharded_strategy is not None:
logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')

if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
if isinstance(strict, bool):
strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
if self.load_strictness is not None:
# Overwrites function argument
strict = self.load_strictness
if strict is None:
# Default behavior
strict = StrictHandling.ASSUME_OK_UNEXPECTED

return dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
checkpoint_dir=path,
sharded_strategy=sharded_strategy,
validate_access_integrity=validate_access_integrity,
strict=strict,
)

def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
missing_keys = []
unexpected_keys = []

def should_remove_missing_sharded_base(x: Any):
if isinstance(x, ShardedBase):
if x.key in ckpt_sharded_metadata:
loaded_keys.append(x.key)
return False
else:
unexpected_keys.append(x.key)
return True
return False

_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict

@_debug_time('DistributedCheckpointIO.remove_checkpoint')
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove a distributed checkpoint.
Expand Down
Loading