forked from Lightning-AI/torchmetrics
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Spearman correlation coefficient (Lightning-AI#158)
* ranking * init files * update * nearly working * fix tests * pep8 * add docs * fix doctests * fix docs * pep8 * isort * ghlog * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
- Loading branch information
1 parent
3f3b1cb
commit 06c13fd
Showing
10 changed files
with
340 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
|
||
import pytest | ||
import torch | ||
from scipy.stats import rankdata, spearmanr | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester | ||
from torchmetrics.functional.regression.spearman import _rank_data, spearman_corrcoef | ||
from torchmetrics.regression.spearman import SpearmanCorrcoef | ||
|
||
seed_all(42) | ||
|
||
Input = namedtuple('Input', ["preds", "target"]) | ||
|
||
_single_target_inputs1 = Input( | ||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE), | ||
target=torch.rand(NUM_BATCHES, BATCH_SIZE), | ||
) | ||
|
||
_single_target_inputs2 = Input( | ||
preds=torch.randn(NUM_BATCHES, BATCH_SIZE), | ||
target=torch.randn(NUM_BATCHES, BATCH_SIZE), | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds, target", [ | ||
(_single_target_inputs1.preds, _single_target_inputs1.target), | ||
(_single_target_inputs2.preds, _single_target_inputs2.target), | ||
] | ||
) | ||
def test_ranking(preds, target): | ||
""" test that ranking function works as expected """ | ||
for p, t in zip(preds, target): | ||
scipy_ranking = [rankdata(p.numpy()), rankdata(t.numpy())] | ||
tm_ranking = [_rank_data(p), _rank_data(t)] | ||
assert (torch.tensor(scipy_ranking[0]) == tm_ranking[0]).all() | ||
assert (torch.tensor(scipy_ranking[1]) == tm_ranking[1]).all() | ||
|
||
|
||
def _sk_metric(preds, target): | ||
sk_preds = preds.view(-1).numpy() | ||
sk_target = target.view(-1).numpy() | ||
return spearmanr(sk_target, sk_preds)[0] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds, target", [ | ||
(_single_target_inputs1.preds, _single_target_inputs1.target), | ||
(_single_target_inputs2.preds, _single_target_inputs2.target), | ||
] | ||
) | ||
class TestSpearmanCorrcoef(MetricTester): | ||
atol = 1e-2 | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step): | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
SpearmanCorrcoef, | ||
_sk_metric, | ||
dist_sync_on_step, | ||
) | ||
|
||
def test_spearman_corrcoef_functional(self, preds, target): | ||
self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric) | ||
|
||
# Spearman half + cpu does not work due to missing support in torch.arange | ||
@pytest.mark.xfail(reason="Spearman metric does not support cpu + half precision") | ||
def test_spearman_corrcoef_half_cpu(self, preds, target): | ||
self.run_precision_test_cpu(preds, target, SpearmanCorrcoef, spearman_corrcoef) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') | ||
def test_spearman_corrcoef_half_gpu(self, preds, target): | ||
self.run_precision_test_gpu(preds, target, SpearmanCorrcoef, spearman_corrcoef) | ||
|
||
|
||
def test_error_on_different_shape(): | ||
metric = SpearmanCorrcoef() | ||
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): | ||
metric(torch.randn(100, ), torch.randn(50, )) | ||
|
||
with pytest.raises(ValueError, match='Expected both predictions and target to be 1 dimensional tensors.'): | ||
metric(torch.randn(100, 2), torch.randn(100, 2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.utilities.checks import _check_same_shape | ||
|
||
|
||
def _find_repeats(data: Tensor): | ||
""" find and return values which have repeats i.e. the same value are more than once in the tensor """ | ||
temp = data.detach().clone() | ||
temp = temp.sort()[0] | ||
|
||
change = torch.cat([torch.tensor([True], device=temp.device), temp[1:] != temp[:-1]]) | ||
unique = temp[change] | ||
change_idx = torch.cat([ | ||
torch.nonzero(change), | ||
torch.tensor([[temp.numel()]], device=temp.device) | ||
]).flatten() | ||
freq = change_idx[1:] - change_idx[:-1] | ||
atleast2 = freq > 1 | ||
return unique[atleast2] | ||
|
||
|
||
def _rank_data(data: Tensor): | ||
""" Calculate the rank for each element of a tensor. The rank refers to the indices of an element in the | ||
corresponding sorted tensor (starting from 1). Duplicates of the same value will be assigned the mean of | ||
their rank | ||
Adopted from: | ||
https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303 | ||
""" | ||
n = data.numel() | ||
rank = torch.empty_like(data) | ||
idx = data.argsort() | ||
rank[idx[:n]] = torch.arange(1, n + 1, dtype=data.dtype, device=data.device) | ||
|
||
repeats = _find_repeats(data) | ||
for r in repeats: | ||
condition = rank == r | ||
rank[condition] = rank[condition].mean() | ||
return rank | ||
|
||
|
||
def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
if preds.dtype != target.dtype: | ||
raise TypeError( | ||
"Expected `preds` and `target` to have the same data type." | ||
f" Got pred: {preds.dtype} and target: {target.dtype}." | ||
) | ||
_check_same_shape(preds, target) | ||
|
||
if preds.ndim > 1 or target.ndim > 1: | ||
raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') | ||
|
||
return preds, target | ||
|
||
|
||
def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor: | ||
preds = _rank_data(preds) | ||
target = _rank_data(target) | ||
|
||
preds_diff = preds - preds.mean() | ||
target_diff = target - target.mean() | ||
|
||
cov = (preds_diff * target_diff).mean() | ||
preds_std = torch.sqrt((preds_diff * preds_diff).mean()) | ||
target_std = torch.sqrt((target_diff * target_diff).mean()) | ||
|
||
corrcoef = cov / (preds_std * target_std + eps) | ||
return torch.clamp(corrcoef, -1.0, 1.0) | ||
|
||
|
||
def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: | ||
r""" | ||
Computes `spearmans rank correlation coefficient | ||
<https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_: | ||
.. math: | ||
r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}} | ||
where :math:`rg_x` and :math:`rg_y` are the rank associated to the variables x and y. Spearmans correlations | ||
coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables. | ||
Args: | ||
preds: estimated scores | ||
target: ground truth scores | ||
Example: | ||
>>> from torchmetrics.functional import spearman_corrcoef | ||
>>> target = torch.tensor([3, -0.5, 2, 7]) | ||
>>> preds = torch.tensor([2.5, 0.0, 2, 8]) | ||
>>> spearman_corrcoef(preds, target) | ||
tensor(1.0000) | ||
""" | ||
preds, target = _spearman_corrcoef_update(preds, target) | ||
return _spearman_corrcoef_compute(preds, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.