Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Change self.loss_decode back to dict in Single Loss situation. #1002

Merged
merged 9 commits into from
Nov 1, 2021
Merged
16 changes: 13 additions & 3 deletions mmseg/core/seg/sampler/ohem_pixel_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import PIXEL_SAMPLERS
Expand Down Expand Up @@ -62,14 +63,23 @@ def sample(self, seg_logit, seg_label):
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
losses = 0.0
for loss_module in self.context.loss_decode:
losses += loss_module(
if isinstance(self.context.loss_decode, nn.ModuleList):
losses = 0.0
for loss_module in self.context.loss_decode:
losses += loss_module(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
else:
losses = self.context.loss_decode(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')

# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
Expand Down
38 changes: 23 additions & 15 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def __init__(self,

self.ignore_index = ignore_index
self.align_corners = align_corners
self.loss_decode = nn.ModuleList()

if isinstance(loss_decode, dict):
self.loss_decode.append(build_loss(loss_decode))
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
Expand Down Expand Up @@ -242,19 +242,27 @@ def losses(self, seg_logit, seg_label):
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
for loss_decode in self.loss_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)

if isinstance(self.loss_decode, nn.ModuleList):
for loss_decode in self.loss_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[self.loss_decode.loss_name] = self.loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)

loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss
8 changes: 6 additions & 2 deletions mmseg/models/decode_heads/point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,12 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
for loss_module in self.loss_decode:
loss['point' + loss_module.loss_name] = loss_module(
if isinstance(self.loss_decode, nn.ModuleList):
for loss_module in self.loss_decode:
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)
else:
loss['point' + self.loss_decode.loss_name] = self.loss_decode(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label)
return loss
Expand Down
39 changes: 39 additions & 0 deletions tests/test_models/test_heads/test_point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,42 @@ def test_point_head():
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)


def test_point_head_multiple_loss():

inputs = [torch.randn(1, 32, 45, 45)]
point_head = PointHead(
in_channels=[32],
in_index=[0],
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
assert len(point_head.fcs) == 3
fcn_head = FCNHead(
in_channels=32,
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(point_head, inputs)
head, inputs = to_cuda(fcn_head, inputs)
prev_output = fcn_head(inputs)
test_cfg = ConfigDict(
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)

fake_label = torch.ones([1, 180, 180], dtype=torch.long)

if torch.cuda.is_available():
fake_label = fake_label.cuda()
loss = point_head.losses(output, fake_label)
assert 'pointloss_1' in loss
assert 'pointloss_2' in loss
40 changes: 40 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ def _context_for_ohem():
return FCNHead(in_channels=32, channels=16, num_classes=19)


def _context_for_ohem_multiple_loss():
return FCNHead(
in_channels=32,
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])


def test_ohem_sampler():

with pytest.raises(AssertionError):
Expand Down Expand Up @@ -37,3 +48,32 @@ def test_ohem_sampler():
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200


def test_ohem_sampler_multiple_loss():
with pytest.raises(AssertionError):
# seg_logit and seg_label must be of the same size
sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss())
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
sampler.sample(seg_logit, seg_label)

# test with thresh
sampler = OHEMPixelSampler(
context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() > 200

# test w.o thresh
sampler = OHEMPixelSampler(
context=_context_for_ohem_multiple_loss(), min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200