-
Notifications
You must be signed in to change notification settings - Fork 402
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
283 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
|
||
import pytest | ||
import torch | ||
from scipy.stats import pearsonr | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester | ||
from torchmetrics.functional.regression.pearson import pearson_corrcoef | ||
from torchmetrics.regression.pearson import PearsonCorrcoef | ||
|
||
seed_all(42) | ||
|
||
|
||
Input = namedtuple('Input', ["preds", "target"]) | ||
|
||
_single_target_inputs1 = Input( | ||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE), | ||
target=torch.rand(NUM_BATCHES, BATCH_SIZE), | ||
) | ||
|
||
_single_target_inputs2 = Input( | ||
preds=torch.randn(NUM_BATCHES, BATCH_SIZE), | ||
target=torch.randn(NUM_BATCHES, BATCH_SIZE), | ||
) | ||
|
||
|
||
def _sk_pearsonr(preds, target): | ||
sk_preds = preds.view(-1).numpy() | ||
sk_target = target.view(-1).numpy() | ||
return pearsonr(sk_target, sk_preds)[0] | ||
|
||
|
||
@pytest.mark.parametrize("preds, target", [ | ||
(_single_target_inputs1.preds, _single_target_inputs1.target), | ||
(_single_target_inputs2.preds, _single_target_inputs2.target), | ||
]) | ||
class TestPearsonCorrcoef(MetricTester): | ||
atol = 1e-2 | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): | ||
self.run_class_metric_test( | ||
ddp=ddp, | ||
preds=preds, | ||
target=target, | ||
metric_class=PearsonCorrcoef, | ||
sk_metric=_sk_pearsonr, | ||
dist_sync_on_step=dist_sync_on_step, | ||
) | ||
|
||
def test_pearson_corrcoef_functional(self, preds, target): | ||
self.run_functional_metric_test( | ||
preds=preds, | ||
target=target, | ||
metric_functional=pearson_corrcoef, | ||
sk_metric=_sk_pearsonr | ||
) | ||
|
||
# Pearson half + cpu does not work due to missing support in torch.sqrt | ||
@pytest.mark.xfail(reason="PearsonCorrcoef metric does not support cpu + half precision") | ||
def test_pearson_corrcoef_half_cpu(self, preds, target): | ||
self.run_precision_test_cpu(preds, target, PearsonCorrcoef, pearson_corrcoef) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') | ||
def test_pearson_corrcoef_half_gpu(self, preds, target): | ||
self.run_precision_test_gpu(preds, target, PearsonCorrcoef, pearson_corrcoef) | ||
|
||
|
||
def test_error_on_different_shape(): | ||
metric = PearsonCorrcoef() | ||
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): | ||
metric(torch.randn(100, ), torch.randn(50, )) | ||
|
||
with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'): | ||
metric(torch.randn(100, 2), torch.randn(100, 2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.utilities.checks import _check_same_shape | ||
|
||
|
||
def _pearson_corrcoef_update( | ||
preds: Tensor, | ||
target: Tensor, | ||
old_mean: Optional[Tensor] = None, | ||
old_cov: Optional[Tensor] = None, | ||
old_nobs: Optional[Tensor] = None | ||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | ||
""" updates current estimates of the mean, cov and n_obs with new data for calculating | ||
pearsons correlation | ||
""" | ||
# Data checking | ||
_check_same_shape(preds, target) | ||
preds = preds.squeeze() | ||
target = target.squeeze() | ||
if preds.ndim > 1 or target.ndim > 1: | ||
raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') | ||
|
||
return preds, target | ||
|
||
|
||
def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor: | ||
""" computes the final pearson correlation based on covariance matrix and number of observatiosn """ | ||
preds_diff = preds - preds.mean() | ||
target_diff = target - target.mean() | ||
|
||
cov = (preds_diff * target_diff).mean() | ||
preds_std = torch.sqrt((preds_diff * preds_diff).mean()) | ||
target_std = torch.sqrt((target_diff * target_diff).mean()) | ||
|
||
denom = preds_std * target_std | ||
# prevent division by zero | ||
if denom == 0: | ||
denom += eps | ||
|
||
corrcoef = cov / denom | ||
return torch.clamp(corrcoef, -1.0, 1.0) | ||
|
||
|
||
def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: | ||
""" | ||
Computes pearson correlation coefficient. | ||
Args: | ||
preds: estimated scores | ||
target: ground truth scores | ||
Example: | ||
>>> from torchmetrics.functional import pearson_corrcoef | ||
>>> target = torch.tensor([3, -0.5, 2, 7]) | ||
>>> preds = torch.tensor([2.5, 0.0, 2, 8]) | ||
>>> pearson_corrcoef(preds, target) | ||
tensor(0.9849) | ||
""" | ||
preds, target = _pearson_corrcoef_update(preds, target) | ||
return _pearson_corrcoef_compute(preds, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update | ||
from torchmetrics.metric import Metric | ||
from torchmetrics.utilities import rank_zero_warn | ||
|
||
|
||
class PearsonCorrcoef(Metric): | ||
r""" | ||
Computes `pearson correlation coefficient | ||
<https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_: | ||
.. math:: | ||
P_{corr}(x,y) = \frac{cov(x,y)}{\sigma_x \sigma_y} | ||
Where :math:`y` is a tensor of target values, and :math:`x` is a | ||
tensor of predictions. | ||
Forward accepts | ||
- ``preds`` (float tensor): ``(N,)`` | ||
- ``target``(float tensor): ``(N,)`` | ||
Args: | ||
compute_on_step: | ||
Forward only calls ``update()`` and return None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. default: False | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects the entire world) | ||
Example: | ||
>>> from torchmetrics import PearsonCorrcoef | ||
>>> target = torch.tensor([3, -0.5, 2, 7]) | ||
>>> preds = torch.tensor([2.5, 0.0, 2, 8]) | ||
>>> pearson = PearsonCorrcoef() | ||
>>> pearson(preds, target) | ||
tensor(0.9849) | ||
""" | ||
def __init__( | ||
self, | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Optional[Any] = None, | ||
): | ||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
) | ||
|
||
rank_zero_warn( | ||
'Metric `PearsonCorrcoef` will save all targets and predictions in buffer.' | ||
' For large datasets this may lead to large memory footprint.' | ||
) | ||
|
||
self.add_state("preds", default=[], dist_reduce_fx=None) | ||
self.add_state("target", default=[], dist_reduce_fx=None) | ||
|
||
def update(self, preds: Tensor, target: Tensor): | ||
""" | ||
Update state with predictions and targets. | ||
Args: | ||
preds: Predictions from model | ||
target: Ground truth values | ||
""" | ||
preds, target = _pearson_corrcoef_update(preds, target) | ||
self.preds.append(preds) | ||
self.target.append(target) | ||
|
||
def compute(self): | ||
""" | ||
Computes pearson correlation coefficient over state. | ||
""" | ||
preds = torch.cat(self.preds, dim=0) | ||
target = torch.cat(self.target, dim=0) | ||
return _pearson_corrcoef_compute(preds, target) |