44import json
55import os
66import time
7+ from functools import cache
78from pathlib import Path
8- from typing import Any , Dict , Literal , Optional , Type , Union
9+ from typing import Any , Callable , Dict , Literal , Optional , Type , Union
910
1011import huggingface_hub
11- from huggingface_hub import (file_exists , hf_hub_download , list_repo_files ,
12- try_to_load_from_cache )
12+ from huggingface_hub import hf_hub_download
13+ from huggingface_hub import list_repo_files as hf_list_repo_files
14+ from huggingface_hub import try_to_load_from_cache
1315from huggingface_hub .utils import (EntryNotFoundError , HfHubHTTPError ,
1416 HFValidationError , LocalEntryNotFoundError ,
1517 RepositoryNotFoundError ,
@@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
8688 MISTRAL = "mistral"
8789
8890
91+ def with_retry (func : Callable [[], Any ],
92+ log_msg : str ,
93+ max_retries : int = 2 ,
94+ retry_delay : int = 2 ):
95+ for attempt in range (max_retries ):
96+ try :
97+ return func ()
98+ except Exception as e :
99+ if attempt == max_retries - 1 :
100+ logger .error ("%s: %s" , log_msg , e )
101+ raise
102+ logger .error ("%s: %s, retrying %d of %d" , log_msg , e , attempt + 1 ,
103+ max_retries )
104+ time .sleep (retry_delay )
105+ retry_delay *= 2
106+
107+
108+ # @cache doesn't cache exceptions
109+ @cache
110+ def list_repo_files (
111+ repo_id : str ,
112+ * ,
113+ revision : Optional [str ] = None ,
114+ repo_type : Optional [str ] = None ,
115+ token : Union [str , bool , None ] = None ,
116+ ) -> list [str ]:
117+
118+ def lookup_files ():
119+ try :
120+ return hf_list_repo_files (repo_id ,
121+ revision = revision ,
122+ repo_type = repo_type ,
123+ token = token )
124+ except huggingface_hub .errors .OfflineModeIsEnabled :
125+ # Don't raise in offline mode,
126+ # all we know is that we don't have this
127+ # file cached.
128+ return []
129+
130+ return with_retry (lookup_files , "Error retrieving file list" )
131+
132+
133+ def file_exists (
134+ repo_id : str ,
135+ file_name : str ,
136+ * ,
137+ repo_type : Optional [str ] = None ,
138+ revision : Optional [str ] = None ,
139+ token : Union [str , bool , None ] = None ,
140+ ) -> bool :
141+
142+ file_list = list_repo_files (repo_id ,
143+ repo_type = repo_type ,
144+ revision = revision ,
145+ token = token )
146+ return file_name in file_list
147+
148+
149+ # In offline mode the result can be a false negative
89150def file_or_path_exists (model : Union [str , Path ], config_name : str ,
90151 revision : Optional [str ]) -> bool :
91152 if Path (model ).exists ():
@@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
103164 # hf_hub. This will fail in offline mode.
104165
105166 # Call HF to check if the file exists
106- # 2 retries and exponential backoff
107- max_retries = 2
108- retry_delay = 2
109- for attempt in range (max_retries ):
110- try :
111- return file_exists (model ,
112- config_name ,
113- revision = revision ,
114- token = HF_TOKEN )
115- except huggingface_hub .errors .OfflineModeIsEnabled :
116- # Don't raise in offline mode,
117- # all we know is that we don't have this
118- # file cached.
119- return False
120- except Exception as e :
121- logger .error (
122- "Error checking file existence: %s, retrying %d of %d" , e ,
123- attempt + 1 , max_retries )
124- if attempt == max_retries - 1 :
125- logger .error ("Error checking file existence: %s" , e )
126- raise
127- time .sleep (retry_delay )
128- retry_delay *= 2
129- continue
130- return False
167+ return file_exists (str (model ),
168+ config_name ,
169+ revision = revision ,
170+ token = HF_TOKEN )
131171
132172
133173def patch_rope_scaling (config : PretrainedConfig ) -> None :
@@ -208,32 +248,7 @@ def get_config(
208248 revision = revision ):
209249 config_format = ConfigFormat .MISTRAL
210250 else :
211- # If we're in offline mode and found no valid config format, then
212- # raise an offline mode error to indicate to the user that they
213- # don't have files cached and may need to go online.
214- # This is conveniently triggered by calling file_exists().
215-
216- # Call HF to check if the file exists
217- # 2 retries and exponential backoff
218- max_retries = 2
219- retry_delay = 2
220- for attempt in range (max_retries ):
221- try :
222- file_exists (model ,
223- HF_CONFIG_NAME ,
224- revision = revision ,
225- token = HF_TOKEN )
226- except Exception as e :
227- logger .error (
228- "Error checking file existence: %s, retrying %d of %d" ,
229- e , attempt + 1 , max_retries )
230- if attempt == max_retries :
231- logger .error ("Error checking file existence: %s" , e )
232- raise e
233- time .sleep (retry_delay )
234- retry_delay *= 2
235-
236- raise ValueError (f"No supported config format found in { model } " )
251+ raise ValueError (f"No supported config format found in { model } ." )
237252
238253 if config_format == ConfigFormat .HF :
239254 config_dict , _ = PretrainedConfig .get_config_dict (
@@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
339354 file_name = file_name ,
340355 revision = revision )
341356
342- if file_path is None and file_or_path_exists (
343- model = model , config_name = file_name , revision = revision ):
357+ if file_path is None :
344358 try :
345359 hf_hub_file = hf_hub_download (model , file_name , revision = revision )
360+ except huggingface_hub .errors .OfflineModeIsEnabled :
361+ return None
346362 except (RepositoryNotFoundError , RevisionNotFoundError ,
347363 EntryNotFoundError , LocalEntryNotFoundError ) as e :
348364 logger .debug ("File or repository not found in hf_hub_download" , e )
@@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
363379 return None
364380
365381
382+ @cache
366383def get_pooling_config (model : str , revision : Optional [str ] = 'main' ):
367384 """
368385 This function gets the pooling and normalize
@@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
390407 if modules_dict is None :
391408 return None
392409
410+ logger .info ("Found sentence-transformers modules configuration." )
411+
393412 pooling = next ((item for item in modules_dict
394413 if item ["type" ] == "sentence_transformers.models.Pooling" ),
395414 None )
@@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
408427 if pooling_type_name is not None :
409428 pooling_type_name = get_pooling_config_name (pooling_type_name )
410429
430+ logger .info ("Found pooling configuration." )
411431 return {"pooling_type" : pooling_type_name , "normalize" : normalize }
412432
413433 return None
@@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
435455 return None
436456
437457
458+ @cache
438459def get_sentence_transformer_tokenizer_config (model : str ,
439460 revision : Optional [str ] = 'main'
440461 ):
@@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
491512 if not encoder_dict :
492513 return None
493514
515+ logger .info ("Found sentence-transformers tokenize configuration." )
516+
494517 if all (k in encoder_dict for k in ("max_seq_length" , "do_lower_case" )):
495518 return encoder_dict
496519 return None
0 commit comments