Skip to content

Commit

Permalink
Fix FlaxPretTrainedModel pt weights check (huggingface#19133)
Browse files Browse the repository at this point in the history
* Fix FlaxPretTrainedModel pt weights check

* Update src/transformers/modeling_flax_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix raise comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and oneraghavan committed Sep 26, 2022
1 parent 1822c90 commit 8241fe0
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def from_pretrained(
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
Expand Down

0 comments on commit 8241fe0

Please sign in to comment.