@@ -25,20 +25,20 @@ def test_bisenetv1_backbone():
25
25
model .init_weights ()
26
26
model .train ()
27
27
batch_size = 2
28
- imgs = torch .randn (batch_size , 3 , 256 , 512 )
28
+ imgs = torch .randn (batch_size , 3 , 64 , 128 )
29
29
feat = model (imgs )
30
30
31
31
assert len (feat ) == 3
32
32
# output for segment Head
33
- assert feat [0 ].shape == torch .Size ([batch_size , 256 , 32 , 64 ])
33
+ assert feat [0 ].shape == torch .Size ([batch_size , 256 , 8 , 16 ])
34
34
# for auxiliary head 1
35
- assert feat [1 ].shape == torch .Size ([batch_size , 128 , 32 , 64 ])
35
+ assert feat [1 ].shape == torch .Size ([batch_size , 128 , 8 , 16 ])
36
36
# for auxiliary head 2
37
- assert feat [2 ].shape == torch .Size ([batch_size , 128 , 16 , 32 ])
37
+ assert feat [2 ].shape == torch .Size ([batch_size , 128 , 4 , 8 ])
38
38
39
39
# Test input with rare shape
40
40
batch_size = 2
41
- imgs = torch .randn (batch_size , 3 , 527 , 279 )
41
+ imgs = torch .randn (batch_size , 3 , 95 , 27 )
42
42
feat = model (imgs )
43
43
assert len (feat ) == 3
44
44
@@ -47,20 +47,20 @@ def test_bisenetv1_backbone():
47
47
BiSeNetV1 (
48
48
backbone_cfg = backbone_cfg ,
49
49
in_channels = 3 ,
50
- spatial_channels = (64 , 64 , 64 ))
50
+ spatial_channels = (16 , 16 , 16 ))
51
51
52
52
with pytest .raises (AssertionError ):
53
53
# BiSeNetV1 context path constraints.
54
54
BiSeNetV1 (
55
55
backbone_cfg = backbone_cfg ,
56
56
in_channels = 3 ,
57
- context_channels = (128 , 256 , 512 , 1024 ))
57
+ context_channels = (16 , 32 , 64 , 128 ))
58
58
59
59
60
60
def test_bisenetv1_spatial_path ():
61
61
with pytest .raises (AssertionError ):
62
62
# BiSeNetV1 spatial path channel constraints.
63
- SpatialPath (num_channels = (64 , 64 , 64 ), in_channels = 3 )
63
+ SpatialPath (num_channels = (16 , 16 , 16 ), in_channels = 3 )
64
64
65
65
66
66
def test_bisenetv1_context_path ():
@@ -79,31 +79,31 @@ def test_bisenetv1_context_path():
79
79
with pytest .raises (AssertionError ):
80
80
# BiSeNetV1 context path constraints.
81
81
ContextPath (
82
- backbone_cfg = backbone_cfg , context_channels = (128 , 256 , 512 , 1024 ))
82
+ backbone_cfg = backbone_cfg , context_channels = (16 , 32 , 64 , 128 ))
83
83
84
84
85
85
def test_bisenetv1_attention_refinement_module ():
86
- x_arm = AttentionRefinementModule (256 , 64 )
87
- assert x_arm .conv_layer .in_channels == 256
88
- assert x_arm .conv_layer .out_channels == 64
86
+ x_arm = AttentionRefinementModule (32 , 8 )
87
+ assert x_arm .conv_layer .in_channels == 32
88
+ assert x_arm .conv_layer .out_channels == 8
89
89
assert x_arm .conv_layer .kernel_size == (3 , 3 )
90
- x = torch .randn (2 , 256 , 32 , 64 )
90
+ x = torch .randn (2 , 32 , 8 , 16 )
91
91
x_out = x_arm (x )
92
- assert x_out .shape == torch .Size ([2 , 64 , 32 , 64 ])
92
+ assert x_out .shape == torch .Size ([2 , 8 , 8 , 16 ])
93
93
94
94
95
95
def test_bisenetv1_feature_fusion_module ():
96
- ffm = FeatureFusionModule (128 , 256 )
97
- assert ffm .conv1 .in_channels == 128
98
- assert ffm .conv1 .out_channels == 256
96
+ ffm = FeatureFusionModule (16 , 32 )
97
+ assert ffm .conv1 .in_channels == 16
98
+ assert ffm .conv1 .out_channels == 32
99
99
assert ffm .conv1 .kernel_size == (1 , 1 )
100
100
assert ffm .gap .output_size == (1 , 1 )
101
- assert ffm .conv_atten [0 ].in_channels == 256
102
- assert ffm .conv_atten [0 ].out_channels == 256
101
+ assert ffm .conv_atten [0 ].in_channels == 32
102
+ assert ffm .conv_atten [0 ].out_channels == 32
103
103
assert ffm .conv_atten [0 ].kernel_size == (1 , 1 )
104
104
105
- ffm = FeatureFusionModule (128 , 128 )
106
- x1 = torch .randn (2 , 64 , 64 , 128 )
107
- x2 = torch .randn (2 , 64 , 64 , 128 )
105
+ ffm = FeatureFusionModule (16 , 16 )
106
+ x1 = torch .randn (2 , 8 , 8 , 16 )
107
+ x2 = torch .randn (2 , 8 , 8 , 16 )
108
108
x_out = ffm (x1 , x2 )
109
- assert x_out .shape == torch .Size ([2 , 128 , 64 , 128 ])
109
+ assert x_out .shape == torch .Size ([2 , 16 , 8 , 16 ])
0 commit comments