Skip to content

Commit

Permalink
[NeMo-UX] Use custom BatchProgress class which does not restore sta…
Browse files Browse the repository at this point in the history
…tes (NVIDIA#10383)

* [WIP] fix batch sampler to match megatron dataloaders

Signed-off-by: ashors1 <ashors@nvidia.com>

* make batchprogress configurable

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

---------

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors1@users.noreply.github.com>
Co-authored-by: ashors1 <ashors1@users.noreply.github.com>
Co-authored-by: Shriya Rishab <69161273+ShriyaPalsamudram@users.noreply.github.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
  • Loading branch information
4 people authored Sep 15, 2024
1 parent 3a60491 commit b5798de
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
8 changes: 8 additions & 0 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import (
_MegatronBatchProgress,
ckpt_to_dir,
create_checkpoint_io,
fix_progress_bar,
Expand Down Expand Up @@ -73,13 +74,15 @@ def __init__(
ckpt_load_optimizer: bool = True,
ckpt_save_optimizer: bool = True,
data_sampler=None,
overwrite_batch_progress: bool = True,
**kwargs,
):
super().__init__(auto_wrap_policy=auto_wrap_policy, state_dict_type=state_dict_type, **kwargs)

self.data_sampler = data_sampler
self.ckpt_load_optimizer = ckpt_load_optimizer
self.ckpt_save_optimizer = ckpt_save_optimizer
self.overwrite_batch_progress = overwrite_batch_progress

@override
def setup_environment(self) -> None:
Expand All @@ -92,6 +95,11 @@ def setup(self, trainer: pl.Trainer) -> None:
self.trainer = trainer
setup_data_sampler(self.trainer)
fix_progress_bar(trainer)

trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING and self.overwrite_batch_progress:
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()

super().setup(trainer)

def _get_loss_reduction(self, step_type: str):
Expand Down
7 changes: 7 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from nemo.lightning.pytorch.callbacks import ModelTransform
from nemo.lightning.pytorch.strategies.utils import (
RestoreConfig,
_MegatronBatchProgress,
ckpt_to_dir,
create_checkpoint_io,
fix_progress_bar,
Expand Down Expand Up @@ -152,6 +153,8 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
that prints the metrics to stdout. Suitable for non-interactive settings.
progress_interval (int): How frequently to print progress to stdout. Only used when
replace_progress_bar is True.
overwrite_batch_progress (bool): Whether to overwrite _BatchProgress class used in PTL by default with
_MegatronBatchProgress. This should be True whenever you're using a Megatron-based dataset.
**kwargs: Additional keyword arguments.
Note:
Expand Down Expand Up @@ -194,6 +197,7 @@ def __init__(
replace_progress_bar: bool = True,
progress_interval: int = 1,
restore_config: Optional[RestoreConfig] = None,
overwrite_batch_progress: bool = True,
**kwargs,
) -> None:
super().__init__(
Expand Down Expand Up @@ -234,6 +238,7 @@ def __init__(

self.replace_progress_bar = replace_progress_bar
self.progress_interval = progress_interval
self.overwrite_batch_progress = overwrite_batch_progress

self.restore_config = restore_config

Expand Down Expand Up @@ -331,6 +336,8 @@ def setup(self, trainer: pl.Trainer) -> None:
self.configure_ddp()

trainer.fit_loop.epoch_loop.automatic_optimization = _MegatronAutomaticOptimization(trainer)
if self.overwrite_batch_progress:
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()

import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

Expand Down
10 changes: 10 additions & 0 deletions nemo/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor
from megatron.core.transformer.utils import _get_extra_state_offsets
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.loops.progress import _BatchProgress
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import override

from nemo.lightning import _strategy_lib
from nemo.lightning.io.pl import MegatronCheckpointIO
Expand All @@ -46,6 +48,14 @@ class RestoreConfig:
load_artifacts: bool = True


class _MegatronBatchProgress(_BatchProgress):
@override
def load_state_dict(self, state_dict: dict) -> None:
## in megatron, we want to start the batch progress over when
## restoring from a checkpoint
return


def setup_parallel_ranks(strategy: pl.strategies.Strategy):
from megatron.core.model_parallel_config import ModelParallelConfig

Expand Down

0 comments on commit b5798de

Please sign in to comment.