diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index d3a37738c1..e26688bdca 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from typing import List, Optional, Union import torch @@ -23,6 +24,35 @@ logger = logging.getLogger(__name__) +@dataclass +class MultiHeadOutput(ModelOutput): + head_outputs: List[ModelOutput] = None + loss: Optional[torch.FloatTensor] = None + + def __getitem__(self, k): + # with number indices the head output at that position is accessed + # e.g output[1] is equivalent to output.head_outputs[1] + if isinstance(k, int): + return self.head_outputs[k] + # with strings the attribute in the underlying dict can be adressed + # e.g output["loss"] is equivalent to output.loss + else: + return super().__getitem__(k) + + def __setitem__(self, k, v): + if isinstance(k, int): + self.head_outputs[k] = v + else: + return super().__setitem__(k, v) + + def __iter__(self): + # iterates over the head outputs + return iter(self.head_outputs) + + def __len__(self): + return len(self.head_outputs) + + # Let this class inherit from nn.Sequential to provide iterable access as before class PredictionHead(nn.Sequential): def __init__(self, name): @@ -662,14 +692,21 @@ def _get_head_input(outputs, cls_out, batch): ) ) head_outputs = [] + labels = kwargs.pop("labels", None) for i, head in enumerate(self.active_head): head_module = self.heads[head] batch_idx = range(sum(self.active_head.batch_sizes[:i]), sum(self.active_head.batch_sizes[: i + 1])) + kwargs["labels"] = labels[batch_idx] if labels is not None else None head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx) # head_attention = attention_mask[batch_idx] if attention_mask is not None else None head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs) head_outputs.append(head_output) - return head_outputs + combined_loss = ( + sum([out["loss"] for out in head_outputs]) + if all("loss" in out and out["loss"] is not None for out in head_outputs) + else None + ) + return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss) elif self.has_parallel_adapters or isinstance(self.active_head, Parallel): if len(self.active_head) != self.config.adapters.active_setup.parallel_channels: raise ValueError("The number of parallel adapters and the number of active heads must match.") @@ -681,7 +718,12 @@ def _get_head_input(outputs, cls_out, batch): head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx) head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs) head_outputs.append(head_output) - return head_outputs + combined_loss = ( + torch.sum(torch.stack([out["loss"] for out in head_outputs])) + if all("loss" in out and out["loss"] is not None for out in head_outputs) + else None + ) + return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss) elif len(used_heads) > 1: head_outputs = [] for head in used_heads: diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 141856f1ec..a43dc9a879 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -18,7 +18,7 @@ from transformers.testing_utils import require_torch, torch_device from .test_adapter_common import AdapterModelTestMixin -from .test_adapter_composition import ParallelAdapterInferenceTestMixin +from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin from .test_adapter_conversion import ModelClassConversionTestMixin from .test_adapter_fusion_common import AdapterFusionModelTestMixin from .test_adapter_heads import PredictionHeadModelTestMixin @@ -78,6 +78,7 @@ class BertAdapterTest( PredictionHeadModelTestMixin, AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, BertAdapterTestBase, unittest.TestCase, ): @@ -144,6 +145,7 @@ class DistilBertAdapterTest( PredictionHeadModelTestMixin, AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, DistilBertAdapterTestBase, unittest.TestCase, ): @@ -181,6 +183,7 @@ class BartAdapterTest( PredictionHeadModelTestMixin, AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, BartAdapterTestBase, unittest.TestCase, ): @@ -251,6 +254,7 @@ class GPT2AdapterTest( PredictionHeadModelTestMixin, AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, GPT2AdapterTestBase, unittest.TestCase, ): diff --git a/tests/test_adapter_composition.py b/tests/test_adapter_composition.py index ae0df99966..9bd99cf6e4 100644 --- a/tests/test_adapter_composition.py +++ b/tests/test_adapter_composition.py @@ -1,8 +1,20 @@ +import copy +import random import unittest import torch -from transformers import AutoModelWithHeads, BertConfig, BertForSequenceClassification +from tests.test_adapter_training import filter_parameters +from transformers import ( + AutoModelWithHeads, + AutoTokenizer, + BertConfig, + BertForSequenceClassification, + GlueDataset, + GlueDataTrainingArguments, + Trainer, + TrainingArguments, +) from transformers.adapters.composition import BatchSplit, Fuse, Parallel, Split, Stack, parse_composition from transformers.testing_utils import require_torch, torch_device @@ -207,3 +219,156 @@ def test_batch_split_with_heads(self): self.assertEqual(2, len(output)) self.assertTrue(torch.allclose(output[0]["logits"], outputs_a["logits"])) self.assertTrue(torch.allclose(output[1]["logits"], outputs_b["logits"])) + + +def create_twin_adapters(model, name): + # create adapter + adapter1, adapter2 = name + "_1", name + "_2" + model.add_adapter(adapter1) + model.add_classification_head(adapter1) + # create a twin initialized with the same random weights + model.add_adapter(adapter2) + model.add_classification_head(adapter2) + + state_dict = model.state_dict() + for k, v in state_dict.items(): + if adapter1 in k: + state_dict[k.replace(adapter1, adapter2)] = v + model.load_state_dict(state_dict) + + return adapter1, adapter2 + + +class ParallelTrainingMixin: + def train_model(self, model, dataset): + # trains model in eval mode for 2 epochs + random.seed(42) + torch.manual_seed(42) + # Depending on the used optimizer the adapters are not exactly the same + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + for epoch in range(2): + for data_input in dataset: + optimizer.zero_grad() + output = model(**data_input) + loss = output["loss"] + loss.backward() + optimizer.step() + return model + + def test_parallel_training(self): + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelWithHeads.from_config(self.config()) + + model.add_adapter("mrpc1") + model.add_adapter("mrpc2") + model.add_classification_head("mrpc1", num_labels=2) + model.add_classification_head("mrpc2", num_labels=3) + model.eval() + model.active_adapters = Parallel("mrpc1", "mrpc2") + model.train_adapter(Parallel("mrpc1", "mrpc2")) + + # all weights of the adapter should be activated + for k, v in filter_parameters(model, "adapters.mrpc1.").items(): + self.assertTrue(v.requires_grad, k) + # all weights of the adapter not used for training should be freezed + for k, v in filter_parameters(model, "adapters.mrpc2.").items(): + self.assertTrue(v.requires_grad, k) + # weights of the model should be freezed (check on some examples) + for k, v in filter_parameters(model, "encoder.layer.0.attention").items(): + self.assertFalse(v.requires_grad, k) + + state_dict_pre = copy.deepcopy(model.state_dict()) + + # setup dataset + data_args = GlueDataTrainingArguments( + task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True + ) + train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") + training_args = TrainingArguments( + output_dir="./examples", do_train=True, learning_rate=0.1, max_steps=7, no_cuda=True + ) + + # evaluate + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + ) + trainer.train() + + for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()): + if "mrpc" in k1: + self.assertFalse(torch.equal(v1, v2)) + else: + self.assertTrue(torch.equal(v1, v2)) + + def test_parallel_training_equivalent_to_single_adapters(self): + model = AutoModelWithHeads.from_config(self.config()) + model.eval() + + a1, a2 = create_twin_adapters(model, "a") + b1, b2 = create_twin_adapters(model, "b") + + dataset = [] + for i in range(3): + input_data = self.get_input_samples((3, 128), config=model.config) + input_data["labels"] = torch.randint(0, 2, (3, 1)) + dataset.append(input_data) + + for adapter in [a1, b1]: + model.active_head = adapter + model.set_active_adapters(adapter) + model.train_adapter(adapter) + model.eval() + + model = self.train_model(model, dataset) + + model.set_active_adapters(Parallel(a2, b2)) + model.train_adapter((Parallel(a2, b2))) + model.eval() + + model = self.train_model(model, dataset) + + state_dict = model.state_dict() + for k, v in state_dict.items(): + if a1 in k: + self.assertTrue(torch.allclose(v, state_dict[k.replace(a1, a2)])) + if b1 in k: + self.assertTrue(torch.allclose(v, state_dict[k.replace(b1, b2)])) + + def test_parallel_training_single_forward_pass(self): + model = AutoModelWithHeads.from_config(self.config()) + model.eval() + + a1, a2 = create_twin_adapters(model, "a") + b1, b2 = create_twin_adapters(model, "b") + + state_dict = model.state_dict() + for k, v in state_dict.items(): + if a1 in k: + self.assertTrue(torch.equal(v, state_dict[k.replace(a1, a2)])) + if b1 in k: + self.assertTrue(torch.equal(v, state_dict[k.replace(b1, b2)])) + + input_data = self.get_input_samples((3, 128), config=model.config) + input_data["labels"] = torch.randint(0, 2, (3, 1)) + + outputs = [] + for adapter in [a1, b1]: + model.active_head = adapter + model.set_active_adapters(adapter) + model.train_adapter(adapter) + model.eval() + outputs.append(model(**input_data)) + + model.set_active_adapters(Parallel(a2, b2)) + model.train_adapter((Parallel(a2, b2))) + model.eval() + + parallel_outputs = model(**input_data) + + for out1, out2 in zip(outputs, parallel_outputs.head_outputs): + self.assertTrue(torch.equal(out1["loss"], out2["loss"])) + self.assertTrue(torch.allclose(out1["logits"], out2["logits"])) diff --git a/tests/test_adapter_training.py b/tests/test_adapter_training.py index 3dd6fa367d..348169c5f3 100644 --- a/tests/test_adapter_training.py +++ b/tests/test_adapter_training.py @@ -11,7 +11,7 @@ Trainer, TrainingArguments, ) -from transformers.adapters.composition import Fuse +from transformers.adapters.composition import BatchSplit, Fuse from transformers.testing_utils import require_torch @@ -21,6 +21,29 @@ def filter_parameters(model, filter_string): @require_torch class AdapterTrainingTestMixin: + def trainings_run(self, model, tokenizer): + # setup dataset + data_args = GlueDataTrainingArguments( + task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True + ) + train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") + training_args = TrainingArguments( + output_dir="./examples", + do_train=True, + learning_rate=0.1, + max_steps=7, + no_cuda=True, + per_device_train_batch_size=8, + ) + + # evaluate + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + ) + trainer.train() + def test_train_single_adapter(self): tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) if tokenizer.pad_token is None: @@ -51,22 +74,7 @@ def test_train_single_adapter(self): state_dict_pre = copy.deepcopy(model.state_dict()) - # setup dataset - data_args = GlueDataTrainingArguments( - task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True - ) - train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") - training_args = TrainingArguments( - output_dir="./examples", do_train=True, learning_rate=0.1, max_steps=7, no_cuda=True - ) - - # evaluate - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset, - ) - trainer.train() + self.trainings_run(model, tokenizer) for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()): if "mrpc" in k1: @@ -120,22 +128,7 @@ def patched_fusion_reg_loss(): model.base_model.get_fusion_regularization_loss = patched_fusion_reg_loss - # setup dataset - data_args = GlueDataTrainingArguments( - task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True - ) - train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") - training_args = TrainingArguments( - output_dir="./examples", do_train=True, learning_rate=0.1, max_steps=7, no_cuda=True - ) - - # evaluate - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset, - ) - trainer.train() + self.trainings_run(model, tokenizer) for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()): if "adapter_fusion_layer" in k1 or "classifier" in k1 or "classification_head" in k1 or "score" in k1: @@ -144,3 +137,37 @@ def patched_fusion_reg_loss(): self.assertTrue(torch.equal(v1, v2), k1) self.assertTrue(regularization_called) + + def test_batch_split_training(self): + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelWithHeads.from_config(self.config()) + + model.add_adapter("mrpc1") + model.add_adapter("mrpc2") + model.add_classification_head("mrpc1") + model.add_classification_head("mrpc2") + adapter_setup = BatchSplit("mrpc1", "mrpc2", batch_sizes=[3, 3]) + model.active_adapters = adapter_setup + model.train_adapter(adapter_setup) + + # all weights of the adapter should be activated + for k, v in filter_parameters(model, "adapters.mrpc1.").items(): + self.assertTrue(v.requires_grad, k) + # all weights of the adapter not used for training should be freezed + for k, v in filter_parameters(model, "adapters.mrpc2.").items(): + self.assertTrue(v.requires_grad, k) + # weights of the model should be freezed (check on some examples) + for k, v in filter_parameters(model, "encoder.layer.0.attention").items(): + self.assertFalse(v.requires_grad, k) + + state_dict_pre = copy.deepcopy(model.state_dict()) + + self.trainings_run(model, tokenizer) + + for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()): + if "mrpc" in k1: + self.assertFalse(torch.equal(v1, v2)) + else: + self.assertTrue(torch.equal(v1, v2))