Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilabel Ranking metrics #787

Merged
merged 46 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4d9a03b
update
SkafteNicki Nov 15, 2021
3ad8efb
update
SkafteNicki Nov 15, 2021
ed05654
update
SkafteNicki Nov 15, 2021
b4ddaf6
Merge branch 'master' into ranking_metrics
SkafteNicki Jan 20, 2022
b3fc696
changelog
SkafteNicki Jan 20, 2022
1b2a882
init
SkafteNicki Jan 20, 2022
555219d
docs
SkafteNicki Jan 20, 2022
8c0eea8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2022
2e1422d
Merge branch 'master' into ranking_metrics
Borda Feb 5, 2022
a4fb0b1
Merge branch 'master' into ranking_metrics
SkafteNicki Feb 10, 2022
ff09cd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2022
b48c429
somethings working
SkafteNicki Feb 11, 2022
dd8bb1d
ranking
SkafteNicki Feb 28, 2022
61776b4
docs refs
SkafteNicki Feb 28, 2022
7fc6568
Merge branch 'master' into ranking_metrics
SkafteNicki Feb 28, 2022
96e43ea
add docs
SkafteNicki Feb 28, 2022
eb941c3
mypy
SkafteNicki Feb 28, 2022
4cba85b
fix tests
SkafteNicki Feb 28, 2022
9e73469
Merge branch 'master' into ranking_metrics
SkafteNicki Mar 1, 2022
409ce9f
update to new format
SkafteNicki Mar 1, 2022
6db88ff
update naming
SkafteNicki Mar 1, 2022
fdd9df6
reference
SkafteNicki Mar 1, 2022
16b3048
fix docs
SkafteNicki Mar 1, 2022
4d03ab8
fix doctest
SkafteNicki Mar 1, 2022
d556ec2
chlog
Borda Mar 1, 2022
ada8caa
.
Borda Mar 1, 2022
326911b
try fix tests
SkafteNicki Mar 1, 2022
d6047c0
Merge branch 'ranking_metrics' of https://github.com/PyTorchLightning…
SkafteNicki Mar 1, 2022
54b2b50
fix tests
SkafteNicki Mar 1, 2022
e080960
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 1, 2022
3168120
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 3, 2022
cc16ac2
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 3, 2022
95b8f0d
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 7, 2022
375489a
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 11, 2022
1e87d99
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 18, 2022
a9db5a6
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 18, 2022
4806886
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 19, 2022
c79b01d
Merge branch 'master' into ranking_metrics
mergify[bot] Mar 20, 2022
9507972
update
SkafteNicki Mar 21, 2022
937aa85
Apply suggestions from code review
Borda Mar 21, 2022
a049748
fix flake8
SkafteNicki Mar 21, 2022
56fe436
Update torchmetrics/__init__.py
SkafteNicki Mar 21, 2022
6ce6d6d
Update torchmetrics/__init__.py
SkafteNicki Mar 21, 2022
bc9dd8e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2022
bc2ce5a
move floating
SkafteNicki Mar 21, 2022
cfd0f75
Merge branch 'ranking_metrics' of https://github.com/PyTorchLightning…
SkafteNicki Mar 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added


