diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index c58e0be4a5089..3a19d8e186e83 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -69,6 +69,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: sharded_checkpoint=sharded_checkpoint, precision=self.cfg.trainer.precision, nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), + sharp=self.cfg.model.get('sharp', False), ) return NLPDDPStrategy( diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 5075863c3dbb6..6ee36d6983cb9 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -517,6 +517,7 @@ def __init__( sharded_checkpoint: bool = False, precision: Union[int, str] = 'bf16-mixed', nccl_communicator_config_path: Optional[str] = None, + sharp: bool = False, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -561,6 +562,7 @@ def __init__( ) self.nccl_communicator_config_path = nccl_communicator_config_path + self.sharp = sharp super().__init__(**kwargs) def _set_mixed_precision_recipe(