Skip to content

Commit

Permalink
Revert "Revert "TF refactor that we'll need later""
Browse files Browse the repository at this point in the history
This reverts commit 1beb0f3.
  • Loading branch information
Rocketknight1 committed Dec 4, 2023
1 parent 1beb0f3 commit 4492de7
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,17 @@


class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
def __init__(self, hidden_size, epsilon=1e-6, **kwargs):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__(**kwargs)
self.variance_epsilon = epsilon
self.hidden_size = hidden_size

def build(self, input_shape):
"""Build shared word embedding layer"""
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones")
super().build(input_shape)

def call(self, hidden_states):
Expand Down Expand Up @@ -157,7 +158,7 @@ def __init__(self, config, **kwargs):
else:
self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")

self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def call(self, hidden_states, training=False):
Expand Down Expand Up @@ -439,7 +440,7 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs):
has_relative_attention_bias=has_relative_attention_bias,
name="SelfAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def call(
Expand Down Expand Up @@ -477,7 +478,7 @@ def __init__(self, config, **kwargs):
has_relative_attention_bias=False,
name="EncDecAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def call(
Expand Down Expand Up @@ -640,7 +641,7 @@ def __init__(self, config, embed_tokens=None, **kwargs):
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}")
for i in range(config.num_layers)
]
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.final_layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def _prune_heads(self, heads_to_prune):
Expand Down

0 comments on commit 4492de7

Please sign in to comment.