Skip to content

Commit 0ee3016

Browse files
authored
[New Feature]add lovasz loss (#351)
* add lovasz loss * Modify as comments * Modify paper url * add unittest and remove Var * impove unittest
1 parent 1c96a89 commit 0ee3016

File tree

3 files changed

+365
-1
lines changed

3 files changed

+365
-1
lines changed

mmseg/models/losses/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from .accuracy import Accuracy, accuracy
22
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
33
cross_entropy, mask_cross_entropy)
4+
from .lovasz_loss import LovaszLoss
45
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
56

67
__all__ = [
78
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
89
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
9-
'weight_reduce_loss', 'weighted_loss'
10+
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'
1011
]

mmseg/models/losses/lovasz_loss.py

+303
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
2+
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
3+
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
4+
5+
import mmcv
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
10+
from ..builder import LOSSES
11+
from .utils import weight_reduce_loss
12+
13+
14+
def lovasz_grad(gt_sorted):
15+
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
16+
17+
See Alg. 1 in paper.
18+
"""
19+
p = len(gt_sorted)
20+
gts = gt_sorted.sum()
21+
intersection = gts - gt_sorted.float().cumsum(0)
22+
union = gts + (1 - gt_sorted).float().cumsum(0)
23+
jaccard = 1. - intersection / union
24+
if p > 1: # cover 1-pixel case
25+
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
26+
return jaccard
27+
28+
29+
def flatten_binary_logits(logits, labels, ignore_index=None):
30+
"""Flattens predictions in the batch (binary case) Remove labels equal to
31+
'ignore_index'."""
32+
logits = logits.view(-1)
33+
labels = labels.view(-1)
34+
if ignore_index is None:
35+
return logits, labels
36+
valid = (labels != ignore_index)
37+
vlogits = logits[valid]
38+
vlabels = labels[valid]
39+
return vlogits, vlabels
40+
41+
42+
def flatten_probs(probs, labels, ignore_index=None):
43+
"""Flattens predictions in the batch."""
44+
if probs.dim() == 3:
45+
# assumes output of a sigmoid layer
46+
B, H, W = probs.size()
47+
probs = probs.view(B, 1, H, W)
48+
B, C, H, W = probs.size()
49+
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
50+
labels = labels.view(-1)
51+
if ignore_index is None:
52+
return probs, labels
53+
valid = (labels != ignore_index)
54+
vprobs = probs[valid.nonzero().squeeze()]
55+
vlabels = labels[valid]
56+
return vprobs, vlabels
57+
58+
59+
def lovasz_hinge_flat(logits, labels):
60+
"""Binary Lovasz hinge loss.
61+
62+
Args:
63+
logits (torch.Tensor): [P], logits at each prediction
64+
(between -infty and +infty).
65+
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
66+
67+
Returns:
68+
torch.Tensor: The calculated loss.
69+
"""
70+
if len(labels) == 0:
71+
# only void pixels, the gradients should be 0
72+
return logits.sum() * 0.
73+
signs = 2. * labels.float() - 1.
74+
errors = (1. - logits * signs)
75+
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
76+
perm = perm.data
77+
gt_sorted = labels[perm]
78+
grad = lovasz_grad(gt_sorted)
79+
loss = torch.dot(F.relu(errors_sorted), grad)
80+
return loss
81+
82+
83+
def lovasz_hinge(logits,
84+
labels,
85+
classes='present',
86+
per_image=False,
87+
class_weight=None,
88+
reduction='mean',
89+
avg_factor=None,
90+
ignore_index=255):
91+
"""Binary Lovasz hinge loss.
92+
93+
Args:
94+
logits (torch.Tensor): [B, H, W], logits at each pixel
95+
(between -infty and +infty).
96+
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
97+
classes (str | list[int], optional): Placeholder, to be consistent with
98+
other loss. Default: None.
99+
per_image (bool, optional): If per_image is True, compute the loss per
100+
image instead of per batch. Default: False.
101+
class_weight (list[float], optional): Placeholder, to be consistent
102+
with other loss. Default: None.
103+
reduction (str, optional): The method used to reduce the loss. Options
104+
are "none", "mean" and "sum". This parameter only works when
105+
per_image is True. Default: 'mean'.
106+
avg_factor (int, optional): Average factor that is used to average
107+
the loss. This parameter only works when per_image is True.
108+
Default: None.
109+
ignore_index (int | None): The label index to be ignored. Default: 255.
110+
111+
Returns:
112+
torch.Tensor: The calculated loss.
113+
"""
114+
if per_image:
115+
loss = [
116+
lovasz_hinge_flat(*flatten_binary_logits(
117+
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
118+
for logit, label in zip(logits, labels)
119+
]
120+
loss = weight_reduce_loss(
121+
torch.stack(loss), None, reduction, avg_factor)
122+
else:
123+
loss = lovasz_hinge_flat(
124+
*flatten_binary_logits(logits, labels, ignore_index))
125+
return loss
126+
127+
128+
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
129+
"""Multi-class Lovasz-Softmax loss.
130+
131+
Args:
132+
probs (torch.Tensor): [P, C], class probabilities at each prediction
133+
(between 0 and 1).
134+
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
135+
classes (str | list[int], optional): Classes choosed to calculate loss.
136+
'all' for all classes, 'present' for classes present in labels, or
137+
a list of classes to average. Default: 'present'.
138+
class_weight (list[float], optional): The weight for each class.
139+
Default: None.
140+
141+
Returns:
142+
torch.Tensor: The calculated loss.
143+
"""
144+
if probs.numel() == 0:
145+
# only void pixels, the gradients should be 0
146+
return probs * 0.
147+
C = probs.size(1)
148+
losses = []
149+
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
150+
for c in class_to_sum:
151+
fg = (labels == c).float() # foreground for class c
152+
if (classes == 'present' and fg.sum() == 0):
153+
continue
154+
if C == 1:
155+
if len(classes) > 1:
156+
raise ValueError('Sigmoid output possible only with 1 class')
157+
class_pred = probs[:, 0]
158+
else:
159+
class_pred = probs[:, c]
160+
errors = (fg - class_pred).abs()
161+
errors_sorted, perm = torch.sort(errors, 0, descending=True)
162+
perm = perm.data
163+
fg_sorted = fg[perm]
164+
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
165+
if class_weight is not None:
166+
loss *= class_weight[c]
167+
losses.append(loss)
168+
return torch.stack(losses).mean()
169+
170+
171+
def lovasz_softmax(probs,
172+
labels,
173+
classes='present',
174+
per_image=False,
175+
class_weight=None,
176+
reduction='mean',
177+
avg_factor=None,
178+
ignore_index=255):
179+
"""Multi-class Lovasz-Softmax loss.
180+
181+
Args:
182+
probs (torch.Tensor): [B, C, H, W], class probabilities at each
183+
prediction (between 0 and 1).
184+
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
185+
C - 1).
186+
classes (str | list[int], optional): Classes choosed to calculate loss.
187+
'all' for all classes, 'present' for classes present in labels, or
188+
a list of classes to average. Default: 'present'.
189+
per_image (bool, optional): If per_image is True, compute the loss per
190+
image instead of per batch. Default: False.
191+
class_weight (list[float], optional): The weight for each class.
192+
Default: None.
193+
reduction (str, optional): The method used to reduce the loss. Options
194+
are "none", "mean" and "sum". This parameter only works when
195+
per_image is True. Default: 'mean'.
196+
avg_factor (int, optional): Average factor that is used to average
197+
the loss. This parameter only works when per_image is True.
198+
Default: None.
199+
ignore_index (int | None): The label index to be ignored. Default: 255.
200+
201+
Returns:
202+
torch.Tensor: The calculated loss.
203+
"""
204+
205+
if per_image:
206+
loss = [
207+
lovasz_softmax_flat(
208+
*flatten_probs(
209+
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
210+
classes=classes,
211+
class_weight=class_weight)
212+
for prob, label in zip(probs, labels)
213+
]
214+
loss = weight_reduce_loss(
215+
torch.stack(loss), None, reduction, avg_factor)
216+
else:
217+
loss = lovasz_softmax_flat(
218+
*flatten_probs(probs, labels, ignore_index),
219+
classes=classes,
220+
class_weight=class_weight)
221+
return loss
222+
223+
224+
@LOSSES.register_module()
225+
class LovaszLoss(nn.Module):
226+
"""LovaszLoss.
227+
228+
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
229+
for the optimization of the intersection-over-union measure in neural
230+
networks <https://arxiv.org/abs/1705.08790>`_.
231+
232+
Args:
233+
loss_type (str, optional): Binary or multi-class loss.
234+
Default: 'multi_class'. Options are "binary" and "multi_class".
235+
classes (str | list[int], optional): Classes choosed to calculate loss.
236+
'all' for all classes, 'present' for classes present in labels, or
237+
a list of classes to average. Default: 'present'.
238+
per_image (bool, optional): If per_image is True, compute the loss per
239+
image instead of per batch. Default: False.
240+
reduction (str, optional): The method used to reduce the loss. Options
241+
are "none", "mean" and "sum". This parameter only works when
242+
per_image is True. Default: 'mean'.
243+
class_weight (list[float], optional): The weight for each class.
244+
Default: None.
245+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
246+
"""
247+
248+
def __init__(self,
249+
loss_type='multi_class',
250+
classes='present',
251+
per_image=False,
252+
reduction='mean',
253+
class_weight=None,
254+
loss_weight=1.0):
255+
super(LovaszLoss, self).__init__()
256+
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
257+
'binary' or 'multi_class'."
258+
259+
if loss_type == 'binary':
260+
self.cls_criterion = lovasz_hinge
261+
else:
262+
self.cls_criterion = lovasz_softmax
263+
assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
264+
if not per_image:
265+
assert reduction == 'none', "reduction should be 'none' when \
266+
per_image is False."
267+
268+
self.classes = classes
269+
self.per_image = per_image
270+
self.reduction = reduction
271+
self.loss_weight = loss_weight
272+
self.class_weight = class_weight
273+
274+
def forward(self,
275+
cls_score,
276+
label,
277+
weight=None,
278+
avg_factor=None,
279+
reduction_override=None,
280+
**kwargs):
281+
"""Forward function."""
282+
assert reduction_override in (None, 'none', 'mean', 'sum')
283+
reduction = (
284+
reduction_override if reduction_override else self.reduction)
285+
if self.class_weight is not None:
286+
class_weight = cls_score.new_tensor(self.class_weight)
287+
else:
288+
class_weight = None
289+
290+
# if multi-class loss, transform logits to probs
291+
if self.cls_criterion == lovasz_softmax:
292+
cls_score = F.softmax(cls_score, dim=1)
293+
294+
loss_cls = self.loss_weight * self.cls_criterion(
295+
cls_score,
296+
label,
297+
self.classes,
298+
self.per_image,
299+
class_weight=class_weight,
300+
reduction=reduction,
301+
avg_factor=avg_factor,
302+
**kwargs)
303+
return loss_cls

tests/test_models/test_losses.py

+60
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,63 @@ def test_accuracy():
142142
with pytest.raises(AssertionError):
143143
accuracy = Accuracy()
144144
accuracy(pred[:, :, None], true_label)
145+
146+
147+
def test_lovasz_loss():
148+
from mmseg.models import build_loss
149+
150+
# loss_type should be 'binary' or 'multi_class'
151+
with pytest.raises(AssertionError):
152+
loss_cfg = dict(
153+
type='LovaszLoss',
154+
loss_type='Binary',
155+
reduction='none',
156+
loss_weight=1.0)
157+
build_loss(loss_cfg)
158+
159+
# reduction should be 'none' when per_image is False.
160+
with pytest.raises(AssertionError):
161+
loss_cfg = dict(type='LovaszLoss', loss_type='multi_class')
162+
build_loss(loss_cfg)
163+
164+
# test lovasz loss with loss_type = 'multi_class' and per_image = False
165+
loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0)
166+
lovasz_loss = build_loss(loss_cfg)
167+
logits = torch.rand(1, 3, 4, 4)
168+
labels = (torch.rand(1, 4, 4) * 2).long()
169+
lovasz_loss(logits, labels)
170+
171+
# test lovasz loss with loss_type = 'multi_class' and per_image = True
172+
loss_cfg = dict(
173+
type='LovaszLoss',
174+
per_image=True,
175+
reduction='mean',
176+
class_weight=[1.0, 2.0, 3.0],
177+
loss_weight=1.0)
178+
lovasz_loss = build_loss(loss_cfg)
179+
logits = torch.rand(1, 3, 4, 4)
180+
labels = (torch.rand(1, 4, 4) * 2).long()
181+
lovasz_loss(logits, labels, ignore_index=None)
182+
183+
# test lovasz loss with loss_type = 'binary' and per_image = False
184+
loss_cfg = dict(
185+
type='LovaszLoss',
186+
loss_type='binary',
187+
reduction='none',
188+
loss_weight=1.0)
189+
lovasz_loss = build_loss(loss_cfg)
190+
logits = torch.rand(2, 4, 4)
191+
labels = (torch.rand(2, 4, 4)).long()
192+
lovasz_loss(logits, labels)
193+
194+
# test lovasz loss with loss_type = 'binary' and per_image = True
195+
loss_cfg = dict(
196+
type='LovaszLoss',
197+
loss_type='binary',
198+
per_image=True,
199+
reduction='mean',
200+
loss_weight=1.0)
201+
lovasz_loss = build_loss(loss_cfg)
202+
logits = torch.rand(2, 4, 4)
203+
labels = (torch.rand(2, 4, 4)).long()
204+
lovasz_loss(logits, labels, ignore_index=None)

0 commit comments

Comments
 (0)