Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_tf_weights doesn't handle the weights added to the TF models at the top level #18802

Closed
ydshieh opened this issue Aug 29, 2022 · 2 comments · Fixed by #18833
Closed

load_tf_weights doesn't handle the weights added to the TF models at the top level #18802

ydshieh opened this issue Aug 29, 2022 · 2 comments · Fixed by #18833
Assignees
Labels

Comments

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 29, 2022

System Info

  • transformers version: 4.22.0.dev0
  • Platform: Windows-10-10.0.22000-SP0
  • Python version: 3.9.11
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.1+cu113 (True)
  • Tensorflow version (GPU?): 2.9.1 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@gante

Reproduction

(TF)MarianMTModel has weights final_logits_bias added at the top-level (i.e. not under any layer)

self.final_logits_bias = self.add_weight(

However, the method load_tf_weights only handle weights under some layers

for layer in model.layers:

This causes problem when we load TF checkpoints for TFMarianMTModel, i.e. final_logits_bias is not loaded.

from transformers import MarianMTModel, TFMarianMTModel
model_name = "Helsinki-NLP/opus-mt-en-ROMANCE"

pt_model = MarianMTModel.from_pretrained(model_name)
tf_model_from_pt = TFMarianMTModel.from_pretrained(model_name, from_pt=True)
tf_model = TFMarianMTModel.from_pretrained(model_name, from_pt=False)

# Only has `TFMarianMainLayer` in `layers`
print(tf_model.layers)

print(pt_model.final_logits_bias.numpy())
print(tf_model_from_pt.final_logits_bias.numpy())
print(tf_model.final_logits_bias.numpy())

Outputs:

[<transformers.models.marian.modeling_tf_marian.TFMarianMainLayer object at 0x000001F00ECE9940>]
[[11.757146  -1.7759448 -7.3816853 ... -1.6559223 -1.6663467  0.       ]]
[[11.757146  -1.7759448 -7.3816853 ... -1.6559223 -1.6663467  0.       ]]
[[0. 0. 0. ... 0. 0. 0.]]

Expected behavior

load_tf_weights should be able to load weights like final_logits_bias, and the TF checkpoint should be loaded correctly.

@ydshieh ydshieh added the bug label Aug 29, 2022
@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 29, 2022

Related to #18149.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 29, 2022

cc @patrickvonplaten as we might need to change the core method load_tf_weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants