Skip to content

Commit

Permalink
Language modeling flex heads (#210)
Browse files Browse the repository at this point in the history
* Init flex LM heads.

* Finished flex LM head implementations.
Added tests for all possible head conversions.

* Invertible adapters in flex LM heads.

* Fix output_embedding method implementation for XModelWithHeads

* hacked fix for GPT-2 pad_token_id problem
  • Loading branch information
calpt committed Aug 24, 2021
1 parent b64e69d commit 84289df
Show file tree
Hide file tree
Showing 13 changed files with 787 additions and 74 deletions.
154 changes: 128 additions & 26 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
logger = logging.getLogger(__name__)


# The "layers" attributes in the configs below map from static head module names to flex head module names.
# In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts).
STATIC_TO_FLEX_HEAD_MAP = {
# BERT
"BertForSequenceClassification": {
Expand All @@ -15,7 +17,7 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"BertForMultipleChoice": {
"config": {
Expand All @@ -24,23 +26,53 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"BertForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"BertForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"BertForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": [
"cls.predictions.transform.dense",
None,
"cls.predictions.transform.LayerNorm",
"cls.predictions.decoder",
],
},
"BertLMHeadModel": {
"config": {
"head_type": "causal_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": [
"cls.predictions.transform.dense",
None,
"cls.predictions.transform.LayerNorm",
"cls.predictions.decoder",
],
},
# RoBERTa
"RobertaForSequenceClassification": {
Expand All @@ -50,7 +82,7 @@
"activation_function": "tanh",
"use_pooler": False,
},
"layers": ["classifier.dense", "classifier.out_proj"],
"layers": [None, "classifier.dense", None, None, "classifier.out_proj"],
},
"RobertaForMultipleChoice": {
"config": {
Expand All @@ -59,23 +91,43 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"RobertaForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"RobertaForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"RobertaForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"],
},
"RobertaForCausalLM": {
"config": {
"head_type": "causal_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"],
},
# XLM-RoBERTa
"XLMRobertaForSequenceClassification": {
Expand All @@ -85,7 +137,7 @@
"activation_function": "tanh",
"use_pooler": False,
},
"layers": ["classifier.dense", "classifier.out_proj"],
"layers": [None, "classifier.dense", None, None, "classifier.out_proj"],
},
"XLMRobertaForMultipleChoice": {
"config": {
Expand All @@ -94,23 +146,43 @@
"activation_function": None,
"use_pooler": True,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"XLMRobertaForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"XLMRobertaForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"XLMRobertaForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", "lm_head.layer_norm", "lm_head.decoder"],
},
"XLMRobertaForCausalLM": {
"config": {
"head_type": "causal_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"],
},
# BART
"BartForSequenceClassification": {
Expand All @@ -119,15 +191,21 @@
"layers": 2,
"activation_function": "tanh",
},
"layers": ["classification_head.dense", "classification_head.out_proj"],
"layers": [None, "classification_head.dense", None, None, "classification_head.out_proj"],
},
"BartForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"BartForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["lm_head"],
},
# MBART
"MBartForSequenceClassification": {
Expand All @@ -136,15 +214,21 @@
"layers": 2,
"activation_function": "tanh",
},
"layers": ["classification_head.dense", "classification_head.out_proj"],
"layers": [None, "classification_head.dense", None, None, "classification_head.out_proj"],
},
"MBartForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"MBartForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["lm_head"],
},
# DistilBERT
"DistilBertForSequenceClassification": {
Expand All @@ -153,31 +237,41 @@
"layers": 2,
"activation_function": "relu",
},
"layers": ["pre_classifier", "classifier"],
"layers": [None, "pre_classifier", None, None, "classifier"],
},
"DistilBertForMultipleChoice": {
"config": {
"head_type": "multiple_choice",
"layers": 2,
"activation_function": "relu",
},
"layers": ["pre_classifier", "classifier"],
"layers": [None, "pre_classifier", None, None, "classifier"],
},
"DistilBertForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": ["classifier"],
"layers": [None, "classifier"],
},
"DistilBertForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": ["qa_outputs"],
"layers": [None, "qa_outputs"],
},
"DistilBertForMaskedLM": {
"config": {
"head_type": "masked_lm",
"layers": 2,
"activation_function": "gelu_orig",
"layer_norm": True,
"bias": True,
},
"layers": ["vocab_transform", None, "vocab_layer_norm", "vocab_projector"],
},
# GPT-2
"GPT2ForSequenceClassification": {
Expand All @@ -187,7 +281,13 @@
"activation_function": None,
"bias": False,
},
"layers": ["score"],
"layers": [None, "score"],
},
"GPT2LMHeadModel": {
"config": {
"head_type": "causal_lm",
},
"layers": ["lm_head"],
},
}

Expand All @@ -213,16 +313,18 @@ def get_head_config_and_rename_list(model_class_name, head_name, label2id, num_l
config = copy.deepcopy(data["config"])
if config["head_type"] == "multiple_choice":
config["num_choices"] = num_labels
else:
config["label2id"] = label2id
elif config["head_type"] not in ["causal_lm", "masked_lm", "seq2seq_lm"]:
config["num_labels"] = num_labels
config["label2id"] = label2id
config["label2id"] = label2id
# rename
rename_list = []
i = 0
for name in data["layers"]:
escaped_name = re.escape(name)
rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i+1}.{{0}}"))
i += 3 if config["activation_function"] else 2 # there's always a dropout layer in between
if name is not None:
escaped_name = re.escape(name)
rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i}.{{0}}"))
i += 1
rename_func = lambda k, rename_list=rename_list: _regex_list_rename_func(k, rename_list)

return config, rename_func
1 change: 1 addition & 0 deletions src/transformers/adapters/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
from .base import *
from .dependency_parsing import *
from .language_modeling import BertStyleMaskedLMHead, CausalLMHead, Seq2SeqLMHead
Loading

0 comments on commit 84289df

Please sign in to comment.