1010from huggingface_hub import (file_exists , hf_hub_download , list_repo_files ,
1111 try_to_load_from_cache )
1212from huggingface_hub .utils import (EntryNotFoundError , HfHubHTTPError ,
13- LocalEntryNotFoundError ,
13+ HFValidationError , LocalEntryNotFoundError ,
1414 RepositoryNotFoundError ,
1515 RevisionNotFoundError )
1616from torch import nn
@@ -265,49 +265,66 @@ def get_config(
265265 return config
266266
267267
268+ def try_get_local_file (model : Union [str , Path ],
269+ file_name : str ,
270+ revision : Optional [str ] = 'main' ) -> Optional [Path ]:
271+ file_path = Path (model ) / file_name
272+ if file_path .is_file ():
273+ return file_path
274+ else :
275+ try :
276+ cached_filepath = try_to_load_from_cache (repo_id = model ,
277+ filename = file_name ,
278+ revision = revision )
279+ if isinstance (cached_filepath , str ):
280+ return Path (cached_filepath )
281+ except HFValidationError :
282+ ...
283+ return None
284+
285+
268286def get_hf_file_to_dict (file_name : str ,
269287 model : Union [str , Path ],
270288 revision : Optional [str ] = 'main' ):
271289 """
272- Downloads a file from the Hugging Face Hub and returns
290+ Downloads a file from the Hugging Face Hub and returns
273291 its contents as a dictionary.
274292
275293 Parameters:
276294 - file_name (str): The name of the file to download.
277295 - model (str): The name of the model on the Hugging Face Hub.
278- - revision (str): The specific version of the model.
296+ - revision (str): The specific version of the model.
279297
280298 Returns:
281- - config_dict (dict): A dictionary containing
299+ - config_dict (dict): A dictionary containing
282300 the contents of the downloaded file.
283301 """
284- file_path = Path (model ) / file_name
285302
286- if file_or_path_exists (model = model ,
287- config_name = file_name ,
288- revision = revision ):
303+ file_path = try_get_local_file (model = model ,
304+ file_name = file_name ,
305+ revision = revision )
289306
290- if not file_path .is_file ():
291- try :
292- hf_hub_file = hf_hub_download (model ,
293- file_name ,
294- revision = revision )
295- except (RepositoryNotFoundError , RevisionNotFoundError ,
296- EntryNotFoundError , LocalEntryNotFoundError ) as e :
297- logger .debug ("File or repository not found in hf_hub_download" ,
298- e )
299- return None
300- except HfHubHTTPError as e :
301- logger .warning (
302- "Cannot connect to Hugging Face Hub. Skipping file "
303- "download for '%s':" ,
304- file_name ,
305- exc_info = e )
306- return None
307- file_path = Path (hf_hub_file )
307+ if file_path is None and file_or_path_exists (
308+ model = model , config_name = file_name , revision = revision ):
309+ try :
310+ hf_hub_file = hf_hub_download (model , file_name , revision = revision )
311+ except (RepositoryNotFoundError , RevisionNotFoundError ,
312+ EntryNotFoundError , LocalEntryNotFoundError ) as e :
313+ logger .debug ("File or repository not found in hf_hub_download" , e )
314+ return None
315+ except HfHubHTTPError as e :
316+ logger .warning (
317+ "Cannot connect to Hugging Face Hub. Skipping file "
318+ "download for '%s':" ,
319+ file_name ,
320+ exc_info = e )
321+ return None
322+ file_path = Path (hf_hub_file )
308323
324+ if file_path is not None and file_path .is_file ():
309325 with open (file_path ) as file :
310326 return json .load (file )
327+
311328 return None
312329
313330
@@ -328,7 +345,12 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
328345 """
329346
330347 modules_file_name = "modules.json"
331- modules_dict = get_hf_file_to_dict (modules_file_name , model , revision )
348+
349+ modules_dict = None
350+ if file_or_path_exists (model = model ,
351+ config_name = modules_file_name ,
352+ revision = revision ):
353+ modules_dict = get_hf_file_to_dict (modules_file_name , model , revision )
332354
333355 if modules_dict is None :
334356 return None
@@ -382,17 +404,17 @@ def get_sentence_transformer_tokenizer_config(model: str,
382404 revision : Optional [str ] = 'main'
383405 ):
384406 """
385- Returns the tokenization configuration dictionary for a
407+ Returns the tokenization configuration dictionary for a
386408 given Sentence Transformer BERT model.
387409
388410 Parameters:
389- - model (str): The name of the Sentence Transformer
411+ - model (str): The name of the Sentence Transformer
390412 BERT model.
391413 - revision (str, optional): The revision of the m
392414 odel to use. Defaults to 'main'.
393415
394416 Returns:
395- - dict: A dictionary containing the configuration parameters
417+ - dict: A dictionary containing the configuration parameters
396418 for the Sentence Transformer BERT model.
397419 """
398420 sentence_transformer_config_files = [
@@ -404,20 +426,33 @@ def get_sentence_transformer_tokenizer_config(model: str,
404426 "sentence_xlm-roberta_config.json" ,
405427 "sentence_xlnet_config.json" ,
406428 ]
407- try :
408- # If model is on HuggingfaceHub, get the repo files
409- repo_files = list_repo_files (model , revision = revision , token = HF_TOKEN )
410- except Exception as e :
411- logger .debug ("Error getting repo files" , e )
412- repo_files = []
413-
414429 encoder_dict = None
415- for config_name in sentence_transformer_config_files :
416- if config_name in repo_files or Path (model ).exists ():
417- encoder_dict = get_hf_file_to_dict (config_name , model , revision )
430+
431+ for config_file in sentence_transformer_config_files :
432+ if try_get_local_file (model = model ,
433+ file_name = config_file ,
434+ revision = revision ) is not None :
435+ encoder_dict = get_hf_file_to_dict (config_file , model , revision )
418436 if encoder_dict :
419437 break
420438
439+ if not encoder_dict :
440+ try :
441+ # If model is on HuggingfaceHub, get the repo files
442+ repo_files = list_repo_files (model ,
443+ revision = revision ,
444+ token = HF_TOKEN )
445+ except Exception as e :
446+ logger .debug ("Error getting repo files" , e )
447+ repo_files = []
448+
449+ for config_name in sentence_transformer_config_files :
450+ if config_name in repo_files :
451+ encoder_dict = get_hf_file_to_dict (config_name , model ,
452+ revision )
453+ if encoder_dict :
454+ break
455+
421456 if not encoder_dict :
422457 return None
423458
0 commit comments