Skip to content

Commit

Permalink
Fix peft weights loading (NVIDIA#9341)
Browse files Browse the repository at this point in the history
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
  • Loading branch information
yaoyu-33 authored and BoxiangW committed Jun 5, 2024
1 parent 6b76597 commit e4c4b83
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ def load_adapters(
state_dict = torch.load(filepath, map_location)['state_dict']
else:
raise RuntimeError(f"{filepath} is not nemo file or ckpt file")
if self.cfg.megatron_amp_O2:
state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()}
if not self.ptuning_only_and_non_first_stage:
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
if self.cfg.megatron_amp_O2:
state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()}

missing_keys, unexpected_keys = NLPModel.load_state_dict(self, state_dict, strict=False)

Expand Down

0 comments on commit e4c4b83

Please sign in to comment.