Skip to content

Commit

Permalink
Move TF building to an actual build() method (huggingface#23760)
Browse files Browse the repository at this point in the history
* A fun new PR where I break the entire codebase again

* A fun new PR where I break the entire codebase again

* Handle cross-attention

* Move calls to model(model.dummy_inputs) to the new build() method

* Seeing what fails with the build context thing

* make fix-copies

* Let's see what fails with new build methods

* Fix the pytorch crossload build calls

* Fix the overridden build methods in vision_text_dual_encoder

* Make sure all our build methods set self.built or call super().build(), which also sets it

* make fix-copies

* Remove finished TODO

* Tentatively remove unneeded (?) line

* Transpose b in deberta correctly and remove unused threading local

* Get rid of build_with_dummies and all it stands for

* Rollback some changes to TF-PT crossloading

* Correctly call super().build()
  • Loading branch information
Rocketknight1 authored and novice03 committed Jun 23, 2023
1 parent 63b2963 commit 43de410
Show file tree
Hide file tree
Showing 27 changed files with 159 additions and 138 deletions.
3 changes: 0 additions & 3 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,6 @@ def load_pytorch_state_dict_in_tf2_model(

K.batch_set_value(weight_value_tuples)

if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure restore ops are run

logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")

unexpected_keys = list(all_pytorch_weights)
Expand Down
35 changes: 24 additions & 11 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list
from .tf_utils import (
expand_1d,
load_attributes_from_hdf5_group,
save_attributes_to_hdf5_group,
shape_list,
)
from .utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
Expand Down Expand Up @@ -69,11 +74,14 @@
if parse(tf.__version__).minor >= 13:
from keras import backend as K
from keras.__internal__ import KerasTensor
from keras.engine.base_layer_utils import call_context
elif parse(tf.__version__).minor >= 11:
from keras import backend as K
from keras.engine.base_layer_utils import call_context
from keras.engine.keras_tensor import KerasTensor
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import call_context
from tensorflow.python.keras.engine.keras_tensor import KerasTensor


Expand Down Expand Up @@ -1140,6 +1148,13 @@ def framework(self) -> str:
"""
return "tf"

def build(self, input_shape=None):
if self.built or call_context().in_call:
self.built = True
else:
self(self.dummy_inputs, training=False)
self.built = True

def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig):
Expand Down Expand Up @@ -1867,7 +1882,7 @@ def set_input_embeddings(self, value):
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
self.build()
main_layer.set_input_embeddings(value)

def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
Expand All @@ -1884,7 +1899,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
return lm_head.get_output_embeddings()
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
self.build()

return lm_head().get_output_embeddings()

Expand All @@ -1904,7 +1919,7 @@ def set_output_embeddings(self, value):
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
self.build()
lm_head.set_output_embeddings(value)

def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
Expand Down Expand Up @@ -1942,7 +1957,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
try:
return lm_head.get_bias()
except AttributeError:
self(self.dummy_inputs)
self.build()

return lm_head.get_bias()
return None
Expand All @@ -1960,7 +1975,7 @@ def set_bias(self, value):
try:
lm_head.set_bias(value)
except AttributeError:
self(self.dummy_inputs)
self.build()
lm_head.set_bias(value)

def get_lm_head(self) -> tf.keras.layers.Layer:
Expand Down Expand Up @@ -2047,7 +2062,7 @@ def _get_word_embedding_weight(model, embedding_layer):
# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model(model.dummy_inputs)
model.build()

embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
Expand Down Expand Up @@ -2870,9 +2885,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model(model.dummy_inputs) # build the network with dummy inputs
model.build() # build the network with dummy inputs
else:
model(model.dummy_inputs) # build the network with dummy inputs
model.build() # build the network with dummy inputs

if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
Expand Down Expand Up @@ -2925,8 +2940,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)

model(model.dummy_inputs) # Make sure restore ops are run

if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def build(self, input_shape):
trainable=True,
name="position_embedding",
)
super().build(input_shape)

def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
# Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch
Expand All @@ -282,7 +283,7 @@ def __init__(self, config: BlipTextConfig, **kwargs):

self.config = config

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"):
self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim),
Expand Down Expand Up @@ -757,13 +758,14 @@ def __init__(self, config: BlipConfig, *args, **kwargs):

self.config = config

def build(self, input_shape):
def build(self, input_shape=None):
self.logit_scale = self.add_weight(
name="logit_scale",
shape=[],
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
trainable=True,
)
super().build(input_shape)

@unpack_inputs
def call(
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/blip/modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,9 @@ def __init__(self, config, **kwargs):
)
self.config = config

def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
super().build(input_shape)

def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/clip/modeling_tf_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self, config: CLIPVisionConfig, **kwargs):
name="patch_embedding",
)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
factor = self.config.initializer_factor

self.class_embedding = self.add_weight(
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(self, config: CLIPTextConfig, **kwargs):

self.config = config

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"):
self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim),
Expand Down Expand Up @@ -739,7 +739,7 @@ def __init__(self, config: CLIPConfig, **kwargs):
name="text_projection",
)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
self.logit_scale = self.add_weight(
shape=(1,),
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/convbert/modeling_tf_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self, input_size, output_size, num_groups, kernel_initializer, **kw
self.group_in_dim = self.input_size // self.num_groups
self.group_out_dim = self.output_size // self.num_groups

def build(self, input_shape):
def build(self, input_shape=None):
self.kernel = self.add_weight(
"kernel",
shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
Expand All @@ -357,6 +357,7 @@ def build(self, input_shape):
self.bias = self.add_weight(
"bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True
)
super().build(input_shape)

def call(self, hidden_states):
batch_size = shape_list(hidden_states)[0]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/convnext/modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, config, dim, drop_path=0.0, **kwargs):
else tf.keras.layers.Activation("linear", name="drop_path")
)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
# PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
self.layer_scale_parameter = (
self.add_weight(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ctrl/modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def __init__(self, config, input_embeddings, **kwargs):
# an output-only bias for each token.
self.input_embeddings = input_embeddings

def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def __init__(
)
self.init_values = config.layer_scale_init_value

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
if self.init_values > 0:
self.lambda_1 = self.add_weight(
shape=(self.config.hidden_size),
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/deberta/modeling_tf_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,10 @@ def call(
else:

def linear(w, b, x):
return tf.cond(
b is not None,
lambda: tf.matmul(x, w, transpose_b=True) + tf.transpose(b),
lambda: tf.matmul(x, w, transpose_b=True),
)
out = tf.matmul(x, w, transpose_b=True)
if b is not None:
out += tf.transpose(b)
return out

ws = tf.split(
tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/dpr/modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def get_input_embeddings(self):
try:
return self.ctx_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
self.build()
return self.ctx_encoder.bert_model.get_input_embeddings()

@unpack_inputs
Expand Down Expand Up @@ -613,7 +613,7 @@ def get_input_embeddings(self):
try:
return self.question_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
self.build()
return self.question_encoder.bert_model.get_input_embeddings()

@unpack_inputs
Expand Down Expand Up @@ -693,7 +693,7 @@ def get_input_embeddings(self):
try:
return self.span_predictor.encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
self.build()
return self.span_predictor.encoder.bert_model.get_input_embeddings()

@unpack_inputs
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/groupvit/modeling_tf_groupvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __init__(self, config: GroupViTTextConfig, **kwargs):

self.config = config

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"):
self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim),
Expand Down
Loading

0 comments on commit 43de410

Please sign in to comment.