Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dice Loss #6960

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Add Dice Loss #6960

wants to merge 11 commits into from

Conversation

pri1311
Copy link

@pri1311 pri1311 commented Nov 18, 2022

Hi, have added an implementation for dice loss as mentioned in #6435

Will add tests once the correctness of the function is confirmed.

@facebook-github-bot
Copy link

Hi @pri1311!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @pri1311! I'll leave it to @oke-aditya to help you check the validity of the implementation and write the tests. The approach looks good overall. I've added just a minor comment below.

torchvision/ops/dice_loss.py Outdated Show resolved Hide resolved
@datumbox datumbox changed the title [New Feature] Dice Loss Add Dice Loss Nov 18, 2022
@datumbox datumbox linked an issue Nov 18, 2022 that may be closed by this pull request
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

Copy link
Contributor

@oke-aditya oke-aditya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @pri1311

The implementation looks correct. Seems like this is referred from Kornia
https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/dice.html#dice_loss

Before we go into adding tests / review implementation further.
We need to think about. cc @datumbox

Do we want Dice Loss to support (B C H W)?

  1. Few libraries like kornia do this.

  2. Few libraries like PaddlePaddle use (B, Num_classes) to compute
    See Add dice loss PaddlePaddle/Paddle#10717
    https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/nn/functional/loss.py#L41

We can either compute the dice loss on a single channel and leave for end user to combine them. Or Provide a BCHW version of loss. Both sound good but let's finalize on that.

torchvision/ops/dice_loss.py Outdated Show resolved Hide resolved
torchvision/ops/dice_loss.py Outdated Show resolved Hide resolved
torchvision/ops/dice_loss.py Outdated Show resolved Hide resolved
torchvision/ops/dice_loss.py Outdated Show resolved Hide resolved
torchvision/ops/dice_loss.py Outdated Show resolved Hide resolved
torchvision/ops/dice_loss.py Outdated Show resolved Hide resolved
@oke-aditya
Copy link
Contributor

A Quick example of how this loss works currently.

import torch
import torch.nn.functional as F

N = 5
input2 = torch.randn(1, N, 3, 5, requires_grad=True)
target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
target_one_hot: torch.Tensor = F.one_hot(target, num_classes=input2.shape[1])
target_one_hot = target_one_hot.permute((0, 3, 1, 2))
dl = dice_loss(input2, target_one_hot)

@pri1311
Copy link
Author

pri1311 commented Nov 20, 2022

2. Few libraries like PaddlePaddle use (B, Num_classes) to compute
See PaddlePaddle/Paddle#10717
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/nn/functional/loss.py#L41

    # compute softmax over the classes axis
    p = F.softmax(inputs, dim=-1)
    p = p.reshape(p.shape[0], -1)
    targets = F.one_hot(targets) #Or we could directly have one-hot tensor as input
    targets = targets.reshape(p.shape[0], -1)

    intersection = torch.sum(p * targets, dim=-1)
    cardinality = torch.sum(p + targets, dim=-1)

    dice_score = 2.0 * intersection / (cardinality + eps)

    loss = 1.0 - dice_score

This will support shape (B, any_shape_here, num_classes) for input and (B, any_shape_here) for target/label

Screenshot 2022-11-20 at 5 08 45 PM

Although, I believe the output of torchvision models is (B, C, H, W)

@oke-aditya
Copy link
Contributor

oke-aditya commented Nov 20, 2022

Yeah this is what I had in my mind.
Although leaving one hot to end user is best thing. Our function should not do one hot as there are various strategies to one hot.
If we compute softmax over all the axis.
Very similar to implementation I had in my repo.

Notice now that the tensor we return also depends on the input.

@pri1311
Copy link
Author

pri1311 commented Dec 16, 2022

Firstly, apologies for the delay, I had my end-term exams going on so had to put the PR on hold.

I think this implementation should align with the requirements.
I believe I should move on to adding tests, but I believe I will require a little help there.

@oke-aditya
Copy link
Contributor

oke-aditya commented Dec 16, 2022

Sure, can you let me know what help you need in writing tests? @pri1311

Simple way is to see tests for other losses, like generalized_box_iou_loss etc.

You need to test it for few predefined values and check if it satisfies.

Something like this for fixed values and assert the values.

#6960 (comment)

@pri1311
Copy link
Author

pri1311 commented Dec 22, 2022

Hi @oke-aditya, I've added the basic tests, could you check once?

Also, I have trouble understanding the failing checks, could you maybe point out where I've gone wrong?

@@ -1770,5 +1770,53 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


class TestDiceLoss:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to write this in a new file. Already test_ops.py is a huge file >1.5k lines of code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah quick check shows that all have written tests for losses to this file. Let this stay for now, thoughts @pmeier @NicolasHug ?

test/test_ops.py Outdated Show resolved Hide resolved
expected = reduction_fn(expected)
torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros, reduction=reduction), expected)

@pytest.mark.parametrize("device", cpu_and_gpu())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if gradcheck is needed. Can you provide some reference of why you added this.
cc @YosuaMichael.

Copy link
Author

@pri1311 pri1311 Jan 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to follow the test cases for Focal Loss, and I believe it checks for grad although not in a separate function.

@oke-aditya
Copy link
Contributor

oke-aditya commented Dec 22, 2022

Hi @pri1311 Thanks for continuing the work!

I had a quick review, although I have not fully validated the tests yet.

Note that the test failure is unrelated to this change, you can do that by viewing the logs of the test failure logs (No need to Login to Circle CI)

https://app.circleci.com/pipelines/github/pytorch/vision/22420/workflows/cfac5971-ed1c-4ef6-9ef2-aa8a94a848af/jobs/1801044

