Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language modeling flex heads #210

Merged
merged 14 commits into from
Aug 24, 2021
154 changes: 128 additions & 26 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
logger = logging.getLogger(__name__)


# The "layers" attributes in the configs below map from static head module names to flex head module names.
# In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts).
STATIC_TO_FLEX_HEAD_MAP = {
# BERT
"BertForSequenceClassification": {
Expand All @@ -15,7 +17,7 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"BertForMultipleChoice": {
"config": {
Expand All @@ -24,23 +26,53 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"BertForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"BertForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"BertForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": [
"cls.predictions.transform.dense",
None,
"cls.predictions.transform.LayerNorm",
"cls.predictions.decoder",
],
},
"BertLMHeadModel": {
"config": {
"head_type": "causal_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": [
"cls.predictions.transform.dense",
None,
"cls.predictions.transform.LayerNorm",
"cls.predictions.decoder",
],
},
# RoBERTa
"RobertaForSequenceClassification": {
Expand All @@ -50,7 +82,7 @@
"activation_function": "tanh",
"use_pooler": False,
},
"layers": ["classifier.dense", "classifier.out_proj"],
"layers": [None, "classifier.dense", None, None, "classifier.out_proj"],
},
"RobertaForMultipleChoice": {
"config": {
Expand All @@ -59,23 +91,43 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"RobertaForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"RobertaForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"RobertaForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"],
},
"RobertaForCausalLM": {
"config": {
"head_type": "causal_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"],
},
# XLM-RoBERTa
"XLMRobertaForSequenceClassification": {
Expand All @@ -85,7 +137,7 @@
"activation_function": "tanh",
"use_pooler": False,
},
"layers": ["classifier.dense", "classifier.out_proj"],
"layers": [None, "classifier.dense", None, None, "classifier.out_proj"],
},
"XLMRobertaForMultipleChoice": {
"config": {
Expand All @@ -94,23 +146,43 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"XLMRobertaForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"XLMRobertaForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"XLMRobertaForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", "lm_head.layer_norm", "lm_head.decoder"],
},
"XLMRobertaForCausalLM": {
"config": {
"head_type": "causal_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"],
},
# BART
"BartForSequenceClassification": {
Expand All @@ -119,15 +191,21 @@
"layers": 2,
"activation_function": "tanh",
},
"layers": ["classification_head.dense", "classification_head.out_proj"],
"layers": [None, "classification_head.dense", None, None, "classification_head.out_proj"],
},
"BartForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"BartForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["lm_head"],
},
# MBART
"MBartForSequenceClassification": {
Expand All @@ -136,15 +214,21 @@
"layers": 2,
"activation_function": "tanh",
},
"layers": ["classification_head.dense", "classification_head.out_proj"],
"layers": [None, "classification_head.dense", None, None, "classification_head.out_proj"],
},
"MBartForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"MBartForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["lm_head"],
},
# DistilBERT
"DistilBertForSequenceClassification": {
Expand All @@ -153,31 +237,41 @@
"layers": 2,
"activation_function": "relu",
},
"layers": ["pre_classifier", "classifier"],
"layers": [None, "pre_classifier", None, None, "classifier"],
},
"DistilBertForMultipleChoice": {
"config": {
"head_type": "multiple_choice",
"layers": 2,
"activation_function": "relu",
},
"layers": ["pre_classifier", "classifier"],
"layers": [None, "pre_classifier", None, None, "classifier"],
},
"DistilBertForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"DistilBertForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"DistilBertForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["vocab_transform", None, "vocab_layer_norm", "vocab_projector"],
},
# GPT-2
"GPT2ForSequenceClassification": {
Expand All @@ -187,7 +281,13 @@
"activation_function": None,
"bias": False,
},
"layers": ["score"],
"layers": [None, "score"],
},
"GPT2LMHeadModel": {
"config": {
"head_type": "causal_lm",
},
"layers": ["lm_head"],
},
}

Expand All @@ -213,16 +313,18 @@ def get_head_config_and_rename_list(model_class_name, head_name, label2id, num_l
config = copy.deepcopy(data["config"])
if config["head_type"] == "multiple_choice":
config["num_choices"] = num_labels
else:
config["label2id"] = label2id
elif config["head_type"] not in ["causal_lm", "masked_lm", "seq2seq_lm"]:
config["num_labels"] = num_labels
config["label2id"] = label2id
config["label2id"] = label2id
# rename
rename_list = []
i = 0
for name in data["layers"]:
escaped_name = re.escape(name)
rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i+1}.{{0}}"))
i += 3 if config["activation_function"] else 2 # there's always a dropout layer in between
if name is not None:
escaped_name = re.escape(name)
rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i}.{{0}}"))
i += 1
rename_func = lambda k, rename_list=rename_list: _regex_list_rename_func(k, rename_list)

return config, rename_func
1 change: 1 addition & 0 deletions src/transformers/adapters/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
from .base import *
from .dependency_parsing import *
from .language_modeling import BertStyleMaskedLMHead, CausalLMHead, Seq2SeqLMHead
Loading