-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Add build_func for UPSAMPLE_LAYERS #1272
Changes from all commits
e006629
8bc3d6f
3a2f218
2211ec2
8326488
4f4e31a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from ..utils import xavier_init | ||
from .registry import UPSAMPLE_LAYERS | ||
from .registry import UPSAMPLE_LAYERS, build_upsample_layer_from_cfg | ||
|
||
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample) | ||
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample) | ||
|
@@ -47,6 +49,8 @@ def forward(self, x): | |
return x | ||
|
||
|
||
# Avoid BC-breaking of importing build_upsample_layer. | ||
# It's backwards compatible with the old `build_upsample_layer`. | ||
zhouzaida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def build_upsample_layer(cfg, *args, **kwargs): | ||
"""Build upsample layer. | ||
|
||
|
@@ -65,20 +69,11 @@ def build_upsample_layer(cfg, *args, **kwargs): | |
Returns: | ||
nn.Module: Created upsample layer. | ||
""" | ||
if not isinstance(cfg, dict): | ||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | ||
if 'type' not in cfg: | ||
raise KeyError( | ||
f'the cfg dict must contain the key "type", but got {cfg}') | ||
cfg_ = cfg.copy() | ||
|
||
layer_type = cfg_.pop('type') | ||
if layer_type not in UPSAMPLE_LAYERS: | ||
raise KeyError(f'Unrecognized upsample type {layer_type}') | ||
else: | ||
upsample = UPSAMPLE_LAYERS.get(layer_type) | ||
|
||
if upsample is nn.Upsample: | ||
cfg_['mode'] = layer_type | ||
layer = upsample(*args, **kwargs, **cfg_) | ||
return layer | ||
|
||
warnings.warn( | ||
ImportWarning( | ||
'``build_upsample_layer(cfg, *args, **kwargs)`` will be ' | ||
'deprecated. Please use ' | ||
'``UPSAMPLE_LAYERS.build(cfg, *args, **kwargs)`` instead.')) | ||
|
||
return build_upsample_layer_from_cfg(cfg, UPSAMPLE_LAYERS, *args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use UPSAMPLE_LAYERS.buil() here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will lead to recursive import. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then please leave a comment above this line of code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstring to indicates what is the difference between build_upsample_layer_from_cfg and build_from_cfg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the interfaces of all UPSAMPLE_LAYERs modules are different.
I have added usage and deprecate information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wrap
build_upsample_layer
as a build_func (build_upsample_layer_from_cfg
) andbuild_upsample_layer
can be deprecated in the future.