But the lint failure seems to be related to this change. You need to lint the code as mentioned in the contributing.md file. (either use pre-commit hooks or lint manually)

https://app.circleci.com/pipelines/github/pytorch/vision/22420/workflows/0908373e-af80-48ec-b8be-fb55ce07ce10/jobs/1801049

Edit:
Merry XMas 🎄 🎅 . Due to vacations, code review might be delayed, but we are on it.

@oke-aditya
Copy link
Contributor

I will review this, validate and get back to you soon!

@oke-aditya
Copy link
Contributor

I'm sorry for the delay in review of this PR. I will be reviewing this over weekend definitely. It's been really hard for me this year but I will be back to this.

Copy link
Contributor

@oke-aditya oke-aditya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unable to match the results with kornia. Let me post what I validated.

test/test_ops.py Outdated Show resolved Hide resolved
test/test_ops.py Outdated Show resolved Hide resolved
torchvision/ops/dice_loss.py Show resolved Hide resolved
@oke-aditya
Copy link
Contributor

oke-aditya commented Jan 29, 2023

This is a modified dice loss from kornia. With the one_hot being commented and torch.mean is re[laced to 1.0 - dice_loss


def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    r"""Criterion that computes Sørensen-Dice Coefficient loss.

    According to [1], we compute the Sørensen-Dice Coefficient as follows:

    .. math::

        \text{Dice}(x, class) = \frac{2 |X \cap Y|}{|X| + |Y|}

    Where:
       - :math:`X` expects to be the scores of each class.
       - :math:`Y` expects to be the one-hot tensor with the class labels.

    the loss, is finally computed as:

    .. math::

        \text{loss}(x, class) = 1 - \text{Dice}(x, class)

    Reference:
        [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

    Args:
        input: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes.
        labels: labels tensor with shape :math:`(N, H, W)` where each value
          is :math:`0 ≤ targets[i] ≤ C−1`.
        eps: Scalar to enforce numerical stabiliy.

    Return:
        the computed loss.

    Example:
        >>> N = 5  # num_classes
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = dice_loss(input, target)
        >>> output.backward()
    """
    if not isinstance(input, torch.Tensor):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    if not len(input.shape) == 4:
        raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}")

    if not input.shape[-2:] == target.shape[-2:]:
        raise ValueError(f"input and target shapes must be the same. Got: {input.shape} and {target.shape}")

    if not input.device == target.device:
        raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}")

    # compute softmax over the classes axis
    input_soft: torch.Tensor = F.softmax(input, dim=1)
    
    # THERE IS A COMMENT HERE!
    # create the labels one hot tensor
    #target: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype)

    # compute the actual dice score
    dims = (1, 2, 3)
    intersection = torch.sum(input_soft * target, dims)
    cardinality = torch.sum(input_soft + target, dims)

    dice_score = 2.0 * intersection / (cardinality + eps)

    return 1.0 - dice_score

And this is my verification script. I simply copied your code and named it as dltv while kornia above file is dlkor

from dltv import dice_loss
from dlkor import dice_loss as dice_loss_kornia
import torch
import torch.nn.functional as F

device = "cpu"

#expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device)


if __name__ == '__main__':

    shape = (16, 4, 4, 2)
    input_zeros = torch.zeros(shape, device=device)
    input_zeros[:, :, :, 0] = 1.0
    input_zeros[:, :, :, 1] = 0.0
    label_zeros = torch.zeros(shape, device=device)
    label_zeros.copy_(input_zeros)
    input_zeros[:, :, :, 0] = 100.0
    expected = torch.zeros(16, device=device)

    # print(input_t.size())

    out = dice_loss(input_zeros, label_zeros, eps=0)
    print(out)

    out2 = dice_loss_kornia(input_zeros, label_zeros, eps=0)
    print(out2)

Here is the O/P

image

Can you please help me understand why there is a difference? It might be something I'm unaware of or something that kornia does (by commenting one_hot I might have made a mistake.

@pri1311
Copy link
Author

pri1311 commented Jan 31, 2023

@oke-aditya I haven't yet given it a detailed look, but I think one thing of concern while comparing the two libraries is the input dimensions, for torchvision, we follow (B, N1, N2, ..., C) and kornia follows (B, C, H, W).

@oke-aditya
Copy link
Contributor

My concern is that results should have matched. At least for cases where we have all zeros. As such dice loss is well defined and adjusting for flexible shapes. Shouldn't affect the values.

@pri1311
Copy link
Author

pri1311 commented Jan 31, 2023

Results don't match because in this piece of code here we are specifying the one hot values as per the torchvision input format (B, N1, N2, .., C) (also followed by PaddleSeg)

input_zeros[:, :, :, 0] = 1.0
input_zeros[:, :, :, 1] = 0.0

But for Kornia we need these values on dimension 1, not the last dimension.
if you transpose the matrix, it should work fine. I can make changes so that torch-vision accepts (B, C, N1, N2....).

@oke-aditya
Copy link
Contributor

B C N H W is something which we assume in all the models as well as the codebase. So it would be much better and to use I think

@NicolasHug NicolasHug mentioned this pull request Feb 10, 2023
49 tasks
@oke-aditya
Copy link
Contributor

Thanks a lot for working on this @pri1311 I will do another round of validation of results. And we would be good.
Although we have to wait from approval from Nicolas (as he is the maintainer 😄 )

@chadrockey
Copy link

@NicolasHug looks like this missed 0.15. I tried this loss function and it seems to work well for segmentation problems with unbalanced classes. Its inclusion would be great, unless there's other functionality I'm overlooking that implements another dice loss function such that this PR is no longer needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

New Feature: Dice Loss
7 participants