From f1c98447936f0c89d172dd6e6aa6f566720d3660 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 4 Dec 2023 17:01:40 -0800 Subject: [PATCH 1/2] support O2 Signed-off-by: Chen Cui --- nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 16a3850852d4..0c7fe51c60fc 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -67,7 +67,7 @@ def __init__(self, *args, **kwargs): self.use_ptuning_only = False super().__init__(*args, **kwargs) if hasattr(self, "enc_dec_model"): - self.model_prefix = "enc_dec_model." # for T5 + self.model_prefix = "enc_dec_model.module." if self.cfg.megatron_amp_O2 else "enc_dec_model." # for T5 else: self.model_prefix = "model.module." if self.cfg.megatron_amp_O2 else "model." @@ -351,7 +351,7 @@ def sharded_state_dict(self, prefix: str = ''): if not use_mcore_gpt or (self.use_peft and self.setup_complete): return None else: - return self.model.sharded_state_dict(prefix=self.model_prefix) + return super().sharded_state_dict(prefix=prefix) def load_state_dict(self, state_dict, strict: bool = True): if len(state_dict) == 0: From 4c41c6b8a552959aaaf2b82bf729d817389bdfb4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Dec 2023 01:05:31 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 0c7fe51c60fc..853ffc6ea012 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -67,7 +67,7 @@ def __init__(self, *args, **kwargs): self.use_ptuning_only = False super().__init__(*args, **kwargs) if hasattr(self, "enc_dec_model"): - self.model_prefix = "enc_dec_model.module." if self.cfg.megatron_amp_O2 else "enc_dec_model." # for T5 + self.model_prefix = "enc_dec_model.module." if self.cfg.megatron_amp_O2 else "enc_dec_model." # for T5 else: self.model_prefix = "model.module." if self.cfg.megatron_amp_O2 else "model."