Skip to content

Commit 166a524

Browse files
committed
Adding new approach
1 parent a955d2c commit 166a524

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

timm/models/_hub.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -542,17 +542,17 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
542542
yield filename[:-4] + ".safetensors"
543543

544544

545-
def _get_license_from_hf_hub(model_id: str | None, hf_hub_id: str | None) -> str | None:
545+
def _get_license_from_hf_hub(hf_hub_id: str | None) -> str | None:
546546
"""Retrieve license information for a model from Hugging Face Hub.
547547
548548
Fetches the license field from the model card metadata on Hugging Face Hub
549-
for the specified model. Returns None if the model is not found, if
550-
huggingface_hub is not installed, or if the model is marked as "untrained".
549+
for the specified model. This function is called lazily when the license
550+
attribute is accessed on PretrainedCfg objects that don't have an explicit
551+
license set.
551552
552553
Args:
553-
model_id: The model identifier/name. In the case of None we assume an untrained model.
554-
hf_hub_id: The Hugging Face Hub organization/user ID. If it is None,
555-
we will return None as we cannot infer the license terms.
554+
hf_hub_id: The Hugging Face Hub model ID (e.g., 'organization/model').
555+
If None or empty, returns None as license cannot be determined.
556556
557557
Returns:
558558
The license string in lowercase if found, None otherwise.
@@ -566,17 +566,17 @@ def _get_license_from_hf_hub(model_id: str | None, hf_hub_id: str | None) -> str
566566
_logger.warning(msg=msg)
567567
return None
568568

569-
if not (model_id and hf_hub_id):
569+
if hf_hub_id is None or hf_hub_id == "timm/":
570570
return None
571571

572-
repo_id: str = hf_hub_id + model_id
573-
574572
try:
575-
info = model_info(repo_id=repo_id)
573+
info = model_info(repo_id=hf_hub_id)
576574

577575
except RepositoryNotFoundError:
578576
# TODO: any wish what happens here? @rwightman
579-
print(repo_id)
577+
return None
578+
579+
except Exception as _:
580580
return None
581581

582582
license = info.card_data.get("license").lower() if info.card_data else None

timm/models/_pretrained.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ class PretrainedCfg:
5858
def has_weights(self):
5959
return self.url or self.file or self.hf_hub_id
6060

61+
def __getattribute__(self, name):
62+
if name == 'license': # Intercept license access to set it in case it was not set anywhere else.
63+
license_value = super().__getattribute__('license')
64+
65+
if license_value is None:
66+
from ._hub import _get_license_from_hf_hub
67+
license_value = _get_license_from_hf_hub(hf_hub_id=self.hf_hub_id)
68+
69+
self.license = license_value
70+
71+
return license_value
72+
73+
return super().__getattribute__(name)
74+
6175
def to_dict(self, remove_source=False, remove_null=True):
6276
return filter_pretrained_cfg(
6377
asdict(self),

0 commit comments

Comments
 (0)