Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hinge metric #120

Merged
merged 39 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a395248
Hinge loss initial commit
ethanwharris Mar 21, 2021
6ceb5fd
Hinge loss initial commit
ethanwharris Mar 21, 2021
9bfdc44
Add tests, squared arg, and class version
ethanwharris Mar 22, 2021
99a4338
Fix import order
ethanwharris Mar 22, 2021
bf42484
Add multiclass_mode argument and tests
ethanwharris Mar 23, 2021
05e1d88
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 23, 2021
d5c8f33
Add doc strings
ethanwharris Mar 23, 2021
1b6ec75
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 23, 2021
8a56ca1
Add to docs
ethanwharris Mar 23, 2021
3588c30
Update CHANGELOG.md
ethanwharris Mar 23, 2021
d6faa4a
Fix doctest
ethanwharris Mar 23, 2021
7e50f1a
Add squeeze
ethanwharris Mar 23, 2021
075d6cc
Fix docs
ethanwharris Mar 23, 2021
e44f28f
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 23, 2021
7f94c27
Update tests
ethanwharris Mar 23, 2021
a393ba3
Merge branch 'feature/hinge_loss' of https://github.com/ethanwharris/…
ethanwharris Mar 23, 2021
153d490
Update mathin docstrings
ethanwharris Mar 23, 2021
661baa5
Merge branch 'master' into feature/hinge_loss
Borda Mar 23, 2021
c6c72d8
Change HingeLoss -> Hinge
ethanwharris Mar 24, 2021
d7ca161
Change hinge_loss -> hinge
ethanwharris Mar 24, 2021
ba869db
Add multiclass_mode check to Hinge
ethanwharris Mar 24, 2021
450cdad
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 24, 2021
cfea3a0
Apply suggestions from code review
Borda Mar 24, 2021
d38b8f2
Add MulticlassMode enum
ethanwharris Mar 24, 2021
4f50d98
Remove whitespace
ethanwharris Mar 24, 2021
2557459
Fix import order
ethanwharris Mar 24, 2021
176377b
Change loss / losses for measure / measures
ethanwharris Mar 24, 2021
9c86cf3
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 24, 2021
667765d
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 24, 2021
86a5453
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 25, 2021
f4ae809
Update torchmetrics/classification/hinge.py
ethanwharris Mar 25, 2021
fd6b990
Update torchmetrics/classification/hinge.py
ethanwharris Mar 25, 2021
519f479
Update torchmetrics/classification/hinge.py
ethanwharris Mar 25, 2021
f569ea7
Add comments
ethanwharris Mar 25, 2021
d472198
Merge branch 'master' into feature/hinge_loss
SkafteNicki Mar 25, 2021
62e009b
Fix hanging test
ethanwharris Mar 25, 2021
31f0f13
Merge branch 'feature/hinge_loss' of https://github.com/ethanwharris/…
ethanwharris Mar 25, 2021
ccb8839
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 25, 2021
bdae664
Merge branch 'master' into feature/hinge_loss
ethanwharris Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98))


- Added `Hinge` metric ([#120](https://github.com/PyTorchLightning/metrics/pull/120))


- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ hamming_distance [func]
.. autofunction:: torchmetrics.functional.hamming_distance
:noindex:

hinge [func]
~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.hinge
:noindex:

iou [func]
~~~~~~~~~~

Expand Down
20 changes: 13 additions & 7 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ FBeta
.. autoclass:: torchmetrics.FBeta
:noindex:

HammingDistance
~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.HammingDistance
:noindex:

Hinge
~~~~~

.. autoclass:: torchmetrics.Hinge
:noindex:

IoU
~~~

Expand All @@ -174,12 +186,6 @@ MatthewsCorrcoef
.. autoclass:: torchmetrics.MatthewsCorrcoef
:noindex:

Hamming Distance
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.HammingDistance
:noindex:

Precision
~~~~~~~~~

Expand Down Expand Up @@ -269,4 +275,4 @@ R2Score
~~~~~~~

.. autoclass:: torchmetrics.R2Score
:noindex:
:noindex:
160 changes: 160 additions & 0 deletions tests/classification/test_hinge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import numpy as np
import pytest
import torch
from sklearn.metrics import hinge_loss as sk_hinge
from sklearn.preprocessing import OneHotEncoder

from tests.classification.inputs import Input
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester
from torchmetrics import Hinge
from torchmetrics.functional import hinge
from torchmetrics.functional.classification.hinge import MulticlassMode

torch.manual_seed(42)

_input_binary = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
)

_input_binary_single = Input(
preds=torch.randn((NUM_BATCHES, 1)),
target=torch.randint(high=2, size=(NUM_BATCHES, 1))
)

_input_multiclass = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
)


