Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci

Signed-off-by: Marc Romeyn <marcromeyn@gmail.com>
  • Loading branch information
pre-commit-ci[bot] authored and marcromeyn committed Apr 22, 2024
1 parent 23f302e commit ec5b8c9
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 182 deletions.
33 changes: 12 additions & 21 deletions nemo/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
"""
from megatron.core import dist_checkpointing

if storage_options is not None:
raise TypeError(
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
Expand All @@ -54,16 +54,13 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
return

fs.makedirs(checkpoint_dir, exist_ok=True)
dist_checkpointing.save(sharded_state_dict=checkpoint, checkpoint_dir=str(checkpoint_dir))

@override
def load_checkpoint(
self,
path: _PATH,
sharded_state_dict=None,
map_location: Optional[Callable] = None
self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None
) -> Dict[str, Any]:
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.
Expand All @@ -80,26 +77,20 @@ def load_checkpoint(
"""
from megatron.core import dist_checkpointing

if map_location is not None:
raise ValueError(
"`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`."
)
raise ValueError("`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.")

# Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
fs = get_filesystem(path)
if not fs.exists(path):
raise FileNotFoundError(f"Checkpoint file not found: {path}")
if not fs.isdir(path):
raise ValueError(
f"Distributed checkpoints should be a directory. Found: {path}."
)

raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")

# return pl_load(path, map_location=map_location)

checkpoint = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path)
)
checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path))
checkpoint = _fix_tensors_device(checkpoint)

return checkpoint
Expand All @@ -122,7 +113,7 @@ def _fix_tensors_device(ckpt: Dict) -> Dict:
"""Ensure checkpoint tensors are on the correct device."""
assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized())
cur_dev = torch.device("cuda", index=torch.cuda.current_device())

from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace

def _fix_device(t):
Expand All @@ -139,7 +130,7 @@ def ckpt_to_dir(filepath: Union[str, Path]) -> Path:
to be used as a directory for distributed checkpoints.
"""
filepath = Path(filepath)

if not filepath.suffix == ".ckpt":
filepath = filepath.with_suffix(filepath.suffix + ".ckpt")

Expand Down Expand Up @@ -167,10 +158,10 @@ def is_distributed_ckpt(path) -> bool:
"""
from megatron.core import dist_checkpointing

checkpoint_dir = ckpt_to_dir(path)
fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
return True

return False
116 changes: 40 additions & 76 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,12 @@
import os
from collections import defaultdict
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Optional,
Protocol,
TypeVar,
)
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Protocol, TypeVar

import torch
from lightning.fabric.utilities.types import Optimizable
from torch import nn


NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE"


Expand All @@ -30,12 +21,7 @@ def sharded_state_dict(self, prefix=""):


def init_parallel_ranks(
world_size: int,
global_rank: int,
local_rank: int,
parallel_config: "ModelParallelConfig",
seed=1234,
fp8=False,
world_size: int, global_rank: int, local_rank: int, parallel_config: "ModelParallelConfig", seed=1234, fp8=False,
) -> None:
"""
Initializes the parallel ranks for distributed training.
Expand All @@ -52,17 +38,13 @@ def init_parallel_ranks(
seed (int, optional): The seed for random number generation. Defaults to 1234.
fp8 (bool, optional): Whether to use fp8 precision for model parameters. Defaults to False.
"""
from nemo.collections.nlp.modules.common.megatron.megatron_init import (
initialize_model_parallel_for_nemo,
)
from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo
from nemo.utils import AppState

app_state = AppState()

if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true":
init_world_size = (
app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size
)
init_world_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size
init_global_rank = app_state.global_rank
init_local_rank = app_state.local_rank
else:
Expand All @@ -78,9 +60,7 @@ def init_parallel_ranks(
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
seed=seed,
pipeline_model_parallel_split_rank=getattr(
parallel_config, "pipeline_model_parallel_split_rank", None
),
pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None),
use_fp8=fp8,
init_mpi_proc_group=getattr(parallel_config, "ub_tp_comm_overlap", False),
# apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
Expand All @@ -91,6 +71,7 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
"""Initializes Megatron-LM model parallel if using model parallelism."""
import torch.distributed
from megatron.core import parallel_state

