Skip to content

Commit

Permalink
check for attr before calling
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <adithya.r@gmail.com>
  • Loading branch information
arendu committed Jan 13, 2024
1 parent 482dde4 commit 6a3c6d4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
23 changes: 13 additions & 10 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,11 @@ def add_adapter(self, peft_cfgs: Union[PEFTConfig, List[PEFTConfig]]):
self.tunable_base_param_keys = set()

for cfg in peft_cfgs:
if cfg.weight_tying:
if hasattr(cfg, "weight_tying") and cfg.weight_tying:
self.tie_weights(cfg)

if cfg.tunable_base_param_names:
for n, p in self.named_parameters():
for tpn in cfg.tunable_base_param_names:
if (
f".{tpn}." in n
): # TODO: simplistic param name matching, should support regex-like syntax @adithyare
self.tunable_base_param_keys.add(n)
p.requires_grad = True # We set these to true to trigger setup_optimizer_param_groups

if hasattr(cfg, "tunable_base_param_names") and cfg.tunable_base_param_names:
self.set_tunable_base_params(cfg)
self.use_peft = True

def _get_config_and_state_dict_from_nemo(self, filepath, map_location):
Expand Down Expand Up @@ -303,6 +296,16 @@ def load_adapters(
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
super().load_state_dict(state_dict, strict=False)

def set_tunable_base_params(self, peft_cfg):
for n, p in self.named_parameters():
for tpn in peft_cfg.tunable_base_param_names:
# TODO: simplistic param name matching, should support regex-like syntax @adithyare
if (f".{tpn}." in n):
self.tunable_base_param_keys.add(n)
p.requires_grad = True # We set these to true to trigger setup_optimizer_param_groups



def tie_weights(self, peft_cfg):
pos_idx = 0

Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self, peft_cfg: DictConfig, name_key_to_cfg: Dict):
self.name_key_to_cfg = name_key_to_cfg

self.layer_selection = peft_cfg.get("layer_selection", None)
self.weight_tying = peft_cfg.get("weight_tying", False)
self.tunable_param_names = peft_cfg.get("tunable_param_names", [])
self.weight_tying = peft_cfg.get("weight_tying", False) #TODO: move this attr to LoraPEFTConfig and AdapterPEFTConfig classes

def get_config_dict(self):
return self.name_key_to_cfg
Expand All @@ -55,6 +54,7 @@ class SelectivePEFTConfig(PEFTConfig):
def __init__(self, cfg):
selective_cfg = cfg.peft.selective_tuning
super().__init__(selective_cfg, name_key_to_cfg={})
self.tunable_base_param_names = selective_cfg.get("tunable_base_param_names", [])


class LoraPEFTConfig(PEFTConfig):
Expand Down

0 comments on commit 6a3c6d4

Please sign in to comment.