Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move TF building to an actual build() method #23760

Merged
merged 17 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,6 @@ def load_pytorch_state_dict_in_tf2_model(

K.batch_set_value(weight_value_tuples)

if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure restore ops are run

logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")

unexpected_keys = list(all_pytorch_weights)
Expand Down
35 changes: 24 additions & 11 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list
from .tf_utils import (
expand_1d,
load_attributes_from_hdf5_group,
save_attributes_to_hdf5_group,
shape_list,
)
from .utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
Expand Down Expand Up @@ -69,11 +74,14 @@
if parse(tf.__version__).minor >= 13:
from keras import backend as K
from keras.__internal__ import KerasTensor
from keras.engine.base_layer_utils import call_context
Copy link

@frostming frostming Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would break since keras 2.13 has moved the import to keras.src.engine

See #23663

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, I'll make the fix ASAP!

elif parse(tf.__version__).minor >= 11:
from keras import backend as K
from keras.engine.base_layer_utils import call_context
from keras.engine.keras_tensor import KerasTensor
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import call_context
from tensorflow.python.keras.engine.keras_tensor import KerasTensor


Expand Down Expand Up @@ -1140,6 +1148,13 @@ def framework(self) -> str:
"""
return "tf"

def build(self, input_shape=None):
if self.built or call_context().in_call:
self.built = True
else:
self(self.dummy_inputs, training=False)
self.built = True

def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig):
Expand Down Expand Up @@ -1864,7 +1879,7 @@ def set_input_embeddings(self, value):
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
self.build()
main_layer.set_input_embeddings(value)

def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
Expand All @@ -1881,7 +1896,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
return lm_head.get_output_embeddings()
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
self.build()

return lm_head().get_output_embeddings()

Expand All @@ -1901,7 +1916,7 @@ def set_output_embeddings(self, value):
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self(self.dummy_inputs)
self.build()
lm_head.set_output_embeddings(value)

def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
Expand Down Expand Up @@ -1939,7 +1954,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
try:
return lm_head.get_bias()
except AttributeError:
self(self.dummy_inputs)
self.build()

return lm_head.get_bias()
return None
Expand All @@ -1957,7 +1972,7 @@ def set_bias(self, value):
try:
lm_head.set_bias(value)
except AttributeError:
self(self.dummy_inputs)
self.build()
lm_head.set_bias(value)

def get_lm_head(self) -> tf.keras.layers.Layer:
Expand Down Expand Up @@ -2044,7 +2059,7 @@ def _get_word_embedding_weight(model, embedding_layer):
# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model(model.dummy_inputs)
model.build()

embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
Expand Down Expand Up @@ -2867,9 +2882,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model(model.dummy_inputs) # build the network with dummy inputs
model.build() # build the network with dummy inputs
else:
model(model.dummy_inputs) # build the network with dummy inputs
model.build() # build the network with dummy inputs

if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
Expand Down Expand Up @@ -2922,8 +2937,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)

model(model.dummy_inputs) # Make sure restore ops are run

if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def build(self, input_shape):
trainable=True,
name="position_embedding",
)
super().build(input_shape)

def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
# Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch
Expand All @@ -282,7 +283,7 @@ def __init__(self, config: BlipTextConfig, **kwargs):

self.config = config

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"):
self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim),
Expand Down Expand Up @@ -757,13 +758,14 @@ def __init__(self, config: BlipConfig, *args, **kwargs):

self.config = config

def build(self, input_shape):
def build(self, input_shape=None):
self.logit_scale = self.add_weight(
name="logit_scale",
shape=[],
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
trainable=True,
)
super().build(input_shape)

@unpack_inputs
def call(
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/blip/modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,9 @@ def __init__(self, config, **kwargs):
)
self.config = config

def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
super().build(input_shape)

def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/clip/modeling_tf_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self, config: CLIPVisionConfig, **kwargs):
name="patch_embedding",
)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
factor = self.config.initializer_factor

self.class_embedding = self.add_weight(
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(self, config: CLIPTextConfig, **kwargs):

self.config = config

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"):
self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim),
Expand Down Expand Up @@ -739,7 +739,7 @@ def __init__(self, config: CLIPConfig, **kwargs):
name="text_projection",
)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
self.logit_scale = self.add_weight(
shape=(1,),
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/convbert/modeling_tf_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self, input_size, output_size, num_groups, kernel_initializer, **kw
self.group_in_dim = self.input_size // self.num_groups
self.group_out_dim = self.output_size // self.num_groups

def build(self, input_shape):
def build(self, input_shape=None):
self.kernel = self.add_weight(
"kernel",
shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
Expand All @@ -357,6 +357,7 @@ def build(self, input_shape):
self.bias = self.add_weight(
"bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True
)
super().build(input_shape)

def call(self, hidden_states):
batch_size = shape_list(hidden_states)[0]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/convnext/modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, config, dim, drop_path=0.0, **kwargs):
else tf.keras.layers.Activation("linear", name="drop_path")
)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
# PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
self.layer_scale_parameter = (
self.add_weight(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ctrl/modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def __init__(self, config, input_embeddings, **kwargs):
# an output-only bias for each token.
self.input_embeddings = input_embeddings

def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def __init__(
)
self.init_values = config.layer_scale_init_value

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
if self.init_values > 0:
self.lambda_1 = self.add_weight(
shape=(self.config.hidden_size),
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/deberta/modeling_tf_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,10 @@ def call(
else:

def linear(w, b, x):
return tf.cond(
b is not None,
lambda: tf.matmul(x, w, transpose_b=True) + tf.transpose(b),
lambda: tf.matmul(x, w, transpose_b=True),
)
out = tf.matmul(x, w, transpose_b=True)
if b is not None:
out += tf.transpose(b)
return out

ws = tf.split(
tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/dpr/modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def get_input_embeddings(self):
try:
return self.ctx_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
self.build()
return self.ctx_encoder.bert_model.get_input_embeddings()

@unpack_inputs
Expand Down Expand Up @@ -613,7 +613,7 @@ def get_input_embeddings(self):
try:
return self.question_encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
self.build()
return self.question_encoder.bert_model.get_input_embeddings()

@unpack_inputs
Expand Down Expand Up @@ -693,7 +693,7 @@ def get_input_embeddings(self):
try:
return self.span_predictor.encoder.bert_model.get_input_embeddings()
except AttributeError:
self(self.dummy_inputs)
self.build()
return self.span_predictor.encoder.bert_model.get_input_embeddings()

@unpack_inputs
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/groupvit/modeling_tf_groupvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __init__(self, config: GroupViTTextConfig, **kwargs):

self.config = config

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"):
self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim),
Expand Down
Loading