Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Adding GPTModel & MockDataModule #9011

Merged
merged 16 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions nemo/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@

"""
from megatron.core import dist_checkpointing
<<<<<<< HEAD
Fixed Show fixed Hide fixed

=======

>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
if storage_options is not None:
raise TypeError(
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
Expand All @@ -54,13 +58,24 @@
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
return
<<<<<<< HEAD

=======

>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
fs.makedirs(checkpoint_dir, exist_ok=True)
dist_checkpointing.save(sharded_state_dict=checkpoint, checkpoint_dir=str(checkpoint_dir))

@override
def load_checkpoint(
<<<<<<< HEAD
self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None
=======
self,
path: _PATH,
sharded_state_dict=None,
map_location: Optional[Callable] = None
>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
) -> Dict[str, Any]:
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.

Expand All @@ -77,20 +92,40 @@

"""
from megatron.core import dist_checkpointing
<<<<<<< HEAD

if map_location is not None:
raise ValueError("`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.")
=======

if map_location is not None:
raise ValueError(
"`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`."
)
>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)

# Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
fs = get_filesystem(path)
if not fs.exists(path):
raise FileNotFoundError(f"Checkpoint file not found: {path}")
if not fs.isdir(path):
<<<<<<< HEAD
raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")

# return pl_load(path, map_location=map_location)

checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path))
=======
raise ValueError(
f"Distributed checkpoints should be a directory. Found: {path}."
)

# return pl_load(path, map_location=map_location)

checkpoint = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path)
)
>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
checkpoint = _fix_tensors_device(checkpoint)

return checkpoint
Expand All @@ -113,7 +148,11 @@
"""Ensure checkpoint tensors are on the correct device."""
assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized())
cur_dev = torch.device("cuda", index=torch.cuda.current_device())
<<<<<<< HEAD

=======

>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace

def _fix_device(t):
Expand All @@ -130,7 +169,11 @@
to be used as a directory for distributed checkpoints.
"""
filepath = Path(filepath)
<<<<<<< HEAD

=======

>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
if not filepath.suffix == ".ckpt":
filepath = filepath.with_suffix(filepath.suffix + ".ckpt")

Expand Down Expand Up @@ -158,10 +201,18 @@

"""
from megatron.core import dist_checkpointing
<<<<<<< HEAD

=======

>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
checkpoint_dir = ckpt_to_dir(path)
fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
return True
<<<<<<< HEAD

=======

>>>>>>> f8ef68139 (Move over _strategy_liMegatronCheckpointIO)
return False
24 changes: 24 additions & 0 deletions nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Union

from lightning.pytorch import plugins as _pl_plugins
from lightning_fabric.plugins.environments import slurm

from nemo.lightning.base import get_vocab_size, teardown
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.strategies import MegatronStrategy


# We monkey patch because nvidia uses a naming convention for SLURM jobs
def _is_slurm_interactive_mode():
job_name = slurm.SLURMEnvironment.job_name()
return job_name is None or job_name.endswith("bash") or job_name.endswith("interactive")


slurm._is_slurm_interactive_mode = _is_slurm_interactive_mode # noqa: SLF001


_pl_plugins._PLUGIN_INPUT = Union[_pl_plugins._PLUGIN_INPUT, _data_sampler.DataSampler] # noqa: SLF001


__all__ = ["MegatronStrategy", "MegatronDataSampler", "get_vocab_size", "teardown"]
51 changes: 51 additions & 0 deletions nemo/lightning/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import gc
import os
from pathlib import Path
from typing import Optional

import torch
import torch.distributed
from pytorch_lightning import Trainer
from torch import nn


NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE"
Fixed Show fixed Hide fixed
DEFAULT_NEMO_CACHE_HOME = Path.home() / ".cache" / "nemo"
NEMO_CACHE_HOME = Path(os.getenv("NEMO_HOME", DEFAULT_NEMO_CACHE_HOME))
DEFAULT_NEMO_DATASETS_CACHE = NEMO_CACHE_HOME / "datasets"
NEMO_DATASETS_CACHE = Path(os.getenv("NEMO_DATASETS_CACHE", DEFAULT_NEMO_DATASETS_CACHE))
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed


def get_vocab_size(config, vocab_size: int, make_vocab_size_divisible_by: int = 128,) -> int:
from nemo.utils import logging

after = vocab_size
multiple = make_vocab_size_divisible_by * config.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
logging.info(
f"Padded vocab_size: {after}, original vocab_size: {vocab_size}, dummy tokens:" f" {after - vocab_size}."
)

return after


def teardown(trainer: Trainer, model: Optional[nn.Module] = None) -> None:
# Destroy torch distributed
if torch.distributed.is_initialized():
from megatron.core import mpu

mpu.destroy_model_parallel()
torch.distributed.destroy_process_group()

trainer._teardown() # noqa: SLF001
if model is not None:
for obj in gc.get_objects():
if torch.is_tensor(obj) and obj.is_cuda:
del obj

gc.collect()
torch.cuda.empty_cache()


__all__ = ["get_vocab_size", "teardown"]
Loading
Loading