load_tf_weights
doesn't handle the weights added to the TF models at the top level
#18802
Labels
load_tf_weights
doesn't handle the weights added to the TF models at the top level
#18802
System Info
transformers
version: 4.22.0.dev0Who can help?
@gante
Reproduction
(TF)MarianMTModel has weights
final_logits_bias
added at the top-level (i.e. not under any layer)transformers/src/transformers/models/marian/modeling_tf_marian.py
Line 1287 in 5f06a09
However, the method
load_tf_weights
only handle weights under some layerstransformers/src/transformers/modeling_tf_utils.py
Line 850 in 5f06a09
This causes problem when we load TF checkpoints for
TFMarianMTModel
, i.e.final_logits_bias
is not loaded.Outputs:
Expected behavior
load_tf_weights
should be able to load weights likefinal_logits_bias
, and the TF checkpoint should be loaded correctly.The text was updated successfully, but these errors were encountered: