1
1
import pytest
2
2
import torch
3
3
from mmcv .cnn import ConvModule
4
- from torch import nn
5
4
6
5
from mmseg .models .backbones .unet import (BasicConvBlock , DeconvModule ,
7
6
InterpConv , UNet , UpConvBlock )
7
+ from mmseg .ops import Upsample
8
8
from .utils import check_norm_state
9
9
10
10
@@ -145,7 +145,7 @@ def test_interp_conv():
145
145
block = InterpConv (64 , 32 , conv_first = False )
146
146
x = torch .randn (1 , 64 , 128 , 128 )
147
147
x_out = block (x )
148
- assert isinstance (block .interp_upsample [0 ], nn . Upsample )
148
+ assert isinstance (block .interp_upsample [0 ], Upsample )
149
149
assert isinstance (block .interp_upsample [1 ], ConvModule )
150
150
assert x_out .shape == torch .Size ([1 , 32 , 256 , 256 ])
151
151
@@ -154,7 +154,7 @@ def test_interp_conv():
154
154
x = torch .randn (1 , 64 , 128 , 128 )
155
155
x_out = block (x )
156
156
assert isinstance (block .interp_upsample [0 ], ConvModule )
157
- assert isinstance (block .interp_upsample [1 ], nn . Upsample )
157
+ assert isinstance (block .interp_upsample [1 ], Upsample )
158
158
assert x_out .shape == torch .Size ([1 , 32 , 256 , 256 ])
159
159
160
160
# test InterpConv with bilinear upsample for upsample 2X.
@@ -166,7 +166,7 @@ def test_interp_conv():
166
166
scale_factor = 2 , mode = 'bilinear' , align_corners = False ))
167
167
x = torch .randn (1 , 64 , 128 , 128 )
168
168
x_out = block (x )
169
- assert isinstance (block .interp_upsample [0 ], nn . Upsample )
169
+ assert isinstance (block .interp_upsample [0 ], Upsample )
170
170
assert isinstance (block .interp_upsample [1 ], ConvModule )
171
171
assert x_out .shape == torch .Size ([1 , 32 , 256 , 256 ])
172
172
assert block .interp_upsample [0 ].mode == 'bilinear'
@@ -179,7 +179,7 @@ def test_interp_conv():
179
179
upsample_cfg = dict (scale_factor = 2 , mode = 'nearest' ))
180
180
x = torch .randn (1 , 64 , 128 , 128 )
181
181
x_out = block (x )
182
- assert isinstance (block .interp_upsample [0 ], nn . Upsample )
182
+ assert isinstance (block .interp_upsample [0 ], Upsample )
183
183
assert isinstance (block .interp_upsample [1 ], ConvModule )
184
184
assert x_out .shape == torch .Size ([1 , 32 , 256 , 256 ])
185
185
assert block .interp_upsample [0 ].mode == 'nearest'
0 commit comments