Skip to content

Commit

Permalink
One more attempt!
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 5, 2023
1 parent 801cb41 commit 5ac3e4c
Show file tree
Hide file tree
Showing 66 changed files with 10,437 additions and 152 deletions.
186 changes: 182 additions & 4 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, config: AlbertConfig, **kwargs):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
Expand All @@ -168,7 +168,12 @@ def build(self, input_shape: tf.TensorShape):
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])

# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(
Expand Down Expand Up @@ -246,6 +251,7 @@ def __init__(self, config: AlbertConfig, **kwargs):
# Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
self.attention_dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config

def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
Expand Down Expand Up @@ -307,6 +313,26 @@ def call(

return outputs

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build(self.config.hidden_size)
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build(self.config.hidden_size)
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build(self.config.hidden_size)
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build(self.config.hidden_size)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])


class TFAlbertLayer(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
Expand All @@ -329,6 +355,7 @@ def __init__(self, config: AlbertConfig, **kwargs):
epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config

def call(
self,
Expand Down Expand Up @@ -356,6 +383,23 @@ def call(

return outputs

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "ffn", None) is not None:
with tf.name_scope(self.ffn.name):
self.ffn.build(self.config.hidden_size)
if getattr(self, "ffn_output", None) is not None:
with tf.name_scope(self.ffn_output.name):
self.ffn_output.build(self.config.intermediate_size)
if getattr(self, "full_layer_layer_norm", None) is not None:
with tf.name_scope(self.full_layer_layer_norm.name):
self.full_layer_layer_norm.build([None, None, self.config.hidden_size])


class TFAlbertLayerGroup(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
Expand Down Expand Up @@ -399,6 +443,15 @@ def call(

return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert_layers", None) is not None:
for layer in self.albert_layers:
with tf.name_scope(layer.name):
layer.build(None)


class TFAlbertTransformer(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
Expand All @@ -416,6 +469,7 @@ def __init__(self, config: AlbertConfig, **kwargs):
self.albert_layer_groups = [
TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
]
self.config = config

def call(
self,
Expand Down Expand Up @@ -457,6 +511,18 @@ def call(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedding_hidden_mapping_in", None) is not None:
with tf.name_scope(self.embedding_hidden_mapping_in.name):
self.embedding_hidden_mapping_in.build(self.config.embedding_size)
if getattr(self, "albert_layer_groups", None) is not None:
for layer in self.albert_layer_groups:
with tf.name_scope(layer.name):
layer.build(None)


class TFAlbertPreTrainedModel(TFPreTrainedModel):
"""
Expand Down Expand Up @@ -488,13 +554,21 @@ def __init__(self, config: AlbertConfig, input_embeddings: tf.keras.layers.Layer
# an output-only bias for each token.
self.decoder = input_embeddings

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
self.decoder_bias = self.add_weight(
shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
)

super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build(self.config.hidden_size)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])

def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.decoder
Expand Down Expand Up @@ -650,6 +724,20 @@ def call(
attentions=encoder_outputs.attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)


@dataclass
class TFAlbertForPreTrainingOutput(ModelOutput):
Expand Down Expand Up @@ -825,6 +913,14 @@ def call(

return outputs

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -921,6 +1017,20 @@ def call(
attentions=outputs.attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
if getattr(self, "sop_classifier", None) is not None:
with tf.name_scope(self.sop_classifier.name):
self.sop_classifier.build(None)


class TFAlbertSOPHead(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
Expand All @@ -932,13 +1042,22 @@ def __init__(self, config: AlbertConfig, **kwargs):
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config

def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
logits = self.classifier(inputs=dropout_pooled_output)

return logits

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(self.config.hidden_size)


@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
Expand Down Expand Up @@ -1035,6 +1154,17 @@ def call(
attentions=outputs.attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)


@add_start_docstrings(
"""
Expand All @@ -1058,6 +1188,7 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config

@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -1117,6 +1248,17 @@ def call(
attentions=outputs.attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(self.config.hidden_size)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1145,6 +1287,7 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config

@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -1200,6 +1343,17 @@ def call(
attentions=outputs.attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(self.config.hidden_size)


@add_start_docstrings(
"""
Expand All @@ -1221,6 +1375,7 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
self.qa_outputs = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config

@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -1295,6 +1450,17 @@ def call(
attentions=outputs.attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build(self.config.hidden_size)


@add_start_docstrings(
"""
Expand All @@ -1316,6 +1482,7 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
self.classifier = tf.keras.layers.Dense(
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config

@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
Expand Down Expand Up @@ -1394,3 +1561,14 @@ def call(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(self.config.hidden_size)
Loading

0 comments on commit 5ac3e4c

Please sign in to comment.