Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Make build_xxx_layer allow accepting a class type #2782

Merged
merged 10 commits into from
May 11, 2023
4 changes: 3 additions & 1 deletion mmcv/cnn/bricks/conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, Optional

from mmengine.registry import MODELS
Expand Down Expand Up @@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy()

layer_type = cfg_.pop('type')

if inspect.isclass(layer_type):
return layer_type(*args, **kwargs, **cfg_) # type: ignore
# Switch registry to the target scope. If `conv_layer` cannot be found
# in the registry, fallback to search `conv_layer` in the
# mmengine.MODELS.
Expand Down
21 changes: 12 additions & 9 deletions mmcv/cnn/bricks/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,25 @@ def build_norm_layer(cfg: Dict,

layer_type = cfg_.pop('type')

# Switch registry to the target scope. If `norm_layer` cannot be found
# in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under scope '
f'name {registry.scope}')
if inspect.isclass(layer_type):
norm_layer = layer_type
else:
# Switch registry to the target scope. If `norm_layer` cannot be found
# in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under '
f'scope name {registry.scope}')
abbr = infer_abbr(norm_layer)

assert isinstance(postfix, (int, str))
name = abbr + str(postfix)

requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
if norm_layer is not nn.GroupNorm:
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1)
Expand Down
4 changes: 3 additions & 1 deletion mmcv/cnn/bricks/padding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict

import torch.nn as nn
Expand Down Expand Up @@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:

cfg_ = cfg.copy()
padding_type = cfg_.pop('type')

if inspect.isclass(padding_type):
return padding_type(*args, **kwargs, **cfg_)
# Switch registry to the target scope. If `padding_layer` cannot be found
# in the registry, fallback to search `padding_layer` in the
# mmengine.MODELS.
Expand Down
21 changes: 12 additions & 9 deletions mmcv/cnn/bricks/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,18 @@ def build_plugin_layer(cfg: Dict,
cfg_ = cfg.copy()

layer_type = cfg_.pop('type')

# Switch registry to the target scope. If `plugin_layer` cannot be found
# in the registry, fallback to search `plugin_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
plugin_layer = registry.get(layer_type)
if plugin_layer is None:
raise KeyError(f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
if inspect.isclass(layer_type):
plugin_layer = layer_type
else:
# Switch registry to the target scope. If `plugin_layer` cannot be
# found in the registry, fallback to search `plugin_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
plugin_layer = registry.get(layer_type)
if plugin_layer is None:
raise KeyError(
f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
abbr = infer_abbr(plugin_layer)

assert isinstance(postfix, (int, str))
Expand Down
18 changes: 11 additions & 7 deletions mmcv/cnn/bricks/upsample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict

import torch
Expand Down Expand Up @@ -76,15 +77,18 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:

layer_type = cfg_.pop('type')

if inspect.isclass(layer_type):
upsample = layer_type
# Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
else:
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_)
return layer
5 changes: 3 additions & 2 deletions mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor,
max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]

nms_type = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type)
nms_op = nms_cfg_.pop('type', 'nms')
if isinstance(nms_op, str):
nms_op = eval(nms_op)

split_thr = nms_cfg_.pop('split_thr', 10000)
# Won't split to multiple nms nodes when exporting to onnx
Expand Down
Loading