Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added small eps to dice and iou to avoid NaN #2545

Merged
merged 1 commit into from
Jul 9, 2020

Conversation

bernardomig
Copy link
Contributor

The code fixes a bug calculating both the IoU and the Dice metrics, which resulted in a nan in the calculation. This was due to not including a small eps in the division on both these metrics. I have added a eps (equal to 1e-15), which successfully prevents this NaN calculation.

Example

Previous Code (with bug)

import torch
from pytorch_lightning.metrics.functional import iou

y = torch.tensor([0, 1, 2, 3, 3])
y_true = torch.tensor([0, 1, 2, 2, 2])

# reduction = 'none' so we can see the iou for each class
out = iou(y, y_true, num_classes=5, reduction='none')

print(out)
>>> tensor([1.0000, 1.0000, 0.3333, 0.0000,    nan])

The fifth class has a IoU of nan, not 0.

Fixed code

import torch
from pytorch_lightning.metrics.functional import iou

y = torch.tensor([0, 1, 2, 3, 3])
y_true = torch.tensor([0, 1, 2, 2, 2])

# reduction = 'none' so we can see the iou for each class
out = iou(y, y_true, num_classes=5, reduction='none')

print(out)
>>> tensor([1.0000, 1.0000, 0.3333, 0.0000, 0.0000])

@mergify mergify bot requested a review from a team July 7, 2020 22:21
@codecov
Copy link

codecov bot commented Jul 7, 2020

Codecov Report

Merging #2545 into master will not change coverage.
The diff coverage is 100%.

@@          Coverage Diff           @@
##           master   #2545   +/-   ##
======================================
  Coverage      90%     90%           
======================================
  Files          69      69           
  Lines        5669    5669           
======================================
  Hits         5077    5077           
  Misses        592     592           

@@ -891,7 +891,7 @@ def dice_score(

tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i)

denom = (2 * tp + fp + fn).to(torch.float)
denom = (2 * tp + fp + fn + 1e-15).to(torch.float)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather add if tp + fp + fn if greater then 0

Copy link
Member

@Borda Borda Jul 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact

1 / (1 + 1e-15)
0.9999999999999989

cc: @justusschock @SkafteNicki

@mergify mergify bot requested a review from a team July 7, 2020 22:33
@williamFalcon williamFalcon merged commit 9a367a8 into Lightning-AI:master Jul 9, 2020
@Borda Borda added the bug Something isn't working label Jul 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants