Skip to content

Commit ed00d06

Browse files
committed
Running formatting with command from CONTRIBUTING.md
- Skipping the cspnet.py file with formatting due to large diff
1 parent ae9bb38 commit ed00d06

File tree

1 file changed

+76
-63
lines changed

1 file changed

+76
-63
lines changed

timm/models/_hub.py

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
try:
1919
import safetensors.torch
20+
2021
_has_safetensors = True
2122
except ImportError:
2223
_has_safetensors = False
@@ -31,10 +32,16 @@
3132

3233
try:
3334
from huggingface_hub import (
34-
create_repo, get_hf_file_metadata,
35-
hf_hub_download, hf_hub_url, model_info,
36-
repo_type_and_id_from_hf_id, upload_folder)
35+
create_repo,
36+
get_hf_file_metadata,
37+
hf_hub_download,
38+
hf_hub_url,
39+
model_info,
40+
repo_type_and_id_from_hf_id,
41+
upload_folder,
42+
)
3743
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
44+
3845
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
3946
_has_hf_hub = True
4047
except ImportError:
@@ -43,8 +50,16 @@
4350

4451
_logger = logging.getLogger(__name__)
4552

46-
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
47-
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
53+
__all__ = [
54+
'get_cache_dir',
55+
'download_cached_file',
56+
'has_hf_hub',
57+
'hf_split',
58+
'load_model_config_from_hf',
59+
'load_state_dict_from_hf',
60+
'save_for_hf',
61+
'push_to_hf_hub',
62+
]
4863

4964
# Default name for a weights file hosted on the Huggingface Hub.
5065
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -69,10 +84,10 @@ def get_cache_dir(child_dir: str = ''):
6984

7085

