diff --git a/mmcv/cnn/bricks/conv.py b/mmcv/cnn/bricks/conv.py index ace744e039..a00b0a52ce 100644 --- a/mmcv/cnn/bricks/conv.py +++ b/mmcv/cnn/bricks/conv.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import inspect from typing import Dict, Optional from mmengine.registry import MODELS @@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module: cfg_ = cfg.copy() layer_type = cfg_.pop('type') - + if inspect.isclass(layer_type): + return layer_type(*args, **kwargs, **cfg_) # type: ignore # Switch registry to the target scope. If `conv_layer` cannot be found # in the registry, fallback to search `conv_layer` in the # mmengine.MODELS. diff --git a/mmcv/cnn/bricks/norm.py b/mmcv/cnn/bricks/norm.py index 2fff684af0..5aabab21a0 100644 --- a/mmcv/cnn/bricks/norm.py +++ b/mmcv/cnn/bricks/norm.py @@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict, layer_type = cfg_.pop('type') - # Switch registry to the target scope. If `norm_layer` cannot be found - # in the registry, fallback to search `norm_layer` in the - # mmengine.MODELS. - with MODELS.switch_scope_and_registry(None) as registry: - norm_layer = registry.get(layer_type) - if norm_layer is None: - raise KeyError(f'Cannot find {norm_layer} in registry under scope ' - f'name {registry.scope}') + if inspect.isclass(layer_type): + norm_layer = layer_type + else: + # Switch registry to the target scope. If `norm_layer` cannot be found + # in the registry, fallback to search `norm_layer` in the + # mmengine.MODELS. + with MODELS.switch_scope_and_registry(None) as registry: + norm_layer = registry.get(layer_type) + if norm_layer is None: + raise KeyError(f'Cannot find {norm_layer} in registry under ' + f'scope name {registry.scope}') abbr = infer_abbr(norm_layer) assert isinstance(postfix, (int, str)) @@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict, requires_grad = cfg_.pop('requires_grad', True) cfg_.setdefault('eps', 1e-5) - if layer_type != 'GN': + if norm_layer is not nn.GroupNorm: layer = norm_layer(num_features, **cfg_) if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): layer._specify_ddp_gpu_num(1) diff --git a/mmcv/cnn/bricks/padding.py b/mmcv/cnn/bricks/padding.py index 4135a190d6..3b29996b94 100644 --- a/mmcv/cnn/bricks/padding.py +++ b/mmcv/cnn/bricks/padding.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import inspect from typing import Dict import torch.nn as nn @@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: cfg_ = cfg.copy() padding_type = cfg_.pop('type') - + if inspect.isclass(padding_type): + return padding_type(*args, **kwargs, **cfg_) # Switch registry to the target scope. If `padding_layer` cannot be found # in the registry, fallback to search `padding_layer` in the # mmengine.MODELS. diff --git a/mmcv/cnn/bricks/plugin.py b/mmcv/cnn/bricks/plugin.py index 83ba3737ab..3195ed13cf 100644 --- a/mmcv/cnn/bricks/plugin.py +++ b/mmcv/cnn/bricks/plugin.py @@ -79,15 +79,18 @@ def build_plugin_layer(cfg: Dict, cfg_ = cfg.copy() layer_type = cfg_.pop('type') - - # Switch registry to the target scope. If `plugin_layer` cannot be found - # in the registry, fallback to search `plugin_layer` in the - # mmengine.MODELS. - with MODELS.switch_scope_and_registry(None) as registry: - plugin_layer = registry.get(layer_type) - if plugin_layer is None: - raise KeyError(f'Cannot find {plugin_layer} in registry under scope ' - f'name {registry.scope}') + if inspect.isclass(layer_type): + plugin_layer = layer_type + else: + # Switch registry to the target scope. If `plugin_layer` cannot be + # found in the registry, fallback to search `plugin_layer` in the + # mmengine.MODELS. + with MODELS.switch_scope_and_registry(None) as registry: + plugin_layer = registry.get(layer_type) + if plugin_layer is None: + raise KeyError( + f'Cannot find {plugin_layer} in registry under scope ' + f'name {registry.scope}') abbr = infer_abbr(plugin_layer) assert isinstance(postfix, (int, str)) diff --git a/mmcv/cnn/bricks/upsample.py b/mmcv/cnn/bricks/upsample.py index d91689a1c8..78fb5bf371 100644 --- a/mmcv/cnn/bricks/upsample.py +++ b/mmcv/cnn/bricks/upsample.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import inspect from typing import Dict import torch @@ -76,15 +77,18 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: layer_type = cfg_.pop('type') + if inspect.isclass(layer_type): + upsample = layer_type # Switch registry to the target scope. If `upsample` cannot be found # in the registry, fallback to search `upsample` in the # mmengine.MODELS. - with MODELS.switch_scope_and_registry(None) as registry: - upsample = registry.get(layer_type) - if upsample is None: - raise KeyError(f'Cannot find {upsample} in registry under scope ' - f'name {registry.scope}') - if upsample is nn.Upsample: - cfg_['mode'] = layer_type + else: + with MODELS.switch_scope_and_registry(None) as registry: + upsample = registry.get(layer_type) + if upsample is None: + raise KeyError(f'Cannot find {upsample} in registry under scope ' + f'name {registry.scope}') + if upsample is nn.Upsample: + cfg_['mode'] = layer_type layer = upsample(*args, **kwargs, **cfg_) return layer diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index feab4f3cad..fb08ba07c6 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor, max_coordinate + torch.tensor(1).to(boxes)) boxes_for_nms = boxes + offsets[:, None] - nms_type = nms_cfg_.pop('type', 'nms') - nms_op = eval(nms_type) + nms_op = nms_cfg_.pop('type', 'nms') + if isinstance(nms_op, str): + nms_op = eval(nms_op) split_thr = nms_cfg_.pop('split_thr', 10000) # Won't split to multiple nms nodes when exporting to onnx diff --git a/tests/test_cnn/test_build_layers.py b/tests/test_cnn/test_build_layers.py index c8903ac40d..eefdd640ca 100644 --- a/tests/test_cnn/test_build_layers.py +++ b/tests/test_cnn/test_build_layers.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import inspect from importlib import import_module import numpy as np @@ -7,10 +8,14 @@ import torch.nn as nn from mmengine.registry import MODELS from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm +from torch.nn import ReflectionPad2d, Upsample -from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer, +from mmcv.cnn.bricks import (ContextBlock, ConvModule, ConvTranspose2d, + GeneralizedAttention, NonLocal2d, + build_activation_layer, build_conv_layer, build_norm_layer, build_padding_layer, build_plugin_layer, build_upsample_layer, is_norm) +from mmcv.cnn.bricks.activation import Clamp from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr from mmcv.cnn.bricks.upsample import PixelShufflePack @@ -65,18 +70,19 @@ def test_build_conv_layer(): kwargs.pop('groups') for type_name, module in MODELS.module_dict.items(): - cfg = dict(type=type_name) - # SparseInverseConv2d and SparseInverseConv3d do not have the argument - # 'dilation' - if type_name == 'SparseInverseConv2d' or type_name == \ - 'SparseInverseConv3d': - kwargs.pop('dilation') - if 'conv' in type_name.lower(): - layer = build_conv_layer(cfg, **kwargs) - assert isinstance(layer, module) - assert layer.in_channels == kwargs['in_channels'] - assert layer.out_channels == kwargs['out_channels'] - kwargs['dilation'] = 2 # recover the key + for type_name_ in (type_name, module): + cfg = dict(type=type_name_) + # SparseInverseConv2d and SparseInverseConv3d do not have the + # argument 'dilation' + if type_name == 'SparseInverseConv2d' or type_name == \ + 'SparseInverseConv3d': + kwargs.pop('dilation') + if 'conv' in type_name.lower(): + layer = build_conv_layer(cfg, **kwargs) + assert isinstance(layer, module) + assert layer.in_channels == kwargs['in_channels'] + assert layer.out_channels == kwargs['out_channels'] + kwargs['dilation'] = 2 # recover the key def test_infer_norm_abbr(): @@ -162,17 +168,18 @@ def test_build_norm_layer(): if type_name == 'MMSyncBN': # skip MMSyncBN continue for postfix in ['_test', 1]: - cfg = dict(type=type_name) - if type_name == 'GN': - cfg['num_groups'] = 3 - name, layer = build_norm_layer(cfg, 3, postfix=postfix) - assert name == abbr_mapping[type_name] + str(postfix) - assert isinstance(layer, module) - if type_name == 'GN': - assert layer.num_channels == 3 - assert layer.num_groups == cfg['num_groups'] - elif type_name != 'LN': - assert layer.num_features == 3 + for type_name_ in (type_name, module): + cfg = dict(type=type_name_) + if type_name == 'GN': + cfg['num_groups'] = 3 + name, layer = build_norm_layer(cfg, 3, postfix=postfix) + assert name == abbr_mapping[type_name] + str(postfix) + assert isinstance(layer, module) + if type_name == 'GN': + assert layer.num_channels == 3 + assert layer.num_groups == cfg['num_groups'] + elif type_name != 'LN': + assert layer.num_features == 3 def test_build_activation_layer(): @@ -184,7 +191,7 @@ def test_build_activation_layer(): for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']: act_module = import_module(f'mmcv.cnn.bricks.{module_name}') for key, value in act_module.__dict__.items(): - if isinstance(value, type) and issubclass(value, nn.Module): + if inspect.isclass(value) and issubclass(value, nn.Module): act_names.append(key) with pytest.raises(TypeError): @@ -210,10 +217,12 @@ def test_build_activation_layer(): assert isinstance(layer, module) # sanity check for Clamp - act = build_activation_layer(dict(type='Clamp')) - x = torch.randn(10) * 1000 - y = act(x) - assert np.logical_and((y >= -1).numpy(), (y <= 1).numpy()).all() + for type_name in ('Clamp', Clamp): + act = build_activation_layer(dict(type='Clamp')) + x = torch.randn(10) * 1000 + y = act(x) + assert np.logical_and((y >= -1).numpy(), (y <= 1).numpy()).all() + act = build_activation_layer(dict(type='Clip', min=0)) y = act(x) assert np.logical_and((y >= 0).numpy(), (y <= 1).numpy()).all() @@ -227,7 +236,7 @@ def test_build_padding_layer(): for module_name in ['padding']: pad_module = import_module(f'mmcv.cnn.bricks.{module_name}') for key, value in pad_module.__dict__.items(): - if isinstance(value, type) and issubclass(value, nn.Module): + if inspect.isclass(value) and issubclass(value, nn.Module): pad_names.append(key) with pytest.raises(TypeError): @@ -250,12 +259,12 @@ def test_build_padding_layer(): cfg['type'] = type_name layer = build_padding_layer(cfg, 2) assert isinstance(layer, module) - - input_x = torch.randn(1, 2, 5, 5) - cfg = dict(type='reflect') - padding_layer = build_padding_layer(cfg, 2) - res = padding_layer(input_x) - assert res.shape == (1, 2, 9, 9) + for type_name in (ReflectionPad2d, 'reflect'): + input_x = torch.randn(1, 2, 5, 5) + cfg = dict(type=type_name) + padding_layer = build_padding_layer(cfg, 2) + res = padding_layer(input_x) + assert res.shape == (1, 2, 9, 9) def test_upsample_layer(): @@ -280,38 +289,48 @@ def test_upsample_layer(): assert isinstance(layer, nn.Upsample) assert layer.mode == type_name + cfg = dict() + cfg['type'] = Upsample + layer_from_cls = build_upsample_layer(cfg) + assert isinstance(layer_from_cls, nn.Upsample) + assert layer_from_cls.mode == 'nearest' + cfg = dict( type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2) layer = build_upsample_layer(cfg) assert isinstance(layer, nn.ConvTranspose2d) - cfg = dict(type='deconv') - kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2) - layer = build_upsample_layer(cfg, **kwargs) - assert isinstance(layer, nn.ConvTranspose2d) - assert layer.in_channels == kwargs['in_channels'] - assert layer.out_channels == kwargs['out_channels'] - assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size']) - assert layer.stride == (kwargs['stride'], kwargs['stride']) - - layer = build_upsample_layer(cfg, 3, 3, 3, 2) - assert isinstance(layer, nn.ConvTranspose2d) - assert layer.in_channels == kwargs['in_channels'] - assert layer.out_channels == kwargs['out_channels'] - assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size']) - assert layer.stride == (kwargs['stride'], kwargs['stride']) - - cfg = dict( - type='pixel_shuffle', - in_channels=3, - out_channels=3, - scale_factor=2, - upsample_kernel=3) - layer = build_upsample_layer(cfg) + for type_name in ('deconv', ConvTranspose2d): + cfg = dict(type=ConvTranspose2d) + kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2) + layer = build_upsample_layer(cfg, **kwargs) + assert isinstance(layer, nn.ConvTranspose2d) + assert layer.in_channels == kwargs['in_channels'] + assert layer.out_channels == kwargs['out_channels'] + assert layer.kernel_size == (kwargs['kernel_size'], + kwargs['kernel_size']) + assert layer.stride == (kwargs['stride'], kwargs['stride']) + + layer = build_upsample_layer(cfg, 3, 3, 3, 2) + assert isinstance(layer, nn.ConvTranspose2d) + assert layer.in_channels == kwargs['in_channels'] + assert layer.out_channels == kwargs['out_channels'] + assert layer.kernel_size == (kwargs['kernel_size'], + kwargs['kernel_size']) + assert layer.stride == (kwargs['stride'], kwargs['stride']) + + for type_name in ('pixel_shuffle', PixelShufflePack): + cfg = dict( + type=type_name, + in_channels=3, + out_channels=3, + scale_factor=2, + upsample_kernel=3) + layer = build_upsample_layer(cfg) - assert isinstance(layer, PixelShufflePack) - assert layer.scale_factor == 2 - assert layer.upsample_kernel == 3 + assert isinstance(layer, PixelShufflePack) + assert layer.scale_factor == 2 + assert layer.upsample_kernel == 3 def test_pixel_shuffle_pack(): @@ -396,35 +415,42 @@ def test_build_plugin_layer(): build_plugin_layer(cfg, postfix=[1, 2]) # test ContextBlock - for postfix in ['', '_test', 1]: - cfg = dict(type='ContextBlock') - name, layer = build_plugin_layer( - cfg, postfix=postfix, in_channels=16, ratio=1. / 4) - assert name == 'context_block' + str(postfix) - assert isinstance(layer, MODELS.module_dict['ContextBlock']) + for type_name in ('ContextBlock', ContextBlock): + for postfix in ['', '_test', 1]: + cfg = dict(type=type_name) + name, layer = build_plugin_layer( + cfg, postfix=postfix, in_channels=16, ratio=1. / 4) + assert name == 'context_block' + str(postfix) + assert isinstance(layer, MODELS.module_dict['ContextBlock']) # test GeneralizedAttention - for postfix in ['', '_test', 1]: - cfg = dict(type='GeneralizedAttention') - name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) - assert name == 'gen_attention_block' + str(postfix) - assert isinstance(layer, MODELS.module_dict['GeneralizedAttention']) + for type_name in ('GeneralizedAttention', GeneralizedAttention): + for postfix in ['', '_test', 1]: + cfg = dict(type=type_name) + name, layer = build_plugin_layer( + cfg, postfix=postfix, in_channels=16) + assert name == 'gen_attention_block' + str(postfix) + assert isinstance(layer, + MODELS.module_dict['GeneralizedAttention']) # test NonLocal2d - for postfix in ['', '_test', 1]: - cfg = dict(type='NonLocal2d') - name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) - assert name == 'nonlocal_block' + str(postfix) - assert isinstance(layer, MODELS.module_dict['NonLocal2d']) + for type_name in ('NonLocal2d', NonLocal2d): + for postfix in ['', '_test', 1]: + cfg = dict(type='NonLocal2d') + name, layer = build_plugin_layer( + cfg, postfix=postfix, in_channels=16) + assert name == 'nonlocal_block' + str(postfix) + assert isinstance(layer, MODELS.module_dict['NonLocal2d']) # test ConvModule for postfix in ['', '_test', 1]: - cfg = dict(type='ConvModule') - name, layer = build_plugin_layer( - cfg, - postfix=postfix, - in_channels=16, - out_channels=4, - kernel_size=3) - assert name == 'conv_block' + str(postfix) - assert isinstance(layer, MODELS.module_dict['ConvModule']) + for type_name in ('ConvModule', ConvModule): + cfg = dict(type=type_name) + name, layer = build_plugin_layer( + cfg, + postfix=postfix, + in_channels=16, + out_channels=4, + kernel_size=3) + assert name == 'conv_block' + str(postfix) + assert isinstance(layer, MODELS.module_dict['ConvModule'])