Skip to content

Commit

Permalink
Moving steps to MegatronParallel to improve UX for Fabric (#10732)
Browse files Browse the repository at this point in the history
* Moving steps to MegatronParallel to improve UX for Fabric

Signed-off-by: Marc Romeijn <mromeijn@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

---------

Signed-off-by: Marc Romeijn <mromeijn@nvidia.com>
Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>
Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com>
  • Loading branch information
marcromeyn and marcromeyn committed Oct 11, 2024
1 parent a485ac5 commit 6e8f362
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 44 deletions.
171 changes: 168 additions & 3 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from contextlib import nullcontext
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Expand All @@ -38,7 +40,6 @@
runtime_checkable,
)

import pytorch_lightning as pl
import torch
import torch.distributed
from megatron.core import parallel_state
Expand All @@ -52,6 +53,10 @@
DataT = TypeVar("DataT", Tensor, Dict[str, Tensor], Sequence[Tensor])
ModelT = TypeVar("ModelT", bound=nn.Module)
T = TypeVar('T')
STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]]

if TYPE_CHECKING:
import pytorch_lightning as pl


@runtime_checkable
Expand Down Expand Up @@ -295,6 +300,134 @@ def forward(

return reduced

def training_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"training",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=False,
**kwargs,
)

def validation_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"validation",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=True,
**kwargs,
)

def test_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"test",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=True,
**kwargs,
)

def predict_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"predict",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=True,
**kwargs,
)

def _step(
self,
step_type: str,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
forward_only: bool = True,
**kwargs,
) -> STEP_OUTPUT:
if not hasattr(self.module, f"{step_type}_step"):
raise AttributeError(f"self.module must have a `{step_type}_step` method")

_data_step = data_step or _ModuleStepFunction.from_data_step(self.module, step_type)
_forward_step = forward_step or _ModuleStepFunction.from_forward_step(self.module, step_type)
_loss_reduction = loss_reduction or _ModuleStepFunction.from_loss_reduction(self.module, step_type)

return self.forward(
data=data,
data_step=_data_step,
forward_step=_forward_step,
loss_reduction=_loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=forward_only,
**kwargs,
)

def wrapped_forward_step(
self, forward_step, loss_reduction, data_step, context
) -> Callable[[nn.Module, DataT], Tuple[torch.Tensor, "MegatronCallbackProtocol"]]:
Expand Down Expand Up @@ -543,6 +676,38 @@ def __init__(self, name: str, is_property: bool = False, includes_self: bool = F
self.is_property = is_property
self.includes_self = includes_self

@classmethod
def from_data_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
for fn_name in [f"{step_type}_data_step", "data_step"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name)

return None

@classmethod
def from_forward_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
from megatron.core import parallel_state

if parallel_state.is_pipeline_last_stage():
if not hasattr(module, f"{step_type}_step"):
raise ValueError(f"LightningModule does not have {step_type}_step method")

return _ModuleStepFunction(f"{step_type}_step", includes_self=True)

for fn_name in [f"{step_type}_forward_step", "forward_step"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name, includes_self=True)

return None

@classmethod
def from_loss_reduction(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name, is_property=True)

return None

def __call__(self, module: nn.Module):

attr = getattr(module, self.name)
Expand Down Expand Up @@ -1051,7 +1216,7 @@ def model(self) -> Union[ModelT, List[ModelT]]:
return self.pipeline.pipeline

@property
def pl_module(self) -> pl.LightningModule:
def pl_module(self) -> "pl.LightningModule":
"""
Retrieves the PyTorch Lightning module from the pipeline.
Expand All @@ -1061,7 +1226,7 @@ def pl_module(self) -> pl.LightningModule:
return self.pipeline.module

@property
def trainer(self) -> pl.Trainer:
def trainer(self) -> "pl.Trainer":
"""
Retrieves the PyTorch Lightning trainer from the pipeline.
Expand Down
49 changes: 8 additions & 41 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
@override
def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "training")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.train_step_context(): # TODO: Do we need this?
# Set grad to zero.
Expand All @@ -481,7 +480,7 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
for opt in self.optimizers:
opt.zero_grad()

out = self.model(dataloader_iter, forward_only=False, *args, **kwargs)
out = self.model.training_step(dataloader_iter, *args, **kwargs)

if torch.is_tensor(out):
reduced_train_loss = out
Expand Down Expand Up @@ -554,11 +553,10 @@ def optimizer_step(
@override
def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "validation")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.val_step_context(): # TODO: Do we need this?
out = self.model(dataloader_iter, forward_only=True, *args, **kwargs)
out = self.model.validation_step(dataloader_iter, *args, **kwargs)

from megatron.core import parallel_state

Expand All @@ -582,20 +580,18 @@ def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OU
@override
def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "test")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.test_step_context(): # TODO: Do we need this?
return self.model(dataloader_iter, forward_only=True, *args, **kwargs)
return self.model.test_step(dataloader_iter, *args, **kwargs)

@override
def predict_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "predict")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.predict_step_context(): # TODO: Do we need this?
return self.model(dataloader_iter, forward_only=True, *args, **kwargs)
return self.model.predict_step(dataloader_iter, *args, **kwargs)

@override
def teardown(self) -> None:
Expand Down Expand Up @@ -773,35 +769,6 @@ def current_epoch_step(self) -> int:
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed,
)

def _get_data_step(self, step_type: str) -> Optional[_ModuleStepFunction]:
for fn_name in [f"{step_type}_data_step", "data_step"]:
if hasattr(self.lightning_module, fn_name):
return _ModuleStepFunction(fn_name)

return None

def _get_forward_step(self, step_type: str) -> Optional[_ModuleStepFunction]:
from megatron.core import parallel_state

if parallel_state.is_pipeline_last_stage():
if not hasattr(self.lightning_module, f"{step_type}_step"):
raise ValueError(f"LightningModule does not have {step_type}_step method")

return _ModuleStepFunction(f"{step_type}_step", includes_self=True)

for fn_name in [f"{step_type}_forward_step", "forward_step"]:
if hasattr(self.lightning_module, fn_name):
return _ModuleStepFunction(fn_name, includes_self=True)

return None

def _get_loss_reduction(self, step_type: str) -> Optional[_ModuleStepFunction]:
for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]:
if hasattr(self.lightning_module, fn_name):
return _ModuleStepFunction(fn_name, is_property=True)

return None

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
from nemo.utils import AppState
Expand Down

0 comments on commit 6e8f362

Please sign in to comment.