Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Matthews corrcoef #98

Merged
merged 15 commits into from
Mar 23, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69))


- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98))

### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
5 changes: 5 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ iou [func]
.. autofunction:: torchmetrics.functional.iou
:noindex:

matthews_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.matthews_corrcoef
:noindex:

roc [func]
~~~~~~~~~~~~~~~~~~~~~
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ IoU
.. autoclass:: torchmetrics.IoU
:noindex:

MatthewsCorrcoef
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.MatthewsCorrcoef
:noindex:

Hamming Distance
~~~~~~~~~~~~~~~~

Expand Down
127 changes: 127 additions & 0 deletions tests/classification/test_matthews_corrcoef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef

torch.manual_seed(42)


def _sk_matthews_corrcoef_binary_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_binary(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multilabel_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multilabel(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multidim_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multidim_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_matthews_corrcoef_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_matthews_corrcoef_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_matthews_corrcoef_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_matthews_corrcoef_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_matthews_corrcoef_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_matthews_corrcoef_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_matthews_corrcoef_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES)]
)
class TestMatthewsCorrCoef(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=MatthewsCorrcoef,
sk_metric=sk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
}
)

def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes):
self.run_functional_metric_test(
preds,
target,
metric_functional=matthews_corrcoef,
sk_metric=sk_metric,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
}
)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
FBeta,
HammingDistance,
IoU,
MatthewsCorrcoef,
Precision,
PrecisionRecallCurve,
Recall,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.classification.f_beta import F1, FBeta # noqa: F401
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from torchmetrics.classification.roc import ROC # noqa: F401
Expand Down
113 changes: 113 additions & 0 deletions torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch

from torchmetrics.functional.classification.matthews_corrcoef import (
_matthews_corrcoef_compute,
_matthews_corrcoef_update,
)
from torchmetrics.metric import Metric


class MatthewsCorrcoef(Metric):
r"""
Calculates `Matthews correlation coefficient
<https://en.wikipedia.org/wiki/Matthews_correlation_coefficient>`_ that measures
the general correlation or quality of a classification. In the binary case it
is defined as:

.. math::
MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}}

where TP, TN, FP and FN are respectively the true postitives, true negatives,
false positives and false negatives. Also works in the case of multi-label or
multi-class input.

Note:
This metric produces a multi-dimensional output, so it can not be directly logged.

Forward accepts

- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

Args:
num_classes: Number of classes in the dataset.
threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather

Example:

>>> from torchmetrics import MatthewsCorrcoef
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> matthews_corrcoef = MatthewsCorrcoef(num_classes=2)
>>> matthews_corrcoef(preds, target)
tensor(0.5774)

"""
def __init__(
self,
num_classes: int,
threshold: float = 0.5,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):

super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.num_classes = num_classes
self.threshold = threshold

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
confmat = _matthews_corrcoef_update(preds, target, self.num_classes, self.threshold)
self.confmat += confmat

def compute(self) -> torch.Tensor:
"""
Computes matthews correlation coefficient
"""
return _matthews_corrcoef_compute(self.confmat)
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
from torchmetrics.functional.classification.roc import roc # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
from torchmetrics.functional.classification.roc import roc # noqa: F401
Expand Down
65 changes: 65 additions & 0 deletions torchmetrics/functional/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update

_matthews_corrcoef_update = _confusion_matrix_update


def _matthews_corrcoef_compute(confmat: torch.Tensor) -> torch.Tensor:
tk = confmat.sum(dim=0).float()
pk = confmat.sum(dim=1).float()
c = torch.trace(confmat).float()
s = confmat.sum().float()
return (c * s - sum(tk * pk)) / (torch.sqrt(s ** 2 - sum(pk * pk)) * torch.sqrt(s ** 2 - sum(tk * tk)))


def matthews_corrcoef(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
threshold: float = 0.5
) -> torch.Tensor:
r"""
Calculates `Matthews correlation coefficient
<https://en.wikipedia.org/wiki/Matthews_correlation_coefficient>`_ that measures
the general correlation or quality of a classification. In the binary case it
is defined as:

.. math::
MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}}

where TP, TN, FP and FN are respectively the true postitives, true negatives,
false positives and false negatives. Also works in the case of multi-label or
multi-class input.

Args:
preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or
``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities
target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels
num_classes: Number of classes in the dataset.
threshold:
Threshold value for binary or multi-label probabilities. default: 0.5

Example:
>>> from torchmetrics.functional import matthews_corrcoef
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> matthews_corrcoef(preds, target, num_classes=2)
tensor(0.5774)

"""
confmat = _matthews_corrcoef_update(preds, target, num_classes, threshold)
return _matthews_corrcoef_compute(confmat)