From 53f0761dcd80626be8ba55d4f71886085bbcb9b6 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 6 Jun 2024 13:05:17 -0700 Subject: [PATCH] fix get_peft_state_dict Signed-off-by: Alexandros Koumparoulis --- nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index a05a3f165e571..e1cea845f2358 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -350,7 +350,7 @@ def load_adapters( '.nemo' ), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument." peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)] - if self.cfg.megatron_amp_O2: + if getattr(self, 'megatron_amp_O2', False): state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()} self.add_adapter(peft_cfgs) if not self.ptuning_only_and_non_first_stage: @@ -397,11 +397,11 @@ def get_peft_state_dict(self): """ Gets the keys associated with the adapters only. """ - state_dict = super().state_dict() + state_dict = self._unwrap_model().state_dict() peft_state_dict = {} for k in self.adapter_keys.union(self.tunable_base_param_keys): # state_dict keys needs to be in non-O2 format and will be corrected in PEFTSaveRestoreConnector if O2=True - new_k = k.replace("model.module.", "model.", 1) + new_k = k.replace("module.", "", 1) peft_state_dict[new_k] = state_dict[new_k] return peft_state_dict