Skip to content

Commit

Permalink
Merge branch 'master' into sync/v3.0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Oct 8, 2020
2 parents 7a5ae91 + 451a0a8 commit 6268523
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 53 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
</h1>

<h3 align="center">
A friendly fork of Huggingface's <i>Transformers</i>, adding Adapters to PyTorch language models
A friendly fork of HuggingFace's <i>Transformers</i>, adding Adapters to PyTorch language models
</h3>

![Tests](https://github.com/Adapter-Hub/adapter-transformers/workflows/Tests/badge.svg)
[![GitHub](https://img.shields.io/github/license/adapter-hub/adapter-transformers.svg?color=blue)](https://github.com/adapter-hub/adapter-transformers/blob/master/LICENSE)
![PyPI](https://img.shields.io/pypi/v/adapter-transformers)

`adapter-transformers` is an extension of [Huggingface's Transformers](https://github.com/huggingface/transformers) library, integrating adapters into state-of-the-art language models by incorporating **[AdapterHub](https://adapterhub.ml)**, a central repository for pre-trained adapter modules.
`adapter-transformers` is an extension of [HuggingFace's Transformers](https://github.com/huggingface/transformers) library, integrating adapters into state-of-the-art language models by incorporating **[AdapterHub](https://adapterhub.ml)**, a central repository for pre-trained adapter modules.

This library can be used as a drop-in replacement for Huggingface Transformers and regulary synchronizes new upstream changes.
This library can be used as a drop-in replacement for HuggingFace Transformers and regularly synchronizes new upstream changes.

## Installation

Expand Down
13 changes: 11 additions & 2 deletions adapter_docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,23 @@ python run_language_modeling.py \

## Train AdapterFusion

We provide an example for training _AdapterFusion_ on the GLUE dataset: [run_fusion_glue.py](https://github.com/Adapter-Hub/adapter-transformers/blob/master/examples/text-classification/run_fusion_glue.py). To start training, you can run something like the following:
We provide an example for training _AdapterFusion_ ([Pfeiffer et al., 2020](https://arxiv.org/pdf/2005.00247)) on the GLUE dataset: [run_fusion_glue.py](https://github.com/Adapter-Hub/adapter-transformers/blob/master/examples/text-classification/run_fusion_glue.py). You can adapt this script to train AdapterFusion with different pre-trained adapters on your own dataset.

```eval_rst
.. important::
AdapterFusion on a target task is trained in a second training stage, after independently training adapters on individual tasks.
When setting up a fusion architecture on your model, make sure to load the pre-trained adapter modules to be fused using ``model.load_adapter()`` before adding a fusion layer.
For more on AdapterFusion, also refer to `Pfeiffer et al., 2020 <https://arxiv.org/pdf/2005.00247>`_.
```

To start fusion training on SST-2 as target task, you can run something like the following:

```
export GLUE_DIR=/path/to/glue
export TASK_NAME=SST-2
python run_fusion_glue.py \
--model_name_or_path bert-base-cased \
--model_name_or_path bert-base-uncased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
Expand Down
28 changes: 28 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
if SRC_DIRS is not None:
import run_generation
import run_glue
import run_fusion_glue
import run_language_modeling
import run_squad

Expand Down Expand Up @@ -106,6 +107,33 @@ def test_run_glue_adapters(self):
for value in result.values():
self.assertGreaterEqual(value, 0.75)

def test_run_fusion_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

testargs = """
run_fusion_glue.py
--model_name_or_path bert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/
--task_name mrpc
--do_train
--do_eval
--output_dir ./tests/fixtures/tests_samples/temp_dir
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=5e-5
--max_steps=20
--warmup_steps=2
--overwrite_output_dir
--seed=42
--max_seq_length=128
""".split()
with patch.object(sys, "argv", testargs):
result = run_fusion_glue.main()
del result["eval_loss"]
for value in result.values():
self.assertGreaterEqual(value, 0.5)

def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
Expand Down
33 changes: 17 additions & 16 deletions examples/text-classification/run_fusion_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,16 @@ def main():
cache_dir=model_args.cache_dir,
)

# Setup adapters
# ~~~~~ Here comes the interesting part of setting up AdapterFusion training ~~~~~

from transformers.adapter_config import PfeifferConfig

model.load_adapter(
"sentiment/sst-2@ukp", "text_task", config=PfeifferConfig(), with_head=False, version="AdapterFusion"
)
model.load_adapter(
"nli/multinli@ukp", "text_task", config=PfeifferConfig(), with_head=False, version="AdapterFusion"
)
model.load_adapter("nli/rte@ukp", "text_task", config=PfeifferConfig(), with_head=False, version="AdapterFusion")
model.load_adapter("sts/mrpc@ukp", "text_task", config=PfeifferConfig(), with_head=False, version="AdapterFusion")
model.load_adapter("sts/qqp@ukp", "text_task", config=PfeifferConfig(), with_head=False, version="AdapterFusion")
# First, load the pre-trained adapters we want to fuse from Hub
model.load_adapter("sentiment/sst-2@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("nli/multinli@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("nli/rte@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("sts/mrpc@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("sts/qqp@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("comsense/cosmosqa@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("comsense/csqa@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("comsense/hellaswag@ukp", "text_task", config=PfeifferConfig(), with_head=False)
Expand All @@ -167,10 +165,10 @@ def main():
model.load_adapter("qa/boolq@ukp", "text_task", config=PfeifferConfig(), with_head=False)
model.load_adapter("sentiment/imdb@ukp", "text_task", config=PfeifferConfig(), with_head=False)

adapter_names = [
adapter_setup = [
[
"sst_glue",
"multinli",
"sst-2",
"mnli",
"rte",
"mrpc",
"qqp",
Expand All @@ -187,8 +185,12 @@ def main():
]
]

model.add_fusion(adapter_names[0], "static", {"regularization": False})
model.base_model.train_fusion(adapter_names[0])
# Add a fusion layer and tell the model to train fusion
model.add_fusion(adapter_setup[0], "dynamic")
model.train_fusion(adapter_setup)

# ~~~~~ Rest is again same as in standard training setup ~~~~~

# Get datasets
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
Expand All @@ -210,7 +212,6 @@ def compute_metrics(p: EvalPrediction) -> Dict:
compute_metrics=compute_metrics,
do_save_full_model=False,
do_save_adapter_fusion=True,
adapter_names=adapter_names,
)

# Training
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@

setup(
name="adapter-transformers",
version="1.0.0",
version="1.0.1",
author="Jonas Pfeiffer, Andreas Rücklé, Clifton Poth, based on work by Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="pfeiffer@ukp.tu-darmstadt.de",
description="A friendly fork of Huggingface's Transformers, adding Adapters to PyTorch language models",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

__version__ = "1.0.0"
__version__ = "1.0.1"

# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.
Expand Down
42 changes: 39 additions & 3 deletions src/transformers/adapter_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,11 @@ def _init_adapter_modules(self):
self.add_fusion_layer(fusion_adapter_names)

def train_adapter(self, adapter_names: list):
"""Sets the model in mode for training the given adapters."""
"""Sets the model into mode for training the given adapters."""
self.train()
self.freeze_model(True)
adapter_names_flat = flatten_adapter_names(adapter_names)
self.encoder.enable_adapters(adapter_names, True, False)
self.encoder.enable_adapters(adapter_names_flat, True, False)
# unfreeze invertible adapters for invertible adapters
for adapter_name in adapter_names_flat:
if adapter_name in self.invertible_lang_adapters:
Expand All @@ -507,7 +507,7 @@ def train_adapter(self, adapter_names: list):
self.set_active_adapters(adapter_names)

def train_fusion(self, adapter_names: list):
"""Sets the model in mode for training of adapter fusion determined by a list of adapter names."""
"""Sets the model into mode for training of adapter fusion determined by a list of adapter names."""
self.train()
self.freeze_model(True)
adapter_names_flat = flatten_adapter_names(adapter_names)
Expand Down Expand Up @@ -669,6 +669,17 @@ def add_tagging_head(
}
self.add_prediction_head(head_name, config, overwrite_ok)

def add_qa_head(
self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False,
):
config = {
"head_type": "question_answering",
"num_labels": num_labels,
"layers": layers,
"activation_function": activation_function,
}
self.add_prediction_head(head_name, config, overwrite_ok)

def add_prediction_head(
self, head_name, config, overwrite_ok=False,
):
Expand Down Expand Up @@ -770,6 +781,31 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, labels=None
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

elif head["head_type"] == "question_answering":
logits = self.heads[head_name](sequence_output)

start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

outputs = (start_logits, end_logits,) + outputs[2:]
if labels is not None:
start_positions, end_positions = labels
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs

else:
raise ValueError("Unknown head_type '{}'".format(head["head_type"]))

Expand Down
17 changes: 15 additions & 2 deletions src/transformers/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,12 @@ def get(self, adapter_name: str, return_type: bool = False):
config = self.config_map.get(config_name, None)
else:
config = ADAPTER_CONFIG_MAP.get(config_name, None)
if not config:
if not config and adapter_type in self.config_map:
config = self.config_map[adapter_type]
elif (
not config
): # If no config is specified via config_name or adapter_type, we just use the global default
config = DEFAULT_ADAPTER_CONFIG
if isinstance(config, str):
config = ADAPTER_CONFIG_MAP[config]
else:
Expand All @@ -203,6 +207,9 @@ def get(self, adapter_name: str, return_type: bool = False):
def add(self, adapter_name: str, adapter_type: AdapterType, config: Optional[Union[str, dict]] = None):
if adapter_name in self.adapters:
raise ValueError(f"An adapter with the name '{adapter_name}' has already been added.")
if config is None and adapter_type not in self.config_map:
# if config is not specified & no per-type default is set, manually set global default
config = DEFAULT_ADAPTER_CONFIG
config_name = config
if isinstance(config, str):
if config not in ADAPTER_CONFIG_MAP and config not in self.config_map:
Expand Down Expand Up @@ -248,7 +255,12 @@ def common_config_value(self, adapter_names: list, attribute: str):
"""
common_value = None
for i, name in enumerate(adapter_names):
config_value = self.get(name).get(attribute, None)
config = self.get(name)
if not config:
raise ValueError(
f"No adapter with name '{name}' found. Make sure that an adapter with this name is loaded."
)
config_value = config.get(attribute, None)
if i > 0 and config_value != common_value:
raise ValueError(f"All given adapters must define the same value for config attribute {attribute}.")
common_value = config_value
Expand All @@ -257,6 +269,7 @@ def common_config_value(self, adapter_names: list, attribute: str):
def to_dict(self):
output_dict = {}
output_dict["adapters"] = copy.deepcopy(self.adapters)
output_dict["config_map"] = copy.deepcopy(self.config_map)
return output_dict


Expand Down
28 changes: 17 additions & 11 deletions src/transformers/adapter_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,13 @@ def add_adapter(self, adapter_name: str, adapter_type: AdapterType, config=None)

@abstractmethod
def train_adapter(self, adapter_names: list):
"""Sets the model in mode for training the given type of adapter.
"""Sets the model into mode for training the given adapters.
"""
pass

@abstractmethod
def train_fusion(self, adapter_names: list):
"""Sets the model into mode for training of adapter fusion determined by a list of adapter names.
"""
pass

Expand Down Expand Up @@ -696,7 +702,7 @@ def set_adapter_config(self, adapter_type: AdapterType, adapter_config):
else:
raise ValueError("Invalid adapter type {}".format(adapter_type))

def set_adapter_fusion_config(self, adapter_fusion_config, kwargs={}):
def set_adapter_fusion_config(self, adapter_fusion_config, override_kwargs=None):
"""Sets the adapter fusion configuration.
Args:
Expand All @@ -705,15 +711,16 @@ def set_adapter_fusion_config(self, adapter_fusion_config, kwargs={}):
- a dictionary representing the adapter fusion configuration
- the path to a file containing the adapter fusion configuration
"""
if override_kwargs is None:
override_kwargs = {}
if isinstance(adapter_fusion_config, str) and adapter_fusion_config in ADAPTERFUSION_CONFIG_MAP:
self.config.adapter_fusion = AdapterFusionConfig.load(adapter_fusion_config, **kwargs)
# ADAPTERFUSION_CONFIG_MAP[adapter_fusion_config](**kwargs).to_dict()
self.config.adapter_fusion = AdapterFusionConfig.load(adapter_fusion_config, **override_kwargs)
elif isinstance(adapter_fusion_config, Mapping):
self.config.adapter_fusion = adapter_fusion_config
else:
raise ValueError("Invalid adapter type {}".format(adapter_fusion_config))

def add_fusion(self, adapter_names, adapter_fusion_config=None, kwargs={}):
def add_fusion(self, adapter_names, adapter_fusion_config=None, override_kwargs=None):
"""Adds AdapterFusion to the model with alll the necessary configurations and weight initializations
Args:
Expand All @@ -722,14 +729,13 @@ def add_fusion(self, adapter_names, adapter_fusion_config=None, kwargs={}):
- a string identifying a pre-defined adapter fusion configuration
- a dictionary representing the adapter fusion configuration
- the path to a file containing the adapter fusion configuration
kwargs: dictionary items for values which should be overwritten in the default AdapterFusion configuration
Returns:
override_kwargs: dictionary items for values which should be overwritten in the default AdapterFusion configuration
"""
if not hasattr(self.config, "adapter_fusion"):
if override_kwargs is None:
override_kwargs = {}
if adapter_fusion_config is not None:
self.set_adapter_fusion_config(adapter_fusion_config, kwargs)
self.set_adapter_fusion_config(adapter_fusion_config, **override_kwargs)
else:
self.set_adapter_fusion_config(DEFAULT_ADAPTERFUSION_CONFIG)
elif hasattr(self.config, "adapter_fusion") and adapter_fusion_config is not None:
Expand Down Expand Up @@ -945,7 +951,7 @@ def add_adapter(self, adapter_name: str, adapter_type: AdapterType, config=None)
self.base_model.add_adapter(adapter_name, adapter_type, config)

def train_adapter(self, adapter_names: list):
"""Sets the model in mode for training the given type of adapter."""
"""Sets the model into mode for training the given adapters."""
self.base_model.train_adapter(adapter_names)

def train_fusion(self, adapter_names: list):
Expand Down
18 changes: 6 additions & 12 deletions src/transformers/adapter_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@


class Activation_Function_Class(nn.Module):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
Implementation of various activation function.
"""

def __init__(self, hidden_act):
Expand Down Expand Up @@ -42,6 +40,10 @@ def forward(self, x):


class Adapter(nn.Module):
"""
Implementation of a single Adapter block.
"""

def __init__(
self,
input_size,
Expand Down Expand Up @@ -466,11 +468,3 @@ def jacobian(self, x, c=[], rev=False):

def output_dims(self, input_dims):
return input_dims


if __name__ == "__main__":
adapter = Adapter(50)

batch = torch.rand(16, 50)

print(adapter(batch))
Loading

0 comments on commit 6268523

Please sign in to comment.