1717
1818try :
1919 import safetensors .torch
20+
2021 _has_safetensors = True
2122except ImportError :
2223 _has_safetensors = False
3132
3233try :
3334 from huggingface_hub import (
34- create_repo , get_hf_file_metadata ,
35- hf_hub_download , hf_hub_url , model_info ,
36- repo_type_and_id_from_hf_id , upload_folder )
35+ create_repo ,
36+ get_hf_file_metadata ,
37+ hf_hub_download ,
38+ hf_hub_url ,
39+ model_info ,
40+ repo_type_and_id_from_hf_id ,
41+ upload_folder ,
42+ )
3743 from huggingface_hub .utils import EntryNotFoundError , RepositoryNotFoundError
44+
3845 hf_hub_download = partial (hf_hub_download , library_name = "timm" , library_version = __version__ )
3946 _has_hf_hub = True
4047except ImportError :
4350
4451_logger = logging .getLogger (__name__ )
4552
46- __all__ = ['get_cache_dir' , 'download_cached_file' , 'has_hf_hub' , 'hf_split' , 'load_model_config_from_hf' ,
47- 'load_state_dict_from_hf' , 'save_for_hf' , 'push_to_hf_hub' ]
53+ __all__ = [
54+ 'get_cache_dir' ,
55+ 'download_cached_file' ,
56+ 'has_hf_hub' ,
57+ 'hf_split' ,
58+ 'load_model_config_from_hf' ,
59+ 'load_state_dict_from_hf' ,
60+ 'save_for_hf' ,
61+ 'push_to_hf_hub' ,
62+ ]
4863
4964# Default name for a weights file hosted on the Huggingface Hub.
5065HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -69,10 +84,10 @@ def get_cache_dir(child_dir: str = ''):
6984
7085
7186def download_cached_file (
72- url : Union [str , List [str ], Tuple [str , str ]],
73- check_hash : bool = True ,
74- progress : bool = False ,
75- cache_dir : Optional [Union [str , Path ]] = None ,
87+ url : Union [str , List [str ], Tuple [str , str ]],
88+ check_hash : bool = True ,
89+ progress : bool = False ,
90+ cache_dir : Optional [Union [str , Path ]] = None ,
7691):
7792 if isinstance (url , (list , tuple )):
7893 url , filename = url
@@ -95,9 +110,9 @@ def download_cached_file(
95110
96111
97112def check_cached_file (
98- url : Union [str , List [str ], Tuple [str , str ]],
99- check_hash : bool = True ,
100- cache_dir : Optional [Union [str , Path ]] = None ,
113+ url : Union [str , List [str ], Tuple [str , str ]],
114+ check_hash : bool = True ,
115+ cache_dir : Optional [Union [str , Path ]] = None ,
101116):
102117 if isinstance (url , (list , tuple )):
103118 url , filename = url
@@ -114,7 +129,7 @@ def check_cached_file(
114129 if hash_prefix :
115130 with open (cached_file , 'rb' ) as f :
116131 hd = hashlib .sha256 (f .read ()).hexdigest ()
117- if hd [:len (hash_prefix )] != hash_prefix :
132+ if hd [: len (hash_prefix )] != hash_prefix :
118133 return False
119134 return True
120135 return False
@@ -124,7 +139,8 @@ def has_hf_hub(necessary: bool = False):
124139 if not _has_hf_hub and necessary :
125140 # if no HF Hub module installed, and it is necessary to continue, raise error
126141 raise RuntimeError (
127- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.' )
142+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
143+ )
128144 return _has_hf_hub
129145
130146
@@ -144,9 +160,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
144160
145161
146162def download_from_hf (
147- model_id : str ,
148- filename : str ,
149- cache_dir : Optional [Union [str , Path ]] = None ,
163+ model_id : str ,
164+ filename : str ,
165+ cache_dir : Optional [Union [str , Path ]] = None ,
150166):
151167 hf_model_id , hf_revision = hf_split (model_id )
152168 return hf_hub_download (
@@ -158,8 +174,8 @@ def download_from_hf(
158174
159175
160176def _parse_model_cfg (
161- cfg : Dict [str , Any ],
162- extra_fields : Dict [str , Any ],
177+ cfg : Dict [str , Any ],
178+ extra_fields : Dict [str , Any ],
163179) -> Tuple [Dict [str , Any ], str , Dict [str , Any ]]:
164180 """"""
165181 # legacy "single‑dict" → split
@@ -170,7 +186,7 @@ def _parse_model_cfg(
170186 "num_features" : pretrained_cfg .pop ("num_features" , None ),
171187 "pretrained_cfg" : pretrained_cfg ,
172188 }
173- if "labels" in pretrained_cfg : # rename ‑‑> label_names
189+ if "labels" in pretrained_cfg : # rename ‑‑> label_names
174190 pretrained_cfg ["label_names" ] = pretrained_cfg .pop ("labels" )
175191
176192 pretrained_cfg = cfg ["pretrained_cfg" ]
@@ -190,8 +206,8 @@ def _parse_model_cfg(
190206
191207
192208def load_model_config_from_hf (
193- model_id : str ,
194- cache_dir : Optional [Union [str , Path ]] = None ,
209+ model_id : str ,
210+ cache_dir : Optional [Union [str , Path ]] = None ,
195211):
196212 """Original HF‑Hub loader (unchanged download, shared parsing)."""
197213 assert has_hf_hub (True )
@@ -201,7 +217,7 @@ def load_model_config_from_hf(
201217
202218
203219def load_model_config_from_path (
204- model_path : Union [str , Path ],
220+ model_path : Union [str , Path ],
205221):
206222 """Load from ``<model_path>/config.json`` on the local filesystem."""
207223 model_path = Path (model_path )
@@ -214,10 +230,10 @@ def load_model_config_from_path(
214230
215231
216232def load_state_dict_from_hf (
217- model_id : str ,
218- filename : str = HF_WEIGHTS_NAME ,
219- weights_only : bool = False ,
220- cache_dir : Optional [Union [str , Path ]] = None ,
233+ model_id : str ,
234+ filename : str = HF_WEIGHTS_NAME ,
235+ weights_only : bool = False ,
236+ cache_dir : Optional [Union [str , Path ]] = None ,
221237):
222238 assert has_hf_hub (True )
223239 hf_model_id , hf_revision = hf_split (model_id )
@@ -234,7 +250,8 @@ def load_state_dict_from_hf(
234250 )
235251 _logger .info (
236252 f"[{ model_id } ] Safe alternative available for '{ filename } ' "
237- f"(as '{ safe_filename } '). Loading weights using safetensors." )
253+ f"(as '{ safe_filename } '). Loading weights using safetensors."
254+ )
238255 return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
239256 except EntryNotFoundError :
240257 pass
@@ -266,9 +283,10 @@ def load_state_dict_from_hf(
266283)
267284_EXT_PRIORITY = ('.safetensors' , '.pth' , '.pth.tar' , '.bin' )
268285
286+
269287def load_state_dict_from_path (
270- path : str ,
271- weights_only : bool = False ,
288+ path : str ,
289+ weights_only : bool = False ,
272290):
273291 found_file = None
274292 for fname in _PREFERRED_FILES :
@@ -283,10 +301,7 @@ def load_state_dict_from_path(
283301 files = sorted (path .glob (f"*{ ext } " ))
284302 if files :
285303 if len (files ) > 1 :
286- logging .warning (
287- f"Multiple { ext } checkpoints in { path } : { names } . "
288- f"Using '{ files [0 ].name } '."
289- )
304+ logging .warning (f"Multiple { ext } checkpoints in { path } : { names } . " f"Using '{ files [0 ].name } '." )
290305 found_file = files [0 ]
291306
292307 if not found_file :
@@ -300,10 +315,10 @@ def load_state_dict_from_path(
300315
301316
302317def load_custom_from_hf (
303- model_id : str ,
304- filename : str ,
305- model : torch .nn .Module ,
306- cache_dir : Optional [Union [str , Path ]] = None ,
318+ model_id : str ,
319+ filename : str ,
320+ model : torch .nn .Module ,
321+ cache_dir : Optional [Union [str , Path ]] = None ,
307322):
308323 assert has_hf_hub (True )
309324 hf_model_id , hf_revision = hf_split (model_id )
@@ -317,10 +332,7 @@ def load_custom_from_hf(
317332
318333
319334def save_config_for_hf (
320- model : torch .nn .Module ,
321- config_path : str ,
322- model_config : Optional [dict ] = None ,
323- model_args : Optional [dict ] = None
335+ model : torch .nn .Module , config_path : str , model_config : Optional [dict ] = None , model_args : Optional [dict ] = None
324336):
325337 model_config = model_config or {}
326338 hf_config = {}
@@ -339,7 +351,8 @@ def save_config_for_hf(
339351 if 'labels' in model_config :
340352 _logger .warning (
341353 "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
342- " Renaming provided 'labels' field to 'label_names'." )
354+ " Renaming provided 'labels' field to 'label_names'."
355+ )
343356 model_config .setdefault ('label_names' , model_config .pop ('labels' ))
344357
345358 label_names = model_config .pop ('label_names' , None )
@@ -366,11 +379,11 @@ def save_config_for_hf(
366379
367380
368381def save_for_hf (
369- model : torch .nn .Module ,
370- save_directory : str ,
371- model_config : Optional [dict ] = None ,
372- model_args : Optional [dict ] = None ,
373- safe_serialization : Union [bool , Literal ["both" ]] = False ,
382+ model : torch .nn .Module ,
383+ save_directory : str ,
384+ model_config : Optional [dict ] = None ,
385+ model_args : Optional [dict ] = None ,
386+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
374387):
375388 assert has_hf_hub (True )
376389 save_directory = Path (save_directory )
@@ -394,18 +407,18 @@ def save_for_hf(
394407
395408
396409def push_to_hf_hub (
397- model : torch .nn .Module ,
398- repo_id : str ,
399- commit_message : str = 'Add model' ,
400- token : Optional [str ] = None ,
401- revision : Optional [str ] = None ,
402- private : bool = False ,
403- create_pr : bool = False ,
404- model_config : Optional [dict ] = None ,
405- model_card : Optional [dict ] = None ,
406- model_args : Optional [dict ] = None ,
407- task_name : str = 'image-classification' ,
408- safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
410+ model : torch .nn .Module ,
411+ repo_id : str ,
412+ commit_message : str = 'Add model' ,
413+ token : Optional [str ] = None ,
414+ revision : Optional [str ] = None ,
415+ private : bool = False ,
416+ create_pr : bool = False ,
417+ model_config : Optional [dict ] = None ,
418+ model_card : Optional [dict ] = None ,
419+ model_args : Optional [dict ] = None ,
420+ task_name : str = 'image-classification' ,
421+ safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
409422):
410423 """
411424 Arguments:
@@ -459,9 +472,9 @@ def push_to_hf_hub(
459472
460473
461474def generate_readme (
462- model_card : dict ,
463- model_name : str ,
464- task_name : str = 'image-classification' ,
475+ model_card : dict ,
476+ model_name : str ,
477+ task_name : str = 'image-classification' ,
465478):
466479 tags = model_card .get ('tags' , None ) or [task_name , 'timm' , 'transformers' ]
467480 readme_text = "---\n "
0 commit comments