Skip to content

Commit 50461ef

Browse files
authoredJul 28, 2021
[Fix] Replace interpolate with resize (open-mmlab#731)
* Replace interpolate with resize * Replace nn.Upsample with ops.Upsample * Fix test
1 parent b5ae7a7 commit 50461ef

File tree

11 files changed

+27
-24
lines changed

11 files changed

+27
-24
lines changed
 

‎mmseg/models/backbones/swin.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch.nn.modules.normalization import LayerNorm
1414
from torch.nn.modules.utils import _pair as to_2tuple
1515

16+
from mmseg.ops import resize
1617
from ...utils import get_root_logger
1718
from ..builder import ATTENTION, BACKBONES
1819
from ..utils import PatchEmbed, swin_convert
@@ -745,7 +746,7 @@ def init_weights(self):
745746
if L1 != L2:
746747
S1 = int(L1**0.5)
747748
S2 = int(L2**0.5)
748-
table_pretrained_resized = F.interpolate(
749+
table_pretrained_resized = resize(
749750
table_pretrained.permute(1, 0).reshape(
750751
1, nH1, S1, S1),
751752
size=(S2, S2),

‎mmseg/models/backbones/unet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mmcv.runner import BaseModule
88
from mmcv.utils.parrots_wrapper import _BatchNorm
99

10+
from mmseg.ops import Upsample
1011
from ..builder import BACKBONES
1112
from ..utils import UpConvBlock
1213

@@ -203,7 +204,7 @@ def __init__(self,
203204
conv_cfg=conv_cfg,
204205
norm_cfg=norm_cfg,
205206
act_cfg=act_cfg)
206-
upsample = nn.Upsample(**upsample_cfg)
207+
upsample = Upsample(**upsample_cfg)
207208
if conv_first:
208209
self.interp_upsample = nn.Sequential(conv, upsample)
209210
else:

‎mmseg/models/backbones/vit.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
import torch
55
import torch.nn as nn
6-
import torch.nn.functional as F
76
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
87
normal_init, trunc_normal_init)
98
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
109
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
1110
from torch.nn.modules.batchnorm import _BatchNorm
1211
from torch.nn.modules.utils import _pair as to_2tuple
1312

13+
from mmseg.ops import resize
1414
from mmseg.utils import get_root_logger
1515
from ..builder import BACKBONES
1616
from ..utils import PatchEmbed, vit_convert
@@ -373,7 +373,7 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
373373
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
374374
pos_embed_weight = pos_embed_weight.reshape(
375375
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
376-
pos_embed_weight = F.interpolate(
376+
pos_embed_weight = resize(
377377
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
378378
cls_token_weight = cls_token_weight.unsqueeze(1)
379379
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)

‎mmseg/models/decode_heads/fpn_head.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
from mmcv.cnn import ConvModule
44

5-
from mmseg.ops import resize
5+
from mmseg.ops import Upsample, resize
66
from ..builder import HEADS
77
from .decode_head import BaseDecodeHead
88

@@ -45,7 +45,7 @@ def __init__(self, feature_strides, **kwargs):
4545
act_cfg=self.act_cfg))
4646
if feature_strides[i] != feature_strides[0]:
4747
scale_head.append(
48-
nn.Upsample(
48+
Upsample(
4949
scale_factor=2,
5050
mode='bilinear',
5151
align_corners=self.align_corners))

‎mmseg/models/decode_heads/setr_mla_head.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from mmcv.cnn import ConvModule
44

5+
from mmseg.ops import Upsample
56
from ..builder import HEADS
67
from .decode_head import BaseDecodeHead
78

@@ -46,7 +47,7 @@ def __init__(self, mla_channels=128, up_scale=4, **kwargs):
4647
padding=1,
4748
norm_cfg=self.norm_cfg,
4849
act_cfg=self.act_cfg),
49-
nn.Upsample(
50+
Upsample(
5051
scale_factor=up_scale,
5152
mode='bilinear',
5253
align_corners=self.align_corners)))

‎mmseg/models/decode_heads/setr_up_head.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch.nn as nn
22
from mmcv.cnn import ConvModule, build_norm_layer
33

4+
from mmseg.ops import Upsample
45
from ..builder import HEADS
56
from .decode_head import BaseDecodeHead
67

@@ -59,7 +60,7 @@ def __init__(self,
5960
padding=int(kernel_size - 1) // 2,
6061
norm_cfg=self.norm_cfg,
6162
act_cfg=self.act_cfg),
62-
nn.Upsample(
63+
Upsample(
6364
scale_factor=up_scale,
6465
mode='bilinear',
6566
align_corners=self.align_corners)))

‎mmseg/models/necks/fpn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from mmcv.cnn import ConvModule
44
from mmcv.runner import BaseModule, auto_fp16
55

6+
from mmseg.ops import resize
67
from ..builder import NECKS
78

89

