Skip to content

Commit

Permalink
Hardcode GELU as the intermediate activation for ESM (huggingface#22892)
Browse files Browse the repository at this point in the history
* Hardcode GELU as the intermediate activation for ESM

* Sneak a quick fix to the weight tying in too

* Make the call to gelu explicit
  • Loading branch information
Rocketknight1 authored and novice03 committed Jun 23, 2023
1 parent c4033d6 commit f590440
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions src/transformers/models/esm/modeling_tf_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
get_tf_activation,
shape_list,
unpack_inputs,
)
Expand Down Expand Up @@ -476,24 +475,19 @@ def call(
return outputs


# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Esm
class TFEsmIntermediate(tf.keras.layers.Layer):
def __init__(self, config: EsmConfig, **kwargs):
super().__init__(**kwargs)

self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
units=config.intermediate_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)

if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act

def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)

hidden_states = tf.nn.gelu(hidden_states)
return hidden_states


Expand Down Expand Up @@ -1216,23 +1210,21 @@ def __init__(self, config, name=None):
)

self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")

self.decoder = None
if config.tie_word_embeddings:
self.decoder = None
else:
self.decoder = Dense(
config.vocab_size,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
use_bias=False,
)
self.config = config

def build(self, input_shape):
super().build(input_shape)
# Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight
if not self.config.tie_word_embeddings:
if self.decoder is not None:
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
self.decoder = self.add_weight(
"decoder.weight",
shape=(self.config.hidden_size, self.config.vocab_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
)
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)

def get_bias(self):
Expand All @@ -1244,7 +1236,10 @@ def call(self, features):
x = self.layer_norm(x)

# project back to size of vocabulary with bias
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
if self.config.tie_word_embeddings:
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
else:
x = self.decoder(x) + self.bias
return x


Expand Down

0 comments on commit f590440

Please sign in to comment.