Skip to content

Commit

Permalink
Add torch f-t5
Browse files Browse the repository at this point in the history
  • Loading branch information
cccntu committed Jul 14, 2021
1 parent c7cd32c commit 54d12d5
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion src/transformers/models/f_t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,36 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
return outputs


class FT5FourierTransform(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.fourier_transform = torch.fft.fftn
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
attn_output = self.fourier_transform(hidden_states).real
outputs = (attn_output,) + (None,) + (None,)

if output_attentions:
outputs = outputs + (None,)
return outputs

class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
if config.is_decoder:
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
else:
self.SelfAttention = FT5FourierTransform(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)

Expand Down

0 comments on commit 54d12d5

Please sign in to comment.