Skip to content

Commit

Permalink
Revert "TF refactor that we'll need later"
Browse files Browse the repository at this point in the history
This reverts commit ca07202.
  • Loading branch information
Rocketknight1 committed Dec 4, 2023
1 parent ca07202 commit 1beb0f3
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,16 @@


class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, hidden_size, epsilon=1e-6, **kwargs):
def __init__(self, epsilon=1e-6, **kwargs):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__(**kwargs)
self.variance_epsilon = epsilon
self.hidden_size = hidden_size

def build(self, input_shape):
"""Build shared word embedding layer"""
self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones")
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape)

def call(self, hidden_states):
Expand Down Expand Up @@ -158,7 +157,7 @@ def __init__(self, config, **kwargs):
else:
self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")

self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def call(self, hidden_states, training=False):
Expand Down Expand Up @@ -440,7 +439,7 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs):
has_relative_attention_bias=has_relative_attention_bias,
name="SelfAttention",
)
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def call(
Expand Down Expand Up @@ -478,7 +477,7 @@ def __init__(self, config, **kwargs):
has_relative_attention_bias=False,
name="EncDecAttention",
)
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def call(
Expand Down Expand Up @@ -641,7 +640,7 @@ def __init__(self, config, embed_tokens=None, **kwargs):
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}")
for i in range(config.num_layers)
]
self.final_layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

def _prune_heads(self, heads_to_prune):
Expand Down

0 comments on commit 1beb0f3

Please sign in to comment.