Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Classwise Metrics Logging #815

Closed
cemde opened this issue Jan 29, 2022 · 3 comments · Fixed by #832
Closed

Improving Classwise Metrics Logging #815

cemde opened this issue Jan 29, 2022 · 3 comments · Fixed by #832
Labels
enhancement New feature or request

Comments

@cemde
Copy link

cemde commented Jan 29, 2022

🚀 Feature

Simplify logging of classwise metrics.

Motivation

Often, we are not only interested in averaged metrics, but also in classwise metrics. These cannot be used in well MetricCollections because the pytorch-lightning logging mechanism does not allow non-scalar values.

This code computing class wise recalls:

m = torchmetrics.MetricCollection({'acc': torchmetrics.Accuracy(),
                                   'recall': torchmetrics.Recall(average='none', num_classes=5)})

yields this error when computing the metricCollection together with the pytorch lightning logging capabilities:

The metric `tensor([1.0000, 0.0000, 0.0000, 0.2500, 0.3000])` does not contain a single element, thus it cannot be converted to a scalar.

It would be very, nice if the metric collection returned a dictionary like this one:

{‘acc’: torch.tensor([0.3]), ‘recall_0’: torch.tensor([1.0]), ‘recall_1’: torch.tenor([0.]), ‘recall_2’: torch.tenor([0.0]), ‘recall_3’: torch.tenor([0.25]), ‘recall_4’: torch.tenor([0.3])}

Pitch

b) Introduce a ClasswiseWrapper and adapt the MetricCollection to unpack dictionaries:

from typing import Dict, List
import torch
import torchmetrics
import torch.nn.functional as F

class ClasswiseWrapper(torchmetrics.Metric):
    def __init__(self, metric: torchmetrics.Metric, class_labels: List[str] = None) -> None:
        super(ClasswiseWrapper, self).__init__()
        self.metric = metric
        self.class_labels = class_labels
    
    def _convert(self, x: torch.tensor) -> Dict[str, float]:
        return {name: val for name, val in zip(self.class_labels, x.tolist())}
        
    def __call__(self, *args, **kwargs) -> Dict[str, float]:
        return self._convert(self.metric(*args, **kwargs))

    def update(self, *args, **kwargs) -> Dict[str, float]:
        return self._convert(self.metric.update(*args, **kwargs))

    def compute(self, *args, **kwargs) -> Dict[str, float]:
        return self._convert(self.metric.compute(*args, **kwargs))
    

class_labels = ['house', 'boat', 'horse', 'computer', 'airplane']
m = torchmetrics.MetricCollection({'acc': torchmetrics.Accuracy(),
                                   'recall': ClasswiseWrapper(torchmetrics.Recall(average='none', num_classes=5), class_labels)}, formatting=“{metric}_{class_label}”)


y_pred = F.softmax(torch.randn((8,5)), dim=1)
y_true = torch.tensor([0,1,2,3,4,0,1,2])

m(y_pred, y_true)

print(m.compute())

This currently leads to a nested dict, but if the MetricCollection could unpack this dict to return

{‘acc’: torch.tensor([0.3]), ‘recall_house’: torch.tensor([1.0]), ‘recall_boat’: torch.tenor([0.]), ‘recall_horse’: torch.tenor([0.0]), ‘recall_computer’: torch.tenor([0.25]), ‘recall_airplane’: torch.tenor([0.3])}

that would be great.

Alternatives

For each metric allow:

torchmetrics.Metric(average=none’, class_label=['house', 'boat', 'horse', 'computer', 'airplane’])

and return a dictionary of the form <metric>_<class_label> which then gets unpacked by the MetricCollection.

Additional context

When researching long tail distributions or working with medical datasets, which naturally contain very long tail distributions, class wise metrics are key. No-one in the medical ML community cares about the overall accuracy. What they care about is classwise sensitivity and specificity. However, logging these is tedious. Improvements here would make torchmetrics better usable by these communities.

@cemde cemde added the enhancement New feature or request label Jan 29, 2022
@Borda
Copy link
Member

Borda commented Jan 30, 2022

@SkafteNicki this may be a nice addition to the smart update, right? 🐰

@SkafteNicki
Copy link
Member

@Borda lets just make this a general wrapper such that it can also be used outside metric collections.

@cemde
Copy link
Author

cemde commented Feb 8, 2022

@Borda lets just make this a general wrapper such that it can also be used outside metric collections.

My suggestion above is fairly independent of the MetricCollection. The only thing the MetricCollection needs to do is to flatten a nested dictionary and the wrapper does most of the extra work. I think this way anyone can also add their own Wrappers and go wild - as long as it outputs a dict?!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants