Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification metrics overhaul: accuracy metrics (2/n) #4838

Merged
merged 136 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from 134 commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
6959ea0
Add stuff
tadejsv Nov 24, 2020
0679015
Change metrics documentation layout
tadejsv Nov 24, 2020
35627b5
Add stuff
tadejsv Nov 24, 2020
55fdaaf
Change testing utils
tadejsv Nov 24, 2020
35f8320
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
5cbf56a
Replace len(*.shape) with *.ndim
tadejsv Nov 24, 2020
9c33d0b
More descriptive error message for input formatting
tadejsv Nov 24, 2020
6562205
Replace movedim with permute
tadejsv Nov 24, 2020
b97aef2
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
cbbc769
PEP 8 compliance
tadejsv Nov 24, 2020
f45fc81
Division with float
tadejsv Nov 24, 2020
a04a71e
Style changes in error messages
tadejsv Nov 25, 2020
eaac5d7
More error message style improvements
tadejsv Nov 25, 2020
c1108f0
Fix typo in docs
tadejsv Nov 25, 2020
277769b
Add more descriptive variable names in utils
tadejsv Nov 25, 2020
4849298
Change internal var names
tadejsv Nov 25, 2020
22906a4
Merge remote-tracking branch 'upstream/master' into cls_metrics_input…
tadejsv Nov 25, 2020
1034a71
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
02bd636
Break down error checking for inputs into separate functions
tadejsv Nov 25, 2020
f97145b
Remove the (N, ..., C) option in MD-MC
tadejsv Nov 25, 2020
536feaf
Simplify select_topk
tadejsv Nov 25, 2020
4241d7c
Remove detach for inputs
tadejsv Nov 25, 2020
99d3c81
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
86d6c4d
Fix typos
tadejsv Nov 25, 2020
54c98a0
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
bb11677
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Nov 25, 2020
bdc4111
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
cde3997
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 26, 2020
05a54da
Update docs/source/metrics.rst
tadejsv Nov 26, 2020
9a43a5e
Minor error message changes
tadejsv Nov 26, 2020
3f4ad3c
Update pytorch_lightning/metrics/utils.py
tadejsv Nov 26, 2020
a654e6a
Reuse case from validation in formatting
tadejsv Nov 26, 2020
7b2ef2b
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 26, 2020
16ab8f7
Refactor code in _input_format_classification
tadejsv Nov 27, 2020
558276f
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 27, 2020
ecffe18
Small improvements
tadejsv Nov 27, 2020
a907ade
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 27, 2020
725c7dd
PEP 8
tadejsv Nov 27, 2020
41ad0b7
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ca13e76
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ede2c7f
Update docs/source/metrics.rst
tadejsv Nov 27, 2020
c6e4de4
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
201d0de
Apply suggestions from code review
tadejsv Nov 27, 2020
f08edbc
Alphabetical reordering of regression metrics
tadejsv Nov 27, 2020
523bae3
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 27, 2020
db24fae
Merge branch 'master' into cls_metrics_input_formatting
Borda Nov 27, 2020
35e3eff
Change default value of top_k and add error checking
tadejsv Nov 28, 2020
dd6f8ea
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 28, 2020
c28aadf
Extract basic validation into separate function
tadejsv Nov 28, 2020
4bfc688
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 28, 2020
323285e
Update to new top_k default
tadejsv Nov 28, 2020
0cb0eac
Update desciption of parameters in input formatting
tadejsv Nov 29, 2020
28acf4c
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 30, 2020
8e7a85a
Apply suggestions from code review
tadejsv Nov 30, 2020
829155e
Check that probabilities in preds sum to 1 (for MC)
tadejsv Nov 30, 2020
768879d
Fix coverage
tadejsv Nov 30, 2020
e4d88e2
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 1, 2020
eeded45
Split accuracy and hamming loss
tadejsv Dec 1, 2020
b49cfdc
Remove old redundant accuracy
tadejsv Dec 1, 2020
15ef14d
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Dec 2, 2020
3d8f584
Merge branch 'master' into cls_metrics_accuracy
tchaton Dec 3, 2020
1568970
Merge branch 'master' into cls_metrics_input_formatting
tchaton Dec 3, 2020
a9fa730
Merge with master and resolve conflicts
tadejsv Dec 6, 2020
44ad276
Merge branch 'master' into cls_metrics_input_formatting
Borda Dec 6, 2020
96d40c8
Minor changes
tadejsv Dec 6, 2020
cca430a
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Dec 6, 2020
b0bde16
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 6, 2020
627d99a
Fix imports
tadejsv Dec 6, 2020
de3defb
Improve docstring descriptions
tadejsv Dec 6, 2020
f3c47f9
Fix edge case and simplify testing
tadejsv Dec 6, 2020
a7e91a9
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 6, 2020
b7ced6e
Fix docs
tadejsv Dec 6, 2020
e91e564
PEP8
tadejsv Dec 6, 2020
798ec03
Reorder imports
tadejsv Dec 6, 2020
7217924
Merge remote-tracking branch 'upstream/master' into cls_metrics_accuracy
tadejsv Dec 7, 2020
a7c143e
Update changelog
tadejsv Dec 7, 2020
531ae33
Update docstring
tadejsv Dec 7, 2020
2eba226
Merge branch 'master' into cls_metrics_accuracy
tadejsv Dec 7, 2020
a66cf31
Update docstring
tadejsv Dec 7, 2020
e93f83e
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 7, 2020
89b09f8
Reverse formatting changes for tests
tadejsv Dec 7, 2020
e715437
Change parameter order
tadejsv Dec 7, 2020
d5daec8
Remove formatting changes 2/2
tadejsv Dec 7, 2020
c820060
Remove formatting 3/3
tadejsv Dec 7, 2020
b576de0
.
tadejsv Dec 7, 2020
dae341b
Improve description of top_k parameter
tadejsv Dec 7, 2020
b2d2b71
Apply suggestions from code review
Borda Dec 7, 2020
9b2a399
Apply suggestions from code review
tadejsv Dec 7, 2020
0952df2
Remove unneeded assert
tadejsv Dec 7, 2020
c7fe698
Update pytorch_lightning/metrics/functional/accuracy.py
tadejsv Dec 7, 2020
e2bc0ab
Remove unneeded assert
tadejsv Dec 7, 2020
acbd1ca
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 7, 2020
8801f8a
Explicit checking of parameter values
tadejsv Dec 7, 2020
c32b36e
Apply suggestions from code review
Borda Dec 7, 2020
0314c7d
Apply suggestions from code review
tadejsv Dec 7, 2020
152cadf
Fix top_k checking
tadejsv Dec 7, 2020
022d6a6
PEP8
tadejsv Dec 7, 2020
9efc963
Don't check dist_sync in test
tadejsv Dec 8, 2020
d992f7d
add back check_dist_sync_on_step
tadejsv Dec 8, 2020
a726060
Make sure half-precision inputs are transformed (#5013)
tadejsv Dec 8, 2020
93c5d02
Fix typo
tadejsv Dec 8, 2020
0813055
Rename hamming loss to hamming distance
tadejsv Dec 8, 2020
6bf714b
Fix tests for half precision
tadejsv Dec 8, 2020
d12f1d6
Fix docs underline length
tadejsv Dec 8, 2020
a55cb46
Fix doc undeline length
tadejsv Dec 8, 2020
d75eec3
Merge branch 'master' into cls_metrics_accuracy
justusschock Dec 8, 2020
6b3b057
Replace mdmc_accuracy parameter with subset_accuracy
tadejsv Dec 8, 2020
6f218d4
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 8, 2020
98cb5f4
Update changelog
tadejsv Dec 8, 2020
72ca3ac
Merge branch 'master' into cls_metrics_accuracy
SkafteNicki Dec 8, 2020
474fbd0
Apply suggestions from code review
tadejsv Dec 8, 2020
03cccc3
Suggestions from code review
tadejsv Dec 8, 2020
de0213e
Fix number in docs
tadejsv Dec 8, 2020
0fbf93c
Update pytorch_lightning/metrics/classification/accuracy.py
rohitgr7 Dec 8, 2020
1b8af65
Replace topk by argsort in select_topk
tadejsv Dec 11, 2020
c2c17f0
Merge branch 'master' into cls_metrics_accuracy
tadejsv Dec 11, 2020
3c4f200
Fix changelog
tadejsv Dec 11, 2020
065c848
Merge remote-tracking branch 'upstream/master' into cls_metrics_accuracy
tadejsv Dec 12, 2020
82d550e
Add test for wrong params
tadejsv Dec 12, 2020
279e4b9
Merge branch 'master' into cls_metrics_accuracy
tadejsv Dec 12, 2020
c4e9aa2
Merge branch 'master' into cls_metrics_accuracy
s-rog Dec 13, 2020
827a544
Merge branch 'master' into cls_metrics_accuracy
tadejsv Dec 14, 2020
eb9cb3c
Add Google Colab badges (#5111)
shacharmirkin Dec 14, 2020
69123af
Fix hanging metrics tests (#5134)
tadejsv Dec 14, 2020
863885a
Merge branch 'master' into cls_metrics_accuracy
tadejsv Dec 14, 2020
811dc00
Merge remote-tracking branch 'upstream/master' into cls_metrics_accuracy
tadejsv Dec 14, 2020
f68acc0
Use torch.topk again as ddp hanging tests fixed in #5134
tadejsv Dec 14, 2020
d47b559
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 16, 2020
3bf3c3a
Fix unwanted notebooks change
tadejsv Dec 17, 2020
4412c9c
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 17, 2020
a9a4847
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 21, 2020
44135c7
Fix too long line in hamming_distance
tadejsv Dec 21, 2020
908e60f
Apply suggestions from code review
Borda Dec 21, 2020
23a997e
Apply suggestions from code review
Borda Dec 21, 2020
b3e458d
protect
Borda Dec 21, 2020
92f5f83
Update CHANGELOG.md
rohitgr7 Dec 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))

- `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))

- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))

### Changed

Expand All @@ -19,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed



rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
### Fixed


Expand Down
14 changes: 12 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ FBeta
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:

Hamming Distance
~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.HammingDistance
:noindex:

Precision
~~~~~~~~~

Expand Down Expand Up @@ -323,10 +329,9 @@ Functional Metrics (Classification)
accuracy [func]
~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.accuracy
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
:noindex:


auc [func]
~~~~~~~~~~

Expand Down Expand Up @@ -382,6 +387,11 @@ fbeta [func]
.. autofunction:: pytorch_lightning.metrics.functional.fbeta
:noindex:

hamming_distance [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.hamming_distance
:noindex:

iou [func]
~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pytorch_lightning.metrics.classification import ( # noqa: F401
Accuracy,
HammingDistance,
Precision,
Recall,
ConfusionMatrix,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from pytorch_lightning.metrics.classification.f_beta import FBeta, Fbeta, F1 # noqa: F401
from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance # noqa: F401
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401
90 changes: 66 additions & 24 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,57 @@
import torch

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.utils import _input_format_classification
from pytorch_lightning.metrics.functional.accuracy import _accuracy_update, _accuracy_compute


class Accuracy(Metric):
r"""
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:

.. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y_i})
.. math::
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)

Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
tensor of predictions. Works with binary, multiclass, and multilabel
data. Accepts logits from a model output or integer class values in
prediction. Works with multi-dimensional preds and target.
tensor of predictions.

Forward accepts
For multi-class and multi-dimensional multi-class data with probability predictions, the
parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
top-K highest probability items are considered to find the correct label.

- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
For multi-label and multi-dimensional multi-class inputs, this metric computes the "global"
accuracy by default, which counts all labels or sub-samples separately. This can be
changed to subset accuracy (which requires all labels or sub-samples in the sample to
be correctly predicted) by setting ``subset_accuracy=True``.

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Accepts all input types listed in :ref:`metrics:Input types`.

Args:
threshold:
Copy link
Contributor

@rohitgr7 rohitgr7 Dec 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering, should threshold be None too by default?? because it's not used when we provider pred_labels, and use 0.5 by default when we have pred_probs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I guess this could be done - but in a way that if pred_probs are passed in, None would default to 0.5. otherwise this would be a very disturbing breaking change for people used to using accuracy without extra params

Copy link
Contributor

@rohitgr7 rohitgr7 Dec 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's what meant here 0.5 by default when we have pred_probs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, will add this in the next PR.

Threshold value for binary or multi-label logits. default: 0.5
Threshold probability value for transforming probability predictions to binary
`(0,1)` predictions, in the case of binary or multi-label inputs.
top_k:
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interpreted as 1 for these inputs.
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

Should be left at default (``None``) for all other types of inputs.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).

For multi-label inputs, if the parameter is set to `True`, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to `False`, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).

For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to `False`, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
Expand All @@ -63,10 +85,19 @@ class Accuracy(Metric):
>>> accuracy(preds, target)
tensor(0.5000)

>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy = Accuracy(top_k=2)
>>> accuracy(preds, target)
tensor(0.6667)

"""

def __init__(
self,
threshold: float = 0.5,
top_k: Optional[int] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -82,24 +113,35 @@ def __init__(
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

Borda marked this conversation as resolved.
Show resolved Hide resolved
if not 0 <= threshold <= 1:
raise ValueError("The `threshold` should lie in the [0,1] interval.")

if top_k is not None and top_k <= 0:
raise ValueError("The `top_k` should be an integer larger than 1.")

self.threshold = threshold
self.top_k = top_k
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
self.subset_accuracy = subset_accuracy

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Update state with predictions and targets. See :ref:`metrics:Input types` for more information
on input types.

Args:
preds: Predictions from model
target: Ground truth values
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""
preds, target = _input_format_classification(preds, target, self.threshold)
assert preds.shape == target.shape

self.correct += torch.sum(preds == target)
self.total += target.numel()
correct, total = _accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
)

self.correct += correct
self.total += total

def compute(self):
def compute(self) -> torch.Tensor:
"""
Computes accuracy over state.
Computes accuracy based on inputs passed in to ``update`` previously.
"""
return self.correct.float() / self.total
return _accuracy_compute(self.correct, self.total)
105 changes: 105 additions & 0 deletions pytorch_lightning/metrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_update, _hamming_distance_compute


class HammingDistance(Metric):
r"""
Computes the average `Hamming distance <https://en.wikipedia.org/wiki/Hamming_distance>`_ (also
known as Hamming loss) between targets and predictions:

.. math::
\text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}})

Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions,
and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that
tensor.

This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it
treats each possible label separately - meaning that, for example, multi-class data is
treated as if it were multi-label.

Accepts all input types listed in :ref:`metrics:Input types`.

Args:
threshold:
Threshold probability value for transforming probability predictions to binary
`(0,1)` predictions, in the case of binary or multi-label inputs.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the all gather.

Example:

>>> from pytorch_lightning.metrics import HammingDistance
>>> target = torch.tensor([[0, 1], [1, 1]])
>>> preds = torch.tensor([[0, 1], [0, 1]])
>>> hamming_distance = HammingDistance()
>>> hamming_distance(preds, target)
tensor(0.2500)

"""

def __init__(
self,
threshold: float = 0.5,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

if not 0 <= threshold <= 1:
raise ValueError("The `threshold` should lie in the [0,1] interval.")
self.threshold = threshold

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets. See :ref:`metrics:Input types` for more information
on input types.

Args:
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""
correct, total = _hamming_distance_update(preds, target, self.threshold)

self.correct += correct
self.total += total

def compute(self) -> torch.Tensor:
"""
Computes hamming distance based on inputs passed in to ``update`` previously.
"""
return _hamming_distance_compute(self.correct, self.total)
5 changes: 5 additions & 0 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,11 @@ def _input_format_classification(
else:
preds, target = preds.squeeze(), target.squeeze()

# Convert half precision tensors to full precision, as not all ops are supported
# for example, min() is not supported
if preds.dtype == torch.float16:
preds = preds.float()

case = _check_classification_inputs(
preds,
target,
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from pytorch_lightning.metrics.functional.average_precision import average_precision # noqa: F401
from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
accuracy,
auc,
auroc,
dice_score,
Expand All @@ -32,8 +31,10 @@
)
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
# TODO: unify metrics between class and functional, add below
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401
Expand Down
Loading