Skip to content

Commit 4b8880a

Browse files
author
Mishig Davaadorj
authored
Make flax from_pretrained work with local subfolder (open-mmlab#608)
1 parent dd350c8 commit 4b8880a

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/diffusers/modeling_flax_utils.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -310,26 +310,31 @@ def from_pretrained(
310310
)
311311

312312
# Load model
313-
if os.path.isdir(pretrained_model_name_or_path):
313+
pretrained_path_with_subfolder = (
314+
pretrained_model_name_or_path
315+
if subfolder is None
316+
else os.path.join(pretrained_model_name_or_path, subfolder)
317+
)
318+
if os.path.isdir(pretrained_path_with_subfolder):
314319
if from_pt:
315-
if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
320+
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
316321
raise EnvironmentError(
317-
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
322+
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
318323
)
319-
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
320-
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
324+
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
325+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
321326
# Load from a Flax checkpoint
322-
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
327+
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
323328
# Check if pytorch weights exist instead
324-
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
329+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
325330
raise EnvironmentError(
326-
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
331+
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
327332
" using `from_pt=True`."
328333
)
329334
else:
330335
raise EnvironmentError(
331336
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
332-
f"{pretrained_model_name_or_path}."
337+
f"{pretrained_path_with_subfolder}."
333338
)
334339
else:
335340
try:

0 commit comments

Comments
 (0)