from nemo.utils import AppState

app_state = AppState()
Expand Down Expand Up @@ -134,18 +115,18 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
if hasattr(child, "set_tensor_parallel_group"):
tp_group = parallel_state.get_tensor_model_parallel_group()
child.set_tensor_parallel_group(tp_group)


@contextmanager
def megatron_lazy_init_context(config) -> Generator[None, None, None]:
def monkey_patched(c):
return {"device": "meta"}

from megatron.core.transformer.custom_layers import transformer_engine as _te

original = _te._get_extra_te_kwargs # noqa: SLF001
_te._get_extra_te_kwargs = monkey_patched # noqa: SLF001
original = _te._get_extra_te_kwargs # noqa: SLF001
_te._get_extra_te_kwargs = monkey_patched # noqa: SLF001

_orig_perform_initialization = config.perform_initialization
_orig_use_cpu_initialization = config.use_cpu_initialization

Expand All @@ -154,11 +135,11 @@ def monkey_patched(c):

yield

_te._get_extra_te_kwargs = original # noqa: SLF001
_te._get_extra_te_kwargs = original # noqa: SLF001
config.perform_initialization = _orig_perform_initialization
config.use_cpu_initialization = _orig_use_cpu_initialization


@contextmanager
def megatron_cpu_init_context(config) -> Generator[None, None, None]:
_orig_use_cpu_initialization = config.use_cpu_initialization
Expand All @@ -168,7 +149,7 @@ def megatron_cpu_init_context(config) -> Generator[None, None, None]:
yield

config.use_cpu_initialization = _orig_use_cpu_initialization


ModelT = TypeVar("ModelT", bound=nn.Module)

Expand All @@ -182,7 +163,7 @@ class GradScaler(torch.cuda.amp.GradScaler):

def __init__(
self,
init_scale=2.0**16,
init_scale=2.0 ** 16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
Expand All @@ -208,17 +189,13 @@ def _unscale_grads_(self, optimizer, *args):

def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
from megatron.core import parallel_state

retval = None
found_inf = torch.cuda.FloatTensor(
[sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]
)
found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())])

# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_model_parallel_group(),
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(),
)