def _sk_hinge(preds, target, squared, multiclass_mode):
sk_preds, sk_target = preds.numpy(), target.numpy()

if multiclass_mode == MulticlassMode.ONE_VS_ALL:
enc = OneHotEncoder()
enc.fit(sk_target.reshape(-1, 1))
sk_target = enc.transform(sk_target.reshape(-1, 1)).toarray()

if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL:
sk_target = 2 * sk_target - 1

if squared or sk_target.max() != 1 or sk_target.min() != -1:
# Squared not an option in sklearn and infers classes incorrectly with single element, so adapted from source
if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL:
margin = sk_target * sk_preds
else:
mask = np.ones_like(sk_preds, dtype=bool)
mask[np.arange(sk_target.shape[0]), sk_target] = False
margin = sk_preds[~mask]
margin -= np.max(sk_preds[mask].reshape(sk_target.shape[0], -1), axis=1)
measures = 1 - margin
measures = np.clip(measures, 0, None)

if squared:
measures = measures ** 2
return measures.mean(axis=0)
else:
if multiclass_mode == MulticlassMode.ONE_VS_ALL:
result = np.zeros(sk_preds.shape[1])
for i in range(result.shape[0]):
result[i] = sk_hinge(y_true=sk_target[:, i], pred_decision=sk_preds[:, i])
return result

return sk_hinge(y_true=sk_target, pred_decision=sk_preds)


@pytest.mark.parametrize(
"preds, target, squared, multiclass_mode",
[
(_input_binary.preds, _input_binary.target, False, None),
(_input_binary.preds, _input_binary.target, True, None),
(_input_binary_single.preds, _input_binary_single.target, False, None),
(_input_binary_single.preds, _input_binary_single.target, True, None),
(_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.CRAMMER_SINGER),
(_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.CRAMMER_SINGER),
(_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.ONE_VS_ALL),
(_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.ONE_VS_ALL),
],
)
class TestHinge(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Hinge,
sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"squared": squared,
"multiclass_mode": multiclass_mode,
},
)

def test_hinge_fn(self, preds, target, squared, multiclass_mode):
self.run_functional_metric_test(
preds,
target,
metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode),
sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode),
)


_input_multi_target = Input(
preds=torch.randn(BATCH_SIZE),
target=torch.randint(high=2, size=(BATCH_SIZE, 2))
)

_input_binary_different_sizes = Input(
preds=torch.randn(BATCH_SIZE * 2),
target=torch.randint(high=2, size=(BATCH_SIZE,))
)

_input_multi_different_sizes = Input(
preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES),
target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,))
)

_input_extra_dim = Input(
preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2),
target=torch.randint(high=2, size=(BATCH_SIZE,))
)


@pytest.mark.parametrize(
"preds, target, multiclass_mode",
[
(_input_multi_target.preds, _input_multi_target.target, None),
(_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None),
(_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None),
(_input_extra_dim.preds, _input_extra_dim.target, None),
(_input_multiclass.preds[0], _input_multiclass.target[0], 'invalid_mode')
],
)
def test_bad_inputs_fn(preds, target, multiclass_mode):
with pytest.raises(ValueError):
_ = hinge(preds, target, multiclass_mode=multiclass_mode)


