Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify the Dice loss #376

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

zifuwanggg
Copy link

@zifuwanggg zifuwanggg commented Oct 11, 2024

The Dice loss in training.loss_fns is modified based on JDTLoss and segmentation_models.pytorch.

The original Dice loss is incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, it is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{|x|_1 + |y|_1 - |x-y|_1}{2}$. This reformulation has been proven to retain equivalence with the original version when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new version is minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, it resolves the issue with soft labels [1, 2].

Although the original SAM/SAM2 models were trained without soft labels, this modification enables soft label training for downstream fine-tuning without changing the existing behavior.

Example

import torch
import torch.linalg as LA
import torch.nn.functional as F

torch.manual_seed(0)

b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)

pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2)

def dice_old(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    intersection = torch.sum(x * y, dim=dims)
    return 2 * intersection / cardinality

def dice_new(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    difference = LA.vector_norm(x - y, ord=1, dim=dims)
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality

print(dice_old(pred, one_hot_label, dims), dice_new(pred, one_hot_label, dims))
print(dice_old(pred, soft_label, dims), dice_new(pred, soft_label, dims))
print(dice_old(pred, pred, dims), dice_new(pred, pred, dims))

# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])

References

[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.

@z-jiaming
Copy link

@zifuwanggg Thanks for your commit, it's beneficial for me.

But after I used your code, neither (_hard_label, hard_label) nor (_soft_label, soft_label) got a result close to 0. I would like to know how I can change it.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# from collections import defaultdict
# from typing import Dict, List

import torch
import torch.distributed
import torch.linalg as LA
import torch.nn as nn
import torch.nn.functional as F

# from training.trainer import CORE_LOSS_KEY

# from training.utils.distributed import get_world_size, is_dist_avail_and_initialized

def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        num_objects: Number of objects in the batch
        loss_on_multimask: True if multimask prediction is enabled
    Returns:
        Dice loss tensor
    """
    inputs = inputs.sigmoid()
    if loss_on_multimask:
        # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
        assert inputs.dim() == 4 and targets.dim() == 4
        # flatten spatial dimension while keeping multimask channel dimension
        inputs = inputs.flatten(2)
        targets = targets.flatten(2)
        numerator = 2 * (inputs * targets).sum(-1)
    else:
        inputs = inputs.flatten(1)
        numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    if loss_on_multimask:
        return loss / num_objects
    return loss.sum() / num_objects

def dice_loss_new(inputs, targets, num_objects, loss_on_multimask=False):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Reference:
        Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels.
                Wang, Z. et. al. MICCAI 2023.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        num_objects: Number of objects in the batch
        loss_on_multimask: True if multimask prediction is enabled
    Returns:
        Dice loss tensor
    """
    inputs = inputs.sigmoid()
    if loss_on_multimask:
        # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
        assert inputs.dim() == 4 and targets.dim() == 4
        # flatten spatial dimension while keeping multimask channel dimension
        inputs = inputs.flatten(2)
        targets = targets.flatten(2)
    else:
        inputs = inputs.flatten(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    difference = LA.vector_norm(inputs - targets, ord=1, dim=-1)
    numerator = (denominator - difference) / 2
    loss = 1 - (numerator + 1) / (denominator + 1)
    if loss_on_multimask:
        return loss / num_objects
    return loss.sum() / num_objects


B, N, H, W = 4, 3, 10, 10

pred_mask = torch.rand(B, N, H, W)
hard_label = torch.randint(0, 2, (B, N, H, W))
soft_label = torch.rand(B, N, H, W)

out_scale, out_bias = 20.0, -10.0
_pred_mask = pred_mask * out_scale + out_bias
_hard_label = hard_label * out_scale + out_bias
_soft_label = soft_label * out_scale + out_bias

print(dice_loss(_pred_mask, pred_mask, B, loss_on_multimask=True))
print(dice_loss(_pred_mask, hard_label, B, loss_on_multimask=True))
print(dice_loss(_pred_mask, soft_label, B, loss_on_multimask=True))
print(dice_loss(_hard_label, hard_label, B, loss_on_multimask=True))
print(dice_loss(_soft_label, soft_label, B, loss_on_multimask=True))


print(dice_loss_new(_pred_mask, pred_mask, B, loss_on_multimask=True))
print(dice_loss_new(_pred_mask, hard_label, B, loss_on_multimask=True))
print(dice_loss_new(_pred_mask, soft_label, B, loss_on_multimask=True))
print(dice_loss_new(_hard_label, hard_label, B, loss_on_multimask=True))
print(dice_loss_new(_soft_label, soft_label, B, loss_on_multimask=True))

Its output:

tensor([[0.0756, 0.0542, 0.0678],
        [0.0701, 0.0715, 0.0490],
        [0.0621, 0.0610, 0.0625],
        [0.0578, 0.0611, 0.0718]])
tensor([[0.1287, 0.1192, 0.1156],
        [0.1324, 0.1408, 0.0959],
        [0.1331, 0.1358, 0.1431],
        [0.1195, 0.1217, 0.1240]])
tensor([[0.1321, 0.1090, 0.1310],
        [0.1268, 0.1212, 0.1004],
        [0.1174, 0.1274, 0.1366],
        [0.1215, 0.1168, 0.1289]])
tensor([[1.2204e-05, 1.2770e-05, 1.0207e-05],
        [1.1235e-05, 1.1697e-05, 1.1459e-05],
        [1.0371e-05, 1.1459e-05, 1.3068e-05],
        [1.2472e-05, 1.1712e-05, 1.1221e-05]])
tensor([[0.0731, 0.0692, 0.0667],
        [0.0539, 0.0582, 0.0650],
        [0.0640, 0.0703, 0.0634],
        [0.0621, 0.0645, 0.0699]])


tensor([[0.1481, 0.1417, 0.1485],
        [0.1457, 0.1486, 0.1431],
        [0.1463, 0.1460, 0.1475],
        [0.1443, 0.1447, 0.1491]])
tensor([[0.1880, 0.1834, 0.1816],
        [0.1899, 0.1941, 0.1718],
        [0.1903, 0.1916, 0.1952],
        [0.1835, 0.1846, 0.1857]])
tensor([[0.1811, 0.1718, 0.1838],
        [0.1811, 0.1782, 0.1704],
        [0.1764, 0.1825, 0.1880],
        [0.1790, 0.1760, 0.1801]])
tensor([[0.1237, 0.1236, 0.1239],
        [0.1238, 0.1237, 0.1237],
        [0.1239, 0.1237, 0.1236],
        [0.1236, 0.1237, 0.1238]])
tensor([[0.1471, 0.1463, 0.1474],
        [0.1432, 0.1438, 0.1450],
        [0.1474, 0.1459, 0.1452],
        [0.1454, 0.1452, 0.1479]])

@zifuwanggg
Copy link
Author

@z-jiaming, thanks for the question.

_soft_label is a linear transformation of soft_label, and the loss is not expected to return 0 under this case. Why would you think the loss should return 0 for (_soft_label, soft_label)? Please also note that there is a sigmoid inside the loss function.

@z-jiaming
Copy link

z-jiaming commented Dec 5, 2024

@zifuwanggg Thanks for your reply!
But when I remove the sigmoid function, the losses are not zero, like these:

print(dice_loss(pred_mask, pred_mask, B, loss_on_multimask=True))
print(dice_loss(pred_mask, hard_label, B, loss_on_multimask=True))
print(dice_loss(pred_mask, soft_label, B, loss_on_multimask=True))
print(dice_loss(hard_label, hard_label, B, loss_on_multimask=True))
print(dice_loss(soft_label, soft_label, B, loss_on_multimask=True))

print("\n")

print(dice_loss_new(pred_mask, pred_mask, B, loss_on_multimask=True))
print(dice_loss_new(pred_mask, hard_label, B, loss_on_multimask=True))
print(dice_loss_new(pred_mask, soft_label, B, loss_on_multimask=True))
print(dice_loss_new(hard_label, hard_label, B, loss_on_multimask=True))
print(dice_loss_new(soft_label, soft_label, B, loss_on_multimask=True))
tensor([[0.0887, 0.0846, 0.0834],
        [0.0830, 0.0920, 0.0995],
        [0.0792, 0.0877, 0.0892],
        [0.0907, 0.0855, 0.0647]])
tensor([[0.1197, 0.1250, 0.1220],
        [0.1343, 0.1443, 0.1222],
        [0.1312, 0.1396, 0.1263],
        [0.1474, 0.1343, 0.1089]])
tensor([[0.1284, 0.1251, 0.1318],
        [0.1172, 0.1256, 0.1283],
        [0.1199, 0.1290, 0.1250],
        [0.1324, 0.1289, 0.1153]])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([[0.0844, 0.0795, 0.0910],
        [0.0890, 0.0767, 0.0815],
        [0.0804, 0.0834, 0.0820],
        [0.0891, 0.0900, 0.0764]])


tensor([[0.1237, 0.1237, 0.1238],
        [0.1238, 0.1237, 0.1236],
        [0.1238, 0.1237, 0.1237],
        [0.1237, 0.1237, 0.1239]])
tensor([[0.1836, 0.1862, 0.1848],
        [0.1909, 0.1959, 0.1848],
        [0.1893, 0.1935, 0.1868],
        [0.1973, 0.1909, 0.1783]])
tensor([[0.1665, 0.1645, 0.1693],
        [0.1583, 0.1661, 0.1653],
        [0.1641, 0.1679, 0.1655],
        [0.1672, 0.1670, 0.1654]])
tensor([[0.1237, 0.1236, 0.1239],
        [0.1237, 0.1238, 0.1238],
        [0.1235, 0.1236, 0.1237],
        [0.1234, 0.1237, 0.1236]])
tensor([[0.1237, 0.1238, 0.1237],
        [0.1237, 0.1239, 0.1238],
        [0.1238, 0.1238, 0.1238],
        [0.1237, 0.1237, 0.1238]])

only the print(dice_loss(hard_label, hard_label, B, loss_on_multimask=True)) is 0

@zifuwanggg
Copy link
Author

@z-jiaming, sorry I made a mistake.

numerator = (denominator - difference) / 2
should be changed to
numerator = denominator - difference

Now the loss should work as expected.

@z-jiaming
Copy link

Thanks a lot!!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants