Skip to content

Commit

Permalink
Proper build() methods for TF (#27794)
Browse files Browse the repository at this point in the history
* Add a convenience method for building in your own name scope

* Second attempt at auto layer building

* Revert "Second attempt at auto layer building"

This reverts commit e03a3aa.

* Attempt #3

* Revert "Attempt #3"

This reverts commit b9df7a0.

* Add missing attributes that we're going to need later

* Add some attributes we're going to need later

* A fourth attempt! Feel the power flow through you!

* Revert "A fourth attempt! Feel the power flow through you!"

This reverts commit 6bf4aaf.

* Add more values we'll need later

* TF refactor that we'll need later

* Revert "TF refactor that we'll need later"

This reverts commit ca07202.

* Revert "Revert "TF refactor that we'll need later""

This reverts commit 1beb0f3.

* make fixup

* Attempt five!

* Revert "Attempt five!"

This reverts commit 3302207.

* Attempt six - this time don't add empty methods

* Revert "Attempt six - this time don't add empty methods"

This reverts commit 67d6012.

* Attempt seven - better base model class detection!

* Revert "Attempt seven - better base model class detection!"

This reverts commit 5f14845.

* Another attribute we'll need later

* Try again with the missing attribute!

* Revert "Try again with the missing attribute!"

This reverts commit 760c6f3.

* This is the attempt that will pierce the heavens!

* Revert "This is the attempt that will pierce the heavens!"

This reverts commit c868bb6.

* Attempt seven - snag list is steadily decreasing

* Revert "Attempt seven - snag list is steadily decreasing"

This reverts commit 46fbd97.

* Attempt eight - will an empty snag list do it?

* Revert "Attempt eight - will an empty snag list do it?"

This reverts commit 7c8a3c2.

* Fixes to Hubert issues that cause problems later

* Trying again with Conv1D/SeparableConv fixes

* Revert "Trying again with Conv1D/SeparableConv fixes"

This reverts commit 55092bc.

* Apply the build shape fixes to Wav2Vec2 as well

* One more attempt!

* Revert "One more attempt!"

This reverts commit 5ac3e4c.

* Another attempt!

* Revert "Another attempt!"

This reverts commit ea16d89.

* Let's see how many failures we get without the internal build method

* Fix OpenAI

* Fix MobileBERT

* (Mostly) fix GroupVIT

* Fix BLIP

* One more BLIP fix

* One more BLIP fix!

* Fix Regnet

* Finally fully fix GroupViT

* Fix Data2Vec and add the new AdaptivePool

* Fix Segformer

* Fix Albert

* Fix Deberta/DebertaV2

* Fix XLM

* Actually fix XLM

* Fix Flaubert

* Fix lxmert

* Fix Resnet

* Fix ConvBERT

* Fix ESM

* Fix Convnext / ConvnextV2

* Fix SAM

* Fix Efficientformer

* Fix LayoutLMv3

* Fix speech_to_text

* Fix mpnet and mobilevit

* Fix Swin

* Fix CTRL

* Fix CVT

* Fix DPR

* Fix Wav2Vec2

* Fix T5

* Fix Hubert

* Fix GPT2

* Fix Whisper

* Fix DeiT

* Fix the encoder-decoder / dual-encoder classes

* make fix-copies

* build in name scope

* Fix summarization test

* Fix tied weight names for BART + Blenderbot

* Fix tied weight name building

* Fix to TFESM weight building

* Update TF SAM

* Expand all the shapes out into Big Boy Shapes
  • Loading branch information
Rocketknight1 authored Dec 14, 2023
1 parent 52c3788 commit 050e0b4
Show file tree
Hide file tree
Showing 73 changed files with 11,039 additions and 503 deletions.
43 changes: 25 additions & 18 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from huggingface_hub import Repository, list_repo_files
from keras import backend as K
from packaging.version import parse
from tensorflow.python.util.keras_deps import get_call_context_function

from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
Expand Down Expand Up @@ -1122,6 +1121,10 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
)
return dummies

def build_in_name_scope(self):
with tf.name_scope(self.name):
self.build(input_shape=None)

@property
def framework(self) -> str:
"""
Expand All @@ -1130,15 +1133,7 @@ def framework(self) -> str:
return "tf"

def build(self, input_shape=None):
call_context = get_call_context_function()
if self.built or call_context().in_call:
self.built = True
else:
self.built = True
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
self._set_save_spec(self.input_signature)
self(self.dummy_inputs, training=False)
pass # This is just here to make sure we don't call the superclass build()

def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down Expand Up @@ -1869,7 +1864,7 @@ def set_input_embeddings(self, value):
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
main_layer.set_input_embeddings(value)

def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
Expand All @@ -1886,7 +1881,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
return lm_head.get_output_embeddings()
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()

return lm_head().get_output_embeddings()

Expand All @@ -1906,7 +1901,7 @@ def set_output_embeddings(self, value):
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
lm_head.set_output_embeddings(value)

def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
Expand Down Expand Up @@ -1944,7 +1939,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
try:
return lm_head.get_bias()
except AttributeError:
self.build()
self.build_in_name_scope()

return lm_head.get_bias()
return None
Expand All @@ -1962,7 +1957,7 @@ def set_bias(self, value):
try:
lm_head.set_bias(value)
except AttributeError:
self.build()
self.build_in_name_scope()
lm_head.set_bias(value)

def get_lm_head(self) -> tf.keras.layers.Layer:
Expand Down Expand Up @@ -2049,7 +2044,7 @@ def _get_word_embedding_weight(model, embedding_layer):
# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model.build()
model.build_in_name_scope()

embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
Expand Down Expand Up @@ -2914,9 +2909,9 @@ def from_pretrained(
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs
else:
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs

if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
Expand Down Expand Up @@ -3215,6 +3210,9 @@ def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
self.initializer_range = initializer_range

def build(self, input_shape):
if self.built:
return
self.built = True
self.weight = self.add_weight(
"weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
)
Expand Down Expand Up @@ -3398,6 +3396,7 @@ def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **
self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
self.hidden_size = config.hidden_size

def call(self, inputs, cls_index=None, training=False):
if not isinstance(inputs, (dict, tuple, list)):
Expand Down Expand Up @@ -3450,6 +3449,14 @@ def call(self, inputs, cls_index=None, training=False):

return output

def build(self, input_shape):
if self.built:
return
self.built = True
if getattr(self, "summary", None) is not None:
with tf.name_scope("summary"):
self.summary.build(self.hidden_size)


def get_initializer(initializer_range: float = 0.02) -> tf.keras.initializers.TruncatedNormal:
"""
Expand Down
Loading

0 comments on commit 050e0b4

Please sign in to comment.