From 43de4105d12dc75686aecf3c3e840ef2fa240c29 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 6 Jun 2023 18:30:51 +0100 Subject: [PATCH] Move TF building to an actual build() method (#23760) * A fun new PR where I break the entire codebase again * A fun new PR where I break the entire codebase again * Handle cross-attention * Move calls to model(model.dummy_inputs) to the new build() method * Seeing what fails with the build context thing * make fix-copies * Let's see what fails with new build methods * Fix the pytorch crossload build calls * Fix the overridden build methods in vision_text_dual_encoder * Make sure all our build methods set self.built or call super().build(), which also sets it * make fix-copies * Remove finished TODO * Tentatively remove unneeded (?) line * Transpose b in deberta correctly and remove unused threading local * Get rid of build_with_dummies and all it stands for * Rollback some changes to TF-PT crossloading * Correctly call super().build() --- src/transformers/modeling_tf_pytorch_utils.py | 3 - src/transformers/modeling_tf_utils.py | 35 +++++--- .../models/blip/modeling_tf_blip.py | 6 +- .../models/blip/modeling_tf_blip_text.py | 3 +- .../models/clip/modeling_tf_clip.py | 6 +- .../models/convbert/modeling_tf_convbert.py | 3 +- .../models/convnext/modeling_tf_convnext.py | 2 +- .../models/ctrl/modeling_tf_ctrl.py | 2 +- .../data2vec/modeling_tf_data2vec_vision.py | 2 +- .../models/deberta/modeling_tf_deberta.py | 9 +-- .../models/dpr/modeling_tf_dpr.py | 6 +- .../models/groupvit/modeling_tf_groupvit.py | 2 +- .../models/led/modeling_tf_led.py | 81 ++++++++++--------- .../longformer/modeling_tf_longformer.py | 81 ++++++++++--------- .../mobilebert/modeling_tf_mobilebert.py | 1 + .../models/sam/modeling_tf_sam.py | 1 + .../modeling_tf_vision_text_dual_encoder.py | 3 +- .../models/vit_mae/modeling_tf_vit_mae.py | 7 +- .../models/whisper/modeling_tf_whisper.py | 11 +-- .../models/xlnet/modeling_tf_xlnet.py | 1 + tests/models/bart/test_modeling_tf_bart.py | 2 +- .../test_modeling_tf_encoder_decoder.py | 4 +- tests/models/gpt2/test_modeling_tf_gpt2.py | 2 +- tests/models/opt/test_modeling_tf_opt.py | 2 +- ...test_modeling_tf_vision_encoder_decoder.py | 4 +- .../whisper/test_modeling_tf_whisper.py | 2 +- tests/test_modeling_tf_common.py | 16 ++-- 27 files changed, 159 insertions(+), 138 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 3b1c030699b..3d08be1a8a1 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -341,9 +341,6 @@ def load_pytorch_state_dict_in_tf2_model( K.batch_set_value(weight_value_tuples) - if tf_inputs is not None: - tf_model(tf_inputs, training=False) # Make sure restore ops are run - logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.") unexpected_keys = list(all_pytorch_weights) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index ba9f4ec87f4..e278fceb4a4 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -40,7 +40,12 @@ from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save from .generation import GenerationConfig, TFGenerationMixin -from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list +from .tf_utils import ( + expand_1d, + load_attributes_from_hdf5_group, + save_attributes_to_hdf5_group, + shape_list, +) from .utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -69,11 +74,14 @@ if parse(tf.__version__).minor >= 13: from keras import backend as K from keras.__internal__ import KerasTensor + from keras.engine.base_layer_utils import call_context elif parse(tf.__version__).minor >= 11: from keras import backend as K + from keras.engine.base_layer_utils import call_context from keras.engine.keras_tensor import KerasTensor else: from tensorflow.python.keras import backend as K + from tensorflow.python.keras.engine import call_context from tensorflow.python.keras.engine.keras_tensor import KerasTensor @@ -1140,6 +1148,13 @@ def framework(self) -> str: """ return "tf" + def build(self, input_shape=None): + if self.built or call_context().in_call: + self.built = True + else: + self(self.dummy_inputs, training=False) + self.built = True + def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) if not isinstance(config, PretrainedConfig): @@ -1867,7 +1882,7 @@ def set_input_embeddings(self, value): main_layer.set_input_embeddings(value) except AttributeError: logger.info("Building the model") - self(self.dummy_inputs) + self.build() main_layer.set_input_embeddings(value) def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]: @@ -1884,7 +1899,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]: return lm_head.get_output_embeddings() except AttributeError: logger.info("Building the model") - self(self.dummy_inputs) + self.build() return lm_head().get_output_embeddings() @@ -1904,7 +1919,7 @@ def set_output_embeddings(self, value): lm_head.set_output_embeddings(value) except AttributeError: logger.info("Building the model") - self(self.dummy_inputs) + self.build() lm_head.set_output_embeddings(value) def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]: @@ -1942,7 +1957,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]: try: return lm_head.get_bias() except AttributeError: - self(self.dummy_inputs) + self.build() return lm_head.get_bias() return None @@ -1960,7 +1975,7 @@ def set_bias(self, value): try: lm_head.set_bias(value) except AttributeError: - self(self.dummy_inputs) + self.build() lm_head.set_bias(value) def get_lm_head(self) -> tf.keras.layers.Layer: @@ -2047,7 +2062,7 @@ def _get_word_embedding_weight(model, embedding_layer): # The reason why the attributes don't exist might be # because the model is not built, so retry getting # the argument after building the model - model(model.dummy_inputs) + model.build() embeds = getattr(embedding_layer, "weight", None) if embeds is not None: @@ -2870,9 +2885,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # we might need to extend the variable scope for composite models if load_weight_prefix is not None: with tf.compat.v1.variable_scope(load_weight_prefix): - model(model.dummy_inputs) # build the network with dummy inputs + model.build() # build the network with dummy inputs else: - model(model.dummy_inputs) # build the network with dummy inputs + model.build() # build the network with dummy inputs if safetensors_from_pt: from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model @@ -2925,8 +2940,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " ) - model(model.dummy_inputs) # Make sure restore ops are run - if cls._keys_to_ignore_on_load_missing is not None: for pat in cls._keys_to_ignore_on_load_missing: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] diff --git a/src/transformers/models/blip/modeling_tf_blip.py b/src/transformers/models/blip/modeling_tf_blip.py index 428151ea9a3..b94c005eb48 100644 --- a/src/transformers/models/blip/modeling_tf_blip.py +++ b/src/transformers/models/blip/modeling_tf_blip.py @@ -258,6 +258,7 @@ def build(self, input_shape): trainable=True, name="position_embedding", ) + super().build(input_shape) def call(self, pixel_values: tf.Tensor) -> tf.Tensor: # Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch @@ -282,7 +283,7 @@ def __init__(self, config: BlipTextConfig, **kwargs): self.config = config - def build(self, input_shape: tf.TensorShape): + def build(self, input_shape: tf.TensorShape = None): with tf.name_scope("token_embedding"): self.weight = self.add_weight( shape=(self.config.vocab_size, self.embed_dim), @@ -757,13 +758,14 @@ def __init__(self, config: BlipConfig, *args, **kwargs): self.config = config - def build(self, input_shape): + def build(self, input_shape=None): self.logit_scale = self.add_weight( name="logit_scale", shape=[], initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value), trainable=True, ) + super().build(input_shape) @unpack_inputs def call( diff --git a/src/transformers/models/blip/modeling_tf_blip_text.py b/src/transformers/models/blip/modeling_tf_blip_text.py index 19ebdac62e2..6fef07e8a3f 100644 --- a/src/transformers/models/blip/modeling_tf_blip_text.py +++ b/src/transformers/models/blip/modeling_tf_blip_text.py @@ -543,8 +543,9 @@ def __init__(self, config, **kwargs): ) self.config = config - def build(self, input_shape): + def build(self, input_shape=None): self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) + super().build(input_shape) def call(self, hidden_states): hidden_states = self.transform(hidden_states) diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py index 778f1ed2c92..009e474440d 100644 --- a/src/transformers/models/clip/modeling_tf_clip.py +++ b/src/transformers/models/clip/modeling_tf_clip.py @@ -151,7 +151,7 @@ def __init__(self, config: CLIPVisionConfig, **kwargs): name="patch_embedding", ) - def build(self, input_shape: tf.TensorShape): + def build(self, input_shape: tf.TensorShape = None): factor = self.config.initializer_factor self.class_embedding = self.add_weight( @@ -204,7 +204,7 @@ def __init__(self, config: CLIPTextConfig, **kwargs): self.config = config - def build(self, input_shape: tf.TensorShape): + def build(self, input_shape: tf.TensorShape = None): with tf.name_scope("token_embedding"): self.weight = self.add_weight( shape=(self.config.vocab_size, self.embed_dim), @@ -739,7 +739,7 @@ def __init__(self, config: CLIPConfig, **kwargs): name="text_projection", ) - def build(self, input_shape: tf.TensorShape): + def build(self, input_shape: tf.TensorShape = None): self.logit_scale = self.add_weight( shape=(1,), initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value), diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py index 9b2bf2383bb..4beb01cb78b 100644 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ b/src/transformers/models/convbert/modeling_tf_convbert.py @@ -346,7 +346,7 @@ def __init__(self, input_size, output_size, num_groups, kernel_initializer, **kw self.group_in_dim = self.input_size // self.num_groups self.group_out_dim = self.output_size // self.num_groups - def build(self, input_shape): + def build(self, input_shape=None): self.kernel = self.add_weight( "kernel", shape=[self.group_out_dim, self.group_in_dim, self.num_groups], @@ -357,6 +357,7 @@ def build(self, input_shape): self.bias = self.add_weight( "bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True ) + super().build(input_shape) def call(self, hidden_states): batch_size = shape_list(hidden_states)[0] diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py index 23a77a928ec..1629988900a 100644 --- a/src/transformers/models/convnext/modeling_tf_convnext.py +++ b/src/transformers/models/convnext/modeling_tf_convnext.py @@ -155,7 +155,7 @@ def __init__(self, config, dim, drop_path=0.0, **kwargs): else tf.keras.layers.Activation("linear", name="drop_path") ) - def build(self, input_shape: tf.TensorShape): + def build(self, input_shape: tf.TensorShape = None): # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa) self.layer_scale_parameter = ( self.add_weight( diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 4dd9e739250..18c8c2c5883 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -576,7 +576,7 @@ def __init__(self, config, input_embeddings, **kwargs): # an output-only bias for each token. self.input_embeddings = input_embeddings - def build(self, input_shape): + def build(self, input_shape=None): self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index 8ebb8c68ff8..ee8bec20a01 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -464,7 +464,7 @@ def __init__( ) self.init_values = config.layer_scale_init_value - def build(self, input_shape: tf.TensorShape): + def build(self, input_shape: tf.TensorShape = None): if self.init_values > 0: self.lambda_1 = self.add_weight( shape=(self.config.hidden_size), diff --git a/src/transformers/models/deberta/modeling_tf_deberta.py b/src/transformers/models/deberta/modeling_tf_deberta.py index 57e6ea8b1e9..5fc4ce783cc 100644 --- a/src/transformers/models/deberta/modeling_tf_deberta.py +++ b/src/transformers/models/deberta/modeling_tf_deberta.py @@ -593,11 +593,10 @@ def call( else: def linear(w, b, x): - return tf.cond( - b is not None, - lambda: tf.matmul(x, w, transpose_b=True) + tf.transpose(b), - lambda: tf.matmul(x, w, transpose_b=True), - ) + out = tf.matmul(x, w, transpose_b=True) + if b is not None: + out += tf.transpose(b) + return out ws = tf.split( tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0 diff --git a/src/transformers/models/dpr/modeling_tf_dpr.py b/src/transformers/models/dpr/modeling_tf_dpr.py index 759e22c8c71..837537a5cad 100644 --- a/src/transformers/models/dpr/modeling_tf_dpr.py +++ b/src/transformers/models/dpr/modeling_tf_dpr.py @@ -532,7 +532,7 @@ def get_input_embeddings(self): try: return self.ctx_encoder.bert_model.get_input_embeddings() except AttributeError: - self(self.dummy_inputs) + self.build() return self.ctx_encoder.bert_model.get_input_embeddings() @unpack_inputs @@ -613,7 +613,7 @@ def get_input_embeddings(self): try: return self.question_encoder.bert_model.get_input_embeddings() except AttributeError: - self(self.dummy_inputs) + self.build() return self.question_encoder.bert_model.get_input_embeddings() @unpack_inputs @@ -693,7 +693,7 @@ def get_input_embeddings(self): try: return self.span_predictor.encoder.bert_model.get_input_embeddings() except AttributeError: - self(self.dummy_inputs) + self.build() return self.span_predictor.encoder.bert_model.get_input_embeddings() @unpack_inputs diff --git a/src/transformers/models/groupvit/modeling_tf_groupvit.py b/src/transformers/models/groupvit/modeling_tf_groupvit.py index 5c989356a5d..e6d6c1d3252 100644 --- a/src/transformers/models/groupvit/modeling_tf_groupvit.py +++ b/src/transformers/models/groupvit/modeling_tf_groupvit.py @@ -538,7 +538,7 @@ def __init__(self, config: GroupViTTextConfig, **kwargs): self.config = config - def build(self, input_shape: tf.TensorShape): + def build(self, input_shape: tf.TensorShape = None): with tf.name_scope("token_embedding"): self.weight = self.add_weight( shape=(self.config.vocab_size, self.embed_dim), diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 6e962ea4934..a661f4b703b 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -135,6 +135,7 @@ def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): def __init__(self, config, layer_id, **kwargs): super().__init__(**kwargs) + self.config = config if config.hidden_size % config.num_attention_heads != 0: raise ValueError( @@ -191,6 +192,16 @@ def __init__(self, config, layer_id, **kwargs): self.one_sided_attn_window_size = attention_window // 2 + def build(self, input_shape=None): + if not self.built: + with tf.name_scope("query_global"): + self.query_global.build((self.config.hidden_size,)) + with tf.name_scope("key_global"): + self.key_global.build((self.config.hidden_size,)) + with tf.name_scope("value_global"): + self.value_global.build((self.config.hidden_size,)) + super().build(input_shape) + def call( self, inputs, @@ -271,9 +282,8 @@ def call( ) = self._get_global_attn_indices(is_index_global_attn) # this function is only relevant for global attention - attn_scores = tf.cond( - is_global_attn, - lambda: self._concat_with_global_key_attn_probs( + if is_global_attn: + attn_scores = self._concat_with_global_key_attn_probs( attn_scores=attn_scores, query_vectors=query_vectors, key_vectors=key_vectors, @@ -281,26 +291,24 @@ def call( is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - ), - lambda: attn_scores, - ) + ) + attn_probs = stable_softmax(attn_scores, axis=-1) # softmax sometimes inserts NaN if all positions are masked, replace them with 0 # Make sure to create a mask with the proper shape: # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - masked_index = tf.cond( - is_global_attn, - lambda: tf.tile( + if is_global_attn: + masked_index = tf.tile( is_index_masked[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ), - lambda: tf.tile( + ) + else: + masked_index = tf.tile( is_index_masked[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ), - ) + ) attn_probs = tf.where( masked_index, tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), @@ -324,19 +332,19 @@ def call( value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # if global attention, compute sum of global and local attn - attn_output = tf.cond( - is_global_attn, - lambda: self._compute_attn_output_with_global_indices( + + if is_global_attn: + attn_output = self._compute_attn_output_with_global_indices( value_vectors=value_vectors, attn_probs=attn_probs, max_num_global_attn_indices=max_num_global_attn_indices, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - ), - lambda: self._sliding_chunks_matmul_attn_probs_value( + ) + else: + attn_output = self._sliding_chunks_matmul_attn_probs_value( attn_probs, value_vectors, self.one_sided_attn_window_size - ), - ) + ) tf.debugging.assert_equal( shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" @@ -345,10 +353,8 @@ def call( attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) # compute value for global attention and overwrite to attention output - # TODO: remove the redundant computation - attn_output, global_attn_probs = tf.cond( - is_global_attn, - lambda: self._compute_global_attn_output_from_hidden( + if is_global_attn: + attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( attn_output=attn_output, hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, @@ -358,25 +364,25 @@ def call( is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_index_masked=is_index_masked, training=training, - ), - lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))), - ) + ) + else: + # Leave attn_output unchanged + global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len)) # make sure that local attention probabilities are set to 0 for indices of global attn # Make sure to create a mask with the proper shape: # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - masked_global_attn_index = tf.cond( - is_global_attn, - lambda: tf.tile( + if is_global_attn: + masked_global_attn_index = tf.tile( is_index_global_attn[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ), - lambda: tf.tile( + ) + else: + masked_global_attn_index = tf.tile( is_index_global_attn[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ), - ) + ) attn_probs = tf.where( masked_global_attn_index, tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), @@ -1864,13 +1870,10 @@ def _pad_to_window_size( input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) if inputs_embeds is not None: - - def pad_embeddings(): + if padding_len > 0: input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id) inputs_embeds_padding = self.embed_tokens(input_ids_padding) - return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) - - inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds) + inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 60cee2a83e8..2bfe79e21dd 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -652,6 +652,7 @@ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool class TFLongformerSelfAttention(tf.keras.layers.Layer): def __init__(self, config, layer_id, **kwargs): super().__init__(**kwargs) + self.config = config if config.hidden_size % config.num_attention_heads != 0: raise ValueError( @@ -708,6 +709,16 @@ def __init__(self, config, layer_id, **kwargs): self.one_sided_attn_window_size = attention_window // 2 + def build(self, input_shape=None): + if not self.built: + with tf.name_scope("query_global"): + self.query_global.build((self.config.hidden_size,)) + with tf.name_scope("key_global"): + self.key_global.build((self.config.hidden_size,)) + with tf.name_scope("value_global"): + self.value_global.build((self.config.hidden_size,)) + super().build(input_shape) + def call( self, inputs, @@ -788,9 +799,8 @@ def call( ) = self._get_global_attn_indices(is_index_global_attn) # this function is only relevant for global attention - attn_scores = tf.cond( - is_global_attn, - lambda: self._concat_with_global_key_attn_probs( + if is_global_attn: + attn_scores = self._concat_with_global_key_attn_probs( attn_scores=attn_scores, query_vectors=query_vectors, key_vectors=key_vectors, @@ -798,26 +808,24 @@ def call( is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - ), - lambda: attn_scores, - ) + ) + attn_probs = stable_softmax(attn_scores, axis=-1) # softmax sometimes inserts NaN if all positions are masked, replace them with 0 # Make sure to create a mask with the proper shape: # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - masked_index = tf.cond( - is_global_attn, - lambda: tf.tile( + if is_global_attn: + masked_index = tf.tile( is_index_masked[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ), - lambda: tf.tile( + ) + else: + masked_index = tf.tile( is_index_masked[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ), - ) + ) attn_probs = tf.where( masked_index, tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), @@ -841,19 +849,19 @@ def call( value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) # if global attention, compute sum of global and local attn - attn_output = tf.cond( - is_global_attn, - lambda: self._compute_attn_output_with_global_indices( + + if is_global_attn: + attn_output = self._compute_attn_output_with_global_indices( value_vectors=value_vectors, attn_probs=attn_probs, max_num_global_attn_indices=max_num_global_attn_indices, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - ), - lambda: self._sliding_chunks_matmul_attn_probs_value( + ) + else: + attn_output = self._sliding_chunks_matmul_attn_probs_value( attn_probs, value_vectors, self.one_sided_attn_window_size - ), - ) + ) tf.debugging.assert_equal( shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" @@ -862,10 +870,8 @@ def call( attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) # compute value for global attention and overwrite to attention output - # TODO: remove the redundant computation - attn_output, global_attn_probs = tf.cond( - is_global_attn, - lambda: self._compute_global_attn_output_from_hidden( + if is_global_attn: + attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( attn_output=attn_output, hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, @@ -875,25 +881,25 @@ def call( is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_index_masked=is_index_masked, training=training, - ), - lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))), - ) + ) + else: + # Leave attn_output unchanged + global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len)) # make sure that local attention probabilities are set to 0 for indices of global attn # Make sure to create a mask with the proper shape: # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - masked_global_attn_index = tf.cond( - is_global_attn, - lambda: tf.tile( + if is_global_attn: + masked_global_attn_index = tf.tile( is_index_global_attn[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), - ), - lambda: tf.tile( + ) + else: + masked_global_attn_index = tf.tile( is_index_global_attn[:, :, None, None], (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), - ), - ) + ) attn_probs = tf.where( masked_global_attn_index, tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), @@ -1828,13 +1834,10 @@ def _pad_to_window_size( position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) if inputs_embeds is not None: - - def pad_embeddings(): + if padding_len > 0: input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64) inputs_embeds_padding = self.embeddings(input_ids_padding) - return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) - - inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds) + inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index c454a8b35db..bc508a47984 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -151,6 +151,7 @@ def __init__(self, feat_size, epsilon=None, **kwargs): def build(self, input_shape): self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros") self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones") + super().build(input_shape) def call(self, inputs: tf.Tensor): return inputs * self.weight + self.bias diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index 46710b32984..b1ef53eb653 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -581,6 +581,7 @@ def build(self, input_shape): initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), trainable=False, ) + super().build(input_shape) def call(self, input_coords, input_shape=None): """Positionally encode points that are normalized to [0,1].""" diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py index 6e0c65a813f..34349c86617 100644 --- a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py @@ -225,6 +225,7 @@ def build(self, input_shape=None): # Build in the build() method to make sure the names are right initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value) self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale") + super().build(input_shape) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): @@ -591,7 +592,7 @@ def from_vision_text_pretrained( if text_model.name != "text_model": raise ValueError("text model must be created with the name `text_model`.") - model(model.dummy_inputs) # Ensure model is fully built + model.build() # Ensure model is fully built return model diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index e7d7770bcf2..21898bbe83b 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -966,11 +966,8 @@ def patchify(self, pixel_values): """ patch_size, num_channels = self.config.patch_size, self.config.num_channels # make sure channels are last - pixel_values = tf.cond( - tf.math.equal(shape_list(pixel_values)[1], num_channels), - lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)), - lambda: pixel_values, - ) + if shape_list(pixel_values)[1] == num_channels: + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) # sanity checks tf.debugging.assert_equal( diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index b8cd87f67ef..3fab2af4376 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -766,11 +766,12 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] batch_size, seq_len = input_shape[0], input_shape[1] - combined_attention_mask = tf.cond( - tf.math.greater(seq_len, 1), - lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length), - lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len), - ) + if seq_len > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index c5f3805ec98..a0e6a8c2aa5 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -476,6 +476,7 @@ def build(self, input_shape): self.mask_emb = self.add_weight( shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb" ) + super().build(input_shape) def _prune_heads(self, heads_to_prune): raise NotImplementedError diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index c113011c567..2068c5f6651 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -328,7 +328,7 @@ def test_save_load_after_resize_token_embeddings(self): old_total_size = config.vocab_size new_total_size = old_total_size + new_tokens_size model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config` - model(model.dummy_inputs) # builds the embeddings layer + model.build() model.resize_token_embeddings(new_total_size) # fetch the output for an input exclusively made of new members of the vocabulary diff --git a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py index aa22e961f65..ab5da3d41e6 100644 --- a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py @@ -1070,9 +1070,9 @@ def test_encoder_decoder_save_load_from_encoder_decoder(self): # create two random BERT models for bert2bert & initialize weights (+cross_attention weights) encoder = TFBertModel(config.encoder) - encoder(encoder.dummy_inputs) + encoder.build() decoder = TFBertLMHeadModel(config.decoder) - decoder(decoder.dummy_inputs) + decoder.build() encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py index c69ab863373..c0aeeddafd1 100644 --- a/tests/models/gpt2/test_modeling_tf_gpt2.py +++ b/tests/models/gpt2/test_modeling_tf_gpt2.py @@ -463,7 +463,7 @@ def test_onnx_runtime_optimize(self): continue model = model_class(config) - model(model.dummy_inputs) + model.build() onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset) diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py index 85514c9d720..0fd3a22bf37 100644 --- a/tests/models/opt/test_modeling_tf_opt.py +++ b/tests/models/opt/test_modeling_tf_opt.py @@ -194,7 +194,7 @@ def _get_word_embedding_weight(model, embedding_layer): else: # Here we build the word embeddings weights if not exists. # And then we retry to get the attribute once built. - model(model.dummy_inputs) + model.build() if hasattr(embedding_layer, "weight"): return embedding_layer.weight else: diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index 04062014b84..7a29bbd2211 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -729,9 +729,9 @@ def test_encoder_decoder_save_load_from_encoder_decoder(self): # create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights) encoder = TFViTModel(config.encoder) - encoder(encoder.dummy_inputs) + encoder.build() decoder = TFGPT2LMHeadModel(config.decoder) - decoder(decoder.dummy_inputs) + decoder.build() encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index b9ad982176e..0783bd67bf4 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -281,7 +281,7 @@ def test_save_load_strict(self): for model_class in self.all_model_classes: model = model_class(config) - model(model.dummy_inputs) + model.build() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, saved_model=False) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 69363686837..586a2a761dc 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -348,7 +348,7 @@ def test_onnx_compliancy(self): with tf.Graph().as_default() as g: model = model_class(config) - model(model.dummy_inputs) + model.build() for op in g.get_operations(): model_op_names.add(op.node_def.op) @@ -375,7 +375,7 @@ def test_onnx_runtime_optimize(self): for model_class in self.all_model_classes: model = model_class(config) - model(model.dummy_inputs) + model.build() onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset) @@ -1180,7 +1180,7 @@ def test_resize_token_embeddings(self): def _get_word_embedding_weight(model, embedding_layer): if isinstance(embedding_layer, tf.keras.layers.Embedding): # builds the embeddings layer - model(model.dummy_inputs) + model.build() return embedding_layer.embeddings else: return model._get_word_embedding_weight(embedding_layer) @@ -1243,7 +1243,7 @@ def test_save_load_after_resize_token_embeddings(self): old_total_size = config.vocab_size new_total_size = old_total_size + new_tokens_size model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config` - model(model.dummy_inputs) # builds the embeddings layer + model.build() model.resize_token_embeddings(new_total_size) # fetch the output for an input exclusively made of new members of the vocabulary @@ -2313,8 +2313,8 @@ def test_checkpoint_sharding_local(self): # Finally, check the model can be reloaded new_model = TFBertModel.from_pretrained(tmp_dir) - model(model.dummy_inputs) - new_model(model.dummy_inputs) + model.build() + new_model.build() for p1, p2 in zip(model.weights, new_model.weights): self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) @@ -2440,7 +2440,7 @@ def test_push_to_hub(self): ) model = TFBertModel(config) # Make sure model is properly initialized - _ = model(model.dummy_inputs) + model.build() logging.set_verbosity_info() logger = logging.get_logger("transformers.utils.hub") @@ -2509,7 +2509,7 @@ def test_push_to_hub_in_organization(self): ) model = TFBertModel(config) # Make sure model is properly initialized - _ = model(model.dummy_inputs) + model.build() model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)