Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dice loss #396

Merged
merged 4 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmseg/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .lovasz_loss import LovaszLoss
from .dice_loss import DiceLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss

__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
]
97 changes: 97 additions & 0 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def dice_loss(pred,
target,
valid_mask,
smooth=1,
exponent=2,
class_weight=None,
ignore_index=-1):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use for loop might be inefficient? Some implementation support to process multi-class in a batched manner.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i != ignore_index:
dice_loss = binary_dice_loss(pred[:, i], target[..., i], valid_mask=valid_mask, smooth=smooth, exponent=exponent)
if class_weight is not None:
dice_loss *= class_weight[i]
total_loss += dice_loss
return total_loss / num_classes

@weighted_loss
def binary_dice_loss(pred,
target,
valid_mask,
smooth=1,
exponent=2,
**kwards):
assert pred.shape[0] == target.shape[0]
pred = pred.contiguous().view(pred.shape[0], -1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.contiguous().view() can be replaced by reshape?

target = target.contiguous().view(target.shape[0], -1)
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)

num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
den = torch.sum((pred.pow(exponent) + target.pow(exponent)) * valid_mask, dim=1) + smooth
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may directly use denominator.


return 1 - num / den

@LOSSES.register_module()
class DiceLoss(nn.Module):
"""DiceLoss.

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add some docstring here.

def __init__(self,
loss_type='multi_class',
smooth=1,
exponent=2,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=-1):
super(DiceLoss, self).__init__()
assert loss_type in ['multi_class', 'binary']
if loss_type == 'multi_class':
self.cls_criterion = dice_loss
else:
self.cls_criterion = binary_dice_loss
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = class_weight
self.loss_weight = loss_weight
self.ignore_index = ignore_index

def forward(self,
pred,
target,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None

pred = F.softmax(pred, dim=1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In instance segmentation, dice loss may use sigmoid for activation. Suggest supporting both cases.

one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0))
valid_mask = (target != self.ignore_index).long()

loss = self.loss_weight * self.cls_criterion(
pred,
one_hot_target,
valid_mask=valid_mask,
reduction=reduction,
avg_factor=avg_factor,
smooth=self.smooth,
exponent=self.exponent,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss
40 changes: 40 additions & 0 deletions tests/test_models/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,43 @@ def test_lovasz_loss():
logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels, ignore_index=None)

def test_dice_lose():
from mmseg.models import build_loss
import sys

# loss_type should be 'binary' or 'multi_class'
with pytest.raises(AssertionError):
loss_cfg = dict(
type='DiceLoss',
loss_type='Binary',
reduction='none',
loss_weight=1.0)
build_loss(loss_cfg)

# test dice loss with loss_type = 'multi_class'
loss_cfg = dict(
type='DiceLoss',
loss_type='multi_class',
reduction='none',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)

# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
loss_type='binary',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(16, 4, 4)
labels = (torch.rand(16, 4, 4)).long()
dice_loss(logits, labels)