Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support kernel updation for some decoder heads. #1299

Merged
merged 12 commits into from
Feb 27, 2022
17 changes: 13 additions & 4 deletions mmseg/models/decode_heads/aspp_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def forward(self, inputs):
"""Forward function."""
def forward_feature(self, inputs):
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
"""Feature map before `self.cls_seg` and learnable semantic kernels can
be both output for kernel updation."""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
Expand All @@ -103,6 +104,14 @@ def forward(self, inputs):
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
output = self.cls_seg(output)
feats = self.bottleneck(aspp_outs)
output = self.cls_seg(feats)
seg_kernels = self.conv_seg.weight.clone()
seg_kernels = seg_kernels[None].expand(
feats.size(0), *seg_kernels.size())
return output, feats, seg_kernels

def forward(self, inputs):
"""Forward function."""
output, feats, seg_kernels = self.forward_feature(inputs)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
return output
20 changes: 15 additions & 5 deletions mmseg/models/decode_heads/fcn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,21 @@ def __init__(self,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def forward(self, inputs):
"""Forward function."""
def forward_feature(self, inputs):
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
"""Feature map before `self.cls_seg` and learnable semantic kernels can
be both output for kernel updation."""
x = self._transform_inputs(inputs)
output = self.convs(x)
feats = self.convs(x)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
feats = self.conv_cat(torch.cat([x, feats], dim=1))
output = self.cls_seg(feats)

seg_kernels = self.conv_seg.weight.clone()
seg_kernels = seg_kernels[None].expand(
feats.size(0), *seg_kernels.size())
return output, feats, seg_kernels

def forward(self, inputs):
"""Forward function."""
output, feats, seg_kernels = self.forward_feature(inputs)
return output
18 changes: 14 additions & 4 deletions mmseg/models/decode_heads/psp_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,22 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def forward(self, inputs):
"""Forward function."""
def forward_feature(self, inputs):
"""Feature map before `self.cls_seg` and learnable semantic kernels can
be both output for kernel updation."""
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
output = self.cls_seg(output)
feats = self.bottleneck(psp_outs)
output = self.cls_seg(feats)

seg_kernels = self.conv_seg.weight.clone()
seg_kernels = seg_kernels[None].expand(
feats.size(0), *seg_kernels.size())
return output, feats, seg_kernels

def forward(self, inputs):
"""Forward function."""
output, feats, seg_kernels = self.forward_feature(inputs)
return output
20 changes: 15 additions & 5 deletions mmseg/models/decode_heads/uper_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def psp_forward(self, inputs):

return output

def forward(self, inputs):
"""Forward function."""
def forward_feature(self, inputs):
"""Feature map before `self.cls_seg` and learnable semantic kernels can
be both output for kernel updation."""

inputs = self._transform_inputs(inputs)

Expand All @@ -101,7 +102,7 @@ def forward(self, inputs):
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] = laterals[i - 1] + resize(
laterals[i - 1] += resize(
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
laterals[i],
size=prev_shape,
mode='bilinear',
Expand All @@ -122,6 +123,15 @@ def forward(self, inputs):
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(output)
feats = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(feats)

seg_kernels = self.conv_seg.weight.clone()
seg_kernels = seg_kernels[None].expand(
feats.size(0), *seg_kernels.size())
return output, feats, seg_kernels

def forward(self, inputs):
"""Forward function."""
output, feats, seg_kernels = self.forward_feature(inputs)
return output
18 changes: 18 additions & 0 deletions tests/test_models/test_heads/test_aspp_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,21 @@ def test_dw_aspp_head():
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)


def test_kernel_update_forward():

# test kernel updation
inputs = [torch.randn(1, 8, 45, 45)]
out_channels = 4
head = ASPPHead(
in_channels=8,
channels=out_channels,
num_classes=19,
dilations=(1, 12, 24))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
output, feats, seg_kernels = head.forward_feature(inputs)
assert output.shape == (1, head.num_classes, 45, 45)
assert feats.shape == (1, out_channels, 45, 45)
assert seg_kernels.shape == (1, head.num_classes, out_channels, 1, 1)
15 changes: 15 additions & 0 deletions tests/test_models/test_heads/test_fcn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,18 @@ def test_sep_fcn_head():
assert head.concat_input
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)


def test_kernel_update_forward():

# test kernel updation
inputs = [torch.randn(1, 8, 23, 23)]
out_channels = 4
head = FCNHead(in_channels=8, channels=out_channels, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)

output, feats, seg_kernels = head.forward_feature(inputs)
assert output.shape == (1, head.num_classes, 23, 23)
assert feats.shape == (1, out_channels, 23, 23)
assert seg_kernels.shape == (1, head.num_classes, out_channels, 1, 1)
18 changes: 18 additions & 0 deletions tests/test_models/test_heads/test_psp_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,21 @@ def test_psp_head():
assert head.psp_modules[2][0].output_size == 3
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 23, 23)


def test_kernel_update_forward():

# test kernel updation
inputs = [torch.randn(1, 4, 23, 23)]
out_channels = 2
head = PSPHead(
in_channels=4,
channels=out_channels,
num_classes=19,
pool_scales=(1, 2, 3))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
output, feats, seg_kernels = head.forward_feature(inputs)
assert output.shape == (1, head.num_classes, 23, 23)
assert feats.shape == (1, out_channels, 23, 23)
assert seg_kernels.shape == (1, head.num_classes, out_channels, 1, 1)
17 changes: 17 additions & 0 deletions tests/test_models/test_heads/test_uper_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,20 @@ def test_uper_head():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)


def test_kernel_update_forward():
# test kernel updation
inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 2, 21, 21)]
out_channels = 2
head = UPerHead(
in_channels=[4, 2],
channels=out_channels,
num_classes=19,
in_index=[-2, -1])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
output, feats, seg_kernels = head.forward_feature(inputs)
assert output.shape == (1, head.num_classes, 45, 45)
assert feats.shape == (1, out_channels, 45, 45)
assert seg_kernels.shape == (1, head.num_classes, out_channels, 1, 1)