From 0a0d875711b5f50136e97eb8437a0e4162094b70 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 15:03:26 +0200 Subject: [PATCH 01/16] initial commit --- .../modeling_flax_pytorch_utils.py | 72 ++++++++++++++++--- src/transformers/modeling_flax_utils.py | 17 ++++- src/transformers/utils/hub.py | 5 ++ 3 files changed, 84 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index a91d41b9d6d91b..7120fad64103fe 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,7 +27,8 @@ from flax.serialization import from_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from .utils import logging +from .utils import FLAX_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, logging +from .utils.hub import get_checkpoint_shard_files logger = logging.get_logger(__name__) @@ -38,7 +39,9 @@ ##################### -def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False): +def load_pytorch_checkpoint_in_flax_state_dict( + flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False +): """Load pytorch checkpoints in a flax model""" try: import torch # noqa: F401 @@ -50,14 +53,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa ) raise - pt_path = os.path.abspath(pytorch_checkpoint_path) - logger.info(f"Loading PyTorch weights from {pt_path}") + if not is_sharded: + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info(f"Loading PyTorch weights from {pt_path}") - pt_state_dict = torch.load(pt_path, map_location="cpu") - logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") - - flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + pt_state_dict = torch.load(pt_path, map_location="cpu") + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") + flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + else: + # model is sharded and pytorch_checkpoint_path already contains the list of shard files + flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model) return flax_state_dict @@ -156,6 +162,56 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): return unflatten_dict(flax_state_dict) +def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): + import torch + + # Load the index + flax_state_dict = {} + for shard_file in shard_filenames: + # load using msgpack utils + pt_state_dict = torch.load(shard_file) + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + model_prefix = flax_model.base_model_prefix + random_flax_state_dict = flatten_dict(flax_model.params) + + load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + + pt_tuple_key = tuple(pt_key.split(".")) + + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + return unflatten_dict(flax_state_dict) + + ##################### # Flax => PyTorch # ##################### diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 77eaa900de622f..2aa69691abef60 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -42,6 +42,7 @@ FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, EntryNotFoundError, PushToHubMixin, @@ -58,7 +59,7 @@ logging, replace_return_docstrings, ) -from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files, hf_url_exists logger = logging.get_logger(__name__) @@ -639,6 +640,10 @@ def from_pretrained( if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded pytorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) @@ -660,6 +665,14 @@ def from_pretrained( ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path + # check if an index file exists + elif hf_url_exists(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME): + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_INDEX_NAME, + revision=revision, + ) + is_sharded = True else: filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME archive_file = hf_bucket_url( @@ -780,7 +793,7 @@ def from_pretrained( model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: - state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) + state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) else: if is_sharded: diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 4e46298e28a920..91d7ea6f65b9b8 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -118,6 +118,11 @@ def is_remote_url(url_or_filename): return parsed.scheme in ("http", "https") +def hf_url_exists(path, file): + r = requests.head(os.path.join("https://huggingface.co/", path, "raw/main/", file)) + return r.status_code == requests.codes.ok + + def hf_bucket_url( model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None ) -> str: From 4dcc32ec0e0b9a20dcc72582225bf6b9f0426eb1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 15:13:10 +0200 Subject: [PATCH 02/16] add small test --- tests/test_modeling_flax_common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f90615efea3604..b4506c4438a9b0 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,12 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + def test_from_sharded_pt(self): + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) + ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): + assert np.allclose(np.array(p1), np.array(p2)) + def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 556fa9a3921cc0a4d3e14072b63f75bc94a60393 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:38:47 +0200 Subject: [PATCH 03/16] add cross pt tf flag to test --- tests/test_modeling_flax_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index b4506c4438a9b0..f33e0fb5d9e71e 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,7 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") From 2e8b241c5eb3eaea711de1e1084164fc8974b71d Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:41:12 +0200 Subject: [PATCH 04/16] fix quality --- src/transformers/modeling_flax_pytorch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 7120fad64103fe..dadc98fc9cdaa1 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,8 +27,8 @@ from flax.serialization import from_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from .utils import FLAX_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, logging -from .utils.hub import get_checkpoint_shard_files +from .utils import logging + logger = logging.get_logger(__name__) From cfda646e93059b0f791a9835009475f9ed7956c3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:41:47 +0200 Subject: [PATCH 05/16] style --- src/transformers/modeling_flax_pytorch_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index dadc98fc9cdaa1..249f754db726ab 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -30,7 +30,6 @@ from .utils import logging - logger = logging.get_logger(__name__) From 4e49a48d225e5d0010c9a074cdbc1e93fffb93da Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 22:31:44 +0200 Subject: [PATCH 06/16] update test with new repo --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f33e0fb5d9e71e..8e9b27c515307b 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From 74a451b8b333989c797198d72a98f098ccdd7e47 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 09:44:44 +0200 Subject: [PATCH 07/16] fix failing test --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 8e9b27c515307b..3f0a8c6c80792f 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only") + ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-fx-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From 708bcb6f4de4eb1af405f08cf043c12e1cb9788d Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 09:45:31 +0200 Subject: [PATCH 08/16] update --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 3f0a8c6c80792f..5bc7d97d46bcb2 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-fx-only") + ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From c49bb18d588e21c4e18e68030396e3e0cb384f78 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 10:07:15 +0200 Subject: [PATCH 09/16] fix wrong param ordering --- tests/test_modeling_flax_common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 5e9c061bc5b7cb..49505856954c3e 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1103,8 +1103,9 @@ def test_checkpoint_sharding_local(self): def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") - for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): - assert np.allclose(np.array(p1), np.array(p2)) + for key,ref_val in flatten_dict(ref_model.params).items(): + val = flatten_dict(model.params)[key] + assert np.allclose(np.array(val), np.array(ref_val)) def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From e20f971f6c83a770dea9ff17ec597587891e8690 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 10:07:37 +0200 Subject: [PATCH 10/16] style --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 49505856954c3e..837f874889ae7d 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1103,7 +1103,7 @@ def test_checkpoint_sharding_local(self): def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") - for key,ref_val in flatten_dict(ref_model.params).items(): + for key, ref_val in flatten_dict(ref_model.params).items(): val = flatten_dict(model.params)[key] assert np.allclose(np.array(val), np.array(ref_val)) From 0e037a446148bc620881b8565bf87f20fe14dc93 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 11 Aug 2022 12:28:23 +0200 Subject: [PATCH 11/16] update based on review --- .../modeling_flax_pytorch_utils.py | 6 +++- src/transformers/modeling_flax_utils.py | 35 ++++++++++++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 249f754db726ab..5aab79b8a8a8a8 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -61,7 +61,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) else: - # model is sharded and pytorch_checkpoint_path already contains the list of shard files + # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model) return flax_state_dict @@ -160,6 +160,10 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): return unflatten_dict(flax_state_dict) +############################ +# Sharded Pytorch => Flax # +############################ + def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): import torch diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index cae2d072dd24db..04d547a172392e 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -665,14 +665,6 @@ def from_pretrained( ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path - # check if an index file exists - elif hf_url_exists(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME): - archive_file = hf_bucket_url( - pretrained_model_name_or_path, - filename=WEIGHTS_INDEX_NAME, - revision=revision, - ) - is_sharded = True else: filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME archive_file = hf_bucket_url( @@ -730,7 +722,32 @@ def from_pretrained( is_sharded = True except EntryNotFoundError: has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token} - if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + # check if an index file exists + if has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): + if from_pt : + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_INDEX_NAME, + revision=revision, + ) + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) + is_sharded = True + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use `from_pt=True` to" + " load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" From 2847af18ee75502731b8552230595eb1d394c751 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 11 Aug 2022 14:09:27 +0200 Subject: [PATCH 12/16] update related to recent new caching mechanism --- .../modeling_flax_pytorch_utils.py | 1 + src/transformers/modeling_flax_utils.py | 27 ++++++------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 5aab79b8a8a8a8..76eaa53f89d04c 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -160,6 +160,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): return unflatten_dict(flax_state_dict) + ############################ # Sharded Pytorch => Flax # ############################ diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index add14a8c992625..5fdbf7aaaf5e17 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -715,29 +715,18 @@ def from_pretrained( "use_auth_token": use_auth_token, } # check if an index file exists - if has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): - if from_pt : - archive_file = hf_bucket_url( - pretrained_model_name_or_path, - filename=WEIGHTS_INDEX_NAME, - revision=revision, - ) - resolved_archive_file = cached_path( - archive_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, + if has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): + if from_pt: + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs ) + if resolved_archive_file is not None: is_sharded = True - else: + else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use `from_pt=True` to" - " load this model from those weights." + f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" + " `from_pt=True` to load this model from those weights." ) elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( From 55e8f7ce68f111e63f291692f41ff3639df8efe9 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 11 Aug 2022 14:16:45 +0200 Subject: [PATCH 13/16] quality --- src/transformers/modeling_flax_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 5fdbf7aaaf5e17..eccf3d9cd84dac 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -40,7 +40,6 @@ from .utils import ( FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, PushToHubMixin, From f8712307867ed8629395f1d33c1e30d522df3926 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 11 Aug 2022 16:07:34 +0200 Subject: [PATCH 14/16] Update based on review Co-authored-by: sgugger --- src/transformers/modeling_flax_utils.py | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index eccf3d9cd84dac..6e619423f52b7c 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -705,6 +705,13 @@ def from_pretrained( ) if resolved_archive_file is not None: is_sharded = True + # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. + elif resolved_archive_file is None and from_pt: + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True if resolved_archive_file is None: # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error # message. @@ -714,25 +721,18 @@ def from_pretrained( "use_auth_token": use_auth_token, } # check if an index file exists - if has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): - if from_pt: - resolved_archive_file = cached_file( - pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs - ) - if resolved_archive_file is not None: - is_sharded = True - else: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" - " `from_pt=True` to load this model from those weights." - ) - elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" " load this model from those weights." ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" + " `from_pt=True` to load this model from those weights." + ) else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" From b584707283b8d6c3a13ef56ca37f91b655ed6d84 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 11 Aug 2022 16:16:59 +0200 Subject: [PATCH 15/16] quality and style --- src/transformers/modeling_flax_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 6e619423f52b7c..069499ada273ad 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -706,7 +706,7 @@ def from_pretrained( if resolved_archive_file is not None: is_sharded = True # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. - elif resolved_archive_file is None and from_pt: + elif resolved_archive_file is None and from_pt: resolved_archive_file = cached_file( pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs ) @@ -720,7 +720,6 @@ def from_pretrained( "proxies": proxies, "use_auth_token": use_auth_token, } - # check if an index file exists if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" From 095739924174f6806e1dc48b791e29fd151a7111 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 11 Aug 2022 16:34:11 +0200 Subject: [PATCH 16/16] Update src/transformers/modeling_flax_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 069499ada273ad..00bb5480ffe3e9 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -705,7 +705,7 @@ def from_pretrained( ) if resolved_archive_file is not None: is_sharded = True - # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. + # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. elif resolved_archive_file is None and from_pt: resolved_archive_file = cached_file( pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs