-
Notifications
You must be signed in to change notification settings - Fork 2.7k
dice loss #396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dice loss #396
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
from .accuracy import Accuracy, accuracy | ||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, | ||
cross_entropy, mask_cross_entropy) | ||
from .dice_loss import DiceLoss | ||
from .lovasz_loss import LovaszLoss | ||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss | ||
|
||
__all__ = [ | ||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', | ||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', | ||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss' | ||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss' | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ | ||
segmentron/solver/loss.py (Apache-2.0 License)""" | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from ..builder import LOSSES | ||
from .utils import weighted_loss | ||
|
||
|
||
@weighted_loss | ||
def dice_loss(pred, | ||
target, | ||
valid_mask, | ||
smooth=1, | ||
exponent=2, | ||
class_weight=None, | ||
ignore_index=-1): | ||
assert pred.shape[0] == target.shape[0] | ||
total_loss = 0 | ||
num_classes = pred.shape[1] | ||
for i in range(num_classes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use for loop might be inefficient? Some implementation support to process multi-class in a batched manner. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if i != ignore_index: | ||
dice_loss = binary_dice_loss( | ||
pred[:, i], | ||
target[..., i], | ||
valid_mask=valid_mask, | ||
smooth=smooth, | ||
exponent=exponent) | ||
if class_weight is not None: | ||
dice_loss *= class_weight[i] | ||
total_loss += dice_loss | ||
return total_loss / num_classes | ||
|
||
|
||
@weighted_loss | ||
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): | ||
assert pred.shape[0] == target.shape[0] | ||
pred = pred.contiguous().view(pred.shape[0], -1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
target = target.contiguous().view(target.shape[0], -1) | ||
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1) | ||
|
||
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth | ||
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth | ||
|
||
return 1 - num / den | ||
|
||
|
||
@LOSSES.register_module() | ||
class DiceLoss(nn.Module): | ||
"""DiceLoss. | ||
|
||
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for | ||
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_. | ||
|
||
Args: | ||
loss_type (str, optional): Binary or multi-class loss. | ||
Default: 'multi_class'. Options are "binary" and "multi_class". | ||
smooth (float): A float number to smooth loss, and avoid NaN error. | ||
Default: 1 | ||
exponent (float): An float number to calculate denominator | ||
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. | ||
reduction (str, optional): The method used to reduce the loss. Options | ||
are "none", "mean" and "sum". This parameter only works when | ||
per_image is True. Default: 'mean'. | ||
class_weight (list[float], optional): The weight for each class. | ||
Default: None. | ||
loss_weight (float, optional): Weight of the loss. Default to 1.0. | ||
ignore_index (int | None): The label index to be ignored. Default: 255. | ||
""" | ||
|
||
def __init__(self, | ||
loss_type='multi_class', | ||
smooth=1, | ||
exponent=2, | ||
reduction='mean', | ||
class_weight=None, | ||
loss_weight=1.0, | ||
ignore_index=255): | ||
super(DiceLoss, self).__init__() | ||
assert loss_type in ['multi_class', 'binary'] | ||
if loss_type == 'multi_class': | ||
self.cls_criterion = dice_loss | ||
else: | ||
self.cls_criterion = binary_dice_loss | ||
self.smooth = smooth | ||
self.exponent = exponent | ||
self.reduction = reduction | ||
self.class_weight = class_weight | ||
self.loss_weight = loss_weight | ||
self.ignore_index = ignore_index | ||
|
||
def forward(self, pred, target, avg_factor=None, reduction_override=None): | ||
assert reduction_override in (None, 'none', 'mean', 'sum') | ||
reduction = ( | ||
reduction_override if reduction_override else self.reduction) | ||
if self.class_weight is not None: | ||
class_weight = pred.new_tensor(self.class_weight) | ||
else: | ||
class_weight = None | ||
|
||
pred = F.softmax(pred, dim=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In instance segmentation, dice loss may use sigmoid for activation. Suggest supporting both cases. |
||
one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0)) | ||
valid_mask = (target != self.ignore_index).long() | ||
|
||
loss = self.loss_weight * self.cls_criterion( | ||
pred, | ||
one_hot_target, | ||
valid_mask=valid_mask, | ||
reduction=reduction, | ||
avg_factor=avg_factor, | ||
smooth=self.smooth, | ||
exponent=self.exponent, | ||
class_weight=class_weight, | ||
ignore_index=self.ignore_index) | ||
return loss |
Uh oh!
There was an error while loading. Please reload this page.