Skip to content

Commit

Permalink
Classification metrics overhaul: stat scores (3/n) (#4839)
Browse files Browse the repository at this point in the history
* Add stuff

* Change metrics documentation layout

* Add stuff

* Add stat scores

* Change testing utils

* Replace len(*.shape) with *.ndim

* More descriptive error message for input formatting

* Replace movedim with permute

* PEP 8 compliance

* WIP

* Add reduce_scores function

* Temporarily add back legacy class_reduce

* Division with float

* PEP 8 compliance

* Remove precision recall

* Replace movedim with permute

* Add back tests

* Add empty newlines

* Add empty line

* Fix permute

* Fix some issues with old versions of PyTorch

* Style changes in error messages

* More error message style improvements

* Fix typo in docs

* Add more descriptive variable names in utils

* Change internal var names

* Break down error checking for inputs into separate functions

* Remove the (N, ..., C) option in MD-MC

* Simplify select_topk

* Remove detach for inputs

* Fix typos

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Minor error message changes

* Update pytorch_lightning/metrics/utils.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Reuse case from validation in formatting

* Refactor code in _input_format_classification

* Small improvements

* PEP 8

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update docs/source/metrics.rst

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Alphabetical reordering of regression metrics

* Change default value of top_k and add error checking

* Extract basic validation into separate function

* Update to new top_k default

* Update desciption of parameters in input formatting

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Check that probabilities in preds sum to 1 (for MC)

* Fix coverage

* Split accuracy and hamming loss

* Remove old redundant accuracy

* Minor changes

* Fix imports

* Improve docstring descriptions

* Fix imports

* Fix edge case and simplify testing

* Fix docs

* PEP8

* Reorder imports

* Add top_k parameter

* Update changelog

* Update docstring

* Update docstring

* Reverse formatting changes for tests

* Change parameter order

* Remove formatting changes 2/2

* Remove formatting 3/3

* .

* Improve description of top_k parameter

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Remove unneeded assert

* Update pytorch_lightning/metrics/functional/accuracy.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Remove unneeded assert

* Explicit checking of parameter values

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Apply suggestions from code review

* Fix top_k checking

* PEP8

* Don't check dist_sync in test

* add back check_dist_sync_on_step

* Make sure half-precision inputs are transformed (#5013)

* Fix typo

* Rename hamming loss to hamming distance

* Fix tests for half precision

* Fix docs underline length

* Fix doc undeline length

* Replace mdmc_accuracy parameter with subset_accuracy

* Update changelog

* Fix unwanted accuracy change

* Enable top_k for ML prob inputs

* Test that default threshold is 0.5

* Fix typo

* Update top_k description in helpers

* updates

* Update styling and add back tests

* Remove excess spaces

* fix torch.where for old versions

* fix linting

* Update docstring

* Fix docstring

* Apply suggestions from code review (mostly docs)

* Default threshold to None, accept only (0,1)

* Change wrong threshold message

* Improve documentation and add tests

* Add back ddp tests

* Change stat reduce method and default

* Remove DDP tests and fix doctests

* Fix doctest

* Update changelog

* Refactoring

* Fix typo

* Refactor

* Increase coverage

* Fix linting

* Consistent use of backticks

* Fix too long line in docs

* Apply suggestions from code review

* Fix deprecation test

* Fix deprecation test

* Default threshold back to 0.5

* Minor documentation fixes

* Add types to tests

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
7 people authored Dec 30, 2020
1 parent 2094633 commit 7f71ee9
Show file tree
Hide file tree
Showing 16 changed files with 948 additions and 142 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `max_fpr` parameter to `auroc` metric for computing partial auroc metric ([#3790](https://github.com/PyTorchLightning/pytorch-lightning/pull/3790))

- `StatScores` metric to compute the number of true positives, false positives, true negatives and false negatives ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))


### Changed

- `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

### Deprecated

- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

### Removed

Expand Down
68 changes: 62 additions & 6 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,62 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])

In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label. For example, if both predictions and targets are 1d
binary tensors. Or it could be the other way around, you want to treat binary/multi-label
inputs as 2-class (multi-dimensional) multi-class inputs.

Using the ``is_multiclass`` parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
integer (binary) tensors. Or it could be the other way around, you want to treat
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.

For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument.
``is_multiclass`` argument. Let's see how this is used on the example of
:class:`~pytorch_lightning.metrics.classification.StatScores` metric.

First, let's consider the case with label predictions with 2 classes, which we want to
treat as binary.

.. testcode::

from pytorch_lightning.metrics.functional import stat_scores

# These inputs are supposed to be binary, but appear as multi-class
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

As you can see below, by default the inputs are treated
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])

Next, consider the opposite example: inputs are binary (as predictions are probabilities),
but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.

.. testcode::

preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])

In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])


Class Metrics (Classification)
------------------------------
Expand Down Expand Up @@ -323,6 +372,13 @@ ROC
:noindex:


StatScores
~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.StatScores
:noindex:


Functional Metrics (Classification)
-----------------------------------

Expand Down Expand Up @@ -444,7 +500,7 @@ select_topk [func]
stat_scores [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.stat_scores
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
:noindex:


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ROC,
FBeta,
F1,
StatScores
)

from pytorch_lightning.metrics.regression import ( # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401
from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401
46 changes: 24 additions & 22 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class Accuracy(Metric):
r"""
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`__:
.. math::
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
Expand All @@ -43,7 +43,7 @@ class Accuracy(Metric):
Args:
threshold:
Threshold probability value for transforming probability predictions to binary
`(0,1)` predictions, in the case of binary or multi-label inputs.
(0,1) predictions, in the case of binary or multi-label inputs.
top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
Expand All @@ -54,27 +54,29 @@ class Accuracy(Metric):
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).
For multi-label inputs, if the parameter is set to `True`, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to `False`, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to `False`, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
- For multi-label inputs, if the parameter is set to ``True``, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to ``False``, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
before returning the value at the step
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Example:
Expand Down Expand Up @@ -113,11 +115,11 @@ def __init__(
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

if not 0 <= threshold <= 1:
raise ValueError("The `threshold` should lie in the [0,1] interval.")
if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")

if top_k is not None and top_k <= 0:
raise ValueError("The `top_k` should be an integer larger than 1.")
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

self.threshold = threshold
self.top_k = top_k
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/metrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ class HammingDistance(Metric):
Args:
threshold:
Threshold probability value for transforming probability predictions to binary
`(0,1)` predictions, in the case of binary or multi-label inputs.
(0 or 1) predictions, in the case of binary or multi-label inputs.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the all gather.
Expand Down Expand Up @@ -80,8 +81,8 @@ def __init__(
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

if not 0 <= threshold <= 1:
raise ValueError("The `threshold` should lie in the [0,1] interval.")
if not 0 < threshold < 1:
raise ValueError("The `threshold` should lie in the (0,1) interval.")
self.threshold = threshold

def update(self, preds: torch.Tensor, target: torch.Tensor):
Expand Down
Loading

0 comments on commit 7f71ee9

Please sign in to comment.