From 1beb0f39f293ed9c27594575e1c849aadeb15c13 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 4 Dec 2023 18:02:45 +0000 Subject: [PATCH] Revert "TF refactor that we'll need later" This reverts commit ca07202fb5b7b7436b893baa8d688b4f348ea7b9. --- src/transformers/models/t5/modeling_tf_t5.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 9272acf047dd47..f0de49645a9b5f 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -75,17 +75,16 @@ class TFT5LayerNorm(tf.keras.layers.Layer): - def __init__(self, hidden_size, epsilon=1e-6, **kwargs): + def __init__(self, epsilon=1e-6, **kwargs): """ Construct a layernorm module in the T5 style No bias and no subtraction of mean. """ super().__init__(**kwargs) self.variance_epsilon = epsilon - self.hidden_size = hidden_size def build(self, input_shape): """Build shared word embedding layer""" - self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones") + self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones") super().build(input_shape) def call(self, hidden_states): @@ -158,7 +157,7 @@ def __init__(self, config, **kwargs): else: self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense") - self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call(self, hidden_states, training=False): @@ -440,7 +439,7 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs): has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention", ) - self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call( @@ -478,7 +477,7 @@ def __init__(self, config, **kwargs): has_relative_attention_bias=False, name="EncDecAttention", ) - self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call( @@ -641,7 +640,7 @@ def __init__(self, config, embed_tokens=None, **kwargs): TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}") for i in range(config.num_layers) ] - self.final_layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm") + self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def _prune_heads(self, heads_to_prune):