diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 5de6a985df41..21d312647f9a 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -120,6 +120,16 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: if pt_tuple_key[-1] == "beta": return renamed_pt_tuple_key, pt_tensor + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + name = None + if pt_tuple_key[-3::2] == ("parametrizations", "original0"): + name = pt_tuple_key[-2] + "_g" + elif pt_tuple_key[-3::2] == ("parametrizations", "original1"): + name = pt_tuple_key[-2] + "_v" + if name is not None: + renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,) + return renamed_pt_tuple_key, pt_tensor + return pt_tuple_key, pt_tensor @@ -372,6 +382,24 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): else: flax_key = ".".join(flax_key_tuple) + # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation. + special_pt_names = {} + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + for key in pt_model_dict: + key_components = key.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + key_to_check = ".".join(key_components) + special_pt_names[key_to_check] = key + + if flax_key in special_pt_names: + flax_key = special_pt_names[flax_key] + if flax_key in pt_model_dict: if flax_tensor.shape != pt_model_dict[flax_key].shape: raise ValueError(