-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Dice Loss #6960
base: main
Are you sure you want to change the base?
Add Dice Loss #6960
Conversation
Hi @pri1311! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @pri1311! I'll leave it to @oke-aditya to help you check the validity of the implementation and write the tests. The approach looks good overall. I've added just a minor comment below.
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR @pri1311
The implementation looks correct. Seems like this is referred from Kornia
https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/dice.html#dice_loss
Before we go into adding tests / review implementation further.
We need to think about. cc @datumbox
Do we want Dice Loss to support (B C H W)?
-
Few libraries like kornia do this.
-
Few libraries like PaddlePaddle use (B, Num_classes) to compute
See Add dice loss PaddlePaddle/Paddle#10717
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/nn/functional/loss.py#L41
We can either compute the dice loss on a single channel and leave for end user to combine them. Or Provide a BCHW version of loss. Both sound good but let's finalize on that.
A Quick example of how this loss works currently.
|
This will support shape (B, any_shape_here, num_classes) for input and (B, any_shape_here) for target/label Although, I believe the output of torchvision models is (B, C, H, W) |
Yeah this is what I had in my mind. Notice now that the tensor we return also depends on the input. |
Firstly, apologies for the delay, I had my end-term exams going on so had to put the PR on hold. I think this implementation should align with the requirements. |
Sure, can you let me know what help you need in writing tests? @pri1311 Simple way is to see tests for other losses, like generalized_box_iou_loss etc. You need to test it for few predefined values and check if it satisfies. Something like this for fixed values and assert the values. |
Hi @oke-aditya, I've added the basic tests, could you check once? Also, I have trouble understanding the failing checks, could you maybe point out where I've gone wrong? |
@@ -1770,5 +1770,53 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): | |||
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs | |||
|
|||
|
|||
class TestDiceLoss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to write this in a new file. Already test_ops.py
is a huge file >1.5k lines of code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah quick check shows that all have written tests for losses to this file. Let this stay for now, thoughts @pmeier @NicolasHug ?
expected = reduction_fn(expected) | ||
torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros, reduction=reduction), expected) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if gradcheck is needed. Can you provide some reference of why you added this.
cc @YosuaMichael.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to follow the test cases for Focal Loss, and I believe it checks for grad although not in a separate function.
Hi @pri1311 Thanks for continuing the work! I had a quick review, although I have not fully validated the tests yet. Note that the test failure is unrelated to this change, you can do that by viewing the logs of the test failure logs (No need to Login to Circle CI) But the lint failure seems to be related to this change. You need to lint the code as mentioned in the contributing.md file. (either use pre-commit hooks or lint manually) Edit: |
I will review this, validate and get back to you soon! |
I'm sorry for the delay in review of this PR. I will be reviewing this over weekend definitely. It's been really hard for me this year but I will be back to this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unable to match the results with kornia. Let me post what I validated.
This is a modified dice loss from kornia. With the one_hot being commented and torch.mean is re[laced to 1.0 - dice_loss
And this is my verification script. I simply copied your code and named it as
Here is the O/P Can you please help me understand why there is a difference? It might be something I'm unaware of or something that kornia does (by commenting one_hot I might have made a mistake. |
@oke-aditya I haven't yet given it a detailed look, but I think one thing of concern while comparing the two libraries is the input dimensions, for torchvision, we follow (B, N1, N2, ..., C) and kornia follows (B, C, H, W). |
My concern is that results should have matched. At least for cases where we have all zeros. As such dice loss is well defined and adjusting for flexible shapes. Shouldn't affect the values. |
Results don't match because in this piece of code here we are specifying the one hot values as per the torchvision input format (B, N1, N2, .., C) (also followed by PaddleSeg)
But for Kornia we need these values on dimension 1, not the last dimension. |
B C N H W is something which we assume in all the models as well as the codebase. So it would be much better and to use I think |
Thanks a lot for working on this @pri1311 I will do another round of validation of results. And we would be good. |
@NicolasHug looks like this missed 0.15. I tried this loss function and it seems to work well for segmentation problems with unbalanced classes. Its inclusion would be great, unless there's other functionality I'm overlooking that implements another dice loss function such that this PR is no longer needed. |
Hi, have added an implementation for dice loss as mentioned in #6435
Will add tests once the correctness of the function is confirmed.