Skip to content

Commit

Permalink
Add support for Stack, Parallel & BatchSplit composition to pre…
Browse files Browse the repository at this point in the history
…fix tuning (#476)
  • Loading branch information
calpt committed Feb 14, 2023
1 parent 1540ab0 commit 1adc4a5
Show file tree
Hide file tree
Showing 30 changed files with 680 additions and 209 deletions.
19 changes: 17 additions & 2 deletions adapter_docs/adapter_composition.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ One of the great advantages of using adapters is the possibility to combine mult
To enable such adapter compositions, `adapter-transformers` comes with a modular and flexible concept to define how the input to the model should flow through the available adapters.
This not only allows stacking ([_MAD-X_](https://arxiv.org/pdf/2005.00052.pdf)) and fusing ([_AdapterFusion_](https://arxiv.org/pdf/2005.00247.pdf)) adapters, but also even more complex adapter setups.

## Adapter activation
## Adapter Activation

The single location where all the adapter composition magic happens is the `active_adapters` property of the model class.
In the simplest case, you can set the name of a single adapter here to activate it:
Expand Down Expand Up @@ -33,9 +33,21 @@ with AdapterSetup("adapter_name"):
outputs = model(**inputs)
```

## Composition Blocks - Overview

The basic building blocks of the more advanced setups are simple objects derived from `AdapterCompositionBlock`,
each representing a different possibility to combine single adapters.
They are presented in more detail in the following.
The following table gives an overview on the supported composition blocks and their support by different adapter methods.

| Block | (Bottleneck)<br> Adapters | Prefix<br> Tuning | Compacter | LoRA | (IA)³ |
| --- | --- | --- | --- | --- | --- |
| [`Stack`](#stack) |||| | |
| [`Fuse`](#fuse) || || | |
| [`Split`](#split) || || | |
| [`BatchSplit`](#batchsplit) |||| | |
| [`Parallel`](#parallel) |||| | |

Next, we present all composition blocks in more detail.

## `Stack`

Expand Down Expand Up @@ -66,6 +78,9 @@ model.add_adapter("c")
model.active_adapters = ac.Stack("a", "b", "c")
```

Since v3.2.0, stacking is also supported for prefix tuning.
Stacked prefixes are prepended to the input states from right to left, i.e. `Stack("a", "b", "c")` will first prepend prefix states for "a" to the input vectors, then prepend "b" to the resulting vectors etc.

In v1.x of `adapter-transformers`, stacking adapters was done using a list of adapter names, i.e. the example from above would be defined as `["a", "b", "c"]`.
For backwards compatibility, you can still do this, although it is recommended to use the new syntax.

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def adjust_tensors_for_parallel(hidden_states, *tensors):
"""
outputs = []
for tensor in tensors:
if tensor is not None and hidden_states.shape[0] != tensor.shape[0]:
if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]:
repeats = [1] * len(tensor.shape)
repeats[0] = hidden_states.shape[0] // tensor.shape[0]
new_tensor = tensor.repeat(*repeats)
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn

from .composition import AdapterCompositionBlock, BatchSplit, Fuse, Parallel, Split, Stack
from .composition import AdapterCompositionBlock, BatchSplit, Fuse, Parallel, Split, Stack, adjust_tensors_for_parallel
from .configuration import AdapterConfig
from .context import AdapterSetup, ForwardContext
from .modeling import Adapter, BertFusion, ParallelAdapter
Expand Down Expand Up @@ -525,6 +525,10 @@ def adapter_layer_forward(self, hidden_states, residual_input, layer_norm):
Returns:
torch.Tensor: Output hidden states of the adapter layer.
"""
# Batch sizes might be different due to prefix tuning w. Parallel block
(residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input)
# Replicate in both directions as residual might be larger (e.g. GPT-J)
(hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states)
adapter_setup = self.get_active_setup(self.adapters)
if adapter_setup is not None:
input_hidden_states = hidden_states
Expand Down
Loading

0 comments on commit 1adc4a5

Please sign in to comment.