From fbfa59212ccb26adffa668b5e66c3df6ce8bae5a Mon Sep 17 00:00:00 2001 From: Cheng-Ping Hsieh Date: Fri, 18 Aug 2023 13:33:13 -0700 Subject: [PATCH 1/3] Fix restore Signed-off-by: Cheng-Ping Hsieh --- .../nlp/models/language_modeling/megatron_gpt_sft_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 6f5356ebc757..5080a8c95dd5 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -583,6 +583,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here app_state = AppState() self._restore_activation_checkpointing_args() + self._restore_sequence_parallelism_args() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank, From b51d6f6430cb2c7614ffc6b82498166149dcaf0d Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 18 Aug 2023 13:56:40 -0700 Subject: [PATCH 2/3] reset and restore transformer config sequence parallel Signed-off-by: Jason Wang --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index d799cb6fb044..792764786f5f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1348,6 +1348,7 @@ def _reset_sequence_parallelism_args(self): # Reset config values. Needed for calling generate. self.cfg.sequence_parallel = False + self.transformer_config.sequence_parallel = False # Reset model parameters. for module in self.get_gpt_module_list(): @@ -1362,6 +1363,7 @@ def _restore_sequence_parallelism_args(self): """ # Restore config values. self.cfg.sequence_parallel = self.last_sequence_parallel + self.transformer_config.sequence_parallel = self.last_sequence_parallel # Restore model parameters. for module in self.get_gpt_module_list(): From d2f3742e729a5439631a5018fbb92249e7fc596c Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 18 Aug 2023 15:07:02 -0700 Subject: [PATCH 3/3] modify model parallel config as well Signed-off-by: Jason Wang --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 792764786f5f..358f3387b812 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1348,6 +1348,7 @@ def _reset_sequence_parallelism_args(self): # Reset config values. Needed for calling generate. self.cfg.sequence_parallel = False + self.model_parallel_config.sequence_parallel = False self.transformer_config.sequence_parallel = False # Reset model parameters. @@ -1363,6 +1364,7 @@ def _restore_sequence_parallelism_args(self): """ # Restore config values. self.cfg.sequence_parallel = self.last_sequence_parallel + self.model_parallel_config.sequence_parallel = self.last_sequence_parallel self.transformer_config.sequence_parallel = self.last_sequence_parallel # Restore model parameters.