Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving steps to MegatronParallel to improve UX for Fabric #10732

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 168 additions & 3 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Expand All @@ -37,7 +39,6 @@
runtime_checkable,
)

import pytorch_lightning as pl
import torch
import torch.distributed
from megatron.core import parallel_state
Expand All @@ -51,6 +52,10 @@
DataT = TypeVar("DataT", Tensor, Dict[str, Tensor], Sequence[Tensor])
ModelT = TypeVar("ModelT", bound=nn.Module)
T = TypeVar('T')
STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]]

if TYPE_CHECKING:
import pytorch_lightning as pl


@runtime_checkable
Expand Down Expand Up @@ -294,6 +299,134 @@ def forward(

return reduced

def training_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"training",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=False,
**kwargs,
)

def validation_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"validation",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=True,
**kwargs,
)

def test_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"test",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=True,
**kwargs,
)

def predict_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"predict",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=True,
**kwargs,
)

def _step(
self,
step_type: str,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
forward_only: bool = True,
**kwargs,
) -> STEP_OUTPUT:
if not hasattr(self.module, f"{step_type}_step"):
raise AttributeError(f"self.module must have a `{step_type}_step` method")

_data_step = data_step or _ModuleStepFunction.from_data_step(self.module, step_type)
_forward_step = forward_step or _ModuleStepFunction.from_forward_step(self.module, step_type)
_loss_reduction = loss_reduction or _ModuleStepFunction.from_loss_reduction(self.module, step_type)

return self.forward(
data=data,
data_step=_data_step,
forward_step=_forward_step,
loss_reduction=_loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=forward_only,
**kwargs,
)

def wrapped_forward_step(
self, forward_step, loss_reduction, data_step, context
) -> Callable[[nn.Module, DataT], Tuple[torch.Tensor, "MegatronCallbackProtocol"]]:
Expand Down Expand Up @@ -537,6 +670,38 @@ def __init__(self, name: str, is_property: bool = False, includes_self: bool = F
self.is_property = is_property
self.includes_self = includes_self

@classmethod
def from_data_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
for fn_name in [f"{step_type}_data_step", "data_step"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name)

return None

@classmethod
def from_forward_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
from megatron.core import parallel_state

if parallel_state.is_pipeline_last_stage():
if not hasattr(module, f"{step_type}_step"):
raise ValueError(f"LightningModule does not have {step_type}_step method")

return _ModuleStepFunction(f"{step_type}_step", includes_self=True)

for fn_name in [f"{step_type}_forward_step", "forward_step"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name, includes_self=True)

return None

@classmethod
def from_loss_reduction(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name, is_property=True)

return None

def __call__(self, module: nn.Module):

attr = getattr(module, self.name)
Expand Down Expand Up @@ -1045,7 +1210,7 @@ def model(self) -> Union[ModelT, List[ModelT]]:
return self.pipeline.pipeline

@property
def pl_module(self) -> pl.LightningModule:
def pl_module(self) -> "pl.LightningModule":
"""
Retrieves the PyTorch Lightning module from the pipeline.

Expand All @@ -1055,7 +1220,7 @@ def pl_module(self) -> pl.LightningModule:
return self.pipeline.module

@property
def trainer(self) -> pl.Trainer:
def trainer(self) -> "pl.Trainer":
"""
Retrieves the PyTorch Lightning trainer from the pipeline.

Expand Down
49 changes: 8 additions & 41 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
@override
def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "training")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.train_step_context(): # TODO: Do we need this?
# Set grad to zero.
Expand All @@ -480,7 +479,7 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
for opt in self.optimizers:
opt.zero_grad()

out = self.model(dataloader_iter, forward_only=False, *args, **kwargs)
out = self.model.training_step(dataloader_iter, *args, **kwargs)

if torch.is_tensor(out):
reduced_train_loss = out
Expand Down Expand Up @@ -553,11 +552,10 @@ def optimizer_step(
@override
def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "validation")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.val_step_context(): # TODO: Do we need this?
out = self.model(dataloader_iter, forward_only=True, *args, **kwargs)
out = self.model.validation_step(dataloader_iter, *args, **kwargs)

from megatron.core import parallel_state

Expand All @@ -581,20 +579,18 @@ def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OU
@override
def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "test")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.test_step_context(): # TODO: Do we need this?
return self.model(dataloader_iter, forward_only=True, *args, **kwargs)
return self.model.test_step(dataloader_iter, *args, **kwargs)

@override
def predict_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "predict")
assert isinstance(self.model, MegatronParallel)

with self.precision_plugin.predict_step_context(): # TODO: Do we need this?
return self.model(dataloader_iter, forward_only=True, *args, **kwargs)
return self.model.predict_step(dataloader_iter, *args, **kwargs)

@override
def teardown(self) -> None:
Expand Down Expand Up @@ -770,35 +766,6 @@ def current_epoch_step(self) -> int:
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed,
)

def _get_data_step(self, step_type: str) -> Optional[_ModuleStepFunction]:
for fn_name in [f"{step_type}_data_step", "data_step"]:
if hasattr(self.lightning_module, fn_name):
return _ModuleStepFunction(fn_name)

return None

def _get_forward_step(self, step_type: str) -> Optional[_ModuleStepFunction]:
from megatron.core import parallel_state

if parallel_state.is_pipeline_last_stage():
if not hasattr(self.lightning_module, f"{step_type}_step"):
raise ValueError(f"LightningModule does not have {step_type}_step method")

return _ModuleStepFunction(f"{step_type}_step", includes_self=True)

for fn_name in [f"{step_type}_forward_step", "forward_step"]:
if hasattr(self.lightning_module, fn_name):
return _ModuleStepFunction(fn_name, includes_self=True)

return None

def _get_loss_reduction(self, step_type: str) -> Optional[_ModuleStepFunction]:
for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]:
if hasattr(self.lightning_module, fn_name):
return _ModuleStepFunction(fn_name, is_property=True)

return None

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
from nemo.utils import AppState
Expand Down
Loading