-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
up_conv_block.py
102 lines (93 loc) · 3.92 KB
/
up_conv_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_upsample_layer
class UpConvBlock(nn.Module):
"""Upsample convolution block in decoder for UNet.
This upsample convolution block consists of one upsample module
followed by one convolution block. The upsample module expands the
high-level low-resolution feature map and the convolution block fuses
the upsampled high-level low-resolution feature map and the low-level
high-resolution feature map from encoder.
Args:
conv_block (nn.Sequential): Sequential of convolutional layers.
in_channels (int): Number of input channels of the high-level
skip_channels (int): Number of input channels of the low-level
high-resolution feature map from encoder.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers in the conv_block.
Default: 2.
stride (int): Stride of convolutional layer in conv_block. Default: 1.
dilation (int): Dilation rate of convolutional layer in conv_block.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv'). If the size of
high-level feature map is the same as that of skip feature map
(low-level feature map from encoder), it does not need upsample the
high-level feature map and the upsample_cfg is None.
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
conv_block,
in_channels,
skip_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
dcn=None,
plugins=None):
super(UpConvBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.conv_block = conv_block(
in_channels=2 * skip_channels,
out_channels=out_channels,
num_convs=num_convs,
stride=stride,
dilation=dilation,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None)
if upsample_cfg is not None:
self.upsample = build_upsample_layer(
cfg=upsample_cfg,
in_channels=in_channels,
out_channels=skip_channels,
with_cp=with_cp,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
self.upsample = ConvModule(
in_channels,
skip_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, skip, x):
"""Forward function."""
x = self.upsample(x)
out = torch.cat([skip, x], dim=1)
out = self.conv_block(out)
return out