diff --git a/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/mmseg/core/seg/sampler/ohem_pixel_sampler.py index 72ba941f03..833a28768c 100644 --- a/mmseg/core/seg/sampler/ohem_pixel_sampler.py +++ b/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +import torch.nn as nn import torch.nn.functional as F from ..builder import PIXEL_SAMPLERS @@ -62,14 +63,19 @@ def sample(self, seg_logit, seg_label): threshold = max(min_threshold, self.thresh) valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. else: + if not isinstance(self.context.loss_decode, nn.ModuleList): + losses_decode = [self.context.loss_decode] + else: + losses_decode = self.context.loss_decode losses = 0.0 - for loss_module in self.context.loss_decode: + for loss_module in losses_decode: losses += loss_module( seg_logit, seg_label, weight=None, ignore_index=self.context.ignore_index, reduction_override='none') + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa _, sort_indices = losses[valid_mask].sort(descending=True) valid_seg_weight[sort_indices[:batch_kept]] = 1. diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index c36555eaf2..1443a81da2 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -83,11 +83,11 @@ def __init__(self, self.ignore_index = ignore_index self.align_corners = align_corners - self.loss_decode = nn.ModuleList() if isinstance(loss_decode, dict): - self.loss_decode.append(build_loss(loss_decode)) + self.loss_decode = build_loss(loss_decode) elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() for loss in loss_decode: self.loss_decode.append(build_loss(loss)) else: @@ -242,7 +242,12 @@ def losses(self, seg_logit, seg_label): else: seg_weight = None seg_label = seg_label.squeeze(1) - for loss_decode in self.loss_decode: + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: if loss_decode.loss_name not in loss: loss[loss_decode.loss_name] = loss_decode( seg_logit, diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py index 56dfd4ed8b..7276218053 100644 --- a/mmseg/models/decode_heads/point_head.py +++ b/mmseg/models/decode_heads/point_head.py @@ -249,9 +249,14 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg): def losses(self, point_logits, point_label): """Compute segmentation loss.""" loss = dict() - for loss_module in self.loss_decode: + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_module in losses_decode: loss['point' + loss_module.loss_name] = loss_module( point_logits, point_label, ignore_index=self.ignore_index) + loss['acc_point'] = accuracy(point_logits, point_label) return loss diff --git a/tests/test_models/test_heads/test_point_head.py b/tests/test_models/test_heads/test_point_head.py index 6c5ea65768..142ab16c6c 100644 --- a/tests/test_models/test_heads/test_point_head.py +++ b/tests/test_models/test_heads/test_point_head.py @@ -21,3 +21,41 @@ def test_point_head(): subdivision_steps=2, subdivision_num_points=8196, scale_factor=2) output = point_head.forward_test(inputs, prev_output, None, test_cfg) assert output.shape == (1, point_head.num_classes, 180, 180) + + # test multiple losses case + inputs = [torch.randn(1, 32, 45, 45)] + point_head_multiple_losses = PointHead( + in_channels=[32], + in_index=[0], + channels=16, + num_classes=19, + loss_decode=[ + dict(type='CrossEntropyLoss', loss_name='loss_1'), + dict(type='CrossEntropyLoss', loss_name='loss_2') + ]) + assert len(point_head_multiple_losses.fcs) == 3 + fcn_head_multiple_losses = FCNHead( + in_channels=32, + channels=16, + num_classes=19, + loss_decode=[ + dict(type='CrossEntropyLoss', loss_name='loss_1'), + dict(type='CrossEntropyLoss', loss_name='loss_2') + ]) + if torch.cuda.is_available(): + head, inputs = to_cuda(point_head_multiple_losses, inputs) + head, inputs = to_cuda(fcn_head_multiple_losses, inputs) + prev_output = fcn_head_multiple_losses(inputs) + test_cfg = ConfigDict( + subdivision_steps=2, subdivision_num_points=8196, scale_factor=2) + output = point_head_multiple_losses.forward_test(inputs, prev_output, None, + test_cfg) + assert output.shape == (1, point_head.num_classes, 180, 180) + + fake_label = torch.ones([1, 180, 180], dtype=torch.long) + + if torch.cuda.is_available(): + fake_label = fake_label.cuda() + loss = point_head_multiple_losses.losses(output, fake_label) + assert 'pointloss_1' in loss + assert 'pointloss_2' in loss diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 8e613a5a1f..14092243f5 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -10,6 +10,17 @@ def _context_for_ohem(): return FCNHead(in_channels=32, channels=16, num_classes=19) +def _context_for_ohem_multiple_loss(): + return FCNHead( + in_channels=32, + channels=16, + num_classes=19, + loss_decode=[ + dict(type='CrossEntropyLoss', loss_name='loss_1'), + dict(type='CrossEntropyLoss', loss_name='loss_2') + ]) + + def test_ohem_sampler(): with pytest.raises(AssertionError): @@ -37,3 +48,31 @@ def test_ohem_sampler(): assert seg_weight.shape[0] == seg_logit.shape[0] assert seg_weight.shape[1:] == seg_logit.shape[2:] assert seg_weight.sum() == 200 + + # test multiple losses case + with pytest.raises(AssertionError): + # seg_logit and seg_label must be of the same size + sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss()) + seg_logit = torch.randn(1, 19, 45, 45) + seg_label = torch.randint(0, 19, size=(1, 1, 89, 89)) + sampler.sample(seg_logit, seg_label) + + # test with thresh in multiple losses case + sampler = OHEMPixelSampler( + context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200) + seg_logit = torch.randn(1, 19, 45, 45) + seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) + seg_weight = sampler.sample(seg_logit, seg_label) + assert seg_weight.shape[0] == seg_logit.shape[0] + assert seg_weight.shape[1:] == seg_logit.shape[2:] + assert seg_weight.sum() > 200 + + # test w.o thresh in multiple losses case + sampler = OHEMPixelSampler( + context=_context_for_ohem_multiple_loss(), min_kept=200) + seg_logit = torch.randn(1, 19, 45, 45) + seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) + seg_weight = sampler.sample(seg_logit, seg_label) + assert seg_weight.shape[0] == seg_logit.shape[0] + assert seg_weight.shape[1:] == seg_logit.shape[2:] + assert seg_weight.sum() == 200