if found_inf.item() == 0:
Expand All @@ -236,7 +213,7 @@ def update(self, new_scale=None):
3. Apply hysteresis to grad scale update.
"""
from megatron.core import parallel_state

if not self._enabled:
return

Expand All @@ -248,8 +225,7 @@ def update(self, new_scale=None):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = (
"new_scale should be a float or a 1-element torch.cuda.FloatTensor with"
" requires_grad=False."
"new_scale should be a float or a 1-element torch.cuda.FloatTensor with" " requires_grad=False."
)
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
Expand All @@ -270,19 +246,15 @@ def update(self, new_scale=None):

# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf_combined,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_model_parallel_group(),
found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(),
)

if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf = found_infs[i]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_model_parallel_group(),
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(),
)
found_inf_combined += found_inf

Expand All @@ -291,7 +263,7 @@ def update(self, new_scale=None):
if self._hysteresis_tracker <= 0:
# When hysteresis becomes zero, follow the native grad scale update rule.
# Increase scale and reset growth tracker
torch._amp_update_scale_( # noqa: SLF001
torch._amp_update_scale_( # noqa: SLF001
_scale,
_growth_tracker,
found_inf_combined,
Expand All @@ -306,7 +278,7 @@ def update(self, new_scale=None):
# When no inf found, follow the native grad scale update rule.
# Increment growth_tracker, update scale when growth tracker reaches the interval, and
# reset the hysteresis tracker.
torch._amp_update_scale_( # noqa: SLF001
torch._amp_update_scale_( # noqa: SLF001
_scale,
_growth_tracker,
found_inf_combined,
Expand All @@ -318,7 +290,7 @@ def update(self, new_scale=None):

# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(
torch.cuda.amp.grad_scaler._refresh_per_optimizer_state # noqa: SLF001
torch.cuda.amp.grad_scaler._refresh_per_optimizer_state # noqa: SLF001
)

def state_dict(self):
Expand Down Expand Up @@ -383,21 +355,19 @@ def enable_nvidia_optimizations() -> None:
# NVFUSER available starting with 21.11
if NVIDIA_TORCH_MAJOR >= 21 or (NVIDIA_TORCH_MAJOR == 21 and NVIDIA_TORCH_MINOR >= 11):
# NVFUSER
torch._C._jit_set_profiling_executor(True) # noqa: SLF001
torch._C._jit_set_profiling_mode(True) # noqa: SLF001
torch._C._jit_override_can_fuse_on_cpu(False) # noqa: SLF001
torch._C._jit_override_can_fuse_on_gpu(False) # noqa: SLF001
torch._C._jit_set_texpr_fuser_enabled(False) # noqa: SLF001
torch._C._jit_set_profiling_executor(True) # noqa: SLF001
torch._C._jit_set_profiling_mode(True) # noqa: SLF001
torch._C._jit_override_can_fuse_on_cpu(False) # noqa: SLF001
torch._C._jit_override_can_fuse_on_gpu(False) # noqa: SLF001
torch._C._jit_set_texpr_fuser_enabled(False) # noqa: SLF001
# torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False) # noqa: SLF001
torch._C._debug_set_autodiff_subgraph_inlining(False) # noqa: SLF001
else:
# Not a Nvidia container. NVFUSER Dependency check is on users
pass


def optimizer_sharded_state_dict(
model: SharedStateDictProtocol, optimizer: Optimizable
) -> Dict[str, torch.Tensor]:
def optimizer_sharded_state_dict(model: SharedStateDictProtocol, optimizer: Optimizable) -> Dict[str, torch.Tensor]:
"""
Sharded state dictionary for an MainParamsOptimizerWrapper.
Used to save and load the optimizer state when training with distributed_checkpoint.
Expand All @@ -413,21 +383,20 @@ def optimizer_sharded_state_dict(
make_sharded_optimizer_tensor,
optim_state_to_sharding_state,
)

from nemo.core.optim import MainParamsOptimizerWrapper
from nemo.core.optim.optimizers import init_optimizer_states

model_sharded_state_dict = model.sharded_state_dict()

# remove _extra_state
model_sharded_state_dict = {
key: value
for key, value in model_sharded_state_dict.items()
if not key.endswith("_extra_state")
key: value for key, value in model_sharded_state_dict.items() if not key.endswith("_extra_state")
}

if hasattr(optimizer, "sharded_state_dict"):
return optimizer.sharded_state_dict(model_sharded_state_dict)

if not isinstance(optimizer, MainParamsOptimizerWrapper):
# Regular optimizer, e.g. Adam or FusedAdam
init_optimizer_states(optimizer)
Expand All @@ -447,9 +416,7 @@ def optimizer_sharded_state_dict(
)

# Convert fp32_from_fp16_params
assert len(optimizer_state_dict["fp32_from_fp16_params"]) == len(
optimizer_state_dict["optimizer"]["param_groups"]
)
assert len(optimizer_state_dict["fp32_from_fp16_params"]) == len(optimizer_state_dict["optimizer"]["param_groups"])

def get_safe(param_id):
try:
Expand All @@ -459,14 +426,11 @@ def get_safe(param_id):

optimizer_state_dict["fp32_from_fp16_params"] = [
[
make_sharded_optimizer_tensor(
get_safe(param_id), fp32_param, prefix="optimizer.state.fp32_param"
)
make_sharded_optimizer_tensor(get_safe(param_id), fp32_param, prefix="optimizer.state.fp32_param")
for param_id, fp32_param in zip(state_group["params"], fp32_group)
]
for fp32_group, state_group in zip(
optimizer_state_dict["fp32_from_fp16_params"],
optimizer_state_dict["optimizer"]["param_groups"],
optimizer_state_dict["fp32_from_fp16_params"], optimizer_state_dict["optimizer"]["param_groups"],
)
]

Expand Down
Loading

0 comments on commit ec5b8c9

Please sign in to comment.