Skip to content

Commit

Permalink
Load sharded pt to flax (huggingface#18419)
Browse files Browse the repository at this point in the history
* initial commit

* add small test

* add cross pt tf flag to test

* fix quality

* style

* update test with new repo

* fix failing test

* update

* fix wrong param ordering

* style

* update based on review

* update related to recent new caching mechanism

* quality

* Update based on review

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* quality and style

* Update src/transformers/modeling_flax_utils.py
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored and amyeroberts committed Oct 18, 2022
1 parent 13d28b3 commit df69750
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushToHubMixin,
add_code_sample_docstrings,
Expand Down Expand Up @@ -705,6 +706,13 @@ def from_pretrained(
)
if resolved_archive_file is not None:
is_sharded = True
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
elif resolved_archive_file is None and from_pt:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
Expand Down

0 comments on commit df69750

Please sign in to comment.