Skip to content

Commit

Permalink
fix get_peft_state_dict
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Jun 6, 2024
1 parent 85f810a commit 53f0761
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def load_adapters(
'.nemo'
), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument."
peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)]
if self.cfg.megatron_amp_O2:
if getattr(self, 'megatron_amp_O2', False):
state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()}
self.add_adapter(peft_cfgs)
if not self.ptuning_only_and_non_first_stage:
Expand Down Expand Up @@ -397,11 +397,11 @@ def get_peft_state_dict(self):
"""
Gets the keys associated with the adapters only.
"""
state_dict = super().state_dict()
state_dict = self._unwrap_model().state_dict()
peft_state_dict = {}
for k in self.adapter_keys.union(self.tunable_base_param_keys):
# state_dict keys needs to be in non-O2 format and will be corrected in PEFTSaveRestoreConnector if O2=True
new_k = k.replace("model.module.", "model.", 1)
new_k = k.replace("module.", "", 1)
peft_state_dict[new_k] = state_dict[new_k]
return peft_state_dict

Expand Down

0 comments on commit 53f0761

Please sign in to comment.