Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in BEiT adapters #439

Merged
merged 1 commit into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/transformers/adapters/mixins/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.attention_adapters._init_adapter_modules()


class BeitOutputAdaptersMixin:
"""Adds adapters to the BeitOutput module."""

def _init_adapter_modules(self):
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.output_adapters._init_adapter_modules()

Expand Down
19 changes: 6 additions & 13 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@
from ...activations import ACT2FN
from ...adapters.context import ForwardContext
from ...adapters.lora import Linear as LoRALinear
from ...adapters.mixins.beit import (
BeitLayerAdaptersMixin,
BeitModelAdaptersMixin,
BeitModelWithHeadsAdaptersMixin,
BeitOutputAdaptersMixin,
)
from ...adapters.mixins.beit import BeitLayerAdaptersMixin, BeitModelAdaptersMixin, BeitModelWithHeadsAdaptersMixin
from ...adapters.prefix_tuning import PrefixTuningShim
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -374,19 +369,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class BeitOutput(BeitOutputAdaptersMixin, nn.Module):
class BeitOutput(nn.Module):
def __init__(self, config: BeitConfig) -> None:
super().__init__()
self.config = config

self.dense = LoRALinear(config.intermediate_size, config.hidden_size, "output", config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self._init_adapter_modules()

def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.output_adapters.adapter_layer_forward(hidden_states, input_tensor, None)
return hidden_states


Expand Down Expand Up @@ -431,14 +424,14 @@ def forward(
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

hidden_states = self.attention_adapters.adapter_layer_forward(attention_output, hidden_states, None)

# apply lambda_1 if present
if self.lambda_1 is not None:
attention_output = self.lambda_1 * attention_output

# first residual connection
hidden_states = self.drop_path(attention_output) + hidden_states
hidden_states = self.attention_adapters.adapter_layer_forward(
self.drop_path(attention_output), hidden_states, None
)

# in BEiT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
Expand All @@ -450,7 +443,7 @@ def forward(
layer_output = self.lambda_2 * layer_output

# second residual connection
layer_output = self.drop_path(layer_output) + hidden_states
layer_output = self.output_adapters.adapter_layer_forward(self.drop_path(layer_output), hidden_states, None)

outputs = (layer_output,) + outputs

Expand Down