Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel training #226

Merged
merged 14 commits into from
Sep 14, 2021
46 changes: 44 additions & 2 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass
from typing import List, Optional, Union

import torch
Expand All @@ -23,6 +24,35 @@
logger = logging.getLogger(__name__)


@dataclass
class MultiHeadOutput(ModelOutput):
head_outputs: List[ModelOutput] = None
loss: Optional[torch.FloatTensor] = None

def __getitem__(self, k):
# with number indices the head output at that position is accessed
# e.g output[1] is equivalent to output.head_outputs[1]
if isinstance(k, int):
return self.head_outputs[k]
# with strings the attribute in the underlying dict can be adressed
# e.g output["loss"] is equivalent to output.loss
else:
return super().__getitem__(k)

def __setitem__(self, k, v):
if isinstance(k, int):
self.head_outputs[k] = v
else:
return super().__setitem__(k, v)

def __iter__(self):
# iterates over the head outputs
return iter(self.head_outputs)

def __len__(self):
return len(self.head_outputs)


# Let this class inherit from nn.Sequential to provide iterable access as before
class PredictionHead(nn.Sequential):
def __init__(self, name):
Expand Down Expand Up @@ -662,14 +692,21 @@ def _get_head_input(outputs, cls_out, batch):
)
)
head_outputs = []
labels = kwargs.pop("labels", None)
for i, head in enumerate(self.active_head):
head_module = self.heads[head]
batch_idx = range(sum(self.active_head.batch_sizes[:i]), sum(self.active_head.batch_sizes[: i + 1]))
kwargs["labels"] = labels[batch_idx] if labels is not None else None
head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx)
# head_attention = attention_mask[batch_idx] if attention_mask is not None else None
head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs)
head_outputs.append(head_output)
return head_outputs
combined_loss = (
sum([out["loss"] for out in head_outputs])
if all("loss" in out and out["loss"] is not None for out in head_outputs)
else None
)
return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
elif self.has_parallel_adapters or isinstance(self.active_head, Parallel):
if len(self.active_head) != self.config.adapters.active_setup.parallel_channels:
raise ValueError("The number of parallel adapters and the number of active heads must match.")
Expand All @@ -681,7 +718,12 @@ def _get_head_input(outputs, cls_out, batch):
head_inputs, head_cls_input = _get_head_input(all_outputs, cls_output, batch_idx)
head_output = head_module(head_inputs, head_cls_input, attention_mask, return_dict, **kwargs)
head_outputs.append(head_output)
return head_outputs
combined_loss = (
torch.sum(torch.stack([out["loss"] for out in head_outputs]))
if all("loss" in out and out["loss"] is not None for out in head_outputs)
else None
)
return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
elif len(used_heads) > 1:
head_outputs = []
for head in used_heads:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers.testing_utils import require_torch, torch_device

from .test_adapter_common import AdapterModelTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin
from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin
from .test_adapter_conversion import ModelClassConversionTestMixin
from .test_adapter_fusion_common import AdapterFusionModelTestMixin
from .test_adapter_heads import PredictionHeadModelTestMixin
Expand Down Expand Up @@ -78,6 +78,7 @@ class BertAdapterTest(
PredictionHeadModelTestMixin,
AdapterTrainingTestMixin,
ParallelAdapterInferenceTestMixin,
ParallelTrainingMixin,
BertAdapterTestBase,
unittest.TestCase,
):
Expand Down Expand Up @@ -144,6 +145,7 @@ class DistilBertAdapterTest(
PredictionHeadModelTestMixin,
AdapterTrainingTestMixin,
ParallelAdapterInferenceTestMixin,
ParallelTrainingMixin,
DistilBertAdapterTestBase,
unittest.TestCase,
):
Expand Down Expand Up @@ -181,6 +183,7 @@ class BartAdapterTest(
PredictionHeadModelTestMixin,
AdapterTrainingTestMixin,
ParallelAdapterInferenceTestMixin,
ParallelTrainingMixin,
BartAdapterTestBase,
unittest.TestCase,
):
Expand Down Expand Up @@ -251,6 +254,7 @@ class GPT2AdapterTest(
PredictionHeadModelTestMixin,
AdapterTrainingTestMixin,
ParallelAdapterInferenceTestMixin,
ParallelTrainingMixin,
GPT2AdapterTestBase,
unittest.TestCase,
):
Expand Down
167 changes: 166 additions & 1 deletion tests/test_adapter_composition.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import copy
import random
import unittest

import torch

from transformers import AutoModelWithHeads, BertConfig, BertForSequenceClassification
from tests.test_adapter_training import filter_parameters
from transformers import (
AutoModelWithHeads,
AutoTokenizer,
BertConfig,
BertForSequenceClassification,
GlueDataset,
GlueDataTrainingArguments,
Trainer,
TrainingArguments,
)
from transformers.adapters.composition import BatchSplit, Fuse, Parallel, Split, Stack, parse_composition
from transformers.testing_utils import require_torch, torch_device

