Skip to content

Commit

Permalink
Merge branch 'master' into dev/vit
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jun 27, 2022
2 parents 94bec82 + 60fc367 commit 27e7bd5
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 5 deletions.
83 changes: 81 additions & 2 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ def delete_adapter(self, adapter_name: str):
return
del self.config.adapters.adapters[adapter_name]
self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name))
# PHM Layer
if adapter_name in self.shared_parameters:
del self.shared_parameters[adapter_name]
if isinstance(self, InvertibleAdaptersMixin):
self.delete_invertible_adapter(adapter_name)
# Reset active adapters if this was the only active adapter
Expand Down Expand Up @@ -762,20 +765,96 @@ def get_adapter(self, name) -> dict:
Returns:
dict: A nested dictionary containing the weights of the adapter. The dictionary is structured as follow:
{<layer id>: {<module location>: <nn.Module>}}.
{<layer id>: {<module location>: <nn.Module>}}. <layer id> = -1 indicates global/ shared weights.
"""
destination = defaultdict(dict)

# global weights are saved at index -1
if name in self.shared_parameters:
destination[-1]["shared"] = self.shared_parameters[name]
if isinstance(self, InvertibleAdaptersMixin) and name in self.invertible_adapters:
destination[-1]["invertible"] = self.invertible_adapters[name]

# use a custom index to ensure numbering is from 0 to N layers
for i, (_, layer) in enumerate(self.iter_layers()):
for module in layer.modules():
if isinstance(module, AdapterLayerBase):
adapter_module = module.get_adapter(name)
if adapter_module is not None:
destination[i][module.location_key] = adapter_module
# location_key might already be added before -> concat to ModuleList
if module.location_key in destination[i]:
old_module = destination[i][module.location_key]
if isinstance(old_module, nn.ModuleList):
old_module.append(adapter_module)
else:
destination[i][module.location_key] = nn.ModuleList([old_module, adapter_module])
else:
destination[i][module.location_key] = adapter_module

return dict(destination)

def adapter_summary(self, as_dict=False) -> Union[str, dict]:
"""
Returns a string summary of all adapters currently added to the model. Each entry in the summary table has the
following attributes:
- name: the name of the adapter
- architecture: the architectural base of the adapter
- #param: the number of parameters of the adapter
- %param: the number of parameters of the adapter relative to the full model
- active: whether the adapter is active
- train: whether the adapter weights are enabled for training
"""
# table header
header = ["name", "architecture", "#param", "%param", "active", "train"]
# rows containing adapter info
rows = []
# fill in data for adapters
for name, config_name in self.config.adapters.adapters.items():
config = self.config.adapters.config_map[config_name]
row = {"name": name, "architecture": config.architecture or "bottleneck"}
weights = self.get_adapter(name)
row["active"] = self.active_adapters is not None and name in self.active_adapters.flatten()
# count parameters
no_params = 0
train = True
for _, module_dict in weights.items():
for _, module in module_dict.items():
no_params += sum(p.numel() for p in module.parameters())
train &= all(p.requires_grad for p in module.parameters())
row["#param"] = no_params
row["train"] = train
rows.append(row)
# count no. of parameters in base network
model_no_params = sum(p.numel() for p in self.base_model.parameters())
model_no_params -= sum([r["#param"] for r in rows])
# add %param info
for row in rows:
row["%param"] = row["#param"] / model_no_params * 100
# add full model info
rows.append(
{
"name": "Full model",
"#param": model_no_params,
"%param": 100.0,
"train": not getattr(self.base_model, "model_frozen", False),
}
)

if as_dict:
return rows
else:
# print
total_length = 80
header_format = "{:<25}{:<15}{:>12}{:>12}{:>8}{:>8}"
row_format = "{:<25}{:<15}{:>12}{:>12.3f}{:>8}{:>8}"
s = [header_format.format(*map(lambda x: x.title(), header))]
s.append("-" * total_length)
for row in rows:
s.append(row_format.format(*[row.get(h, "") for h in header]))
s.insert(len(s) - 1, "-" * total_length)
return "\n".join(s)

def eject_prefix_tuning(self, name: str):
"""
Converts the prefix tuning with the given name from the reparameterized form into the flat form.
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/adapters/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt
def get_adapter(self, adapter_name):
# Make sure to only return params once
if adapter_name in self.prefixes and self.prefixes[adapter_name] == 0:
return self.pool.get_prefix(adapter_name)
prefix_module = self.pool.get_prefix(adapter_name)
if prefix_module is not None:
return prefix_module[self.location_key]

return None

Expand Down
16 changes: 16 additions & 0 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ def test_model_config_serialization(self):
# should not raise an exception
model.config.to_json_string()

def test_model_adapter_summary(self):
# count model parameters before
model = self.get_model()
model_no_params = sum(p.numel() for p in model.parameters())
for k, v in ADAPTER_CONFIG_MAP.items():
# HACK: reduce the reduction factor such that
# the small test model can have a phm_dim of 4
if hasattr(v, "phm_layer") and v.phm_layer:
v = v.__class__(reduction_factor=4)
model.add_adapter(k, config=v)
summary = model.adapter_summary(as_dict=True)
self.assertEqual(len(ADAPTER_CONFIG_MAP) + 1, len(summary))
for name in ADAPTER_CONFIG_MAP.keys():
self.assertTrue(any([row["name"] == name for row in summary]))
self.assertEqual(model_no_params, summary[-1]["#param"])

def test_loading_adapter_weights_with_prefix(self):
if self.config_class not in ADAPTER_MODEL_MAPPING:
self.skipTest("Does not support flex heads.")
Expand Down
4 changes: 2 additions & 2 deletions tests_adapters/test_adapter_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def test_parallel_training(self):
training_args = TrainingArguments(
output_dir="./examples",
do_train=True,
learning_rate=0.1,
max_steps=15,
learning_rate=0.5,
max_steps=20,
no_cuda=True,
remove_unused_columns=False,
)
Expand Down

0 comments on commit 27e7bd5

Please sign in to comment.