Skip to content

Commit

Permalink
Add Adapters to DistilBERT (#67)
Browse files Browse the repository at this point in the history
* Adds adapter support to DistilBERT models (via mixins in adapter_distilbert.py)
* Adds a flex-head model for DistilBERT (DistilBertModelWithHeads)
* Moved invertible adapters to separate InvertibleAdaptersMixin to improve modularity
* Adjustments in BERT adapters implementation to allow partial reuse for DistilBERT
  • Loading branch information
calpt authored Oct 27, 2020
1 parent 0b5a2ba commit 6588b3f
Show file tree
Hide file tree
Showing 16 changed files with 439 additions and 98 deletions.
7 changes: 7 additions & 0 deletions adapter_docs/classes/bert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ BertModel
:members:


BertModelWithHeads
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BertModelWithHeads
:members:


BertForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
69 changes: 69 additions & 0 deletions adapter_docs/classes/distilbert.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
DistilBERT
===========

The DistilBERT model was proposed in the blog post
`Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT <https://medium.com/huggingface/distilbert-8cf3380435b5>`__,
and the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter <https://arxiv.org/abs/1910.01108>`__.
DistilBERT is a small, fast, cheap and light Transformer model trained by distilling Bert base. It has 40% less
parameters than `bert-base-uncased`, runs 60% faster while preserving over 95% of Bert's performances as measured on
the GLUE language understanding benchmark.

.. note::
This class is nearly identical to the PyTorch implementation of DistilBERT in Huggingface Transformers.
For more information, visit `the corresponding section in their documentation <https://huggingface.co/transformers/model_doc/distilbert.html>`_.


DistilBertConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertConfig
:members:


DistilBertTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertTokenizer
:members:


DistilBertTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertTokenizerFast
:members:


DistilBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertModel
:members:


DistilBertModelWithHeads
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertModelWithHeads
:members:


DistilBertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertForMaskedLM
:members:


DistilBertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertForSequenceClassification
:members:


DistilBertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DistilBertForQuestionAnswering
:members:
7 changes: 7 additions & 0 deletions adapter_docs/classes/roberta.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ RobertaModel
:members:


RobertaModelWithHeads
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.RobertaModelWithHeads
:members:


RobertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions adapter_docs/classes/xlmroberta.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ XLMRobertaModel
:members:


XLMRobertaModelWithHeads
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.XLMRobertaModelWithHeads
:members:


XLMRobertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Currently, we support the PyTorch versions of all models listed in the *Supporte
classes/bert
classes/roberta
classes/xlmroberta
classes/distilbert


Citation
Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@
DistilBertPreTrainedModel,
DistilBertForMaskedLM,
DistilBertModel,
DistilBertModelWithHeads,
DistilBertForMultipleChoice,
DistilBertForSequenceClassification,
DistilBertForQuestionAnswering,
Expand Down
70 changes: 23 additions & 47 deletions src/transformers/adapter_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from .adapter_config import DEFAULT_ADAPTER_CONFIG, AdapterType
from .adapter_model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin
from .adapter_modeling import Activation_Function_Class, Adapter, BertFusion, GLOWCouplingBlock, NICECouplingBlock
from .adapter_model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin, ModelWithHeadsAdaptersMixin
from .adapter_modeling import Activation_Function_Class, Adapter, BertFusion
from .adapter_utils import flatten_adapter_names, parse_adapter_names


Expand Down Expand Up @@ -40,6 +40,11 @@ class BertSelfOutputAdaptersMixin:
"""Adds adapters to the BertSelfOutput module.
"""

# override this property if layer norm has a different name
@property
def layer_norm(self):
return self.LayerNorm

def _init_adapter_modules(self):
self.attention_text_task_adapters = nn.ModuleDict(dict())
self.adapter_fusion_layer = nn.ModuleDict(dict())
Expand Down Expand Up @@ -115,7 +120,7 @@ def get_adapter_preparams(
query = hidden_states

if adapter_config["original_ln_before"]:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
hidden_states = self.layer_norm(hidden_states + input_tensor)

if not adapter_config["residual_before_ln"]:
residual = hidden_states
Expand Down Expand Up @@ -227,10 +232,10 @@ def adapters_forward(self, hidden_states, input_tensor, adapter_names=None):

last_config = self.config.adapters.get(adapter_names[-1][-1])
if last_config["original_ln_after"]:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
hidden_states = self.layer_norm(hidden_states + input_tensor)

else:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
hidden_states = self.layer_norm(hidden_states + input_tensor)

return hidden_states

Expand All @@ -239,9 +244,12 @@ class BertOutputAdaptersMixin:
"""Adds adapters to the BertOutput module.
"""

# override this property if layer norm has a different name
@property
def layer_norm(self):
return self.LayerNorm

def _init_adapter_modules(self):
# self.bert_adapter_att = BertAdapterAttention(config)
# self.bert_adapter_att = SimpleAdapterWeightingSentLvl(config)
self.adapter_fusion_layer = nn.ModuleDict(dict())
self.layer_text_task_adapters = nn.ModuleDict(dict())
self.layer_text_lang_adapters = nn.ModuleDict(dict())
Expand Down Expand Up @@ -311,7 +319,7 @@ def get_adapter_preparams(
query = hidden_states

if adapter_config["original_ln_before"]:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
hidden_states = self.layer_norm(hidden_states + input_tensor)

if not adapter_config["residual_before_ln"]:
residual = hidden_states
Expand Down Expand Up @@ -424,10 +432,10 @@ def adapters_forward(self, hidden_states, input_tensor, adapter_names=None):

last_config = self.config.adapters.get(adapter_names[-1][-1])
if last_config["original_ln_after"]:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
hidden_states = self.layer_norm(hidden_states + input_tensor)

else:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
hidden_states = self.layer_norm(hidden_states + input_tensor)

return hidden_states

Expand Down Expand Up @@ -469,15 +477,15 @@ def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze
layer.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)


class BertModelAdaptersMixin(ModelAdaptersMixin):
class BertModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin):
"""Adds adapters to the BertModel module.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _init_adapter_modules(self):
self.invertible_lang_adapters = nn.ModuleDict(dict())
super()._init_adapter_modules()

# language adapters
for language in self.config.adapters.adapter_list(AdapterType.text_lang):
Expand All @@ -496,12 +504,8 @@ def train_adapter(self, adapter_names: list):
self.train()
self.freeze_model(True)
adapter_names_flat = flatten_adapter_names(adapter_names)
self.encoder.enable_adapters(adapter_names_flat, True, False)
# unfreeze invertible adapters for invertible adapters
for adapter_name in adapter_names_flat:
if adapter_name in self.invertible_lang_adapters:
for param in self.invertible_lang_adapters[adapter_name].parameters():
param.requires_grad = True
self.encoder.enable_adapters(adapter_names, True, False)
self.enable_invertible_adapters(adapter_names_flat)
# use the adapters to be trained by default in every forward pass
self.set_active_adapters(adapter_names)

Expand Down Expand Up @@ -535,35 +539,7 @@ def add_adapter(self, adapter_name: str, adapter_type: AdapterType, config=None)
if adapter_type == AdapterType.text_lang:
self.add_invertible_lang_adapter(adapter_name)

def add_invertible_lang_adapter(self, language):
if language in self.invertible_lang_adapters:
raise ValueError(f"Model already contains an adapter module for '{language}'.")
inv_adap_config = self.config.adapters.get(language)["invertible_adapter"]
if inv_adap_config["block_type"] == "nice":
inv_adap = NICECouplingBlock(
[[self.config.hidden_size]],
non_linearity=inv_adap_config["non_linearity"],
reduction_factor=inv_adap_config["reduction_factor"],
)
elif inv_adap_config["block_type"] == "glow":
inv_adap = GLOWCouplingBlock(
[[self.config.hidden_size]],
non_linearity=inv_adap_config["non_linearity"],
reduction_factor=inv_adap_config["reduction_factor"],
)
else:
raise ValueError(f"Invalid invertible adapter type '{inv_adap_config['block_type']}'.")
self.invertible_lang_adapters[language] = inv_adap
self.invertible_lang_adapters[language].apply(Adapter.init_bert_weights)

def get_invertible_lang_adapter(self, language):
if language in self.invertible_lang_adapters:
return self.invertible_lang_adapters[language]
else:
return None

def add_fusion_layer(self, adapter_names):
"""See BertModel.add_attention_layer"""
def _add_fusion_layer(self, adapter_names):
self.encoder.add_fusion_layer(adapter_names)


Expand Down
Loading

0 comments on commit 6588b3f

Please sign in to comment.