From 60fc3673c3e5c85d5d993025b4cf99a1c4600ddd Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 27 Jun 2022 14:05:01 +0200 Subject: [PATCH] Add `adapter_summary()` method (#371) --- src/transformers/adapters/model_mixin.py | 83 ++++++++++++++++++- src/transformers/adapters/prefix_tuning.py | 4 +- tests_adapters/methods/test_adapter_common.py | 16 ++++ tests_adapters/test_adapter_composition.py | 2 +- 4 files changed, 101 insertions(+), 4 deletions(-) diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 3482262e5..28b4beb0d 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -345,6 +345,9 @@ def delete_adapter(self, adapter_name: str): return del self.config.adapters.adapters[adapter_name] self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name)) + # PHM Layer + if adapter_name in self.shared_parameters: + del self.shared_parameters[adapter_name] if isinstance(self, InvertibleAdaptersMixin): self.delete_invertible_adapter(adapter_name) # Reset active adapters if this was the only active adapter @@ -754,20 +757,96 @@ def get_adapter(self, name) -> dict: Returns: dict: A nested dictionary containing the weights of the adapter. The dictionary is structured as follow: - {: {: }}. + {: {: }}. = -1 indicates global/ shared weights. """ destination = defaultdict(dict) + # global weights are saved at index -1 + if name in self.shared_parameters: + destination[-1]["shared"] = self.shared_parameters[name] + if isinstance(self, InvertibleAdaptersMixin) and name in self.invertible_adapters: + destination[-1]["invertible"] = self.invertible_adapters[name] + # use a custom index to ensure numbering is from 0 to N layers for i, (_, layer) in enumerate(self.iter_layers()): for module in layer.modules(): if isinstance(module, AdapterLayerBase): adapter_module = module.get_adapter(name) if adapter_module is not None: - destination[i][module.location_key] = adapter_module + # location_key might already be added before -> concat to ModuleList + if module.location_key in destination[i]: + old_module = destination[i][module.location_key] + if isinstance(old_module, nn.ModuleList): + old_module.append(adapter_module) + else: + destination[i][module.location_key] = nn.ModuleList([old_module, adapter_module]) + else: + destination[i][module.location_key] = adapter_module return dict(destination) + def adapter_summary(self, as_dict=False) -> Union[str, dict]: + """ + Returns a string summary of all adapters currently added to the model. Each entry in the summary table has the + following attributes: + + - name: the name of the adapter + - architecture: the architectural base of the adapter + - #param: the number of parameters of the adapter + - %param: the number of parameters of the adapter relative to the full model + - active: whether the adapter is active + - train: whether the adapter weights are enabled for training + """ + # table header + header = ["name", "architecture", "#param", "%param", "active", "train"] + # rows containing adapter info + rows = [] + # fill in data for adapters + for name, config_name in self.config.adapters.adapters.items(): + config = self.config.adapters.config_map[config_name] + row = {"name": name, "architecture": config.architecture or "bottleneck"} + weights = self.get_adapter(name) + row["active"] = self.active_adapters is not None and name in self.active_adapters.flatten() + # count parameters + no_params = 0 + train = True + for _, module_dict in weights.items(): + for _, module in module_dict.items(): + no_params += sum(p.numel() for p in module.parameters()) + train &= all(p.requires_grad for p in module.parameters()) + row["#param"] = no_params + row["train"] = train + rows.append(row) + # count no. of parameters in base network + model_no_params = sum(p.numel() for p in self.base_model.parameters()) + model_no_params -= sum([r["#param"] for r in rows]) + # add %param info + for row in rows: + row["%param"] = row["#param"] / model_no_params * 100 + # add full model info + rows.append( + { + "name": "Full model", + "#param": model_no_params, + "%param": 100.0, + "train": not getattr(self.base_model, "model_frozen", False), + } + ) + + if as_dict: + return rows + else: + # print + total_length = 80 + header_format = "{:<25}{:<15}{:>12}{:>12}{:>8}{:>8}" + row_format = "{:<25}{:<15}{:>12}{:>12.3f}{:>8}{:>8}" + s = [header_format.format(*map(lambda x: x.title(), header))] + s.append("-" * total_length) + for row in rows: + s.append(row_format.format(*[row.get(h, "") for h in header])) + s.insert(len(s) - 1, "-" * total_length) + return "\n".join(s) + def eject_prefix_tuning(self, name: str): """ Converts the prefix tuning with the given name from the reparameterized form into the flat form. diff --git a/src/transformers/adapters/prefix_tuning.py b/src/transformers/adapters/prefix_tuning.py index 1c6f015b1..506a2b224 100644 --- a/src/transformers/adapters/prefix_tuning.py +++ b/src/transformers/adapters/prefix_tuning.py @@ -278,7 +278,9 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt def get_adapter(self, adapter_name): # Make sure to only return params once if adapter_name in self.prefixes and self.prefixes[adapter_name] == 0: - return self.pool.get_prefix(adapter_name) + prefix_module = self.pool.get_prefix(adapter_name) + if prefix_module is not None: + return prefix_module[self.location_key] return None diff --git a/tests_adapters/methods/test_adapter_common.py b/tests_adapters/methods/test_adapter_common.py index bc42ff757..6bdd12497 100644 --- a/tests_adapters/methods/test_adapter_common.py +++ b/tests_adapters/methods/test_adapter_common.py @@ -182,6 +182,22 @@ def test_model_config_serialization(self): # should not raise an exception model.config.to_json_string() + def test_model_adapter_summary(self): + # count model parameters before + model = self.get_model() + model_no_params = sum(p.numel() for p in model.parameters()) + for k, v in ADAPTER_CONFIG_MAP.items(): + # HACK: reduce the reduction factor such that + # the small test model can have a phm_dim of 4 + if hasattr(v, "phm_layer") and v.phm_layer: + v = v.__class__(reduction_factor=4) + model.add_adapter(k, config=v) + summary = model.adapter_summary(as_dict=True) + self.assertEqual(len(ADAPTER_CONFIG_MAP) + 1, len(summary)) + for name in ADAPTER_CONFIG_MAP.keys(): + self.assertTrue(any([row["name"] == name for row in summary])) + self.assertEqual(model_no_params, summary[-1]["#param"]) + def test_loading_adapter_weights_with_prefix(self): if self.config_class not in ADAPTER_MODEL_MAPPING: self.skipTest("Does not support flex heads.") diff --git a/tests_adapters/test_adapter_composition.py b/tests_adapters/test_adapter_composition.py index 1c0230251..9450a6fee 100644 --- a/tests_adapters/test_adapter_composition.py +++ b/tests_adapters/test_adapter_composition.py @@ -311,7 +311,7 @@ def test_parallel_training(self): train_dataset = self.dataset(tokenizer) training_args = TrainingArguments( - output_dir="./examples", do_train=True, learning_rate=0.1, max_steps=15, no_cuda=True + output_dir="./examples", do_train=True, learning_rate=0.5, max_steps=20, no_cuda=True ) # evaluate