From 5fa8dbc6ac93c1cd82b7dafd1dfb2cb899654979 Mon Sep 17 00:00:00 2001 From: Marc Romeijn Date: Thu, 3 Oct 2024 01:15:57 -0700 Subject: [PATCH 1/2] Moving steps to MegatronParallel to improve UX for Fabric Signed-off-by: Marc Romeijn --- nemo/lightning/megatron_parallel.py | 171 +++++++++++++++++- .../pytorch/strategies/megatron_strategy.py | 49 +---- 2 files changed, 176 insertions(+), 44 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 096c7728d4a1..6965530ead3f 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -26,6 +26,7 @@ Generic, Iterable, Iterator, + Mapping, List, Optional, Protocol, @@ -35,9 +36,9 @@ Union, cast, runtime_checkable, + TYPE_CHECKING, ) -import pytorch_lightning as pl import torch import torch.distributed from megatron.core import parallel_state @@ -51,6 +52,10 @@ DataT = TypeVar("DataT", Tensor, Dict[str, Tensor], Sequence[Tensor]) ModelT = TypeVar("ModelT", bound=nn.Module) T = TypeVar('T') +STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] + +if TYPE_CHECKING: + import pytorch_lightning as pl @runtime_checkable @@ -293,6 +298,134 @@ def forward( self.callbacks.event("on_megatron_step_end", step=step, microbatch_outputs=microbatch_outputs, reduced=reduced) return reduced + + def training_step( + self, + data: DataT, + data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, + forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, + loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + num_microbatches: Optional[int] = None, + **kwargs + ) -> STEP_OUTPUT: + return self._step( + "training", + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + forward_only=False, + **kwargs + ) + + def validation_step( + self, + data: DataT, + data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, + forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, + loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + num_microbatches: Optional[int] = None, + **kwargs + ) -> STEP_OUTPUT: + return self._step( + "validation", + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + forward_only=True, + **kwargs + ) + + def test_step( + self, + data: DataT, + data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, + forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, + loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + num_microbatches: Optional[int] = None, + **kwargs + ) -> STEP_OUTPUT: + return self._step( + "test", + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + forward_only=True, + **kwargs + ) + + def predict_step( + self, + data: DataT, + data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, + forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, + loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + num_microbatches: Optional[int] = None, + **kwargs + ) -> STEP_OUTPUT: + return self._step( + "predict", + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + forward_only=True, + **kwargs + ) + + def _step( + self, + step_type: str, + data: DataT, + data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, + forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, + loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + num_microbatches: Optional[int] = None, + forward_only: bool = True, + **kwargs + ) -> STEP_OUTPUT: + if not hasattr(self.module, f"{step_type}_step"): + raise AttributeError(f"self.module must have a `{step_type}_step` method") + + _data_step = data_step or _ModuleStepFunction.from_data_step(self.module, step_type) + _forward_step = forward_step or _ModuleStepFunction.from_forward_step(self.module, step_type) + _loss_reduction = loss_reduction or _ModuleStepFunction.from_loss_reduction(self.module, step_type) + + return self.forward( + data=data, + data_step=_data_step, + forward_step=_forward_step, + loss_reduction=_loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + forward_only=forward_only, + **kwargs + ) def wrapped_forward_step( self, forward_step, loss_reduction, data_step, context @@ -536,6 +669,38 @@ def __init__(self, name: str, is_property: bool = False, includes_self: bool = F self.name = name self.is_property = is_property self.includes_self = includes_self + + @classmethod + def from_data_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]: + for fn_name in [f"{step_type}_data_step", "data_step"]: + if hasattr(module, fn_name): + return _ModuleStepFunction(fn_name) + + return None + + @classmethod + def from_forward_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]: + from megatron.core import parallel_state + + if parallel_state.is_pipeline_last_stage(): + if not hasattr(module, f"{step_type}_step"): + raise ValueError(f"LightningModule does not have {step_type}_step method") + + return _ModuleStepFunction(f"{step_type}_step", includes_self=True) + + for fn_name in [f"{step_type}_forward_step", "forward_step"]: + if hasattr(module, fn_name): + return _ModuleStepFunction(fn_name, includes_self=True) + + return None + + @classmethod + def from_loss_reduction(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]: + for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]: + if hasattr(module, fn_name): + return _ModuleStepFunction(fn_name, is_property=True) + + return None def __call__(self, module: nn.Module): @@ -1045,7 +1210,7 @@ def model(self) -> Union[ModelT, List[ModelT]]: return self.pipeline.pipeline @property - def pl_module(self) -> pl.LightningModule: + def pl_module(self) -> "pl.LightningModule": """ Retrieves the PyTorch Lightning module from the pipeline. @@ -1055,7 +1220,7 @@ def pl_module(self) -> pl.LightningModule: return self.pipeline.module @property - def trainer(self) -> pl.Trainer: + def trainer(self) -> "pl.Trainer": """ Retrieves the PyTorch Lightning trainer from the pipeline. diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index d1e2c7dbae57..3c389091c8e1 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -470,8 +470,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: @override def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: assert self.lightning_module is not None - assert self.model is not None - kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "training") + assert isinstance(self.model, MegatronParallel) with self.precision_plugin.train_step_context(): # TODO: Do we need this? # Set grad to zero. @@ -480,7 +479,7 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP for opt in self.optimizers: opt.zero_grad() - out = self.model(dataloader_iter, forward_only=False, *args, **kwargs) + out = self.model.training_step(dataloader_iter, *args, **kwargs) if torch.is_tensor(out): reduced_train_loss = out @@ -553,11 +552,10 @@ def optimizer_step( @override def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: assert self.lightning_module is not None - assert self.model is not None - kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "validation") + assert isinstance(self.model, MegatronParallel) with self.precision_plugin.val_step_context(): # TODO: Do we need this? - out = self.model(dataloader_iter, forward_only=True, *args, **kwargs) + out = self.model.validation_step(dataloader_iter, *args, **kwargs) from megatron.core import parallel_state @@ -581,20 +579,18 @@ def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OU @override def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: assert self.lightning_module is not None - assert self.model is not None - kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "test") + assert isinstance(self.model, MegatronParallel) with self.precision_plugin.test_step_context(): # TODO: Do we need this? - return self.model(dataloader_iter, forward_only=True, *args, **kwargs) + return self.model.test_step(dataloader_iter, *args, **kwargs) @override def predict_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: assert self.lightning_module is not None - assert self.model is not None - kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "predict") + assert isinstance(self.model, MegatronParallel) with self.precision_plugin.predict_step_context(): # TODO: Do we need this? - return self.model(dataloader_iter, forward_only=True, *args, **kwargs) + return self.model.predict_step(dataloader_iter, *args, **kwargs) @override def teardown(self) -> None: @@ -770,35 +766,6 @@ def current_epoch_step(self) -> int: self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed, ) - def _get_data_step(self, step_type: str) -> Optional[_ModuleStepFunction]: - for fn_name in [f"{step_type}_data_step", "data_step"]: - if hasattr(self.lightning_module, fn_name): - return _ModuleStepFunction(fn_name) - - return None - - def _get_forward_step(self, step_type: str) -> Optional[_ModuleStepFunction]: - from megatron.core import parallel_state - - if parallel_state.is_pipeline_last_stage(): - if not hasattr(self.lightning_module, f"{step_type}_step"): - raise ValueError(f"LightningModule does not have {step_type}_step method") - - return _ModuleStepFunction(f"{step_type}_step", includes_self=True) - - for fn_name in [f"{step_type}_forward_step", "forward_step"]: - if hasattr(self.lightning_module, fn_name): - return _ModuleStepFunction(fn_name, includes_self=True) - - return None - - def _get_loss_reduction(self, step_type: str) -> Optional[_ModuleStepFunction]: - for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]: - if hasattr(self.lightning_module, fn_name): - return _ModuleStepFunction(fn_name, is_property=True) - - return None - @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: from nemo.utils import AppState From d620525962c9ac941d4b617c17c126aba2fe735f Mon Sep 17 00:00:00 2001 From: marcromeyn Date: Thu, 3 Oct 2024 08:17:12 +0000 Subject: [PATCH 2/2] Apply isort and black reformatting Signed-off-by: marcromeyn --- nemo/lightning/megatron_parallel.py | 122 ++++++++++++++-------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 6965530ead3f..59d024b07b84 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -20,14 +20,15 @@ from collections import defaultdict from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, Callable, Dict, Generic, Iterable, Iterator, - Mapping, List, + Mapping, Optional, Protocol, Sequence, @@ -36,7 +37,6 @@ Union, cast, runtime_checkable, - TYPE_CHECKING, ) import torch @@ -298,107 +298,107 @@ def forward( self.callbacks.event("on_megatron_step_end", step=step, microbatch_outputs=microbatch_outputs, reduced=reduced) return reduced - + def training_step( - self, - data: DataT, + self, + data: DataT, data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, num_microbatches: Optional[int] = None, - **kwargs + **kwargs, ) -> STEP_OUTPUT: return self._step( "training", - data, - data_step=data_step, - forward_step=forward_step, - loss_reduction=loss_reduction, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - num_microbatches=num_microbatches, + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, forward_only=False, - **kwargs + **kwargs, ) - + def validation_step( - self, - data: DataT, + self, + data: DataT, data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, num_microbatches: Optional[int] = None, - **kwargs + **kwargs, ) -> STEP_OUTPUT: return self._step( - "validation", - data, - data_step=data_step, - forward_step=forward_step, - loss_reduction=loss_reduction, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - num_microbatches=num_microbatches, + "validation", + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, forward_only=True, - **kwargs + **kwargs, ) - + def test_step( - self, - data: DataT, + self, + data: DataT, data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, num_microbatches: Optional[int] = None, - **kwargs + **kwargs, ) -> STEP_OUTPUT: return self._step( - "test", - data, - data_step=data_step, - forward_step=forward_step, - loss_reduction=loss_reduction, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - num_microbatches=num_microbatches, + "test", + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, forward_only=True, - **kwargs + **kwargs, ) - + def predict_step( - self, - data: DataT, + self, + data: DataT, data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, num_microbatches: Optional[int] = None, - **kwargs + **kwargs, ) -> STEP_OUTPUT: return self._step( - "predict", - data, - data_step=data_step, - forward_step=forward_step, - loss_reduction=loss_reduction, - seq_length=seq_length, + "predict", + data, + data_step=data_step, + forward_step=forward_step, + loss_reduction=loss_reduction, + seq_length=seq_length, micro_batch_size=micro_batch_size, num_microbatches=num_microbatches, forward_only=True, - **kwargs + **kwargs, ) - + def _step( - self, + self, step_type: str, - data: DataT, + data: DataT, data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, @@ -406,11 +406,11 @@ def _step( micro_batch_size: Optional[int] = None, num_microbatches: Optional[int] = None, forward_only: bool = True, - **kwargs + **kwargs, ) -> STEP_OUTPUT: if not hasattr(self.module, f"{step_type}_step"): raise AttributeError(f"self.module must have a `{step_type}_step` method") - + _data_step = data_step or _ModuleStepFunction.from_data_step(self.module, step_type) _forward_step = forward_step or _ModuleStepFunction.from_forward_step(self.module, step_type) _loss_reduction = loss_reduction or _ModuleStepFunction.from_loss_reduction(self.module, step_type) @@ -424,7 +424,7 @@ def _step( micro_batch_size=micro_batch_size, num_microbatches=num_microbatches, forward_only=forward_only, - **kwargs + **kwargs, ) def wrapped_forward_step( @@ -669,15 +669,15 @@ def __init__(self, name: str, is_property: bool = False, includes_self: bool = F self.name = name self.is_property = is_property self.includes_self = includes_self - + @classmethod def from_data_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]: for fn_name in [f"{step_type}_data_step", "data_step"]: if hasattr(module, fn_name): return _ModuleStepFunction(fn_name) - + return None - + @classmethod def from_forward_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]: from megatron.core import parallel_state @@ -693,7 +693,7 @@ def from_forward_step(cls, module: "pl.LightningModule", step_type: str) -> Opti return _ModuleStepFunction(fn_name, includes_self=True) return None - + @classmethod def from_loss_reduction(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]: for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]: