Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pearson correlation coefficient #157

Merged
merged 26 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))


- Added `PearsonCorrcoef` metric ([#157](https://github.com/PyTorchLightning/metrics/pull/157))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ mean_squared_log_error [func]
:noindex:


pearson_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.pearson_corrcoef
:noindex:


psnr [func]
~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ MeanSquaredLogError
:noindex:


PearsonCorrcoef
~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.PearsonCorrcoef
:noindex:


PSNR
~~~~

Expand Down
119 changes: 119 additions & 0 deletions tests/regression/test_pearson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple

import pytest
import torch
from scipy.stats import pearsonr

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.regression.pearson import _update_cov, _update_mean, pearson_corrcoef
from torchmetrics.regression.pearson import PearsonCorrcoef

seed_all(42)


def test_update_functions(tmpdir):
""" Test that updating the estimates are equal to estimating them on all data """
data = torch.randn(100, 2)
batch1, batch2 = data.chunk(2)

def _mean_cov(data):
mean = data.mean(0)
diff = data - mean
cov = diff.T @ diff
return mean, cov

mean_update, cov_update, size_update = torch.zeros(2), torch.zeros(2, 2), torch.zeros(1)
for batch in [batch1, batch2]:
new_mean = _update_mean(mean_update, size_update, batch)
new_cov = _update_cov(cov_update, mean_update, new_mean, batch)

assert not torch.allclose(new_mean, mean_update), "mean estimate did not update"
assert not torch.allclose(new_cov, cov_update), "covariance estimate did not update"

size_update += batch.shape[0]
mean_update = new_mean
cov_update = new_cov

mean, cov = _mean_cov(data)

assert torch.allclose(mean, mean_update), "updated mean does not correspond to mean of all data"
assert torch.allclose(cov, cov_update), "updated covariance does not correspond to covariance of all data"


Input = namedtuple('Input', ["preds", "target"])

_single_target_inputs1 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)

_single_target_inputs2 = Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE),
target=torch.randn(NUM_BATCHES, BATCH_SIZE),
)


def _sk_pearsonr(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return pearsonr(sk_target, sk_preds)[0]


@pytest.mark.parametrize("preds, target", [
(_single_target_inputs1.preds, _single_target_inputs1.target),
(_single_target_inputs2.preds, _single_target_inputs2.target),
])
class TestPearsonCorrcoef(MetricTester):
atol = 1e-4

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=PearsonCorrcoef,
sk_metric=_sk_pearsonr,
dist_sync_on_step=dist_sync_on_step,
)

def test_pearson_corrcoef_functional(self, preds, target):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=pearson_corrcoef,
sk_metric=_sk_pearsonr
)

# Pearson half + cpu does not work due to missing support in torch.sqrt
@pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision")
def test_pearson_corrcoef_half_cpu(self, preds, target):
self.run_precision_test_cpu(preds, target, PearsonCorrcoef, pearson_corrcoef)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_pearson_corrcoef_half_gpu(self, preds, target):
self.run_precision_test_gpu(preds, target, PearsonCorrcoef, pearson_corrcoef)


def test_error_on_different_shape():
metric = PearsonCorrcoef()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100, ), torch.randn(50, ))

