Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ignore_index for F1 doesn't behave as expected. #613

Closed
yukw777 opened this issue Nov 9, 2021 · 7 comments · Fixed by #1195
Closed

ignore_index for F1 doesn't behave as expected. #613

yukw777 opened this issue Nov 9, 2021 · 7 comments · Fixed by #1195
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@yukw777
Copy link
Contributor

yukw777 commented Nov 9, 2021

🐛 Bug

F1 doesn't ignore indices properly.

To Reproduce

Run the following code.

import torch
from torchmetrics import F1

f1 = F1(ignore_index=0)
f1(torch.tensor([1, 1, 1, 1, 2, 1, 1]), torch.tensor([0, 0, 1, 1, 2, 0, 0]))

This gives you tensor(0.6000) not tensor(1.0).

Expected behavior

The specified ignore_index should not count towards the F1 score. For example, the above code example should be effectively equivalent to the following:

import torch
from torchmetrics import F1

f1 = F1()
f1(torch.tensor([1, 1, 2]), torch.tensor([1, 1, 2]))

Environment

  • PyTorch Version (e.g., 1.0): 1.10
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8
  • CUDA/cuDNN version: N/A
  • GPU models and configuration: N/A
  • Any other relevant information:
@yukw777 yukw777 added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 9, 2021
@github-actions
Copy link

github-actions bot commented Nov 9, 2021

Hi! thanks for your contribution!, great first issue!

@yukw777
Copy link
Contributor Author

yukw777 commented Nov 9, 2021

https://github.com/PyTorchLightning/metrics/blob/fef83a028880f2ad3e0c265b3e5bb8a184805798/torchmetrics/functional/classification/stat_scores.py#L132-L143

It seems like if reduce is not macro it just deletes that particular column, which I don't think is the right thing to do.. rather we should delete the "rows" whose value is ignore_index right?

@Hommus
Copy link

Hommus commented Jan 7, 2022

I've had a similar experience using ignore_index with IoU (Jaccard Index), where the IoU value will start at 100.00 and as training progresses the value tends towards 0.

@tchayintr
Copy link

tchayintr commented Jan 7, 2022

Since I am not sure about the original intention of ignore_index.
To make it works as the expected behavior for now, I make some modifications for the tensor before passing to the torchmetric.F1();

import torch
from torchmetrics import F1

ignore_index = 0
y = torch.tensor([0, 0, 1, 1, 2, 0, 0])
y_hat = torch.tensor([1, 1, 1, 1, 2, 1, 1])
inactive_index = y == ignore_index
y_hat[inactive_index] = ignore_index 

f1 = F1(ignore_index=0)
f1(y_hat, y)

@Borda
Copy link
Member

Borda commented Jan 19, 2022

@tchayintr would you be interested in sending a PR and @stancld may help if needed? 🐰

@tchayintr
Copy link

@Borda Sure.
Let me throughly review the code, particularly metrics/torchmetrics/functional/classification/stat_scores.py, before considering possibilities and drafting a PR.

@SkafteNicki SkafteNicki added this to the v0.9 milestone Mar 23, 2022
@SkafteNicki SkafteNicki modified the milestones: v0.9, v0.10 May 12, 2022
@SkafteNicki
Copy link
Member

Issue will be fixed by classification refactor: see this issue #1001 and this PR #1195 for all changes

Small recap: This issue describes that the ignore_index argument is not giving the right result currently in the f1_score metric. This is due to how ignore_index samples are currently accounted for. In the refactor the code has been changed to correctly ignore samples, see example below using the new multiclass_f1_score function:

from torchmetrics.functional import multiclass_f1_score
import torch

preds = torch.tensor([1, 1, 1, 1, 2, 1, 1])
target = torch.tensor([0, 0, 1, 1, 2, 0, 0])

multiclass_f1_score(preds, target, num_classes=3, average="micro", ignore_index=0)  # tensor(1.)

which give the correct result.
Issue will be closed when #1195 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
5 participants