Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ab5ebc8

Browse files
author
谢昕辰
authoredMar 29, 2021
Merge 22f9fb4 into 340132d
2 parents 340132d + 22f9fb4 commit ab5ebc8

File tree

2 files changed

+19
-27
lines changed

2 files changed

+19
-27
lines changed
 

‎mmseg/models/losses/dice_loss.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def dice_loss(pred,
1515
smooth=1,
1616
exponent=2,
1717
class_weight=None,
18-
ignore_index=-1):
18+
ignore_index=255):
1919
assert pred.shape[0] == target.shape[0]
2020
total_loss = 0
2121
num_classes = pred.shape[1]
@@ -36,9 +36,9 @@ def dice_loss(pred,
3636
@weighted_loss
3737
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
3838
assert pred.shape[0] == target.shape[0]
39-
pred = pred.contiguous().view(pred.shape[0], -1)
40-
target = target.contiguous().view(target.shape[0], -1)
41-
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)
39+
pred = pred.reshape(pred.shape[0], -1)
40+
target = target.reshape(target.shape[0], -1)
41+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
4242

4343
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
4444
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
@@ -70,27 +70,27 @@ class DiceLoss(nn.Module):
7070
"""
7171

7272
def __init__(self,
73-
loss_type='multi_class',
7473
smooth=1,
7574
exponent=2,
7675
reduction='mean',
7776
class_weight=None,
7877
loss_weight=1.0,
79-
ignore_index=255):
78+
ignore_index=255,
79+
**kwards):
8080
super(DiceLoss, self).__init__()
81-
assert loss_type in ['multi_class', 'binary']
82-
if loss_type == 'multi_class':
83-
self.cls_criterion = dice_loss
84-
else:
85-
self.cls_criterion = binary_dice_loss
8681
self.smooth = smooth
8782
self.exponent = exponent
8883
self.reduction = reduction
8984
self.class_weight = class_weight
9085
self.loss_weight = loss_weight
9186
self.ignore_index = ignore_index
9287

93-
def forward(self, pred, target, avg_factor=None, reduction_override=None):
88+
def forward(self,
89+
pred,
90+
target,
91+
avg_factor=None,
92+
reduction_override=None,
93+
**kwards):
9494
assert reduction_override in (None, 'none', 'mean', 'sum')
9595
reduction = (
9696
reduction_override if reduction_override else self.reduction)
@@ -100,10 +100,13 @@ def forward(self, pred, target, avg_factor=None, reduction_override=None):
100100
class_weight = None
101101

102102
pred = F.softmax(pred, dim=1)
103-
one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0))
103+
num_classes = pred.shape[1]
104+
one_hot_target = F.one_hot(
105+
torch.clamp(target.long(), 0, num_classes - 1),
106+
num_classes=num_classes)
104107
valid_mask = (target != self.ignore_index).long()
105108

106-
loss = self.loss_weight * self.cls_criterion(
109+
loss = self.loss_weight * dice_loss(
107110
pred,
108111
one_hot_target,
109112
valid_mask=valid_mask,

‎tests/test_models/test_losses.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -207,19 +207,9 @@ def test_lovasz_loss():
207207
def test_dice_lose():
208208
from mmseg.models import build_loss
209209

210-
# loss_type should be 'binary' or 'multi_class'
211-
with pytest.raises(AssertionError):
212-
loss_cfg = dict(
213-
type='DiceLoss',
214-
loss_type='Binary',
215-
reduction='none',
216-
loss_weight=1.0)
217-
build_loss(loss_cfg)
218-
219210
# test dice loss with loss_type = 'multi_class'
220211
loss_cfg = dict(
221212
type='DiceLoss',
222-
loss_type='multi_class',
223213
reduction='none',
224214
class_weight=[1.0, 2.0, 3.0],
225215
loss_weight=1.0,
@@ -232,13 +222,12 @@ def test_dice_lose():
232222
# test dice loss with loss_type = 'binary'
233223
loss_cfg = dict(
234224
type='DiceLoss',
235-
loss_type='binary',
236225
smooth=2,
237226
exponent=3,
238227
reduction='sum',
239228
loss_weight=1.0,
240229
ignore_index=0)
241230
dice_loss = build_loss(loss_cfg)
242-
logits = torch.rand(16, 4, 4)
243-
labels = (torch.rand(16, 4, 4)).long()
231+
logits = torch.rand(8, 2, 4, 4)
232+
labels = (torch.rand(8, 4, 4) * 2).long()
244233
dice_loss(logits, labels)

0 commit comments

Comments
 (0)
Please sign in to comment.