def test_bad_inputs_class():
with pytest.raises(ValueError):
Hinge(multiclass_mode='invalid_mode')
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ConfusionMatrix,
FBeta,
HammingDistance,
Hinge,
IoU,
MatthewsCorrcoef,
Precision,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.f_beta import F1, FBeta # noqa: F401
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import Hinge # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
Expand Down
126 changes: 126 additions & 0 deletions torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union

from torch import Tensor, tensor

from torchmetrics.functional.classification.hinge import MulticlassMode, _hinge_compute, _hinge_update
from torchmetrics.metric import Metric


class Hinge(Metric):
r"""
Computes the mean `Hinge loss <https://en.wikipedia.org/wiki/Hinge_loss>`_, typically used for Support Vector
Machines (SVMs). In the binary case it is defined as:

.. math::
\text{Hinge loss} = \max(0, 1 - y \times \hat{y})

Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction.

In the multi-class case, when ``multiclass_mode=None`` (default), ``multiclass_mode=MulticlassMode.CRAMMER_SINGER``
or ``multiclass_mode="crammer-singer"``, this metric will compute the multi-class hinge loss defined by Crammer and
Singer as:

.. math::
\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)

Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes),
and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class.

In the multi-class case when ``multiclass_mode=MulticlassMode.ONE_VS_ALL`` or ``multiclass_mode='one-vs-all'``, this
metric will use a one-vs-all approach to compute the hinge loss, giving a vector of C outputs where each entry pits
that class against all remaining classes.

This metric can optionally output the mean of the squared hinge loss by setting ``squared=True``

Only accepts inputs with preds shape of (N) (binary) or (N, C) (multi-class) and target shape of (N).

Args:
squared:
If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss (default).
multiclass_mode:
Which approach to use for multi-class inputs (has no effect in the binary case). ``None`` (default),
``MulticlassMode.CRAMMER_SINGER`` or ``"crammer-singer"``, uses the Crammer Singer multi-class hinge loss.
``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"`` computes the hinge loss in a one-vs-all fashion.

Raises:
ValueError:
If ``multiclass_mode`` is not: None, ``MulticlassMode.CRAMMER_SINGER``, ``"crammer-singer"``,
``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"``.

Example:
# binary example
>>> import torch
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
>>> from torchmetrics import Hinge
>>> target = torch.tensor([0, 1, 1])
>>> preds = torch.tensor([-2.2, 2.4, 0.1])
>>> hinge = Hinge()
>>> hinge(preds, target)
tensor(0.3000)


# multiclass example, default mode
>>> target = torch.tensor([0, 1, 2])
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge()
>>> hinge(preds, target)
tensor(2.9000)


# multiclass example, one vs all mode
>>> target = torch.tensor([0, 1, 2])
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge(multiclass_mode="one-vs-all")
>>> hinge(preds, target)
tensor([2.2333, 1.5000, 1.2333])
"""

def __init__(
self,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.add_state("measure", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

if multiclass_mode not in (None, MulticlassMode.CRAMMER_SINGER, MulticlassMode.ONE_VS_ALL):
raise ValueError(
"The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER"
"(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL,"
f" got {multiclass_mode}."
)

self.squared = squared
self.multiclass_mode = multiclass_mode

def update(self, preds: Tensor, target: Tensor):
measure, total = _hinge_update(preds, target, squared=self.squared, multiclass_mode=self.multiclass_mode)

self.measure = measure + self.measure
self.total = total + self.total

def compute(self) -> Tensor:
return _hinge_compute(self.measure, self.total)
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional.classification.dice import dice_score # noqa: F401
from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional.classification.dice import dice_score # noqa: F401
from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
Expand Down
Loading