Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PT/Flax weight conversion after #24030 #24556

Merged
merged 3 commits into from
Jun 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor

# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
name = None
if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
name = pt_tuple_key[-2] + "_g"
elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
name = pt_tuple_key[-2] + "_v"
if name is not None:
renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
return renamed_pt_tuple_key, pt_tensor

return pt_tuple_key, pt_tensor


Expand Down Expand Up @@ -372,6 +382,24 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
else:
flax_key = ".".join(flax_key_tuple)

# We also need to look at `pt_model_dict` and see if there are keys requiring further transformation.
special_pt_names = {}
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
for key in pt_model_dict:
key_components = key.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
key_to_check = ".".join(key_components)
special_pt_names[key_to_check] = key

if flax_key in special_pt_names:
flax_key = special_pt_names[flax_key]

if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape:
raise ValueError(
Expand Down