diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index a91d41b9d6d91b..76eaa53f89d04c 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -38,7 +38,9 @@ ##################### -def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False): +def load_pytorch_checkpoint_in_flax_state_dict( + flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False +): """Load pytorch checkpoints in a flax model""" try: import torch # noqa: F401 @@ -50,14 +52,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa ) raise - pt_path = os.path.abspath(pytorch_checkpoint_path) - logger.info(f"Loading PyTorch weights from {pt_path}") + if not is_sharded: + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info(f"Loading PyTorch weights from {pt_path}") - pt_state_dict = torch.load(pt_path, map_location="cpu") - logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") - - flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + pt_state_dict = torch.load(pt_path, map_location="cpu") + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") + flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + else: + # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files + flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model) return flax_state_dict @@ -156,6 +161,61 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): return unflatten_dict(flax_state_dict) +############################ +# Sharded Pytorch => Flax # +############################ + + +def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): + import torch + + # Load the index + flax_state_dict = {} + for shard_file in shard_filenames: + # load using msgpack utils + pt_state_dict = torch.load(shard_file) + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + model_prefix = flax_model.base_model_prefix + random_flax_state_dict = flatten_dict(flax_model.params) + + load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + + pt_tuple_key = tuple(pt_key.split(".")) + + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + return unflatten_dict(flax_state_dict) + + ##################### # Flax => PyTorch # ##################### diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 683e25631c0f44..00bb5480ffe3e9 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -40,6 +40,7 @@ from .utils import ( FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, PushToHubMixin, add_code_sample_docstrings, @@ -650,6 +651,10 @@ def from_pretrained( if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded pytorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) @@ -700,6 +705,13 @@ def from_pretrained( ) if resolved_archive_file is not None: is_sharded = True + # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. + elif resolved_archive_file is None and from_pt: + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True if resolved_archive_file is None: # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error # message. @@ -714,6 +726,12 @@ def from_pretrained( f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" " load this model from those weights." ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" + " `from_pt=True` to load this model from those weights." + ) else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" @@ -761,7 +779,7 @@ def from_pretrained( model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: - state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) + state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) else: if is_sharded: diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index e22c7e6705b3bd..837f874889ae7d 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,14 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + @is_pt_flax_cross_test + def test_from_sharded_pt(self): + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) + ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") + for key, ref_val in flatten_dict(ref_model.params).items(): + val = flatten_dict(model.params)[key] + assert np.allclose(np.array(val), np.array(ref_val)) + def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()