Skip to content

Commit ae9bb38

Browse files
committed
Adding licensing information to cspnet.py
1 parent 019550e commit ae9bb38

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

timm/models/_hub.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
try:
3333
from huggingface_hub import (
3434
create_repo, get_hf_file_metadata,
35-
hf_hub_download, hf_hub_url,
35+
hf_hub_download, hf_hub_url, model_info,
3636
repo_type_and_id_from_hf_id, upload_folder)
37-
from huggingface_hub.utils import EntryNotFoundError
37+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
3838
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
3939
_has_hf_hub = True
4040
except ImportError:
@@ -540,3 +540,44 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
540540
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
541541
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
542542
yield filename[:-4] + ".safetensors"
543+
544+
545+
def _get_license_from_hf_hub(model_id: str | None, hf_hub_id: str | None) -> str | None:
546+
"""Retrieve license information for a model from Hugging Face Hub.
547+
548+
Fetches the license field from the model card metadata on Hugging Face Hub
549+
for the specified model. Returns None if the model is not found, if
550+
huggingface_hub is not installed, or if the model is marked as "untrained".
551+
552+
Args:
553+
model_id: The model identifier/name. In the case of None we assume an untrained model.
554+
hf_hub_id: The Hugging Face Hub organization/user ID. If it is None,
555+
we will return None as we cannot infer the license terms.
556+
557+
Returns:
558+
The license string in lowercase if found, None otherwise.
559+
560+
Note:
561+
Requires huggingface_hub package to be installed. Will log a warning
562+
and return None if the package is not available.
563+
"""
564+
if not has_hf_hub(True):
565+
msg = "For updated license information run `pip install huggingface_hub`."
566+
_logger.warning(msg=msg)
567+
return None
568+
569+
if not (model_id and hf_hub_id):
570+
return None
571+
572+
repo_id: str = hf_hub_id + model_id
573+
574+
try:
575+
info = model_info(repo_id=repo_id)
576+
577+
except RepositoryNotFoundError:
578+
# TODO: any wish what happens here? @rwightman
579+
print(repo_id)
580+
return None
581+
582+
license = info.card_data.get("license").lower() if info.card_data else None
583+
return license

timm/models/cspnet.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2323
from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
2424
from ._builder import build_model_with_cfg
25+
from ._hub import _get_license_from_hf_hub
2526
from ._manipulate import named_apply, MATCH_PREV_GROUP
2627
from ._registry import register_model, generate_default_cfgs
2728

@@ -918,82 +919,97 @@ def _cfg(url='', **kwargs):
918919
'crop_pct': 0.887, 'interpolation': 'bilinear',
919920
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
920921
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
922+
'license': _get_license_from_hf_hub(kwargs.pop('model_id', None), kwargs.get('hf_hub_id')),
921923
**kwargs
922924
}
923925

924-
925926
default_cfgs = generate_default_cfgs({
926927
'cspresnet50.ra_in1k': _cfg(
927928
hf_hub_id='timm/',
929+
model_id='cspresnet50.ra_in1k',
928930
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
929931
'cspresnet50d.untrained': _cfg(),
930932
'cspresnet50w.untrained': _cfg(),
931933
'cspresnext50.ra_in1k': _cfg(
932934
hf_hub_id='timm/',
935+
model_id='cspresnext50.ra_in1k',
933936
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
934937
),
935938
'cspdarknet53.ra_in1k': _cfg(
936939
hf_hub_id='timm/',
940+
model_id='cspdarknet53.ra_in1k',
937941
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
938942

939943
'darknet17.untrained': _cfg(),
940944
'darknet21.untrained': _cfg(),
941945
'sedarknet21.untrained': _cfg(),
942946
'darknet53.c2ns_in1k': _cfg(
943947
hf_hub_id='timm/',
948+
model_id='darknet53.c2ns_in1k',
944949
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
945950
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
946951
'darknetaa53.c2ns_in1k': _cfg(
947952
hf_hub_id='timm/',
953+
model_id='darknetaa53.c2ns_in1k',
948954
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
949955
test_input_size=(3, 288, 288), test_crop_pct=1.0),
950956

951957
'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
952958
'cs3darknet_m.c2ns_in1k': _cfg(
953959
hf_hub_id='timm/',
960+
model_id='cs3darknet_m.c2ns_in1k',
954961
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
955962
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
956963
),
957964
'cs3darknet_l.c2ns_in1k': _cfg(
958965
hf_hub_id='timm/',
966+
model_id='cs3darknet_l.c2ns_in1k',
959967
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
960968
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
961969
'cs3darknet_x.c2ns_in1k': _cfg(
962970
hf_hub_id='timm/',
971+
model_id='cs3darknet_x.c2ns_in1k',
963972
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
964973
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
965974

966975
'cs3darknet_focus_s.ra4_e3600_r256_in1k': _cfg(
967976
hf_hub_id='timm/',
977+
model_id='cs3darknet_focus_s.ra4_e3600_r256_in1k',
968978
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
969979
interpolation='bicubic', test_input_size=(3, 320, 320), test_crop_pct=1.0),
970980
'cs3darknet_focus_m.c2ns_in1k': _cfg(
971981
hf_hub_id='timm/',
982+
model_id='cs3darknet_focus_m.c2ns_in1k',
972983
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
973984
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
974985
'cs3darknet_focus_l.c2ns_in1k': _cfg(
975986
hf_hub_id='timm/',
987+
model_id='cs3darknet_focus_l.c2ns_in1k',
976988
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
977989
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
978990
'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
979991

980992
'cs3sedarknet_l.c2ns_in1k': _cfg(
981993
hf_hub_id='timm/',
994+
model_id='cs3sedarknet_l.c2ns_in1k',
982995
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
983996
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
984997
'cs3sedarknet_x.c2ns_in1k': _cfg(
985998
hf_hub_id='timm/',
999+
model_id='cs3sedarknet_x.c2ns_in1k',
9861000
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
9871001
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
9881002

9891003
'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
9901004

9911005
'cs3edgenet_x.c2_in1k': _cfg(
9921006
hf_hub_id='timm/',
1007+
model_id='cs3edgenet_x.c2_in1k',
9931008
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
9941009
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
9951010
'cs3se_edgenet_x.c2ns_in1k': _cfg(
9961011
hf_hub_id='timm/',
1012+
model_id='cs3se_edgenet_x.c2ns_in1k',
9971013
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
9981014
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
9991015
})
@@ -1111,4 +1127,4 @@ def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
11111127

11121128
@register_model
11131129
def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
1114-
return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
1130+
return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)

0 commit comments

Comments
 (0)