-
Notifications
You must be signed in to change notification settings - Fork 402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving Classwise Metrics Logging #815
Comments
@SkafteNicki this may be a nice addition to the smart update, right? 🐰 |
@Borda lets just make this a general wrapper such that it can also be used outside metric collections. |
My suggestion above is fairly independent of the MetricCollection. The only thing the MetricCollection needs to do is to flatten a nested dictionary and the wrapper does most of the extra work. I think this way anyone can also add their own Wrappers and go wild - as long as it outputs a dict?! |
🚀 Feature
Simplify logging of classwise metrics.
Motivation
Often, we are not only interested in averaged metrics, but also in classwise metrics. These cannot be used in well MetricCollections because the pytorch-lightning logging mechanism does not allow non-scalar values.
This code computing class wise recalls:
yields this error when computing the metricCollection together with the pytorch lightning logging capabilities:
It would be very, nice if the metric collection returned a dictionary like this one:
Pitch
b) Introduce a
ClasswiseWrapper
and adapt the MetricCollection to unpack dictionaries:This currently leads to a nested dict, but if the
MetricCollection
could unpack this dict to returnthat would be great.
Alternatives
For each metric allow:
and return a dictionary of the form
<metric>_<class_label>
which then gets unpacked by theMetricCollection
.Additional context
When researching long tail distributions or working with medical datasets, which naturally contain very long tail distributions, class wise metrics are key. No-one in the medical ML community cares about the overall accuracy. What they care about is classwise sensitivity and specificity. However, logging these is tedious. Improvements here would make
torchmetrics
better usable by these communities.The text was updated successfully, but these errors were encountered: