Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MulticlassJaccardIndex returns 0 for classes not present in prediction and target #1693

Closed
JohannesK14 opened this issue Apr 5, 2023 · 3 comments · Fixed by #1821
Closed
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@JohannesK14
Copy link

🐛 Bug

MulticlassJaccardIndex returns 0 for classes not present in prediction and target. This leads to bad (worse) results if used in conjunction with average="macro".

There was a similar issue when the metrics were still a part of pytorch lightning: Lightning-AI/pytorch-lightning#3098

To Reproduce

MWE:

import torch
from torchmetrics.classification import MulticlassJaccardIndex

# edge case: class 2 is not present in the target AND the prediction
target = torch.tensor([0, 1, 0, 0])
preds = torch.tensor([0, 1, 0, 1])

metric = MulticlassJaccardIndex(num_classes=3, average="none")
print(metric(preds, target))
# tensor([0.6667, 0.5000, 0.0000])
# I would expect tensor([0.6667, 0.5000, 1.0000])

metric = MulticlassJaccardIndex(num_classes=3, average="macro")
print(metric(preds, target))
# tensor(0.3889)

Second example (adapted copy from Lightning-AI/pytorch-lightning#3098):

import torch
from torchmetrics.functional.classification.jaccard import multiclass_jaccard_index

target = torch.tensor([0, 1])
pred = torch.tensor([0, 1])
print(multiclass_jaccard_index(pred, target, num_classes=10))  # Should return tensor(1.)
print(
    multiclass_jaccard_index(pred, target, num_classes=10, average="none")
)  # Should return tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Expected behavior

See comments in the MWEs.

Environment

  • python 3.10.9 (conda)
  • pytorch 1.13.1 (conda)
  • torchmetrics 0.11.4 (pip)
@JohannesK14 JohannesK14 added bug / fix Something isn't working help wanted Extra attention is needed labels Apr 5, 2023
@github-actions
Copy link

github-actions bot commented Apr 5, 2023

Hi! thanks for your contribution!, great first issue!

@JohannesK14
Copy link
Author

sklearn raises a warning in the described case: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html

@JohannesK14
Copy link
Author

I thought a bit more about this problem and I think it is questionable if we could set the IoU to 1.0 for unseen classes.

I found that the JaccardIndex implementation internally updates a confusion matrix. The attached code shows the calculations for two "batches", where class 2 is missing in the first.

As one can see, the default value of 0.0 when a class is missing, leads to a worse IoU.

BUT calling compute() after all batches calculates the IoU values again based on the accumulated entries in the confusion matrix over all batches. In my opinion, this results in the correct values for the IoU values and should raise a warning or error if a class was missing in all batches.

Nevertheless, a problem arises in combination with the lightning log process, as it calculates the mean over the macro IoU from the batches and does not use a final compute() based on the confusion matrix with the values from all batches.

Personally, I don't use self.log in the validation/test step anymore, but update the JaccardMetrics there. At the end of the validation/test epoch I call compute() once, using the confusion matrix over all batches. Then I log this value.

import torch
from torchmetrics.classification import MulticlassJaccardIndex

m_none = MulticlassJaccardIndex(num_classes=3, average="none")
m_macro = MulticlassJaccardIndex(num_classes=3, average="macro")

# edge case: class 2 is not present in the target AND the prediction
target = torch.tensor([0, 1, 0, 0, 0, 0])
preds = torch.tensor([0, 1, 0, 1, 1, 0])

# confusion matrix (CONSIDERS ONLY THE "BATCH"):
# 3 2 0  -> iou = 3 / (3 + 2 + 0) = 3 / 5 = 0.6
# 0 1 0  -> iou = 1 / (2 + 1 + 0) = 1 / 3 = 0.33
# 0 0 0  -> iou = 0 / (0 + 0 + 0) = 0 / 0 = nan (is set to 0)

print(f"Batch 1: iou per class: {m_none(preds, target)}")
batch_1_mean_iou = m_macro(preds, target)
print(f"Batch 1: mean iou: {batch_1_mean_iou}")

target = torch.tensor([0, 1, 2, 0, 2, 2])
preds = torch.tensor([0, 1, 0, 1, 2, 1])

# confusion matrix (CONSIDERS ONLY THE "BATCH"):
# 1 1 0 -> iou = 1 / (1 + 1 + 1) = 1 / 3 = 0.33
# 0 1 0 -> iou = 1 / (1 + 1 + 1) = 1 / 3 = 0.33
# 1 1 1 -> iou = 1 / (1 + 1 + 1) = 1 / 3 = 0.33

# IMPORTANT: the confusion matrix gets updated during the following forward calls
# printed values are only for the "batch"

print(f"Batch 2: iou per class: {m_none(preds, target)}")
batch_2_mean_iou = m_macro(preds, target)
print(f"Batch 2: mean iou: {batch_2_mean_iou}")

# confusion matrix (CONSIDERS ALL THE "BATCHES") -> element-wise sum
# 4 3 0 -> iou = 4 / (4 + 3 + 1) = 4 / 8 = 0.5
# 0 2 0 -> iou = 2 / (3 + 2 + 1) = 2 / 6 = 0.33
# 1 1 1 -> iou = 1 / (1 + 1 + 1) = 1 / 3 = 0.33

print(f"Mean IoU: {(batch_1_mean_iou + batch_2_mean_iou) / 2}")
# With 0.0 as default IoU for unseen classes

print("Mean IoU with 1.0 as default IoU for unseen classes: 0.485")

print(f"IoU overall: {m_macro.compute()}")

print("Mean IoU if we ignore the IoU for unseen classes: 0.3975")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
3 participants
@SkafteNicki @JohannesK14 and others