Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dice loss #396

Merged
merged 4 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmseg/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .accuracy import Accuracy, accuracy
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .lovasz_loss import LovaszLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss

__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
]
116 changes: 116 additions & 0 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
segmentron/solver/loss.py (Apache-2.0 License)"""
import torch
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss


@weighted_loss
def dice_loss(pred,
target,
valid_mask,
smooth=1,
exponent=2,
class_weight=None,
ignore_index=-1):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use for loop might be inefficient? Some implementation support to process multi-class in a batched manner.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i != ignore_index:
dice_loss = binary_dice_loss(
pred[:, i],
target[..., i],
valid_mask=valid_mask,
smooth=smooth,
exponent=exponent)
if class_weight is not None:
dice_loss *= class_weight[i]
total_loss += dice_loss
return total_loss / num_classes


@weighted_loss
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
assert pred.shape[0] == target.shape[0]
pred = pred.contiguous().view(pred.shape[0], -1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.contiguous().view() can be replaced by reshape?

target = target.contiguous().view(target.shape[0], -1)
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)

num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth

return 1 - num / den


@LOSSES.register_module()
class DiceLoss(nn.Module):
"""DiceLoss.

This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.

Args:
loss_type (str, optional): Binary or multi-class loss.
Default: 'multi_class'. Options are "binary" and "multi_class".
smooth (float): A float number to smooth loss, and avoid NaN error.
Default: 1
exponent (float): An float number to calculate denominator
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
"""

def __init__(self,
loss_type='multi_class',
smooth=1,
exponent=2,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255):
super(DiceLoss, self).__init__()
assert loss_type in ['multi_class', 'binary']
if loss_type == 'multi_class':
self.cls_criterion = dice_loss
else:
self.cls_criterion = binary_dice_loss
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = class_weight
self.loss_weight = loss_weight
self.ignore_index = ignore_index

def forward(self, pred, target, avg_factor=None, reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None

pred = F.softmax(pred, dim=1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In instance segmentation, dice loss may use sigmoid for activation. Suggest supporting both cases.

one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0))
valid_mask = (target != self.ignore_index).long()

loss = self.loss_weight * self.cls_criterion(
pred,
one_hot_target,
valid_mask=valid_mask,
reduction=reduction,
avg_factor=avg_factor,
smooth=self.smooth,
exponent=self.exponent,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss
40 changes: 40 additions & 0 deletions tests/test_models/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,43 @@ def test_lovasz_loss():
logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels, ignore_index=None)


def test_dice_lose():
from mmseg.models import build_loss

# loss_type should be 'binary' or 'multi_class'
with pytest.raises(AssertionError):
loss_cfg = dict(
type='DiceLoss',
loss_type='Binary',
reduction='none',
loss_weight=1.0)
build_loss(loss_cfg)

# test dice loss with loss_type = 'multi_class'
loss_cfg = dict(
type='DiceLoss',
loss_type='multi_class',
reduction='none',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)

# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
loss_type='binary',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(16, 4, 4)
labels = (torch.rand(16, 4, 4)).long()
dice_loss(logits, labels)