Skip to content

Commit

Permalink
Revert "add fsdp fix for tp > 1 (#8689)" (#8807)
Browse files Browse the repository at this point in the history
This reverts commit e38e352.

Signed-off-by: Marek Wawrzos <mwawrzos@nvidia.com>
  • Loading branch information
mwawrzos authored Apr 4, 2024
1 parent c6218e3 commit dd74f7c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
7 changes: 6 additions & 1 deletion nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ def _plugins(self) -> list:
if megatron_amp_O2 and not with_distributed_adam:
plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))
if self.cfg.model.get('fsdp', False):
plugins.append(FSDPPrecision(precision=plugin_precision, scaler=scaler))
else:
plugins.append(
PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)
)
self.cfg.trainer.precision = None

if self.cfg.get('cluster_type', None) == 'BCP':
Expand Down
5 changes: 2 additions & 3 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytorch_lightning.loops.fetchers import _DataFetcher
from pytorch_lightning.plugins import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import FSDPPrecision, MixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -66,7 +66,6 @@

try:
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam

HAVE_APEX = True
Expand Down Expand Up @@ -1151,7 +1150,7 @@ def dummy():
return instance


class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin, FSDPPrecision):
class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin):
""" Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
Expand Down

0 comments on commit dd74f7c

Please sign in to comment.