Expand Down Expand Up @@ -207,3 +219,156 @@ def test_batch_split_with_heads(self):
self.assertEqual(2, len(output))
self.assertTrue(torch.allclose(output[0]["logits"], outputs_a["logits"]))
self.assertTrue(torch.allclose(output[1]["logits"], outputs_b["logits"]))


def create_twin_adapters(model, name):
hSterz marked this conversation as resolved.
Show resolved Hide resolved
# create adapter
adapter1, adapter2 = name + "_1", name + "_2"
model.add_adapter(adapter1)
model.add_classification_head(adapter1)
# create a twin initialized with the same random weights
model.add_adapter(adapter2)
model.add_classification_head(adapter2)

state_dict = model.state_dict()
for k, v in state_dict.items():
if adapter1 in k:
state_dict[k.replace(adapter1, adapter2)] = v
model.load_state_dict(state_dict)

return adapter1, adapter2


class ParallelTrainingMixin:
def train_model(self, model, dataset):
# trains model in eval mode for 2 epochs
random.seed(42)
torch.manual_seed(42)
# Depending on the used optimizer the adapters are not exactly the same
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(2):
for data_input in dataset:
optimizer.zero_grad()
output = model(**data_input)
loss = output["loss"]
loss.backward()
optimizer.step()
return model

def test_parallel_training(self):
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelWithHeads.from_config(self.config())

model.add_adapter("mrpc1")
model.add_adapter("mrpc2")
model.add_classification_head("mrpc1", num_labels=2)
model.add_classification_head("mrpc2", num_labels=3)
model.eval()
model.active_adapters = Parallel("mrpc1", "mrpc2")
model.train_adapter(Parallel("mrpc1", "mrpc2"))

# all weights of the adapter should be activated
for k, v in filter_parameters(model, "adapters.mrpc1.").items():
self.assertTrue(v.requires_grad, k)
# all weights of the adapter not used for training should be freezed
for k, v in filter_parameters(model, "adapters.mrpc2.").items():
self.assertTrue(v.requires_grad, k)
# weights of the model should be freezed (check on some examples)
for k, v in filter_parameters(model, "encoder.layer.0.attention").items():
self.assertFalse(v.requires_grad, k)

state_dict_pre = copy.deepcopy(model.state_dict())

# setup dataset
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train")
training_args = TrainingArguments(
output_dir="./examples", do_train=True, learning_rate=0.1, max_steps=7, no_cuda=True
)

# evaluate
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()

for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()):
if "mrpc" in k1:
self.assertFalse(torch.equal(v1, v2))
else:
self.assertTrue(torch.equal(v1, v2))

def test_parallel_training_equivalent_to_single_adapters(self):
model = AutoModelWithHeads.from_config(self.config())
model.eval()

a1, a2 = create_twin_adapters(model, "a")
b1, b2 = create_twin_adapters(model, "b")

dataset = []
for i in range(3):
input_data = self.get_input_samples((3, 128), config=model.config)
input_data["labels"] = torch.randint(0, 2, (3, 1))
dataset.append(input_data)

for adapter in [a1, b1]:
model.active_head = adapter
model.set_active_adapters(adapter)
model.train_adapter(adapter)
model.eval()

model = self.train_model(model, dataset)

model.set_active_adapters(Parallel(a2, b2))
model.train_adapter((Parallel(a2, b2)))
model.eval()

model = self.train_model(model, dataset)

state_dict = model.state_dict()
for k, v in state_dict.items():
if a1 in k:
self.assertTrue(torch.allclose(v, state_dict[k.replace(a1, a2)]))
if b1 in k:
self.assertTrue(torch.allclose(v, state_dict[k.replace(b1, b2)]))

def test_parallel_training_single_forward_pass(self):
model = AutoModelWithHeads.from_config(self.config())
model.eval()

a1, a2 = create_twin_adapters(model, "a")
b1, b2 = create_twin_adapters(model, "b")

state_dict = model.state_dict()
for k, v in state_dict.items():
if a1 in k:
self.assertTrue(torch.equal(v, state_dict[k.replace(a1, a2)]))
if b1 in k:
self.assertTrue(torch.equal(v, state_dict[k.replace(b1, b2)]))

input_data = self.get_input_samples((3, 128), config=model.config)
input_data["labels"] = torch.randint(0, 2, (3, 1))

outputs = []
for adapter in [a1, b1]:
model.active_head = adapter
model.set_active_adapters(adapter)
model.train_adapter(adapter)
model.eval()
outputs.append(model(**input_data))

model.set_active_adapters(Parallel(a2, b2))
model.train_adapter((Parallel(a2, b2)))
model.eval()

parallel_outputs = model(**input_data)

for out1, out2 in zip(outputs, parallel_outputs.head_outputs):
self.assertTrue(torch.equal(out1["loss"], out2["loss"]))
self.assertTrue(torch.allclose(out1["logits"], out2["logits"]))
Loading