with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'):
metric(torch.randn(100, 2), torch.randn(100, 2))
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torchmetrics.collections import MetricCollection # noqa: F401 E402
from torchmetrics.metric import Metric # noqa: F401 E402
from torchmetrics.regression import ( # noqa: F401 E402
PearsonCorrcoef,
PSNR,
SSIM,
ExplainedVariance,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torchmetrics.functional.regression.mean_relative_error import mean_relative_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.psnr import psnr # noqa: F401
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod


def _safe_divide(num: torch.Tensor, denom: torch.Tensor):
def _safe_divide(num: Tensor, denom: Tensor):
""" prevent zero division """
denom[denom == 0.] = 1
return num / denom
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.psnr import psnr # noqa: F401
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
106 changes: 106 additions & 0 deletions torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def _update_mean(old_mean: Tensor, old_nobs: Tensor, data: Tensor) -> Tensor:
""" Update a mean estimate given new data
Args:
old_mean: current mean estimate
old_nobs: number of observation until now
data: data used for updating the estimate
Returns:
new_mean: updated mean estimate
"""
data_size = data.shape[0]
return (old_mean * old_nobs + data.mean(dim=0) * data_size) / (old_nobs + data_size)


def _update_cov(old_cov: Tensor, old_mean: Tensor, new_mean: Tensor, data: Tensor):
""" Update a covariance estimate given new data
Args:
old_cov: current covariance estimate
old_mean: current mean estimate
new_mean: updated mean estimate
data: data used for updating the estimate
Returns:
new_mean: updated covariance estimate
"""
return old_cov + (data - new_mean).T @ (data - old_mean)


def _pearson_corrcoef_update(
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
preds: Tensor,
target: Tensor,
old_mean: Optional[Tensor] = None,
old_cov: Optional[Tensor] = None,
old_nobs: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
""" updates current estimates of the mean, cov and n_obs with new data for calculating
pearsons correlation
"""
# Data checking
_check_same_shape(preds, target)
preds = preds.squeeze()
target = target.squeeze()
if preds.ndim > 1 or target.ndim > 1:
raise ValueError('Expected both predictions and target to be 1 dimensional tensors.')
data = torch.stack([preds, target], dim=1)

if old_mean is None:
old_mean = 0
if old_cov is None:
old_cov = 0
if old_nobs is None:
old_nobs = 0
Borda marked this conversation as resolved.
Show resolved Hide resolved

new_mean = _update_mean(old_mean, old_nobs, data)
new_cov = _update_cov(old_cov, old_mean, new_mean, data)
new_size = old_nobs + preds.numel()

return new_mean, new_cov, new_size


def _pearson_corrcoef_compute(c: Tensor, nobs: Tensor):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
""" computes the final pearson correlation based on covariance matrix and number of observatiosn """
c /= (nobs - 1)
x_var = c[0, 0]
y_var = c[1, 1]
cov = c[0, 1]
corrcoef = cov / (x_var * y_var).sqrt()
return torch.clamp(corrcoef, -1.0, 1.0)


def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes pearson correlation coefficient.

Args:
preds: estimated scores
target: ground truth scores

Example:
>>> from torchmetrics.functional import pearson_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson_corrcoef(preds, target)
tensor(0.9849)
"""
_, c, nobs = _pearson_corrcoef_update(preds, target)
return _pearson_corrcoef_compute(c, nobs)
1 change: 1 addition & 0 deletions torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mean_squared_error import MeanSquaredError # noqa: F401
from torchmetrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401
from torchmetrics.regression.psnr import PSNR # noqa: F401
from torchmetrics.regression.r2score import R2Score # noqa: F401
from torchmetrics.regression.ssim import SSIM # noqa: F401
92 changes: 92 additions & 0 deletions torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torch import Tensor

from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update
from torchmetrics.metric import Metric


class PearsonCorrcoef(Metric):
r"""
Computes `pearson correlation coefficient
<https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_:

.. math:: \text{P_corr}(x,y) = \frac{cov(x,y)}{\sigma_x \times \sigma_y}

Where :math:`y` is a tensor of target values, and :math:`x` is a
tensor of predictions.

Forward accepts

- ``preds`` (float tensor): ``(N,)``
- ``target``(float tensor): ``(N,)``

Args:
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather

Example:
>>> from torchmetrics import PearsonCorrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson = PearsonCorrcoef()
>>> pearson(preds, target)
tensor(0.9849)

"""
def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.add_state("cov", default=torch.zeros(2, 2), dist_reduce_fx="sum")
self.add_state("mean", default=torch.zeros(2), dist_reduce_fx="sum")
self.add_state("n_obs", default=torch.zeros(1), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
self.mean, self.cov, self.n_obs = _pearson_corrcoef_update(
preds, target, self.mean, self.cov, self.n_obs
)

def compute(self):
"""
Computes pearson correlation coefficient over state.
"""
return _pearson_corrcoef_compute(self.cov, self.n_obs)