Skip to content

Commit

Permalink
Add a convenience method for building in your own name scope
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 1, 2023
1 parent 0ad4e7e commit f13b801
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,10 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
)
return dummies

def build_in_name_scope(self):
with tf.name_scope(self.name):
self.build(input_shape=None)

@property
def framework(self) -> str:
"""
Expand Down Expand Up @@ -1869,7 +1873,7 @@ def set_input_embeddings(self, value):
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
main_layer.set_input_embeddings(value)

def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
Expand All @@ -1886,7 +1890,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
return lm_head.get_output_embeddings()
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()

return lm_head().get_output_embeddings()

Expand All @@ -1906,7 +1910,7 @@ def set_output_embeddings(self, value):
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
lm_head.set_output_embeddings(value)

def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
Expand Down Expand Up @@ -1944,7 +1948,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
try:
return lm_head.get_bias()
except AttributeError:
self.build()
self.build_in_name_scope()

return lm_head.get_bias()
return None
Expand All @@ -1962,7 +1966,7 @@ def set_bias(self, value):
try:
lm_head.set_bias(value)
except AttributeError:
self.build()
self.build_in_name_scope()
lm_head.set_bias(value)

def get_lm_head(self) -> tf.keras.layers.Layer:
Expand Down Expand Up @@ -2049,7 +2053,7 @@ def _get_word_embedding_weight(model, embedding_layer):
# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model.build()
model.build_in_name_scope()

embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
Expand Down Expand Up @@ -2908,9 +2912,9 @@ def from_pretrained(
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs
else:
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs

if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
Expand Down

0 comments on commit f13b801

Please sign in to comment.