diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a883322463..008f16ba52b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,9 +19,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added argument `normalize` to `LPIPS` metric ([#1216](https://github.com/Lightning-AI/metrics/pull/1216)) - - Added support for multiprocessing of batches in `PESQ` metric ([#1227](https://github.com/Lightning-AI/metrics/pull/1227)) +- Added support for multioutput in `PearsonCorrCoef` and `SpearmanCorrCoef` ([#1200](https://github.com/Lightning-AI/metrics/pull/1200)) + ### Changed - Classification refactor ( diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index cc3b26edb3d..5273bbb8388 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -28,6 +28,7 @@ def _pearson_corrcoef_update( var_y: Tensor, corr_xy: Tensor, n_prior: Tensor, + num_outputs: int, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Updates and returns variables required to compute Pearson Correlation Coefficient. @@ -43,18 +44,24 @@ def _pearson_corrcoef_update( """ # Data checking _check_same_shape(preds, target) - preds = preds.squeeze() - target = target.squeeze() - if preds.ndim > 1 or target.ndim > 1: - raise ValueError("Expected both predictions and target to be 1 dimensional tensors.") - - n_obs = preds.numel() - mx_new = (n_prior * mean_x + preds.mean() * n_obs) / (n_prior + n_obs) - my_new = (n_prior * mean_y + target.mean() * n_obs) / (n_prior + n_obs) + if preds.ndim > 2 or target.ndim > 2: + raise ValueError( + f"Expected both predictions and target to be either 1- or 2-dimensional tensors," + f" but got {target.ndim} and {preds.ndim}." + ) + if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): + raise ValueError( + f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" + f" and {preds.ndim}." + ) + + n_obs = preds.shape[0] + mx_new = (n_prior * mean_x + preds.mean(0) * n_obs) / (n_prior + n_obs) + my_new = (n_prior * mean_y + target.mean(0) * n_obs) / (n_prior + n_obs) n_prior += n_obs - var_x += ((preds - mx_new) * (preds - mean_x)).sum() - var_y += ((target - my_new) * (target - mean_y)).sum() - corr_xy += ((preds - mx_new) * (target - mean_y)).sum() + var_x += ((preds - mx_new) * (preds - mean_x)).sum(0) + var_y += ((target - my_new) * (target - mean_y)).sum(0) + corr_xy += ((preds - mx_new) * (target - mean_y)).sum(0) mean_x = mx_new mean_y = my_new @@ -89,15 +96,25 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: preds: estimated scores target: ground truth scores - Example: + Example (single output regression): >>> from torchmetrics.functional import pearson_corrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> pearson_corrcoef(preds, target) tensor(0.9849) + + Example (multi output regression): + >>> from torchmetrics.functional import pearson_corrcoef + >>> target = torch.tensor([[3, -0.5], [2, 7]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> pearson_corrcoef(preds, target) + tensor([1., 1.]) """ - _temp = torch.zeros(1, dtype=preds.dtype, device=preds.device) + d = preds.shape[1] if preds.ndim == 2 else 1 + _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device) mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone() var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone() - _, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update(preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb) + _, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update( + preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb, num_outputs=1 if preds.ndim == 1 else preds.shape[-1] + ) return _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb) diff --git a/src/torchmetrics/functional/regression/spearman.py b/src/torchmetrics/functional/regression/spearman.py index 6879cf4c132..4eb1206e5b1 100644 --- a/src/torchmetrics/functional/regression/spearman.py +++ b/src/torchmetrics/functional/regression/spearman.py @@ -52,7 +52,7 @@ def _rank_data(data: Tensor) -> Tensor: return rank -def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: +def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]: """Updates and returns variables required to compute Spearman Correlation Coefficient. Checks for same shape and type of input tensors. @@ -68,10 +68,17 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Te f" Got preds: {preds.dtype} and target: {target.dtype}." ) _check_same_shape(preds, target) - preds = preds.squeeze() - target = target.squeeze() - if preds.ndim > 1 or target.ndim > 1: - raise ValueError("Expected both predictions and target to be 1 dimensional tensors.") + if preds.ndim > 2 or target.ndim > 2: + raise ValueError( + f"Expected both predictions and target to be either 1- or 2-dimensional tensors," + f" but got {target.ndim} and {preds.ndim}." + ) + if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): + raise ValueError( + f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" + f" and {preds.ndim}." + ) + return preds, target @@ -86,20 +93,23 @@ def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) Example: >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> preds, target = _spearman_corrcoef_update(preds, target) + >>> preds, target = _spearman_corrcoef_update(preds, target, num_outputs=1) >>> _spearman_corrcoef_compute(preds, target) tensor(1.0000) """ + if preds.ndim == 1: + preds = _rank_data(preds) + target = _rank_data(target) + else: + preds = torch.stack([_rank_data(p) for p in preds.T]).T + target = torch.stack([_rank_data(t) for t in target.T]).T - preds = _rank_data(preds) - target = _rank_data(target) - - preds_diff = preds - preds.mean() - target_diff = target - target.mean() + preds_diff = preds - preds.mean(0) + target_diff = target - target.mean(0) - cov = (preds_diff * target_diff).mean() - preds_std = torch.sqrt((preds_diff * preds_diff).mean()) - target_std = torch.sqrt((target_diff * target_diff).mean()) + cov = (preds_diff * target_diff).mean(0) + preds_std = torch.sqrt((preds_diff * preds_diff).mean(0)) + target_std = torch.sqrt((target_diff * target_diff).mean(0)) corrcoef = cov / (preds_std * target_std + eps) return torch.clamp(corrcoef, -1.0, 1.0) @@ -119,13 +129,19 @@ def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor: preds: estimated scores target: ground truth scores - Example: + Example (single output regression): >>> from torchmetrics.functional import spearman_corrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> spearman_corrcoef(preds, target) tensor(1.0000) + Example (multi output regression): + >>> from torchmetrics.functional import spearman_corrcoef + >>> target = torch.tensor([[3, -0.5], [2, 7]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> spearman_corrcoef(preds, target) + tensor([1.0000, 1.0000]) """ - preds, target = _spearman_corrcoef_update(preds, target) + preds, target = _spearman_corrcoef_update(preds, target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1]) return _spearman_corrcoef_compute(preds, target) diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 463b8ed540a..0bf188ab682 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -73,13 +73,14 @@ class PearsonCorrCoef(Metric): Forward accepts - - ``preds`` (float tensor): ``(N,)`` - - ``target``(float tensor): ``(N,)`` + - ``preds`` (float tensor): either single output tensor with shape ``(N,)`` or multioutput tensor of shape ``(N,d)`` + - ``target``(float tensor): either single output tensor with shape ``(N,)`` or multioutput tensor of shape ``(N,d)`` Args: + num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Example: + Example (single output regression): >>> from torchmetrics import PearsonCorrCoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) @@ -87,6 +88,14 @@ class PearsonCorrCoef(Metric): >>> pearson(preds, target) tensor(0.9849) + Example (multi output regression): + >>> from torchmetrics import PearsonCorrCoef + >>> target = torch.tensor([[3, -0.5], [2, 7]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> pearson = PearsonCorrCoef(num_outputs=2) + >>> pearson(preds, target) + tensor([1., 1.]) + """ is_differentiable = True higher_is_better = None # both -1 and 1 are optimal @@ -102,16 +111,20 @@ class PearsonCorrCoef(Metric): def __init__( self, + num_outputs: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) + if not isinstance(num_outputs, int) and num_outputs < 1: + raise ValueError("Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}") + self.num_outputs = num_outputs - self.add_state("mean_x", default=torch.tensor(0.0), dist_reduce_fx=None) - self.add_state("mean_y", default=torch.tensor(0.0), dist_reduce_fx=None) - self.add_state("var_x", default=torch.tensor(0.0), dist_reduce_fx=None) - self.add_state("var_y", default=torch.tensor(0.0), dist_reduce_fx=None) - self.add_state("corr_xy", default=torch.tensor(0.0), dist_reduce_fx=None) - self.add_state("n_total", default=torch.tensor(0.0), dist_reduce_fx=None) + self.add_state("mean_x", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("mean_y", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("var_x", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("var_y", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("corr_xy", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("n_total", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. @@ -121,12 +134,21 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore target: Ground truth values """ self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total = _pearson_corrcoef_update( - preds, target, self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total + preds, + target, + self.mean_x, + self.mean_y, + self.var_x, + self.var_y, + self.corr_xy, + self.n_total, + self.num_outputs, ) def compute(self) -> Tensor: """Computes pearson correlation coefficient over state.""" - if self.mean_x.numel() > 1: # multiple devices, need further reduction + if (self.num_outputs == 1 and self.mean_x.numel() > 1) or (self.num_outputs > 1 and self.mean_x.ndim > 1): + # multiple devices, need further reduction var_x, var_y, corr_xy, n_total = _final_aggregation( self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total ) diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 4351d0444dc..53b54b2dc03 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -32,10 +32,16 @@ class SpearmanCorrCoef(Metric): Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables. + Forward accepts + + - ``preds`` (float tensor): ``(N,d)`` + - ``target``(float tensor): ``(N,d)`` + Args: + num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Example: + Example (single output regression): >>> from torchmetrics import SpearmanCorrCoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) @@ -43,6 +49,14 @@ class SpearmanCorrCoef(Metric): >>> spearman(preds, target) tensor(1.0000) + Example (multi output regression): + >>> from torchmetrics import SpearmanCorrCoef + >>> target = torch.tensor([[3, -0.5], [2, 7]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> spearman = SpearmanCorrCoef(num_outputs=2) + >>> spearman(preds, target) + tensor([1.0000, 1.0000]) + """ is_differentiable: bool = False higher_is_better: bool = True @@ -52,6 +66,7 @@ class SpearmanCorrCoef(Metric): def __init__( self, + num_outputs: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -59,6 +74,9 @@ def __init__( "Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer." " For large datasets, this may lead to large memory footprint." ) + if not isinstance(num_outputs, int) and num_outputs < 1: + raise ValueError("Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}") + self.num_outputs = num_outputs self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") @@ -70,7 +88,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model target: Ground truth values """ - preds, target = _spearman_corrcoef_update(preds, target) + preds, target = _spearman_corrcoef_update(preds, target, num_outputs=self.num_outputs) self.preds.append(preds) self.target.append(target) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 400251585a6..7364f653740 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple +from functools import partial import pytest import torch @@ -20,7 +21,7 @@ from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.regression.pearson import PearsonCorrCoef from unittests.helpers import seed_all -from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester seed_all(42) @@ -37,10 +38,22 @@ ) +_multi_target_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) + +_multi_target_inputs2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) + + def _sk_pearsonr(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return pearsonr(sk_target, sk_preds)[0] + if preds.ndim == 2: + return [pearsonr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T)] + else: + return pearsonr(target.numpy(), preds.numpy())[0] @pytest.mark.parametrize( @@ -48,12 +61,17 @@ def _sk_pearsonr(preds, target): [ (_single_target_inputs1.preds, _single_target_inputs1.target), (_single_target_inputs2.preds, _single_target_inputs2.target), + (_multi_target_inputs1.preds, _multi_target_inputs1.target), + (_multi_target_inputs2.preds, _multi_target_inputs2.target), ], ) class TestPearsonCorrcoef(MetricTester): + atol = 1e-3 + @pytest.mark.parametrize("compute_on_cpu", [True, False]) @pytest.mark.parametrize("ddp", [True, False]) def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 self.run_class_metric_test( ddp=ddp, preds=preds, @@ -61,7 +79,7 @@ def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp): metric_class=PearsonCorrCoef, sk_metric=_sk_pearsonr, dist_sync_on_step=False, - metric_args={"compute_on_cpu": compute_on_cpu}, + metric_args={"num_outputs": num_outputs, "compute_on_cpu": compute_on_cpu}, ) def test_pearson_corrcoef_functional(self, preds, target): @@ -70,24 +88,35 @@ def test_pearson_corrcoef_functional(self, preds, target): ) def test_pearson_corrcoef_differentiability(self, preds, target): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 self.run_differentiability_test( - preds=preds, target=target, metric_module=PearsonCorrCoef, metric_functional=pearson_corrcoef + preds=preds, + target=target, + metric_module=partial(PearsonCorrCoef, num_outputs=num_outputs), + metric_functional=pearson_corrcoef, ) # Pearson half + cpu does not work due to missing support in torch.sqrt @pytest.mark.xfail(reason="PearsonCorrCoef metric does not support cpu + half precision") def test_pearson_corrcoef_half_cpu(self, preds, target): - self.run_precision_test_cpu(preds, target, PearsonCorrCoef, pearson_corrcoef) + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_precision_test_cpu(preds, target, partial(PearsonCorrCoef, num_outputs=num_outputs), pearson_corrcoef) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_pearson_corrcoef_half_gpu(self, preds, target): - self.run_precision_test_gpu(preds, target, PearsonCorrCoef, pearson_corrcoef) + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_precision_test_gpu(preds, target, partial(PearsonCorrCoef, num_outputs=num_outputs), pearson_corrcoef) def test_error_on_different_shape(): - metric = PearsonCorrCoef() + metric = PearsonCorrCoef(num_outputs=1) with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) - with pytest.raises(ValueError, match="Expected both predictions and target to be 1 dimensional tensors."): - metric(torch.randn(100, 2), torch.randn(100, 2)) + metric = PearsonCorrCoef(num_outputs=5) + with pytest.raises(ValueError, match="Expected both predictions and target to be either 1- or 2-.*"): + metric(torch.randn(100, 2, 5), torch.randn(100, 2, 5)) + + metric = PearsonCorrCoef(num_outputs=2) + with pytest.raises(ValueError, match="Expected argument `num_outputs` to match the second dimension of input.*"): + metric(torch.randn(100, 5), torch.randn(100, 5)) diff --git a/tests/unittests/regression/test_spearman.py b/tests/unittests/regression/test_spearman.py index 7acc8c4b205..221ea8e9c5c 100644 --- a/tests/unittests/regression/test_spearman.py +++ b/tests/unittests/regression/test_spearman.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple +from functools import partial import pytest import torch @@ -20,7 +21,7 @@ from torchmetrics.functional.regression.spearman import _rank_data, spearman_corrcoef from torchmetrics.regression.spearman import SpearmanCorrCoef from unittests.helpers import seed_all -from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester seed_all(42) @@ -36,6 +37,16 @@ target=torch.randn(NUM_BATCHES, BATCH_SIZE), ) +_multi_target_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) + +_multi_target_inputs2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), +) + _specific_input = Input( preds=torch.stack([torch.tensor([1.0, 0.0, 4.0, 1.0, 0.0, 3.0, 0.0]) for _ in range(NUM_BATCHES)]), target=torch.stack([torch.tensor([4.0, 0.0, 3.0, 3.0, 3.0, 1.0, 1.0]) for _ in range(NUM_BATCHES)]), @@ -60,9 +71,10 @@ def test_ranking(preds, target): def _sk_metric(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return spearmanr(sk_target, sk_preds)[0] + if preds.ndim == 2: + return [spearmanr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T)] + else: + return spearmanr(target.numpy(), preds.numpy())[0] @pytest.mark.parametrize( @@ -70,6 +82,8 @@ def _sk_metric(preds, target): [ (_single_target_inputs1.preds, _single_target_inputs1.target), (_single_target_inputs2.preds, _single_target_inputs2.target), + (_multi_target_inputs1.preds, _multi_target_inputs1.target), + (_multi_target_inputs2.preds, _multi_target_inputs2.target), (_specific_input.preds, _specific_input.target), ], ) @@ -79,6 +93,7 @@ class TestSpearmanCorrCoef(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 self.run_class_metric_test( ddp, preds, @@ -86,30 +101,46 @@ def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step): SpearmanCorrCoef, _sk_metric, dist_sync_on_step, + metric_args={"num_outputs": num_outputs}, ) def test_spearman_corrcoef_functional(self, preds, target): self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric) def test_spearman_corrcoef_differentiability(self, preds, target): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 self.run_differentiability_test( - preds=preds, target=target, metric_module=SpearmanCorrCoef, metric_functional=spearman_corrcoef + preds=preds, + target=target, + metric_module=partial(SpearmanCorrCoef, num_outputs=num_outputs), + metric_functional=spearman_corrcoef, ) # Spearman half + cpu does not work due to missing support in torch.arange @pytest.mark.xfail(reason="Spearman metric does not support cpu + half precision") def test_spearman_corrcoef_half_cpu(self, preds, target): - self.run_precision_test_cpu(preds, target, SpearmanCorrCoef, spearman_corrcoef) + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_precision_test_cpu( + preds, target, partial(SpearmanCorrCoef, num_outputs=num_outputs), spearman_corrcoef + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_spearman_corrcoef_half_gpu(self, preds, target): - self.run_precision_test_gpu(preds, target, SpearmanCorrCoef, spearman_corrcoef) + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_precision_test_gpu( + preds, target, partial(SpearmanCorrCoef, num_outputs=num_outputs), spearman_corrcoef + ) def test_error_on_different_shape(): - metric = SpearmanCorrCoef() + metric = SpearmanCorrCoef(num_outputs=1) with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) - with pytest.raises(ValueError, match="Expected both predictions and target to be 1 dimensional tensors."): - metric(torch.randn(100, 2), torch.randn(100, 2)) + metric = SpearmanCorrCoef(num_outputs=5) + with pytest.raises(ValueError, match="Expected both predictions and target to be either 1- or 2-dimensional.*"): + metric(torch.randn(100, 2, 5), torch.randn(100, 2, 5)) + + metric = SpearmanCorrCoef(num_outputs=2) + with pytest.raises(ValueError, match="Expected argument `num_outputs` to match the second dimension of input.*"): + metric(torch.randn(100, 5), torch.randn(100, 5))