Skip to content

Commit

Permalink
Fix T5 adapter layer input (#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 18, 2023
1 parent cffdf39 commit dcd9f70
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
49 changes: 34 additions & 15 deletions src/transformers/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,42 +512,61 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
hidden_states = torch.cat(children_hidden, 0)
return hidden_states

def adapter_layer_forward(self, hidden_states, input_tensor, layer_norm):
"""
Called for each forward pass through adapters.
def adapter_layer_forward(self, hidden_states, residual_input, layer_norm):
"""Forward pass through the adapter layer.
NOTE: This method should only be called if the calling module directly inherits from AdapterLayer. Otherwise,
call the regular forward() method.
Args:
hidden_states (torch.Tensor): Input hidden states to the adapter layer.
residual_input (torch.Tensor): Residual input to the adapter layer.
layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer.
Returns:
torch.Tensor: Output hidden states of the adapter layer.
"""
adapter_setup = self.get_active_setup(self.adapters)
if adapter_setup is not None:
input_hidden_states = hidden_states

if isinstance(adapter_setup, Stack):
hidden_states, _, input_tensor = self.adapter_stack(
adapter_setup, hidden_states, input_tensor, layer_norm
hidden_states, _, residual_input = self.adapter_stack(
adapter_setup, hidden_states, residual_input, layer_norm
)
elif isinstance(adapter_setup, Fuse):
hidden_states = self.adapter_fusion(adapter_setup, hidden_states, input_tensor, layer_norm)
hidden_states = self.adapter_fusion(adapter_setup, hidden_states, residual_input, layer_norm)
elif isinstance(adapter_setup, Split):
hidden_states = self.adapter_split(adapter_setup, hidden_states, input_tensor, layer_norm)
hidden_states = self.adapter_split(adapter_setup, hidden_states, residual_input, layer_norm)
elif isinstance(adapter_setup, Parallel):
# notice that we are overriding input tensor here to keep the same dim as hidden_states for the residual
# in case we were blowing up the batch for parallel processing of multiple adapters for the same input
hidden_states, input_tensor = self.adapter_parallel(
adapter_setup, hidden_states, input_tensor, layer_norm
hidden_states, residual_input = self.adapter_parallel(
adapter_setup, hidden_states, residual_input, layer_norm
)
elif isinstance(adapter_setup, BatchSplit):
hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, input_tensor, layer_norm)
hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, residual_input, layer_norm)
else:
raise ValueError(f"Invalid adapter setup {adapter_setup}")

last_adapter = self.adapters[adapter_setup.last()]
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, input_tensor, layer_norm)
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm)

elif layer_norm:
hidden_states = layer_norm(hidden_states + input_tensor)
hidden_states = layer_norm(hidden_states + residual_input)
else:
hidden_states = hidden_states + input_tensor
hidden_states = hidden_states + residual_input

return hidden_states

def forward(self, hidden_states, input_tensor, layer_norm):
return self.adapter_layer_forward(hidden_states, input_tensor, layer_norm)
def forward(self, hidden_states, residual_input, layer_norm):
"""Forward pass through the adapter layer.
Args:
hidden_states (torch.Tensor): Input hidden states to the adapter layer.
residual_input (torch.Tensor): Residual input to the adapter layer.
layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer.
Returns:
torch.Tensor: Output hidden states of the adapter layer.
"""
return self.adapter_layer_forward(hidden_states, residual_input, layer_norm)
12 changes: 9 additions & 3 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def __init__(self, config: T5Config):
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = self.adapter_layer_forward(hidden_states, self.dropout(forwarded_states), None)
hidden_states = self.adapter_layer_forward(
hidden_states=self.dropout(forwarded_states), residual_input=hidden_states, layer_norm=None
)
return hidden_states


Expand Down Expand Up @@ -609,7 +611,9 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = self.adapter_layer_forward(hidden_states, self.dropout(attention_output[0]), None)
hidden_states = self.adapter_layer_forward(
hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None
)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs

Expand Down Expand Up @@ -647,7 +651,9 @@ def forward(
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = self.adapter_layer_forward(hidden_states, self.dropout(attention_output[0]), None)
layer_output = self.adapter_layer_forward(
hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None
)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs

Expand Down

0 comments on commit dcd9f70

Please sign in to comment.