Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix binary segmentation when num_classes==1 #2016

Merged
merged 11 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod

import torch
Expand All @@ -18,6 +19,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg.
threshold (float): Threshold for binary segmentation in the case of
`num_classes==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
Expand Down Expand Up @@ -56,6 +60,8 @@ def __init__(self,
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
Expand All @@ -74,7 +80,6 @@ def __init__(self,
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
Expand All @@ -84,6 +89,29 @@ def __init__(self,
self.ignore_index = ignore_index
self.align_corners = align_corners

if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert seg_logist into a prediction'
'applying a threshold')
out_channels = num_classes

if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except out_channels == 1 and num_classes == 2, but got'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'except out_channels == 1 and num_classes == 2, but got'
'except binary segmentation set out_channels == 1 and num_classes == 2, but got'

f'out_channels={out_channels}, num_classes={num_classes}')

if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold

if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
Expand All @@ -99,7 +127,7 @@ def __init__(self,
else:
self.sampler = None

self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
Expand Down
1 change: 1 addition & 0 deletions mmseg/models/segmentors/cascade_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _init_decode_head(self, decode_head):
self.decode_head.append(builder.build_head(decode_head[i]))
self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes
self.out_channels = self.decode_head[-1].out_channels

def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation
Expand Down
22 changes: 17 additions & 5 deletions mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _init_decode_head(self, decode_head):
self.decode_head = builder.build_head(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
self.out_channels = self.decode_head.out_channels

def _init_auxiliary_head(self, auxiliary_head):
"""Initialize ``auxiliary_head``"""
Expand Down Expand Up @@ -162,10 +163,10 @@ def slide_inference(self, img, img_meta, rescale):
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
num_classes = self.num_classes
out_channels = self.out_channels
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
preds = img.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
Expand Down Expand Up @@ -245,7 +246,10 @@ def inference(self, img, img_meta, rescale):
seg_logit = self.slide_inference(img, img_meta, rescale)
else:
seg_logit = self.whole_inference(img, img_meta, rescale)
output = F.softmax(seg_logit, dim=1)
if self.out_channels == 1:
output = F.sigmoid(seg_logit)
else:
output = F.softmax(seg_logit, dim=1)
flip = img_meta[0]['flip']
if flip:
flip_direction = img_meta[0]['flip_direction']
Expand All @@ -260,7 +264,11 @@ def inference(self, img, img_meta, rescale):
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if self.out_channels == 1:
seg_pred = (seg_logit >
self.decode_head.threshold).to(seg_logit).squeeze(1)
else:
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
Expand All @@ -283,7 +291,11 @@ def aug_test(self, imgs, img_metas, rescale=True):
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
seg_logit += cur_seg_logit
seg_logit /= len(imgs)
seg_pred = seg_logit.argmax(dim=1)
if self.out_channels == 1:
seg_pred = (seg_logit >
self.decode_head.threshold).to(seg_logit).squeeze(1)
else:
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_models/test_heads/test_decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def test_decode_head():
in_index=[-1],
input_transform='resize_concat')

with pytest.raises(ValueError):
# out_channels should be equal to num_classes
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)

# test out_channels
head = BaseDecodeHead(32, 16, num_classes=2)
assert head.out_channels == 2

# test out_channels == 1 and num_classes == 2
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
assert head.out_channels == 1 and head.num_classes == 2

# test default dropout
head = BaseDecodeHead(32, 16, num_classes=19)
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
Expand Down
6 changes: 6 additions & 0 deletions tests/test_models/test_segmentors/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def test_encoder_decoder():
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)

# test out_channels == 1
segmentor.out_channels = 1
segmentor.decode_head.out_channels = 1
segmentor.decode_head.threshold = 0.3
_segmentor_forward_train_test(segmentor)

# test slide mode
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
segmentor = build_segmentor(cfg)
Expand Down