Skip to content

Commit

Permalink
Fix bug in BEiT adapters (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
jannik-brinkmann committed Oct 27, 2022
1 parent 5df99d4 commit 2aa1d36
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 18 deletions.
5 changes: 0 additions & 5 deletions src/transformers/adapters/mixins/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.attention_adapters._init_adapter_modules()


class BeitOutputAdaptersMixin:
"""Adds adapters to the BeitOutput module."""

def _init_adapter_modules(self):
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.output_adapters._init_adapter_modules()

Expand Down
19 changes: 6 additions & 13 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@
from ...activations import ACT2FN
from ...adapters.context import ForwardContext
from ...adapters.lora import Linear as LoRALinear
from ...adapters.mixins.beit import (
BeitLayerAdaptersMixin,
BeitModelAdaptersMixin,
BeitModelWithHeadsAdaptersMixin,
BeitOutputAdaptersMixin,
)
from ...adapters.mixins.beit import BeitLayerAdaptersMixin, BeitModelAdaptersMixin, BeitModelWithHeadsAdaptersMixin
from ...adapters.prefix_tuning import PrefixTuningShim
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -374,19 +369,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class BeitOutput(BeitOutputAdaptersMixin, nn.Module):
class BeitOutput(nn.Module):
def __init__(self, config: BeitConfig) -> None:
super().__init__()
self.config = config

self.dense = LoRALinear(config.intermediate_size, config.hidden_size, "output", config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self._init_adapter_modules()

def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.output_adapters.adapter_layer_forward(hidden_states, input_tensor, None)
return hidden_states


Expand Down Expand Up @@ -431,14 +424,14 @@ def forward(
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

hidden_states = self.attention_adapters.adapter_layer_forward(attention_output, hidden_states, None)

# apply lambda_1 if present
if self.lambda_1 is not None:
attention_output = self.lambda_1 * attention_output

# first residual connection
hidden_states = self.drop_path(attention_output) + hidden_states
hidden_states = self.attention_adapters.adapter_layer_forward(
self.drop_path(attention_output), hidden_states, None
)

# in BEiT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
Expand All @@ -450,7 +443,7 @@ def forward(
layer_output = self.lambda_2 * layer_output

# second residual connection
layer_output = self.drop_path(layer_output) + hidden_states
layer_output = self.output_adapters.adapter_layer_forward(self.drop_path(layer_output), hidden_states, None)

outputs = (layer_output,) + outputs

Expand Down

0 comments on commit 2aa1d36

Please sign in to comment.