Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Enhance get_torchvision_models #1867

Merged
merged 16 commits into from
Apr 18, 2022
Merged
57 changes: 57 additions & 0 deletions mmcv/model_zoo/torchvision_0.12.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": null,
"shufflenetv2_x2.0": null,
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
}
52 changes: 43 additions & 9 deletions mmcv/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..fileio import FileClient
from ..fileio import load as load_file
from ..parallel import is_module_wrapper
from ..utils import load_url, mkdir_or_exist
from ..utils import digit_version, load_url, mkdir_or_exist
from .dist_utils import get_dist_info

ENV_MMCV_HOME = 'MMCV_HOME'
Expand Down Expand Up @@ -106,14 +106,48 @@ def load(module, prefix=''):


def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
else:
# Since torchvision bumps to 0.13, the weight loading logic, model keys
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
# and model urls have been changed. We will load the old version urls
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
# in mmcv/model_zoo/torchvision_0.12.json to prevent from BC
# Breaking. If your version of torchvision is higher than 0.13.0,
# new urls will be added to `model_urls` additionally. You can get the
# newest torchvision model by resnet50.imagenet1k_v1.
json_path = osp.join(mmcv.__path__[0],
'model_zoo/torchvision_0.12.json')
model_urls = mmcv.load(json_path)
for cls_name, cls in torchvision.models.__dict__.items():
# The name of torchvision model weights classes ends with
# `_Weights` such as ResNet18_Weights`. However, some model weight
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if (not cls_name.endswith('_Weights')
or not hasattr(cls, 'DEFAULT')):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
continue
# Since `cls.DEFAULT` can not be accessed by iterate cls, we set
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
# default urls explicitly.
cls_key = cls_name.replace('_Weights', '').lower()
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls:
cls_key = cls_name.replace('_Weights', '').lower()
cls_key = f'{cls_key}.{weight_enum.name.lower()}'
model_urls[cls_key] = weight_enum.url

return model_urls


Expand Down
39 changes: 21 additions & 18 deletions tests/test_load_model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from unittest.mock import patch

import pytest
import torchvision

import mmcv
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
ENV_XDG_CACHE_HOME, _get_mmcv_home,
_load_checkpoint,
get_deprecated_model_names,
get_external_models)
from mmcv.utils import TORCH_VERSION
from mmcv.utils import digit_version


@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
Expand Down Expand Up @@ -77,24 +78,26 @@ def load(filepath, map_location=None):
@patch('torch.load', load)
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
torchvision_version = torchvision.__version__
if digit_version(torchvision_version) < digit_version('0.10.0'):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
assert (_load_checkpoint('modelzoo://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
assert (_load_checkpoint('torchvision://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')

# test torchvision://
url = _load_checkpoint('torchvision://resnet50')
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
assert (_load_checkpoint('modelzoo://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
assert (_load_checkpoint('torchvision://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')

if digit_version(torchvision.__version__) >= digit_version('0.13.0a0'):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
# Test load new format torchvision models.
_load_checkpoint('torchvision://resnet50.imagenet1k_v1')
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
_load_checkpoint('torchvision://resnet50.default')

# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
Expand Down