- Added `CoverageError` to classification metrics ([#787](https://github.com/PyTorchLightning/metrics/pull/787))


- Added `LabelRankingAveragePrecision` ([#787](https://github.com/PyTorchLightning/metrics/pull/787))


- Added `LabelRankingLoss` ([#787](https://github.com/PyTorchLightning/metrics/pull/787))


- Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718))


Expand Down
21 changes: 21 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ confusion_matrix [func]
:noindex:


coverage_error [func]
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.coverage_error
:noindex:


dice_score [func]
~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -161,6 +168,20 @@ kl_divergence [func]
:noindex:


label_ranking_average_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.label_ranking_average_precision
:noindex:


label_ranking_loss [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.label_ranking_loss
:noindex:


matthews_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
18 changes: 18 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ ConfusionMatrix
.. autoclass:: torchmetrics.ConfusionMatrix
:noindex:

CoverageError
~~~~~~~~~~~~~

.. autoclass:: torchmetrics.CoverageError
:noindex:

F1Score
~~~~~~~

Expand Down Expand Up @@ -310,6 +316,18 @@ KLDivergence
.. autoclass:: torchmetrics.KLDivergence
:noindex:

LabelRankingAveragePrecision
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.LabelRankingAveragePrecision
:noindex:

LabelRankingLoss
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.LabelRankingLoss
:noindex:

MatthewsCorrCoef
~~~~~~~~~~~~~~~~

Expand Down
101 changes: 101 additions & 0 deletions tests/classification/test_ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from sklearn.metrics import coverage_error as sk_coverage_error
from sklearn.metrics import label_ranking_average_precision_score as sk_label_ranking
from sklearn.metrics import label_ranking_loss as sk_label_ranking_loss

from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.classification.ranking import CoverageError, LabelRankingAveragePrecision, LabelRankingLoss
from torchmetrics.functional.classification.ranking import (
coverage_error,
label_ranking_average_precision,
label_ranking_loss,
)

seed_all(42)


def _sk_coverage_error(preds, target, sample_weight=None):
if sample_weight is not None:
sample_weight = sample_weight.numpy()
return sk_coverage_error(target.numpy(), preds.numpy(), sample_weight=sample_weight)


def _sk_label_ranking(preds, target, sample_weight=None):
if sample_weight is not None:
sample_weight = sample_weight.numpy()
return sk_label_ranking(target.numpy(), preds.numpy(), sample_weight=sample_weight)


def _sk_label_ranking_loss(preds, target, sample_weight=None):
if sample_weight is not None:
sample_weight = sample_weight.numpy()
return sk_label_ranking_loss(target.numpy(), preds.numpy(), sample_weight=sample_weight)


@pytest.mark.parametrize(
"metric, functional_metric, sk_metric",
[
(CoverageError, coverage_error, _sk_coverage_error),
(LabelRankingAveragePrecision, label_ranking_average_precision, _sk_label_ranking),
(LabelRankingLoss, label_ranking_loss, _sk_label_ranking_loss),
],
)
@pytest.mark.parametrize(
"preds, target",
[
(_input_mlb_logits.preds, _input_mlb_logits.target),
(_input_mlb_prob.preds, _input_mlb_prob.target),
],
)
@pytest.mark.parametrize("sample_weight", [None, torch.rand(NUM_BATCHES, BATCH_SIZE)])
class TestRanking(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_ranking_class(
self, ddp, dist_sync_on_step, preds, target, metric, functional_metric, sk_metric, sample_weight
):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=metric,
sk_metric=sk_metric,
dist_sync_on_step=dist_sync_on_step,
fragment_kwargs=True,
sample_weight=sample_weight,
)

def test_ranking_functional(self, preds, target, metric, functional_metric, sk_metric, sample_weight):
self.run_functional_metric_test(
preds,
target,
metric_functional=functional_metric,
sk_metric=sk_metric,
fragment_kwargs=True,
sample_weight=sample_weight,
)

def test_ranking_differentiability(self, preds, target, metric, functional_metric, sk_metric, sample_weight):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=metric,
metric_functional=functional_metric,
)
8 changes: 6 additions & 2 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import logging as __logging
import os

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.__about__ import * # noqa: F401, F403

_logger = __logging.getLogger("torchmetrics")
_logger.addHandler(__logging.StreamHandler())
_logger.setLevel(__logging.INFO)
Expand Down Expand Up @@ -32,12 +30,15 @@
CalibrationError,
CohenKappa,
ConfusionMatrix,
CoverageError,
F1Score,
FBetaScore,
HammingDistance,
HingeLoss,
JaccardIndex,
KLDivergence,
LabelRankingAveragePrecision,
LabelRankingLoss,
MatthewsCorrCoef,
Precision,
PrecisionRecallCurve,
Expand Down Expand Up @@ -115,6 +116,7 @@
"CohenKappa",
"ConfusionMatrix",
"CosineSimilarity",
"CoverageError",
"TweedieDevianceScore",
"ExplainedVariance",
"ExtendedEditDistance",
Expand All @@ -124,6 +126,8 @@
"HingeLoss",
"JaccardIndex",
"KLDivergence",
"LabelRankingAveragePrecision",
"LabelRankingLoss",
"MatthewsCorrCoef",
"MaxMetric",
"MeanAbsoluteError",
Expand Down
5 changes: 5 additions & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef # noqa: F401
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from torchmetrics.classification.ranking import ( # noqa: F401
CoverageError,
LabelRankingAveragePrecision,
LabelRankingLoss,
)
from torchmetrics.classification.roc import ROC # noqa: F401
from torchmetrics.classification.specificity import Specificity # noqa: F401
from torchmetrics.classification.stat_scores import StatScores # noqa: F401
Loading