Skip to content

Commit

Permalink
Unify built-in & custom head implementation (#252)
Browse files Browse the repository at this point in the history
Unify built-in & custom head implementation such that except for the head_type specification when adding a custom head they are the same.
Co-authored-by: calpt <calpt@mail.de>
  • Loading branch information
hSterz authored and calpt committed Feb 8, 2022
1 parent 675d10e commit 1e540d1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
12 changes: 7 additions & 5 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,7 @@ def add_prediction_head_from_config(
self.add_prediction_head(head, overwrite_ok=overwrite_ok, set_active=set_active)
elif head_type in self.config.custom_heads:
# we have to re-add the head type for custom heads
config["head_type"] = head_type
self.add_custom_head(head_name, config, overwrite_ok=overwrite_ok)
self.add_custom_head(head_type, head_name, overwrite_ok=overwrite_ok, **config)
else:
raise AttributeError(
"Given head type '{}' is not known. Please register this head type before loading the model".format(
Expand Down Expand Up @@ -595,9 +594,12 @@ def set_active_adapters(
else:
logger.info("Could not identify '{}' as a valid prediction head.".format(final_block))

def add_custom_head(self, head_name, config, overwrite_ok=False, set_active=True):
if config["head_type"] in self.config.custom_heads:
head = self.config.custom_heads[config["head_type"]](head_name, config, self)
def add_custom_head(self, head_type, head_name, overwrite_ok=False, set_active=True, **kwargs):
if head_type in self.config.custom_heads:
head = self.config.custom_heads[head_type](self, head_name, **kwargs)
# When a build-in head is added as a custom head it does not have the head_type property
if not hasattr(head.config, "head_type"):
head.config["head_type"] = head_type
self.add_prediction_head(head, overwrite_ok, set_active=set_active)
else:
raise AttributeError(
Expand Down
39 changes: 30 additions & 9 deletions tests/test_adapter_custom_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@

from tests.test_modeling_common import ids_tensor
from transformers import AutoConfig, AutoModelWithHeads
from transformers.adapters.heads import PredictionHead
from transformers.adapters.heads import ClassificationHead, PredictionHead
from transformers.testing_utils import require_torch, torch_device


class CustomHead(PredictionHead):
def __init__(self, name, config, model):
super().__init__(name)
def __init__(
self,
model,
head_name,
**config,
):
super().__init__(head_name)
self.config = config
self.build(model=model)

Expand All @@ -27,8 +32,8 @@ def test_add_custom_head(self):
model_name = "bert-base-uncased"
model = AutoModelWithHeads.from_pretrained(model_name)
model.register_custom_head("tag", CustomHead)
config = {"head_type": "tag", "num_labels": 3, "layers": 2, "activation_function": "tanh"}
model.add_custom_head("custom_head", config)
config = {"num_labels": 3, "layers": 2, "activation_function": "tanh"}
model.add_custom_head(head_type="tag", head_name="custom_head", **config)
model.eval()
model.to(torch_device)
in_data = ids_tensor((1, 128), 1000)
Expand All @@ -42,8 +47,8 @@ def test_custom_head_from_model_config(self):
model_name = "bert-base-uncased"
model_config = AutoConfig.from_pretrained(model_name, custom_heads={"tag": CustomHead})
model = AutoModelWithHeads.from_pretrained(model_name, config=model_config)
config = {"head_type": "tag", "num_labels": 3, "layers": 2, "activation_function": "tanh"}
model.add_custom_head("custom_head", config)
config = {"num_labels": 3, "layers": 2, "activation_function": "tanh"}
model.add_custom_head(head_type="tag", head_name="custom_head", **config)
model.eval()
model.to(torch_device)
in_data = ids_tensor((1, 128), 1000)
Expand All @@ -58,8 +63,8 @@ def test_save_load_custom_head(self):
model_config = AutoConfig.from_pretrained(model_name, custom_heads={"tag": CustomHead})
model1 = AutoModelWithHeads.from_pretrained(model_name, config=model_config)
model2 = AutoModelWithHeads.from_pretrained(model_name, config=model_config)
config = {"head_type": "tag", "num_labels": 3, "layers": 2, "activation_function": "tanh"}
model1.add_custom_head("custom_head", config)
config = {"num_labels": 3, "layers": 2, "activation_function": "tanh"}
model1.add_custom_head(head_type="tag", head_name="custom_head", **config)

with tempfile.TemporaryDirectory() as temp_dir:
model1.save_head(temp_dir, "custom_head")
Expand All @@ -78,3 +83,19 @@ def test_save_load_custom_head(self):
state2 = model2.heads["custom_head"].state_dict()
for ((k1, v1), (k2, v2)) in zip(state1.items(), state2.items()):
self.assertTrue(torch.equal(v1, v2))

def test_builtin_head_as_custom(self):
model_name = "bert-base-uncased"
model_config = AutoConfig.from_pretrained(model_name, custom_heads={"tag": CustomHead})
model = AutoModelWithHeads.from_pretrained(model_name, config=model_config)
model.eval()
in_data = ids_tensor((1, 128), 1000)

model.register_custom_head("classification", ClassificationHead)
model.add_custom_head(
head_type="classification", head_name="custom_head", num_labels=3, layers=2, activation_function="tanh"
)
output = model(in_data)

self.assertEqual((1, 3), output[0].shape)
self.assertEqual("custom_head", model.active_head)

0 comments on commit 1e540d1

Please sign in to comment.