7186
def download_cached_file(
72-
url: Union[str, List[str], Tuple[str, str]],
73-
check_hash: bool = True,
74-
progress: bool = False,
75-
cache_dir: Optional[Union[str, Path]] = None,
87+
url: Union[str, List[str], Tuple[str, str]],
88+
check_hash: bool = True,
89+
progress: bool = False,
90+
cache_dir: Optional[Union[str, Path]] = None,
7691
):
7792
if isinstance(url, (list, tuple)):
7893
url, filename = url
@@ -95,9 +110,9 @@ def download_cached_file(
95110

96111

97112
def check_cached_file(
98-
url: Union[str, List[str], Tuple[str, str]],
99-
check_hash: bool = True,
100-
cache_dir: Optional[Union[str, Path]] = None,
113+
url: Union[str, List[str], Tuple[str, str]],
114+
check_hash: bool = True,
115+
cache_dir: Optional[Union[str, Path]] = None,
101116
):
102117
if isinstance(url, (list, tuple)):
103118
url, filename = url
@@ -114,7 +129,7 @@ def check_cached_file(
114129
if hash_prefix:
115130
with open(cached_file, 'rb') as f:
116131
hd = hashlib.sha256(f.read()).hexdigest()
117-
if hd[:len(hash_prefix)] != hash_prefix:
132+
if hd[: len(hash_prefix)] != hash_prefix:
118133
return False
119134
return True
120135
return False
@@ -124,7 +139,8 @@ def has_hf_hub(necessary: bool = False):
124139
if not _has_hf_hub and necessary:
125140
# if no HF Hub module installed, and it is necessary to continue, raise error
126141
raise RuntimeError(
127-
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
142+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
143+
)
128144
return _has_hf_hub
129145

130146

@@ -144,9 +160,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
144160

145161

146162
def download_from_hf(
147-
model_id: str,
148-
filename: str,
149-
cache_dir: Optional[Union[str, Path]] = None,
163+
model_id: str,
164+
filename: str,
165+
cache_dir: Optional[Union[str, Path]] = None,
150166
):
151167
hf_model_id, hf_revision = hf_split(model_id)
152168
return hf_hub_download(
@@ -158,8 +174,8 @@ def download_from_hf(
158174

159175

160176
def _parse_model_cfg(
161-
cfg: Dict[str, Any],
162-
extra_fields: Dict[str, Any],
177+
cfg: Dict[str, Any],
178+
extra_fields: Dict[str, Any],
163179
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
164180
""""""
165181
# legacy "single‑dict" → split
@@ -170,7 +186,7 @@ def _parse_model_cfg(
170186
"num_features": pretrained_cfg.pop("num_features", None),
171187
"pretrained_cfg": pretrained_cfg,
172188
}
173-
if "labels" in pretrained_cfg: # rename ‑‑> label_names
189+
if "labels" in pretrained_cfg: # rename ‑‑> label_names
174190
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
175191

176192
pretrained_cfg = cfg["pretrained_cfg"]
@@ -190,8 +206,8 @@ def _parse_model_cfg(
190206

191207

192208
def load_model_config_from_hf(
193-
model_id: str,
194-
cache_dir: Optional[Union[str, Path]] = None,
209+
model_id: str,
210+
cache_dir: Optional[Union[str, Path]] = None,
195211
):
196212
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
197213
assert has_hf_hub(True)
@@ -201,7 +217,7 @@ def load_model_config_from_hf(
201217

202218

203219
def load_model_config_from_path(
204-
model_path: Union[str, Path],
220+
model_path: Union[str, Path],
205221
):
206222
"""Load from ``<model_path>/config.json`` on the local filesystem."""
207223
model_path = Path(model_path)
@@ -214,10 +230,10 @@ def load_model_config_from_path(
214230

215231

216232
def load_state_dict_from_hf(
217-
model_id: str,
218-
filename: str = HF_WEIGHTS_NAME,
219-
weights_only: bool = False,
220-
cache_dir: Optional[Union[str, Path]] = None,
233+
model_id: str,
234+
filename: str = HF_WEIGHTS_NAME,
235+
weights_only: bool = False,
236+
cache_dir: Optional[Union[str, Path]] = None,
221237
):
222238
assert has_hf_hub(True)
223239
hf_model_id, hf_revision = hf_split(model_id)
@@ -234,7 +250,8 @@ def load_state_dict_from_hf(
234250
)
235251
_logger.info(
236252
f"[{model_id}] Safe alternative available for '{filename}' "
237-
f"(as '{safe_filename}'). Loading weights using safetensors.")
253+
f"(as '{safe_filename}'). Loading weights using safetensors."
254+
)
238255
return safetensors.torch.load_file(cached_safe_file, device="cpu")
239256
except EntryNotFoundError:
240257
pass
@@ -266,9 +283,10 @@ def load_state_dict_from_hf(
266283
)
267284
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
268285

286+
269287
def load_state_dict_from_path(
270-
path: str,
271-
weights_only: bool = False,
288+
path: str,
289+
weights_only: bool = False,
272290
):
273291
found_file = None
274292
for fname in _PREFERRED_FILES:
@@ -283,10 +301,7 @@ def load_state_dict_from_path(
283301
files = sorted(path.glob(f"*{ext}"))
284302
if files:
285303
if len(files) > 1:
286-
logging.warning(
287-
f"Multiple {ext} checkpoints in {path}: {names}. "
288-
f"Using '{files[0].name}'."
289-
)
304+
logging.warning(f"Multiple {ext} checkpoints in {path}: {names}. " f"Using '{files[0].name}'.")
290305
found_file = files[0]
291306

292307
if not found_file:
@@ -300,10 +315,10 @@ def load_state_dict_from_path(
300315

301316

302317
def load_custom_from_hf(
303-
model_id: str,
304-
filename: str,
305-
model: torch.nn.Module,
306-
cache_dir: Optional[Union[str, Path]] = None,
318+
model_id: str,
319+
filename: str,
320+
model: torch.nn.Module,
321+
cache_dir: Optional[Union[str, Path]] = None,
307322
):
308323
assert has_hf_hub(True)
309324
hf_model_id, hf_revision = hf_split(model_id)
@@ -317,10 +332,7 @@ def load_custom_from_hf(
317332

318333

319334
def save_config_for_hf(
320-
model: torch.nn.Module,
321-
config_path: str,
322-
model_config: Optional[dict] = None,
323-
model_args: Optional[dict] = None
335+
model: torch.nn.Module, config_path: str, model_config: Optional[dict] = None, model_args: Optional[dict] = None
324336
):
325337
model_config = model_config or {}
326338
hf_config = {}
@@ -339,7 +351,8 @@ def save_config_for_hf(
339351
if 'labels' in model_config:
340352
_logger.warning(
341353
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
342-
" Renaming provided 'labels' field to 'label_names'.")
354+
" Renaming provided 'labels' field to 'label_names'."
355+
)
343356
model_config.setdefault('label_names', model_config.pop('labels'))
344357

345358
label_names = model_config.pop('label_names', None)
@@ -366,11 +379,11 @@ def save_config_for_hf(
366379

367380

368381
def save_for_hf(
369-
model: torch.nn.Module,
370-
save_directory: str,
371-
model_config: Optional[dict] = None,
372-
model_args: Optional[dict] = None,
373-
safe_serialization: Union[bool, Literal["both"]] = False,
382+
model: torch.nn.Module,
383+
save_directory: str,
384+
model_config: Optional[dict] = None,
385+
model_args: Optional[dict] = None,
386+
safe_serialization: Union[bool, Literal["both"]] = False,
374387
):
375388
assert has_hf_hub(True)
376389
save_directory = Path(save_directory)
@@ -394,18 +407,18 @@ def save_for_hf(
394407

395408

396409
def push_to_hf_hub(
397-
model: torch.nn.Module,
398-
repo_id: str,
399-
commit_message: str = 'Add model',
400-
token: Optional[str] = None,
401-
revision: Optional[str] = None,
402-
private: bool = False,
403-
create_pr: bool = False,
404-
model_config: Optional[dict] = None,
405-
model_card: Optional[dict] = None,
406-
model_args: Optional[dict] = None,
407-
task_name: str = 'image-classification',
408-
safe_serialization: Union[bool, Literal["both"]] = 'both',
410+
model: torch.nn.Module,
411+
repo_id: str,
412+
commit_message: str = 'Add model',
413+
token: Optional[str] = None,
414+
revision: Optional[str] = None,
415+
private: bool = False,
416+
create_pr: bool = False,
417+
model_config: Optional[dict] = None,
418+
model_card: Optional[dict] = None,
419+
model_args: Optional[dict] = None,
420+
task_name: str = 'image-classification',
421+
safe_serialization: Union[bool, Literal["both"]] = 'both',
409422
):
410423
"""
411424
Arguments:
@@ -459,9 +472,9 @@ def push_to_hf_hub(
459472

460473

461474
def generate_readme(
462-
model_card: dict,
463-
model_name: str,
464-
task_name: str = 'image-classification',
475+
model_card: dict,
476+
model_name: str,
477+
task_name: str = 'image-classification',
465478
):
466479
tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers']
467480
readme_text = "---\n"

0 commit comments

Comments
 (0)