@@ -310,26 +310,31 @@ def from_pretrained(
310
310
)
311
311
312
312
# Load model
313
- if os .path .isdir (pretrained_model_name_or_path ):
313
+ pretrained_path_with_subfolder = (
314
+ pretrained_model_name_or_path
315
+ if subfolder is None
316
+ else os .path .join (pretrained_model_name_or_path , subfolder )
317
+ )
318
+ if os .path .isdir (pretrained_path_with_subfolder ):
314
319
if from_pt :
315
- if not os .path .isfile (os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )):
320
+ if not os .path .isfile (os .path .join (pretrained_path_with_subfolder , WEIGHTS_NAME )):
316
321
raise EnvironmentError (
317
- f"Error no file named { WEIGHTS_NAME } found in directory { pretrained_model_name_or_path } "
322
+ f"Error no file named { WEIGHTS_NAME } found in directory { pretrained_path_with_subfolder } "
318
323
)
319
- model_file = os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )
320
- elif os .path .isfile (os .path .join (pretrained_model_name_or_path , FLAX_WEIGHTS_NAME )):
324
+ model_file = os .path .join (pretrained_path_with_subfolder , WEIGHTS_NAME )
325
+ elif os .path .isfile (os .path .join (pretrained_path_with_subfolder , FLAX_WEIGHTS_NAME )):
321
326
# Load from a Flax checkpoint
322
- model_file = os .path .join (pretrained_model_name_or_path , FLAX_WEIGHTS_NAME )
327
+ model_file = os .path .join (pretrained_path_with_subfolder , FLAX_WEIGHTS_NAME )
323
328
# Check if pytorch weights exist instead
324
- elif os .path .isfile (os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )):
329
+ elif os .path .isfile (os .path .join (pretrained_path_with_subfolder , WEIGHTS_NAME )):
325
330
raise EnvironmentError (
326
- f"{ WEIGHTS_NAME } file found in directory { pretrained_model_name_or_path } . Please load the model"
331
+ f"{ WEIGHTS_NAME } file found in directory { pretrained_path_with_subfolder } . Please load the model"
327
332
" using `from_pt=True`."
328
333
)
329
334
else :
330
335
raise EnvironmentError (
331
336
f"Error no file named { FLAX_WEIGHTS_NAME } or { WEIGHTS_NAME } found in directory "
332
- f"{ pretrained_model_name_or_path } ."
337
+ f"{ pretrained_path_with_subfolder } ."
333
338
)
334
339
else :
335
340
try :
0 commit comments