From 4fe5a37579c324a6ec54a9f036adc3671fe46796 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 11 Apr 2022 20:32:03 +0800 Subject: [PATCH 01/16] enhance get_torchvision_models --- mmcv/runner/checkpoint.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 7eaa0816c9..864cb959ad 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -18,7 +18,7 @@ from ..fileio import FileClient from ..fileio import load as load_file from ..parallel import is_module_wrapper -from ..utils import load_url, mkdir_or_exist +from ..utils import digit_version, load_url, mkdir_or_exist from .dist_utils import get_dist_info ENV_MMCV_HOME = 'MMCV_HOME' @@ -107,13 +107,30 @@ def load(module, prefix=''): def get_torchvision_models(): model_urls = dict() - for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): - if ispkg: - continue - _zoo = import_module(f'torchvision.models.{name}') - if hasattr(_zoo, 'model_urls'): - _urls = getattr(_zoo, 'model_urls') - model_urls.update(_urls) + if digit_version(torchvision.__version__) <= digit_version('0.12.1'): + for _, name, ispkg in pkgutil.walk_packages( + torchvision.models.__path__): + if ispkg: + continue + if name.startswith('_'): + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + else: + for cls_name, cls in torchvision.models.__dict__.items(): + if not hasattr(cls, '__base__'): + continue + if cls.__base__ != torchvision.models._api.WeightsEnum: + continue + cls_key = cls_name.replace('_Weights', '').lower() + if hasattr(cls, 'DEFAULT'): + model_urls[cls_key] = cls.DEFAULT.url + else: + warnings.warn(f'{cls_key} does not have default weight, see ' + f'more information in' + f'torchvision.models.{cls_name}') return model_urls From d429f647bc08b58f7cb478ee4d1d687e8ca22f34 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 12 Apr 2022 10:57:06 +0800 Subject: [PATCH 02/16] simplify logic --- mmcv/runner/checkpoint.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 864cb959ad..05607ae0d5 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -120,17 +120,9 @@ def get_torchvision_models(): model_urls.update(_urls) else: for cls_name, cls in torchvision.models.__dict__.items(): - if not hasattr(cls, '__base__'): - continue - if cls.__base__ != torchvision.models._api.WeightsEnum: - continue cls_key = cls_name.replace('_Weights', '').lower() - if hasattr(cls, 'DEFAULT'): + if cls_name.endswith('_Weights') and hasattr(cls, 'DEFAULT'): model_urls[cls_key] = cls.DEFAULT.url - else: - warnings.warn(f'{cls_key} does not have default weight, see ' - f'more information in' - f'torchvision.models.{cls_name}') return model_urls From 7e98ec49cafc503bfcf43b6d93fbf36de36161f1 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Wed, 13 Apr 2022 13:51:11 +0800 Subject: [PATCH 03/16] Dump ckpt in torchvision lower than 0.13.0 to a json file --- mmcv/runner/checkpoint.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 05607ae0d5..8d9b53ef75 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -2,12 +2,10 @@ import io import os import os.path as osp -import pkgutil import re import time import warnings from collections import OrderedDict -from importlib import import_module from tempfile import TemporaryDirectory import torch @@ -107,18 +105,20 @@ def load(module, prefix=''): def get_torchvision_models(): model_urls = dict() + # Since torchvision reconstruct its weight loading logic, some model keys + # and urls in `model_urls` have been changed. If you want to experiment + # based on old weights, please use torchvision lower than 13.0. See more + # details at https://github.com/open-mmlab/mmcv/issues/1848. if digit_version(torchvision.__version__) <= digit_version('0.12.1'): - for _, name, ispkg in pkgutil.walk_packages( - torchvision.models.__path__): - if ispkg: - continue - if name.startswith('_'): - continue - _zoo = import_module(f'torchvision.models.{name}') - if hasattr(_zoo, 'model_urls'): - _urls = getattr(_zoo, 'model_urls') - model_urls.update(_urls) + model_zoo_path = osp.join( + osp.dirname(__file__), '..', 'model_zoo', + 'torchvision_before0.13.json') + return mmcv.load(model_zoo_path) else: + warnings.warn( + 'Checkpoints loaded from torchvision have been changed ' + 'since torchvision 0.13.0. If you want to experiment based on old ' + 'weights, please use torchvision lower than 13.0') for cls_name, cls in torchvision.models.__dict__.items(): cls_key = cls_name.replace('_Weights', '').lower() if cls_name.endswith('_Weights') and hasattr(cls, 'DEFAULT'): From 87c865e9633979bd307fa95a53d4c816575fc887 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Wed, 13 Apr 2022 15:17:43 +0800 Subject: [PATCH 04/16] add json --- mmcv/model_zoo/torchvision_before0.13.json | 57 ++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 mmcv/model_zoo/torchvision_before0.13.json diff --git a/mmcv/model_zoo/torchvision_before0.13.json b/mmcv/model_zoo/torchvision_before0.13.json new file mode 100644 index 0000000000..06defe6748 --- /dev/null +++ b/mmcv/model_zoo/torchvision_before0.13.json @@ -0,0 +1,57 @@ +{ + "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", + "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + "shufflenetv2_x1.5": null, + "shufflenetv2_x2.0": null, + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" +} From 2866cdf775d09e515d06e7dba6c9c94f92ee59cd Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Fri, 15 Apr 2022 18:36:18 +0800 Subject: [PATCH 05/16] refactor load urls logic --- mmcv/runner/checkpoint.py | 49 ++++++++++++------- .../model_zoo/torchvision_before0.13.json | 3 ++ tests/test_load_model_zoo.py | 10 +++- 3 files changed, 44 insertions(+), 18 deletions(-) create mode 100644 tests/data/model_zoo/torchvision_before0.13.json diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 8d9b53ef75..a4f63f9610 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -2,10 +2,12 @@ import io import os import os.path as osp +import pkgutil import re import time import warnings from collections import OrderedDict +from importlib import import_module from tempfile import TemporaryDirectory import torch @@ -104,25 +106,38 @@ def load(module, prefix=''): def get_torchvision_models(): - model_urls = dict() - # Since torchvision reconstruct its weight loading logic, some model keys - # and urls in `model_urls` have been changed. If you want to experiment - # based on old weights, please use torchvision lower than 13.0. See more - # details at https://github.com/open-mmlab/mmcv/issues/1848. - if digit_version(torchvision.__version__) <= digit_version('0.12.1'): - model_zoo_path = osp.join( - osp.dirname(__file__), '..', 'model_zoo', - 'torchvision_before0.13.json') - return mmcv.load(model_zoo_path) + if digit_version(torchvision.__version__) <= digit_version('0.12.0'): + model_urls = dict() + # When the version of torchvision is lower than 0.13, the model url is + # not declared in `torchvision.model.__init__.py`, so we need to + # iterate through `torchvision.models.__path__` to get the url for each + # model. + for _, name, ispkg in pkgutil.walk_packages( + torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) else: - warnings.warn( - 'Checkpoints loaded from torchvision have been changed ' - 'since torchvision 0.13.0. If you want to experiment based on old ' - 'weights, please use torchvision lower than 13.0') + # Since torchvision bumps to 0.13, the weight loading logic, model keys + # and model urls have been changed. We will load the old version urls + # in mmcv/model_zoo/torchvision_before0.13.json to prevent from BC + # Breaking. If your version of torchvision is higher than 0.13.0, + # new urls will be added to `model_urls` additionally. You can get the + # newest torchvision model by resnet50.IMAGENET1K_V1. + json_path = osp.join(mmcv.__path__[0], + 'model_zoo/torchvision_before0.13.json') + model_urls = mmcv.load(json_path) for cls_name, cls in torchvision.models.__dict__.items(): - cls_key = cls_name.replace('_Weights', '').lower() - if cls_name.endswith('_Weights') and hasattr(cls, 'DEFAULT'): - model_urls[cls_key] = cls.DEFAULT.url + if not cls_name.endswith('_Weights'): + continue + + for weight_enum in cls: + cls_key = cls_name.replace('_Weights', '').lower() + cls_key = f'{cls_key}.{weight_enum.name}' + model_urls[cls_key] = weight_enum.url return model_urls diff --git a/tests/data/model_zoo/torchvision_before0.13.json b/tests/data/model_zoo/torchvision_before0.13.json new file mode 100644 index 0000000000..2ea3a8e474 --- /dev/null +++ b/tests/data/model_zoo/torchvision_before0.13.json @@ -0,0 +1,3 @@ +{ + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth" +} diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 35492fa8a0..ed7a482a76 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest +import torchvision import mmcv from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME, @@ -11,7 +12,7 @@ _load_checkpoint, get_deprecated_model_names, get_external_models) -from mmcv.utils import TORCH_VERSION +from mmcv.utils import TORCH_VERSION, digit_version @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @@ -96,6 +97,13 @@ def test_load_external_url(): assert url == ('url:https://download.pytorch.org/models/resnet50-0676b' 'a61.pth') + if digit_version(torchvision.__version__) > digit_version('0.12.0'): + assert ( + _load_checkpoint('torchvision://resnet50.IMAGENET1K_V1') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + assert ( + _load_checkpoint('torchvision://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') # test open-mmlab:// with default MMCV_HOME os.environ.pop(ENV_MMCV_HOME, None) os.environ.pop(ENV_XDG_CACHE_HOME, None) From d4e8365fd40bad7900b6ac0ea386c4f0542be760 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Fri, 15 Apr 2022 18:41:47 +0800 Subject: [PATCH 06/16] fix unit test --- tests/test_load_model_zoo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index ed7a482a76..7aa5df399e 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -98,9 +98,9 @@ def test_load_external_url(): 'a61.pth') if digit_version(torchvision.__version__) > digit_version('0.12.0'): - assert ( - _load_checkpoint('torchvision://resnet50.IMAGENET1K_V1') == - 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + # Test load new format torchvision models. + _load_checkpoint('torchvision://resnet50.IMAGENET1K_V1') + _load_checkpoint('torchvision://resnet50.DEFAULT') assert ( _load_checkpoint('torchvision://resnet50') == 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') From 532ee137222ceda0653691cd7b01f1a6907c40e2 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Fri, 15 Apr 2022 23:34:44 +0800 Subject: [PATCH 07/16] change url key to lower letters --- ...on_before0.13.json => torchvision_0.12.json} | 0 mmcv/runner/checkpoint.py | 17 +++++++++++------ .../data/model_zoo/torchvision_before0.13.json | 3 --- tests/test_load_model_zoo.py | 4 ++-- 4 files changed, 13 insertions(+), 11 deletions(-) rename mmcv/model_zoo/{torchvision_before0.13.json => torchvision_0.12.json} (100%) delete mode 100644 tests/data/model_zoo/torchvision_before0.13.json diff --git a/mmcv/model_zoo/torchvision_before0.13.json b/mmcv/model_zoo/torchvision_0.12.json similarity index 100% rename from mmcv/model_zoo/torchvision_before0.13.json rename to mmcv/model_zoo/torchvision_0.12.json diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index a4f63f9610..657ce29c3a 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -123,21 +123,26 @@ def get_torchvision_models(): else: # Since torchvision bumps to 0.13, the weight loading logic, model keys # and model urls have been changed. We will load the old version urls - # in mmcv/model_zoo/torchvision_before0.13.json to prevent from BC + # in mmcv/model_zoo/torchvision_0.12.json to prevent from BC # Breaking. If your version of torchvision is higher than 0.13.0, # new urls will be added to `model_urls` additionally. You can get the - # newest torchvision model by resnet50.IMAGENET1K_V1. + # newest torchvision model by resnet50.imagenet1k_v1. json_path = osp.join(mmcv.__path__[0], - 'model_zoo/torchvision_before0.13.json') + 'model_zoo/torchvision_0.12.json') model_urls = mmcv.load(json_path) for cls_name, cls in torchvision.models.__dict__.items(): - if not cls_name.endswith('_Weights'): + if (not cls_name.endswith('_Weights') + or not hasattr(cls, 'DEFAULT')): continue - + # Since `cls.DEFAULT` can not be accessed by iterate cls, we set + # default urls explicitly. + cls_key = cls_name.replace('_Weights', '').lower() + model_urls[f'{cls_key}.default'] = cls.DEFAULT.url for weight_enum in cls: cls_key = cls_name.replace('_Weights', '').lower() - cls_key = f'{cls_key}.{weight_enum.name}' + cls_key = f'{cls_key}.{weight_enum.name.lower()}' model_urls[cls_key] = weight_enum.url + return model_urls diff --git a/tests/data/model_zoo/torchvision_before0.13.json b/tests/data/model_zoo/torchvision_before0.13.json deleted file mode 100644 index 2ea3a8e474..0000000000 --- a/tests/data/model_zoo/torchvision_before0.13.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth" -} diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 7aa5df399e..153d7c65f8 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -99,8 +99,8 @@ def test_load_external_url(): if digit_version(torchvision.__version__) > digit_version('0.12.0'): # Test load new format torchvision models. - _load_checkpoint('torchvision://resnet50.IMAGENET1K_V1') - _load_checkpoint('torchvision://resnet50.DEFAULT') + _load_checkpoint('torchvision://resnet50.imagenet1k_v1') + _load_checkpoint('torchvision://resnet50.default') assert ( _load_checkpoint('torchvision://resnet50') == 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') From d2931fd4d944f6726469b7be33f0fc73ef4cb293 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Fri, 15 Apr 2022 23:39:48 +0800 Subject: [PATCH 08/16] check torchvision version rather than check torch version in unittest --- tests/test_load_model_zoo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 153d7c65f8..0a29d9b6d8 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -12,7 +12,7 @@ _load_checkpoint, get_deprecated_model_names, get_external_models) -from mmcv.utils import TORCH_VERSION, digit_version +from mmcv.utils import digit_version @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @@ -78,8 +78,9 @@ def load(filepath, map_location=None): @patch('torch.load', load) def test_load_external_url(): # test modelzoo:// + torchvision_version = torchvision.__version__ url = _load_checkpoint('modelzoo://resnet50') - if TORCH_VERSION < '1.9.0': + if torchvision_version < '0.10.0': assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' '357.pth') else: @@ -89,7 +90,7 @@ def test_load_external_url(): # test torchvision:// url = _load_checkpoint('torchvision://resnet50') - if TORCH_VERSION < '1.9.0': + if torchvision_version < '0.10.0': assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' '357.pth') else: From d3b2da9bfb3118795c128e38020c4d8cc90c730d Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 16 Apr 2022 21:47:25 +0800 Subject: [PATCH 09/16] Fix CI and refine test logic of torchvision version --- mmcv/runner/checkpoint.py | 1 + tests/test_load_model_zoo.py | 34 ++++++++++++++-------------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 657ce29c3a..6b7ec87400 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -131,6 +131,7 @@ def get_torchvision_models(): 'model_zoo/torchvision_0.12.json') model_urls = mmcv.load(json_path) for cls_name, cls in torchvision.models.__dict__.items(): + # if (not cls_name.endswith('_Weights') or not hasattr(cls, 'DEFAULT')): continue diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 0a29d9b6d8..8f03b5c454 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -79,32 +79,26 @@ def load(filepath, map_location=None): def test_load_external_url(): # test modelzoo:// torchvision_version = torchvision.__version__ - url = _load_checkpoint('modelzoo://resnet50') - if torchvision_version < '0.10.0': - assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' - '357.pth') + if digit_version(torchvision_version) < digit_version('0.10.0'): + assert (_load_checkpoint('modelzoo://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-19c8e' + '357.pth') + assert (_load_checkpoint('torchvision://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-19c8e' + '357.pth') else: - # filename of checkpoint is renamed in torch1.9.0 - assert url == ('url:https://download.pytorch.org/models/resnet50-0676b' - 'a61.pth') - - # test torchvision:// - url = _load_checkpoint('torchvision://resnet50') - if torchvision_version < '0.10.0': - assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' - '357.pth') - else: - # filename of checkpoint is renamed in torch1.9.0 - assert url == ('url:https://download.pytorch.org/models/resnet50-0676b' - 'a61.pth') + assert (_load_checkpoint('modelzoo://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-0676b' + 'a61.pth') + assert (_load_checkpoint('torchvision://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-0676b' + 'a61.pth') if digit_version(torchvision.__version__) > digit_version('0.12.0'): # Test load new format torchvision models. _load_checkpoint('torchvision://resnet50.imagenet1k_v1') _load_checkpoint('torchvision://resnet50.default') - assert ( - _load_checkpoint('torchvision://resnet50') == - 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + # test open-mmlab:// with default MMCV_HOME os.environ.pop(ENV_MMCV_HOME, None) os.environ.pop(ENV_XDG_CACHE_HOME, None) From f0d01a4f6c82410b1e784a0ca5bac091ee3341d5 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 16 Apr 2022 22:02:52 +0800 Subject: [PATCH 10/16] add comment --- mmcv/runner/checkpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 6b7ec87400..2f03947f25 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -131,7 +131,11 @@ def get_torchvision_models(): 'model_zoo/torchvision_0.12.json') model_urls = mmcv.load(json_path) for cls_name, cls in torchvision.models.__dict__.items(): - # + # The name of torchvision model weights classes ends with + # `_Weights` such as ResNet18_Weights`. However, some model weight + # classes, such as `MNASNet0_75_Weights` does not have any urls in + # torchvision 0.13.0 and cannot be iterated. Here we simply check + # `DEFAULT` attribute to ensure the class is not empty. if (not cls_name.endswith('_Weights') or not hasattr(cls, 'DEFAULT')): continue From 479b89fe04cd87b0c32298966bd594f0786554b5 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sun, 17 Apr 2022 09:56:08 +0800 Subject: [PATCH 11/16] support compare pre-release version --- mmcv/runner/checkpoint.py | 2 +- tests/test_load_model_zoo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 2f03947f25..03af2c7484 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -106,7 +106,7 @@ def load(module, prefix=''): def get_torchvision_models(): - if digit_version(torchvision.__version__) <= digit_version('0.12.0'): + if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): model_urls = dict() # When the version of torchvision is lower than 0.13, the model url is # not declared in `torchvision.model.__init__.py`, so we need to diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 8f03b5c454..b1c4f6acbd 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -94,7 +94,7 @@ def test_load_external_url(): 'url:https://download.pytorch.org/models/resnet50-0676b' 'a61.pth') - if digit_version(torchvision.__version__) > digit_version('0.12.0'): + if digit_version(torchvision.__version__) >= digit_version('0.13.0a0'): # Test load new format torchvision models. _load_checkpoint('torchvision://resnet50.imagenet1k_v1') _load_checkpoint('torchvision://resnet50.default') From 0b3ad6537139fe1c75941c87c5227555e42d9f4f Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sun, 17 Apr 2022 13:34:27 +0800 Subject: [PATCH 12/16] support loaad modeel like torchvision --- mmcv/runner/checkpoint.py | 4 ++++ tests/test_load_model_zoo.py | 1 + 2 files changed, 5 insertions(+) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 03af2c7484..cdc71b57db 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -430,6 +430,10 @@ def load_from_torchvision(filename, map_location=None): model_name = filename[11:] else: model_name = filename[14:] + + # Support get model urls like torchvision, `ResNet50_Weights.IMAGENET1K_V1` + # will be mapped to resnet50.imagenet1k_v1 + model_name = model_name.replace('_Weights', '').lower() return load_from_http(model_urls[model_name], map_location=map_location) diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index b1c4f6acbd..42f260e8aa 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -97,6 +97,7 @@ def test_load_external_url(): if digit_version(torchvision.__version__) >= digit_version('0.13.0a0'): # Test load new format torchvision models. _load_checkpoint('torchvision://resnet50.imagenet1k_v1') + _load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') _load_checkpoint('torchvision://resnet50.default') # test open-mmlab:// with default MMCV_HOME From 5bacf771998e6decb52441bfd97fa9b2487c2395 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 18 Apr 2022 11:55:31 +0800 Subject: [PATCH 13/16] refine comment. --- mmcv/runner/checkpoint.py | 19 ++++++++++--------- tests/test_load_model_zoo.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index cdc71b57db..c83ddfbe8a 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -121,18 +121,18 @@ def get_torchvision_models(): _urls = getattr(_zoo, 'model_urls') model_urls.update(_urls) else: - # Since torchvision bumps to 0.13, the weight loading logic, model keys - # and model urls have been changed. We will load the old version urls - # in mmcv/model_zoo/torchvision_0.12.json to prevent from BC - # Breaking. If your version of torchvision is higher than 0.13.0, - # new urls will be added to `model_urls` additionally. You can get the - # newest torchvision model by resnet50.imagenet1k_v1. + # Since torchvision bumps to v0.13, the weight loading logic, + # model keys and model urls have been changed. Here the URLs of old + # version is loaded to avoid breaking back compatibility. If the + # torchvision version>=0.13.0, new URLs will be added. Users can get + # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', + # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. json_path = osp.join(mmcv.__path__[0], 'model_zoo/torchvision_0.12.json') model_urls = mmcv.load(json_path) for cls_name, cls in torchvision.models.__dict__.items(): # The name of torchvision model weights classes ends with - # `_Weights` such as ResNet18_Weights`. However, some model weight + # `_Weights` such as `ResNet18_Weights`. However, some model weight # classes, such as `MNASNet0_75_Weights` does not have any urls in # torchvision 0.13.0 and cannot be iterated. Here we simply check # `DEFAULT` attribute to ensure the class is not empty. @@ -431,8 +431,9 @@ def load_from_torchvision(filename, map_location=None): else: model_name = filename[14:] - # Support get model urls like torchvision, `ResNet50_Weights.IMAGENET1K_V1` - # will be mapped to resnet50.imagenet1k_v1 + # Support getting model urls like torchvision + # `ResNet50_Weights.IMAGENET1K_V1` will be mapped to + # resnet50.imagenet1k_v1. model_name = model_name.replace('_Weights', '').lower() return load_from_http(model_urls[model_name], map_location=map_location) diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 42f260e8aa..9fc855a176 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -94,7 +94,7 @@ def test_load_external_url(): 'url:https://download.pytorch.org/models/resnet50-0676b' 'a61.pth') - if digit_version(torchvision.__version__) >= digit_version('0.13.0a0'): + if digit_version(torchvision_version) >= digit_version('0.13.0a0'): # Test load new format torchvision models. _load_checkpoint('torchvision://resnet50.imagenet1k_v1') _load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') From 5ba18528685511cd97d719edcc203ba4cf897a97 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 18 Apr 2022 14:40:35 +0800 Subject: [PATCH 14/16] fix unit test and comment --- mmcv/runner/checkpoint.py | 4 ++-- tests/test_load_model_zoo.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index c83ddfbe8a..acf2c522d3 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -139,7 +139,7 @@ def get_torchvision_models(): if (not cls_name.endswith('_Weights') or not hasattr(cls, 'DEFAULT')): continue - # Since `cls.DEFAULT` can not be accessed by iterate cls, we set + # Since `cls.DEFAULT` can not be accessed by iterating cls, we set # default urls explicitly. cls_key = cls_name.replace('_Weights', '').lower() model_urls[f'{cls_key}.default'] = cls.DEFAULT.url @@ -431,7 +431,7 @@ def load_from_torchvision(filename, map_location=None): else: model_name = filename[14:] - # Support getting model urls like torchvision + # Support getting model urls in the same way as torchvision # `ResNet50_Weights.IMAGENET1K_V1` will be mapped to # resnet50.imagenet1k_v1. model_name = model_name.replace('_Weights', '').lower() diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 9fc855a176..da62aef488 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -79,7 +79,7 @@ def load(filepath, map_location=None): def test_load_external_url(): # test modelzoo:// torchvision_version = torchvision.__version__ - if digit_version(torchvision_version) < digit_version('0.10.0'): + if digit_version(torchvision_version) < digit_version('0.10.0a0'): assert (_load_checkpoint('modelzoo://resnet50') == 'url:https://download.pytorch.org/models/resnet50-19c8e' '357.pth') @@ -96,8 +96,12 @@ def test_load_external_url(): if digit_version(torchvision_version) >= digit_version('0.13.0a0'): # Test load new format torchvision models. - _load_checkpoint('torchvision://resnet50.imagenet1k_v1') - _load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') + (_load_checkpoint('torchvision://resnet50.imagenet1k_v1') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + + (_load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + _load_checkpoint('torchvision://resnet50.default') # test open-mmlab:// with default MMCV_HOME From 0df01968c0e449e6d604a39f7c79b7a24ff8efc3 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 18 Apr 2022 15:18:35 +0800 Subject: [PATCH 15/16] fxi unit test bug --- tests/test_load_model_zoo.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index da62aef488..904cb94031 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -96,11 +96,13 @@ def test_load_external_url(): if digit_version(torchvision_version) >= digit_version('0.13.0a0'): # Test load new format torchvision models. - (_load_checkpoint('torchvision://resnet50.imagenet1k_v1') == - 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + assert ( + _load_checkpoint('torchvision://resnet50.imagenet1k_v1') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') - (_load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') == - 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + assert ( + _load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') _load_checkpoint('torchvision://resnet50.default') From da9616c89ce1d9b34779c0aedfcf207295e0a5d0 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 18 Apr 2022 22:50:05 +0800 Subject: [PATCH 16/16] support get model by lower weights --- mmcv/runner/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index acf2c522d3..835ee725a0 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -434,7 +434,7 @@ def load_from_torchvision(filename, map_location=None): # Support getting model urls in the same way as torchvision # `ResNet50_Weights.IMAGENET1K_V1` will be mapped to # resnet50.imagenet1k_v1. - model_name = model_name.replace('_Weights', '').lower() + model_name = model_name.lower().replace('_weights', '') return load_from_http(model_urls[model_name], map_location=map_location)