@@ -173,11 +174,10 @@ def forward(self, inputs):
173174
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
174175
# it cannot co-exist with `size` in `F.interpolate`.
175176
if 'scale_factor' in self.upsample_cfg:
176-
laterals[i - 1] += F.interpolate(laterals[i],
177-
**self.upsample_cfg)
177+
laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
178178
else:
179179
prev_shape = laterals[i - 1].shape[2:]
180-
laterals[i - 1] += F.interpolate(
180+
laterals[i - 1] += resize(
181181
laterals[i], size=prev_shape, **self.upsample_cfg)
182182

183183
# build outputs

‎mmseg/models/necks/multilevel_neck.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch.nn as nn
2-
import torch.nn.functional as F
32
from mmcv.cnn import ConvModule, xavier_init
43

4+
from mmseg.ops import resize
55
from ..builder import NECKS
66

77

@@ -70,7 +70,7 @@ def forward(self, inputs):
7070
inputs = [inputs[0] for _ in range(self.num_outs)]
7171
outs = []
7272
for i in range(self.num_outs):
73-
x_resize = F.interpolate(
73+
x_resize = resize(
7474
inputs[i], scale_factor=self.scales[i], mode='bilinear')
7575
outs.append(self.convs[i](x_resize))
7676
return tuple(outs)

‎tests/test_models/test_backbones/test_unet.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import pytest
22
import torch
33
from mmcv.cnn import ConvModule
4-
from torch import nn
54

65
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
76
InterpConv, UNet, UpConvBlock)
7+
from mmseg.ops import Upsample
88
from .utils import check_norm_state
99

1010

@@ -145,7 +145,7 @@ def test_interp_conv():
145145
block = InterpConv(64, 32, conv_first=False)
146146
x = torch.randn(1, 64, 128, 128)
147147
x_out = block(x)
148-
assert isinstance(block.interp_upsample[0], nn.Upsample)
148+
assert isinstance(block.interp_upsample[0], Upsample)
149149
assert isinstance(block.interp_upsample[1], ConvModule)
150150
assert x_out.shape == torch.Size([1, 32, 256, 256])
151151

@@ -154,7 +154,7 @@ def test_interp_conv():
154154
x = torch.randn(1, 64, 128, 128)
155155
x_out = block(x)
156156
assert isinstance(block.interp_upsample[0], ConvModule)
157-
assert isinstance(block.interp_upsample[1], nn.Upsample)
157+
assert isinstance(block.interp_upsample[1], Upsample)
158158
assert x_out.shape == torch.Size([1, 32, 256, 256])
159159

160160
# test InterpConv with bilinear upsample for upsample 2X.
@@ -166,7 +166,7 @@ def test_interp_conv():
166166
scale_factor=2, mode='bilinear', align_corners=False))
167167
x = torch.randn(1, 64, 128, 128)
168168
x_out = block(x)
169-
assert isinstance(block.interp_upsample[0], nn.Upsample)
169+
assert isinstance(block.interp_upsample[0], Upsample)
170170
assert isinstance(block.interp_upsample[1], ConvModule)
171171
assert x_out.shape == torch.Size([1, 32, 256, 256])
172172
assert block.interp_upsample[0].mode == 'bilinear'
@@ -179,7 +179,7 @@ def test_interp_conv():
179179
upsample_cfg=dict(scale_factor=2, mode='nearest'))
180180
x = torch.randn(1, 64, 128, 128)
181181
x_out = block(x)
182-
assert isinstance(block.interp_upsample[0], nn.Upsample)
182+
assert isinstance(block.interp_upsample[0], Upsample)
183183
assert isinstance(block.interp_upsample[1], ConvModule)
184184
assert x_out.shape == torch.Size([1, 32, 256, 256])
185185
assert block.interp_upsample[0].mode == 'nearest'

‎tools/deploy_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from mmseg.apis import single_gpu_test
1515
from mmseg.datasets import build_dataloader, build_dataset
1616
from mmseg.models.segmentors.base import BaseSegmentor
17+
from mmseg.ops import resize
1718

1819

1920
class ONNXRuntimeSegmentor(BaseSegmentor):
@@ -79,7 +80,7 @@ def simple_test(self, img: torch.Tensor, img_meta: Iterable,
7980
if not (ori_shape[0] == seg_pred.shape[-2]
8081
and ori_shape[1] == seg_pred.shape[-1]):
8182
seg_pred = torch.from_numpy(seg_pred).float()
82-
seg_pred = torch.nn.functional.interpolate(
83+
seg_pred = resize(
8384
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
8485
seg_pred = seg_pred.long().detach().cpu().numpy()
8586
seg_pred = seg_pred[0]
@@ -127,7 +128,7 @@ def simple_test(self, img: torch.Tensor, img_meta: Iterable,
127128
if not (ori_shape[0] == seg_pred.shape[-2]
128129
and ori_shape[1] == seg_pred.shape[-1]):
129130
seg_pred = torch.from_numpy(seg_pred).float()
130-
seg_pred = torch.nn.functional.interpolate(
131+
seg_pred = resize(
131132
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
132133
seg_pred = seg_pred.long().detach().cpu().numpy()
133134
seg_pred = seg_pred[0]

‎tools/pytorch2onnx.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mmseg.apis.inference import LoadImage
1717
from mmseg.datasets.pipelines import Compose
1818
from mmseg.models import build_segmentor
19+
from mmseg.ops import resize
1920

2021
torch.manual_seed(3)
2122

@@ -210,10 +211,7 @@ def pytorch2onnx(model,
210211

211212
if dynamic_export and test_mode == 'whole':
212213
# scale image for dynamic shape test
213-
img_list = [
214-
nn.functional.interpolate(_, scale_factor=1.5)
215-
for _ in img_list
216-
]
214+
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
217215
# concate flip image for batch test
218216
flip_img_list = [_.flip(-1) for _ in img_list]
219217
img_list = [

0 commit comments

Comments
 (0)
Please sign in to comment.