Skip to content

Commit

Permalink
More TF fixes (#28081)
Browse files Browse the repository at this point in the history
* More build_in_name_scope()

* Make sure we set the save spec now we don't do it with dummies anymore

* make fixup
  • Loading branch information
Rocketknight1 authored Dec 18, 2023
1 parent 0695b24 commit 71d47f0
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ def __init__(self, config, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
self._set_save_spec(self.input_signature)

def get_config(self):
return self.config.to_dict()
Expand Down
4 changes: 2 additions & 2 deletions tests/models/auto/test_modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_from_pretrained_with_tuple_values(self):
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config)
model.build()
model.build_in_name_scope()

self.assertIsInstance(model, TFFunnelBaseModel)

Expand Down Expand Up @@ -249,7 +249,7 @@ def test_new_model_registration(self):
config = NewModelConfig(**tiny_config.to_dict())

model = auto_class.from_config(config)
model.build()
model.build_in_name_scope()

self.assertIsInstance(model, TFNewModel)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/gpt2/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test_onnx_runtime_optimize(self):
continue

model = model_class(config)
model.build()
model.build_in_name_scope()

onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/whisper/test_modeling_tf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_encoder_sinusoidal_embed_positions(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.build()
model.build_in_name_scope()

embeds = model.get_encoder().embed_positions.get_weights()[0]
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_modeling_tf_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_saved_model_creation_extended(self):
for model_class in self.all_model_classes[:2]:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
model.build()
model.build_in_name_scope()
num_out = len(model(class_inputs_dict))

for key in list(class_inputs_dict.keys()):
Expand Down

0 comments on commit 71d47f0

Please sign in to comment.