From 38cc03aba60cd205771a420b3db1b36972105b44 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Thu, 22 Sep 2022 16:39:36 +0800 Subject: [PATCH 1/6] add out_channels --- mmseg/models/decode_heads/decode_head.py | 33 ++++++++++++++++++++-- mmseg/models/segmentors/base.py | 15 ++++++---- mmseg/models/segmentors/encoder_decoder.py | 1 + 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 9d81ec237a..c7223f944c 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from abc import ABCMeta, abstractmethod from typing import List, Tuple @@ -44,6 +45,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): in_channels (int|Sequence[int]): Input channels. channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. + out_channels (int): Output channels of conv_seg. + threshold (float): Threshold for binary segmentation in the case of + `num_classes==1`. Default: None. dropout_ratio (float): Ratio of dropout layer. Default: 0.1. conv_cfg (dict|None): Config of conv layers. Default: None. norm_cfg (dict|None): Config of norm layers. Default: None. @@ -82,6 +86,8 @@ def __init__(self, channels, *, num_classes, + out_channels=None, + threshold=None, dropout_ratio=0.1, conv_cfg=None, norm_cfg=None, @@ -100,7 +106,6 @@ def __init__(self, super().__init__(init_cfg) self._init_inputs(in_channels, in_index, input_transform) self.channels = channels - self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -110,6 +115,30 @@ def __init__(self, self.ignore_index = ignore_index self.align_corners = align_corners + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert seg_logist into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + if isinstance(loss_decode, dict): self.loss_decode = build_loss(loss_decode) elif isinstance(loss_decode, (list, tuple)): @@ -125,7 +154,7 @@ def __init__(self, else: self.sampler = None - self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 8ae713f8f8..d22e3baddd 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -148,8 +148,6 @@ def postprocess_result(self, segmentation before normalization. """ batch_size, C, H, W = seg_logits.shape - assert C > 1, ('This post processes does not binary segmentation, and ' - f'channels `seg_logtis` must be > 1 but got {C}') if data_samples is None: data_samples = [] @@ -175,7 +173,11 @@ def postprocess_result(self, align_corners=self.align_corners, warning=False).squeeze(0) # i_seg_logits shape is C, H, W with original shape - i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) data_samples[i].set_data({ 'seg_logits': PixelData(**{'data': i_seg_logits}), @@ -183,8 +185,11 @@ def postprocess_result(self, PixelData(**{'data': i_seg_pred}) }) else: - i_seg_logits = seg_logits[i] - i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) prediction = SegDataSample() prediction.set_data({ 'seg_logits': diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index cf99ecf63c..c4f44ba005 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -100,6 +100,7 @@ def _init_decode_head(self, decode_head: ConfigType) -> None: self.decode_head = MODELS.build(decode_head) self.align_corners = self.decode_head.align_corners self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None: """Initialize ``auxiliary_head``""" From a03b1c1e0967507291dc58d6add5401f2ad66902 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Thu, 22 Sep 2022 17:36:56 +0800 Subject: [PATCH 2/6] fix forward --- mmseg/models/segmentors/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index d22e3baddd..35a58bf542 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -185,6 +185,7 @@ def postprocess_result(self, PixelData(**{'data': i_seg_pred}) }) else: + i_seg_logits = seg_logits[i] if C > 1: i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) else: From 463a9485d38338d924d65389c289861761f02cfe Mon Sep 17 00:00:00 2001 From: xiexinch Date: Mon, 26 Sep 2022 19:32:42 +0800 Subject: [PATCH 3/6] add decode_head ut --- .../test_heads/test_decode_head.py | 193 ++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 tests/test_models/test_heads/test_decode_head.py diff --git a/tests/test_models/test_heads/test_decode_head.py b/tests/test_models/test_heads/test_decode_head.py new file mode 100644 index 0000000000..88e6bed10f --- /dev/null +++ b/tests/test_models/test_heads/test_decode_head.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +import pytest +import torch +from mmengine.structures import PixelData + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.structures import SegDataSample +from .utils import to_cuda + + +@patch.multiple(BaseDecodeHead, __abstractmethods__=set()) +def test_decode_head(): + + with pytest.raises(AssertionError): + # default input_transform doesn't accept multiple inputs + BaseDecodeHead([32, 16], 16, num_classes=19) + + with pytest.raises(AssertionError): + # default input_transform doesn't accept multiple inputs + BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2]) + + with pytest.raises(AssertionError): + # supported mode is resize_concat only + BaseDecodeHead(32, 16, num_classes=19, input_transform='concat') + + with pytest.raises(AssertionError): + # in_channels should be list|tuple + BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat') + + with pytest.raises(AssertionError): + # in_index should be list|tuple + BaseDecodeHead([32], + 16, + in_index=-1, + num_classes=19, + input_transform='resize_concat') + + with pytest.raises(AssertionError): + # len(in_index) should equal len(in_channels) + BaseDecodeHead([32, 16], + 16, + num_classes=19, + in_index=[-1], + input_transform='resize_concat') + + with pytest.raises(ValueError): + # out_channels should be equal to num_classes + BaseDecodeHead(32, 16, num_classes=19, out_channels=18) + + # test out_channels + head = BaseDecodeHead(32, 16, num_classes=2) + assert head.out_channels == 2 + + # test out_channels == 1 and num_classes == 2 + head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1) + assert head.out_channels == 1 and head.num_classes == 2 + + # test default dropout + head = BaseDecodeHead(32, 16, num_classes=19) + assert hasattr(head, 'dropout') and head.dropout.p == 0.1 + + # test set dropout + head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2) + assert hasattr(head, 'dropout') and head.dropout.p == 0.2 + + # test no input_transform + inputs = [torch.randn(1, 32, 45, 45)] + head = BaseDecodeHead(32, 16, num_classes=19) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.in_channels == 32 + assert head.input_transform is None + transformed_inputs = head._transform_inputs(inputs) + assert transformed_inputs.shape == (1, 32, 45, 45) + + # test input_transform = resize_concat + inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] + head = BaseDecodeHead([32, 16], + 16, + num_classes=19, + in_index=[0, 1], + input_transform='resize_concat') + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.in_channels == 48 + assert head.input_transform == 'resize_concat' + transformed_inputs = head._transform_inputs(inputs) + assert transformed_inputs.shape == (1, 48, 45, 45) + + # test multi-loss, loss_decode is dict + with pytest.raises(TypeError): + # loss_decode must be a dict or sequence of dict. + BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss']) + + inputs = torch.randn(2, 19, 8, 8).float() + data_samples = [ + SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long())) + for _ in range(2) + ] + + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + loss = head.loss_by_feat( + seg_logits=inputs, batch_data_samples=data_samples) + assert 'loss_ce' in loss + + # test multi-loss, loss_decode is list of dict + inputs = torch.randn(2, 19, 8, 8).float() + data_samples = [ + SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long())) + for _ in range(2) + ] + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=[ + dict(type='CrossEntropyLoss', loss_name='loss_1'), + dict(type='CrossEntropyLoss', loss_name='loss_2') + ]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + + loss = head.loss_by_feat( + seg_logits=inputs, batch_data_samples=data_samples) + assert 'loss_1' in loss + assert 'loss_2' in loss + + # 'loss_decode' must be a dict or sequence of dict + with pytest.raises(TypeError): + BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss']) + with pytest.raises(TypeError): + BaseDecodeHead(3, 16, num_classes=19, loss_decode=0) + + # test multi-loss, loss_decode is list of dict + inputs = torch.randn(2, 19, 8, 8).float() + data_samples = [ + SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long())) + for _ in range(2) + ] + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'), + dict(type='CrossEntropyLoss', loss_name='loss_2'), + dict(type='CrossEntropyLoss', loss_name='loss_3'))) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + loss = head.loss_by_feat( + seg_logits=inputs, batch_data_samples=data_samples) + assert 'loss_1' in loss + assert 'loss_2' in loss + assert 'loss_3' in loss + + # test multi-loss, loss_decode is list of dict, names of them are identical + inputs = torch.randn(2, 19, 8, 8).float() + data_samples = [ + SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long())) + for _ in range(2) + ] + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='CrossEntropyLoss', loss_name='loss_ce'))) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + loss_3 = head.loss_by_feat( + seg_logits=inputs, batch_data_samples=data_samples) + + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'))) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + loss = head.loss_by_feat( + seg_logits=inputs, batch_data_samples=data_samples) + assert 'loss_ce' in loss + assert 'loss_ce' in loss_3 + assert loss_3['loss_ce'] == 3 * loss['loss_ce'] From 445d0623dd00ba473dba14ac604b5f301ed0010a Mon Sep 17 00:00:00 2001 From: xiexinch Date: Tue, 27 Sep 2022 11:29:23 +0800 Subject: [PATCH 4/6] add segmentor ut --- tests/test_models/test_segmentors/__init__.py | 1 + .../test_cascade_encoder_decoder.py | 57 ++++++++ .../test_segmentors/test_encoder_decoder.py | 53 +++++++ tests/test_models/test_segmentors/utils.py | 133 ++++++++++++++++++ 4 files changed, 244 insertions(+) create mode 100644 tests/test_models/test_segmentors/__init__.py create mode 100644 tests/test_models/test_segmentors/test_cascade_encoder_decoder.py create mode 100644 tests/test_models/test_segmentors/test_encoder_decoder.py create mode 100644 tests/test_models/test_segmentors/utils.py diff --git a/tests/test_models/test_segmentors/__init__.py b/tests/test_models/test_segmentors/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_models/test_segmentors/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py b/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py new file mode 100644 index 0000000000..941816d253 --- /dev/null +++ b/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine import ConfigDict + +from mmseg.models import build_segmentor +from .utils import _segmentor_forward_train_test + + +def test_cascade_encoder_decoder(): + + # test 1 decode head, w.o. aux head + cfg = ConfigDict( + type='CascadeEncoderDecoder', + num_stages=2, + backbone=dict(type='ExampleBackbone'), + decode_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleCascadeDecodeHead') + ]) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test slide mode + cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 1 aux head + cfg = ConfigDict( + type='CascadeEncoderDecoder', + num_stages=2, + backbone=dict(type='ExampleBackbone'), + decode_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleCascadeDecodeHead') + ], + auxiliary_head=dict(type='ExampleDecodeHead')) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 2 aux head + cfg = ConfigDict( + type='CascadeEncoderDecoder', + num_stages=2, + backbone=dict(type='ExampleBackbone'), + decode_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleCascadeDecodeHead') + ], + auxiliary_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleDecodeHead') + ]) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) diff --git a/tests/test_models/test_segmentors/test_encoder_decoder.py b/tests/test_models/test_segmentors/test_encoder_decoder.py new file mode 100644 index 0000000000..a86420a6f9 --- /dev/null +++ b/tests/test_models/test_segmentors/test_encoder_decoder.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine import ConfigDict + +from mmseg.models import build_segmentor +from .utils import _segmentor_forward_train_test + + +def test_encoder_decoder(): + + # test 1 decode head, w.o. aux head + + cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict(type='ExampleDecodeHead'), + train_cfg=None, + test_cfg=dict(mode='whole')) + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test out_channels == 1 + segmentor.out_channels = 1 + segmentor.decode_head.out_channels = 1 + segmentor.decode_head.threshold = 0.3 + _segmentor_forward_train_test(segmentor) + + # test slide mode + cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 1 aux head + cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict(type='ExampleDecodeHead'), + auxiliary_head=dict(type='ExampleDecodeHead')) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) + + # test 1 decode head, 2 aux head + cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict(type='ExampleDecodeHead'), + auxiliary_head=[ + dict(type='ExampleDecodeHead'), + dict(type='ExampleDecodeHead') + ]) + cfg.test_cfg = ConfigDict(mode='whole') + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py new file mode 100644 index 0000000000..09f5c76107 --- /dev/null +++ b/tests/test_models/test_segmentors/utils.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.optim import OptimWrapper +from mmengine.structures import PixelData +from torch import nn +from torch.optim import SGD + +from mmseg.models import SegDataPreProcessor +from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample + + +def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + + imgs = torch.randn(*input_shape) + segs = torch.randint( + low=0, high=num_classes - 1, size=(N, H, W), dtype=torch.long) + + img_metas = [{ + 'img_shape': (H, W), + 'ori_shape': (H, W), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': 1.0, + 'flip': False, + 'flip_direction': 'horizontal' + } for _ in range(N)] + + data_samples = [ + SegDataSample( + gt_sem_seg=PixelData(data=segs[i]), metainfo=img_metas[i]) + for i in range(N) + ] + + mm_inputs = {'imgs': torch.FloatTensor(imgs), 'data_samples': data_samples} + + return mm_inputs + + +@MODELS.register_module() +class ExampleBackbone(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 3) + + def init_weights(self, pretrained=None): + pass + + def forward(self, x): + return [self.conv(x)] + + +@MODELS.register_module() +class ExampleDecodeHead(BaseDecodeHead): + + def __init__(self): + super().__init__(3, 3, num_classes=19) + + def forward(self, inputs): + return self.cls_seg(inputs[0]) + + +@MODELS.register_module() +class ExampleCascadeDecodeHead(BaseCascadeDecodeHead): + + def __init__(self): + super().__init__(3, 3, num_classes=19) + + def forward(self, inputs, prev_out): + return self.cls_seg(inputs[0]) + + +def _segmentor_forward_train_test(segmentor): + if isinstance(segmentor.decode_head, nn.ModuleList): + num_classes = segmentor.decode_head[-1].num_classes + else: + num_classes = segmentor.decode_head.num_classes + # batch_size=2 for BatchNorm + mm_inputs = _demo_mm_inputs(num_classes=num_classes) + + # convert to cuda Tensor if applicable + if torch.cuda.is_available(): + segmentor = segmentor.cuda() + + # check data preprocessor + if not hasattr(segmentor, + 'data_preprocessor') or segmentor.data_preprocessor is None: + segmentor.data_preprocessor = SegDataPreProcessor() + + mm_inputs = segmentor.data_preprocessor(mm_inputs, True) + imgs = mm_inputs.pop('imgs') + data_samples = mm_inputs.pop('data_samples') + + # create optimizer wrapper + optimizer = SGD(segmentor.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + + # Test forward train + losses = segmentor.forward(imgs, data_samples, mode='loss') + assert isinstance(losses, dict) + + # Test train_step + data_batch = dict(inputs=imgs, data_samples=data_samples) + outputs = segmentor.train_step(data_batch, optim_wrapper) + assert isinstance(outputs, dict) + assert 'loss' in outputs + + # Test val_step + with torch.no_grad(): + segmentor.eval() + data_batch = dict(inputs=imgs, data_samples=data_samples) + outputs = segmentor.val_step(data_batch) + assert isinstance(outputs, list) + + # Test forward simple test + with torch.no_grad(): + segmentor.eval() + data_batch = dict(inputs=imgs, data_samples=data_samples) + results = segmentor.forward(imgs, data_samples, mode='tensor') + assert isinstance(results, torch.Tensor) From ec28bd8a3fe56f44226eda4202dfd6d173a3e0bf Mon Sep 17 00:00:00 2001 From: xiexinch Date: Tue, 27 Sep 2022 14:40:52 +0800 Subject: [PATCH 5/6] refine postprocess --- mmseg/models/segmentors/base.py | 40 +++++++------------ .../test_segmentors/test_encoder_decoder.py | 12 ++++-- tests/test_models/test_segmentors/utils.py | 5 ++- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 35a58bf542..861ba0287f 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -150,7 +150,7 @@ def postprocess_result(self, batch_size, C, H, W = seg_logits.shape if data_samples is None: - data_samples = [] + data_samples = [SegDataSample()] * batch_size only_prediction = True else: only_prediction = False @@ -172,31 +172,19 @@ def postprocess_result(self, mode='bilinear', align_corners=self.align_corners, warning=False).squeeze(0) - # i_seg_logits shape is C, H, W with original shape - if C > 1: - i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) - else: - i_seg_pred = (i_seg_logits > - self.decode_head.threshold).to(i_seg_logits) - data_samples[i].set_data({ - 'seg_logits': - PixelData(**{'data': i_seg_logits}), - 'pred_sem_seg': - PixelData(**{'data': i_seg_pred}) - }) else: i_seg_logits = seg_logits[i] - if C > 1: - i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) - else: - i_seg_pred = (i_seg_logits > - self.decode_head.threshold).to(i_seg_logits) - prediction = SegDataSample() - prediction.set_data({ - 'seg_logits': - PixelData(**{'data': i_seg_logits}), - 'pred_sem_seg': - PixelData(**{'data': i_seg_pred}) - }) - data_samples.append(prediction) + + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) + data_samples[i].set_data({ + 'seg_logits': + PixelData(**{'data': i_seg_logits}), + 'pred_sem_seg': + PixelData(**{'data': i_seg_pred}) + }) + return data_samples diff --git a/tests/test_models/test_segmentors/test_encoder_decoder.py b/tests/test_models/test_segmentors/test_encoder_decoder.py index a86420a6f9..81f89db412 100644 --- a/tests/test_models/test_segmentors/test_encoder_decoder.py +++ b/tests/test_models/test_segmentors/test_encoder_decoder.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. + from mmengine import ConfigDict from mmseg.models import build_segmentor @@ -19,9 +20,14 @@ def test_encoder_decoder(): _segmentor_forward_train_test(segmentor) # test out_channels == 1 - segmentor.out_channels = 1 - segmentor.decode_head.out_channels = 1 - segmentor.decode_head.threshold = 0.3 + cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict( + type='ExampleDecodeHead', num_classes=2, out_channels=1), + train_cfg=None, + test_cfg=dict(mode='whole')) + segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) # test slide mode diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py index 09f5c76107..9b155c0961 100644 --- a/tests/test_models/test_segmentors/utils.py +++ b/tests/test_models/test_segmentors/utils.py @@ -66,8 +66,9 @@ def forward(self, x): @MODELS.register_module() class ExampleDecodeHead(BaseDecodeHead): - def __init__(self): - super().__init__(3, 3, num_classes=19) + def __init__(self, num_classes=19, out_channels=None): + super().__init__( + 3, 3, num_classes=num_classes, out_channels=out_channels) def forward(self, inputs): return self.cls_seg(inputs[0]) From 4aaa1b909c4e4a9364b916b37d926f49c89405b6 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Tue, 27 Sep 2022 14:57:02 +0800 Subject: [PATCH 6/6] fix --- mmseg/models/segmentors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 861ba0287f..dfceddd99f 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -150,7 +150,7 @@ def postprocess_result(self, batch_size, C, H, W = seg_logits.shape if data_samples is None: - data_samples = [SegDataSample()] * batch_size + data_samples = [SegDataSample() for _ in range(batch_size)] only_prediction = True else: only_prediction = False