Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix binary segmentation #2101

Merged
merged 6 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
from typing import List, Tuple

Expand Down Expand Up @@ -44,6 +45,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg.
threshold (float): Threshold for binary segmentation in the case of
`num_classes==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
Expand Down Expand Up @@ -82,6 +86,8 @@ def __init__(self,
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
Expand All @@ -100,7 +106,6 @@ def __init__(self,
super().__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
Expand All @@ -110,6 +115,30 @@ def __init__(self,
self.ignore_index = ignore_index
self.align_corners = align_corners

if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert seg_logist into a prediction'
'applying a threshold')
out_channels = num_classes

if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')

if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold

if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
Expand All @@ -125,7 +154,7 @@ def __init__(self,
else:
self.sampler = None

self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
Expand Down
32 changes: 13 additions & 19 deletions mmseg/models/segmentors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,9 @@ def postprocess_result(self,
segmentation before normalization.
"""
batch_size, C, H, W = seg_logits.shape
assert C > 1, ('This post processes does not binary segmentation, and '
f'channels `seg_logtis` must be > 1 but got {C}')

if data_samples is None:
data_samples = []
data_samples = [SegDataSample() for _ in range(batch_size)]
only_prediction = True
else:
only_prediction = False
Expand All @@ -174,23 +172,19 @@ def postprocess_result(self,
mode='bilinear',
align_corners=self.align_corners,
warning=False).squeeze(0)
# i_seg_logits shape is C, H, W with original shape
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
data_samples[i].set_data({
'seg_logits':
PixelData(**{'data': i_seg_logits}),
'pred_sem_seg':
PixelData(**{'data': i_seg_pred})
})
else:
i_seg_logits = seg_logits[i]

if C > 1:
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
prediction = SegDataSample()
prediction.set_data({
'seg_logits':
PixelData(**{'data': i_seg_logits}),
'pred_sem_seg':
PixelData(**{'data': i_seg_pred})
})
data_samples.append(prediction)
else:
i_seg_pred = (i_seg_logits >
self.decode_head.threshold).to(i_seg_logits)
data_samples[i].set_data({
'seg_logits':
PixelData(**{'data': i_seg_logits}),
'pred_sem_seg':
PixelData(**{'data': i_seg_pred})
})

return data_samples
1 change: 1 addition & 0 deletions mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _init_decode_head(self, decode_head: ConfigType) -> None:
self.decode_head = MODELS.build(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
self.out_channels = self.decode_head.out_channels

def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None:
"""Initialize ``auxiliary_head``"""
Expand Down
193 changes: 193 additions & 0 deletions tests/test_models/test_heads/test_decode_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch

import pytest
import torch
from mmengine.structures import PixelData

from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.structures import SegDataSample
from .utils import to_cuda


@patch.multiple(BaseDecodeHead, __abstractmethods__=set())
def test_decode_head():

with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead([32, 16], 16, num_classes=19)

with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2])

with pytest.raises(AssertionError):
# supported mode is resize_concat only
BaseDecodeHead(32, 16, num_classes=19, input_transform='concat')

with pytest.raises(AssertionError):
# in_channels should be list|tuple
BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat')

with pytest.raises(AssertionError):
# in_index should be list|tuple
BaseDecodeHead([32],
16,
in_index=-1,
num_classes=19,
input_transform='resize_concat')

with pytest.raises(AssertionError):
# len(in_index) should equal len(in_channels)
BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[-1],
input_transform='resize_concat')

with pytest.raises(ValueError):
# out_channels should be equal to num_classes
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)

# test out_channels
head = BaseDecodeHead(32, 16, num_classes=2)
assert head.out_channels == 2

# test out_channels == 1 and num_classes == 2
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
assert head.out_channels == 1 and head.num_classes == 2

# test default dropout
head = BaseDecodeHead(32, 16, num_classes=19)
assert hasattr(head, 'dropout') and head.dropout.p == 0.1

# test set dropout
head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2)
assert hasattr(head, 'dropout') and head.dropout.p == 0.2

# test no input_transform
inputs = [torch.randn(1, 32, 45, 45)]
head = BaseDecodeHead(32, 16, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 32
assert head.input_transform is None
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 32, 45, 45)

# test input_transform = resize_concat
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)]
head = BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[0, 1],
input_transform='resize_concat')
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 48
assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 48, 45, 45)

# test multi-loss, loss_decode is dict
with pytest.raises(TypeError):
# loss_decode must be a dict or sequence of dict.
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])

inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]

head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_ce' in loss

# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)

loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_1' in loss
assert 'loss_2' in loss

# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)

# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss

# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss_3 = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)

head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']
1 change: 1 addition & 0 deletions tests/test_models/test_segmentors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
57 changes: 57 additions & 0 deletions tests/test_models/test_segmentors/test_cascade_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import ConfigDict

from mmseg.models import build_segmentor
from .utils import _segmentor_forward_train_test


def test_cascade_encoder_decoder():

# test 1 decode head, w.o. aux head
cfg = ConfigDict(
type='CascadeEncoderDecoder',
num_stages=2,
backbone=dict(type='ExampleBackbone'),
decode_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleCascadeDecodeHead')
])
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)

# test slide mode
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)

# test 1 decode head, 1 aux head
cfg = ConfigDict(
type='CascadeEncoderDecoder',
num_stages=2,
backbone=dict(type='ExampleBackbone'),
decode_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleCascadeDecodeHead')
],
auxiliary_head=dict(type='ExampleDecodeHead'))
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)

# test 1 decode head, 2 aux head
cfg = ConfigDict(
type='CascadeEncoderDecoder',
num_stages=2,
backbone=dict(type='ExampleBackbone'),
decode_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleCascadeDecodeHead')
],
auxiliary_head=[
dict(type='ExampleDecodeHead'),
dict(type='ExampleDecodeHead')
])
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
Loading