Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClasswiseWrapper double prefixes with self.prefix when enclosed in a MetricCollection #1915

Closed
anivegesana opened this issue Jul 12, 2023 · 2 comments · Fixed by #1918
Closed
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@anivegesana
Copy link

🐛 Bug

#1866 double prefixes with self.prefix when enclosed in a MetricCollection. This is because MetricCollection already handles prefixing, here:

for key, v in res.items():
if hasattr(m, "prefix") and m.prefix is not None:
key = f"{m.prefix}{key}"
if hasattr(m, "postfix") and m.postfix is not None:
key = f"{key}{m.postfix}"

but #1866 doesn't account for it.

To Reproduce

Enclose a ClasswiseWrapper with a prefix within a MetricCollection.

Finds a multiclass accuracy and a classwise f1 score.
from torchmetrics import *
import torch

category_names = ['Tree', 'Bush']
num_classes = len(category_names)

input_ = torch.rand((5, num_classes, 3, 3))
target = torch.ones((5, num_classes, 3, 3)).long()


val_metrics = MetricCollection(
    {
        "accuracy": Accuracy(task="multiclass", num_classes=num_classes),
        "f1": ClasswiseWrapper(
            F1Score(
                task="multiclass",
                num_classes=num_classes,
                average="none",
                dist_sync_on_step=True,
            ),
            category_names,
            prefix="f_score_",
        ),
    },
    prefix="val/",
)

val_metrics["precision"](input_, target)
val_metrics(input_, target)

Expected behavior

I should get {'val/acc': tensor(0.), 'val/f1_Tree': tensor(0.), 'val/f1_Bush': tensor(0.)}. I instead get {'val/acc': tensor(0.), 'val/f1_f1_Tree': tensor(0.), 'val/f1_f1_Bush': tensor(0.)}.

Environment

  • torchmetrics 1.0.0 via pip
  • Python 3.10.6 & PyTorch 1.12:
  • OS: Linux
@anivegesana anivegesana added bug / fix Something isn't working help wanted Extra attention is needed labels Jul 12, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@anivegesana
Copy link
Author

Perhaps adding a flag for Metric subclasses to specify if they want MetricCollections to prefix them, like _key_names_fix = False can fix the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
1 participant