From cc78cb8acec0cf2279c8e9486e0be46c2b9f5c2b Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Sat, 20 Jan 2024 01:02:23 +0900 Subject: [PATCH] Add the interface to use SHARP to FSDP strategy (#8202) Signed-off-by: Sangkug Lym Signed-off-by: stevehuang52 --- nemo/collections/nlp/parts/megatron_trainer_builder.py | 1 + nemo/collections/nlp/parts/nlp_overrides.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index c58e0be4a5089..3a19d8e186e83 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -69,6 +69,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: sharded_checkpoint=sharded_checkpoint, precision=self.cfg.trainer.precision, nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), + sharp=self.cfg.model.get('sharp', False), ) return NLPDDPStrategy( diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 5075863c3dbb6..6ee36d6983cb9 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -517,6 +517,7 @@ def __init__( sharded_checkpoint: bool = False, precision: Union[int, str] = 'bf16-mixed', nccl_communicator_config_path: Optional[str] = None, + sharp: bool = False, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -561,6 +562,7 @@ def __init__( ) self.nccl_communicator_config_path = nccl_communicator_config_path + self.sharp = sharp super().__init__(**kwargs) def _set_mixed_precision_recipe(