Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted Pearson, Spearman, and R2 score #1759

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019)


## [1.1.2] - 2023-09-11
Expand Down
95 changes: 72 additions & 23 deletions src/torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.functional.regression.utils import _check_data_shape_to_weights, _check_data_shape_to_num_outputs
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _check_same_shape

Expand All @@ -32,6 +32,7 @@ def _pearson_corrcoef_update(
corr_xy: Tensor,
n_prior: Tensor,
num_outputs: int,
weights: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Update and returns variables required to compute Pearson Correlation Coefficient.

Expand All @@ -47,34 +48,56 @@ def _pearson_corrcoef_update(
corr_xy: current covariance estimate between x and y tensor
n_prior: current number of observed observations
num_outputs: Number of outputs in multioutput setting

weights: weights associated with scores
"""
# Data checking
_check_same_shape(preds, target)
_check_data_shape_to_num_outputs(preds, target, num_outputs)
n_obs = preds.shape[0]
if weights is not None:
_check_data_shape_to_weights(preds, weights)

n_obs = preds.shape[0] if weights is None else weights.sum()
cond = n_prior.mean() > 0 or n_obs == 1

if cond:
mx_new = (n_prior * mean_x + preds.sum(0)) / (n_prior + n_obs)
my_new = (n_prior * mean_y + target.sum(0)) / (n_prior + n_obs)
if weights is None:
mx_new = (n_prior * mean_x + preds.sum(0)) / (n_prior + n_obs)
my_new = (n_prior * mean_y + target.sum(0)) / (n_prior + n_obs)
else:
mx_new = (n_prior * mean_x + torch.matmul(weights, preds)) / (n_prior + n_obs)
my_new = (n_prior * mean_y + torch.matmul(weights, target)) / (n_prior + n_obs)
else:
mx_new = preds.mean(0)
my_new = target.mean(0)
if weights is None:
mx_new = preds.mean(0)
my_new = target.mean(0)
else:
mx_new = torch.matmul(weights, preds) / weights.sum()
my_new = torch.matmul(weights, target) / weights.sum()
Comment on lines 63 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rather have have if weights is NOne to make it full of ones to have no effect but the code simpler


n_prior += n_obs

# Calculate variances
if cond:
var_x += ((preds - mx_new) * (preds - mean_x)).sum(0)
var_y += ((target - my_new) * (target - mean_y)).sum(0)
if weights is None:
var_x += ((preds - mx_new) * (preds - mean_x)).sum(0)
var_y += ((target - my_new) * (target - mean_y)).sum(0)
else:
var_x += torch.matmul(weights, (preds - mx_new) * (preds - mean_x))
var_y += torch.matmul(weights, (preds - my_new) * (preds - mean_y))
else:
var_x += preds.var(0) * (n_obs - 1)
var_y += target.var(0) * (n_obs - 1)
corr_xy += ((preds - mx_new) * (target - mean_y)).sum(0)
mean_x = mx_new
mean_y = my_new
if weights is None:
var_x += preds.var(0) * (n_obs - 1)
var_y += target.var(0) * (n_obs - 1)
else:
var_x += torch.matmul(weights, (preds - mx_new) ** 2)
var_y += torch.matmul(weights, (target - my_new) ** 2)

if weights is None:
corr_xy += ((preds - mx_new) * (target - my_new)).sum(0)
else:
corr_xy += torch.matmul(weights, (preds - mx_new) * (target - my_new))

return mean_x, mean_y, var_x, var_y, corr_xy, n_prior
return mx_new, my_new, var_x, var_y, corr_xy, n_prior


def _pearson_corrcoef_compute(
Expand All @@ -92,9 +115,6 @@ def _pearson_corrcoef_compute(
nb: number of observations

"""
var_x /= nb - 1
var_y /= nb - 1
corr_xy /= nb - 1
# if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16
# on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed
if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"):
Expand All @@ -114,12 +134,16 @@ def _pearson_corrcoef_compute(
return torch.clamp(corrcoef, -1.0, 1.0)


def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
def pearson_corrcoef(preds: Tensor, target: Tensor, weights: Optional[Tensor] = None) -> Tensor:
"""Compute pearson correlation coefficient.

Args:
preds: estimated scores
target: ground truth scores
preds: torch.Tensor of shape (n_samples,) or (n_samples, n_outputs)
Estimated scores
target: torch.Tensor of shape (n_samples,) or (n_samples, n_outputs)
Ground truth scores
weights: torch.Tensor of shape (n_samples,), default=None
Sample weights

Example (single output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
Expand All @@ -128,19 +152,44 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
>>> pearson_corrcoef(preds, target)
tensor(0.9849)

Example (weighted single output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> weights = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson_corrcoef(preds, target, weights)
tensor(0.9849)

Example (multi output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
>>> target = torch.tensor([[3, -0.5], [2, 7]])
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> pearson_corrcoef(preds, target)
tensor([1., 1.])

Example (weighted multiple output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> weights = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson_corrcoef(preds, target, weights)
tensor(0.9849)

"""
d = preds.shape[1] if preds.ndim == 2 else 1
_temp = torch.zeros(d, dtype=preds.dtype, device=preds.device)
mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone()
var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone()
_, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update(
preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb, num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
preds,
target,
mean_x,
mean_y,
var_x,
var_y,
corr_xy,
nb,
num_outputs=1 if preds.ndim == 1 else preds.shape[-1],
weights=weights,
)
return _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb)
22 changes: 22 additions & 0 deletions src/torchmetrics/functional/regression/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,25 @@ def _check_data_shape_to_num_outputs(
f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}"
f" and {preds.shape[1]}."
)


def _check_data_shape_to_weights(preds: Tensor, weights: Tensor) -> None:
"""Check that the predictions and weights have the correct shape, else raise error.

This test assumes that the prediction and target tensors have been confirmed to have the same shape.
It further assumes that the `pred` is either a 1- or 2-dimensional tensor.

"""
if preds.ndim == 1 and preds.shape != weights.shape:
raise ValueError(
f"Expected `preds.shape` to equal to `weights.shape`, but got {preds.shape} and {weights.shape}."
)
elif preds.ndim == 2:
if weights.ndim == 1 and preds.shape[0] != len(weights):
raise ValueError(
f"Expected `preds.shape[0]` to equal to `len(weights)` but got {preds.shape[0]} and {len(weights)}."
)
if weights.ndim == 2 and preds.shape != weights.shape:
raise ValueError(
f"Expected `preds.shape` to equal to `weights.shape`, but got {preds.shape} and {weights.shape}."
)
3 changes: 2 additions & 1 deletion src/torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self.add_state("corr_xy", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
self.add_state("n_total", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor, weights: Optional[Tensor] = None) -> None:
"""Update state with predictions and targets."""
self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total = _pearson_corrcoef_update(
preds,
Expand All @@ -153,6 +153,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.corr_xy,
self.n_total,
self.num_outputs,
weights=weights,
)

def compute(self) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _functional_test(
extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
tm_result = metric(preds[i], target[i], **extra_kwargs)
extra_kwargs = {
k: v.cpu() if isinstance(v, Tensor) else v
k: v[i].cpu() if isinstance(v, Tensor) else v
for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items()
}
ref_result = reference_metric(
Expand Down
54 changes: 49 additions & 5 deletions tests/unittests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from collections import namedtuple
from functools import partial

import numpy as np
import pytest
import torch
from scipy.stats import pearsonr
from torch import Tensor
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation

Expand All @@ -28,6 +30,7 @@

Input = namedtuple("Input", ["preds", "target"])


_single_target_inputs1 = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
Expand All @@ -50,12 +53,34 @@
)


def _ref_metric(preds, target, weights=None):
if weights is None:
return _scipy_pearson(preds, target)
return _weighted_pearson(preds, target, weights)


def _scipy_pearson(preds, target):
if preds.ndim == 2:
return [pearsonr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T)]
return pearsonr(target.numpy(), preds.numpy())[0]


def _weighted_pearson(preds, target, weights):
preds = preds.numpy() if isinstance(preds, Tensor) else preds
target = target.numpy() if isinstance(target, Tensor) else target
weights = weights.numpy() if isinstance(weights, Tensor) else weights

if preds.ndim == 2:
return [_weighted_pearson(p, t, weights) for p, t in zip(preds.T, target.T)]

mx = (weights * preds).sum() / weights.sum()
my = (weights * target).sum() / weights.sum()
var_x = (weights * (preds - mx) ** 2).sum()
var_y = (weights * (target - my) ** 2).sum()
cov_xy = (weights * (preds - mx) * (target - my)).sum()
return cov_xy / np.sqrt(var_x * var_y)


@pytest.mark.parametrize(
"preds, target",
[
Expand All @@ -70,24 +95,43 @@ class TestPearsonCorrCoef(MetricTester):

atol = 1e-3

@pytest.mark.parametrize(
"kwargs_update",
[
pytest.param({}, id="None weights"),
pytest.param({"weights": torch.rand(NUM_BATCHES, BATCH_SIZE)}, id="tensor weights"),
],
)
@pytest.mark.parametrize("compute_on_cpu", [True, False])
@pytest.mark.parametrize("ddp", [True, False])
def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp):
def test_pearson_corrcoef(self, preds, target, kwargs_update, compute_on_cpu, ddp):
"""Test class implementation of metric."""
num_outputs = EXTRA_DIM if preds.ndim == 3 else 1
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=PearsonCorrCoef,
reference_metric=_scipy_pearson,
reference_metric=_ref_metric,
metric_args={"num_outputs": num_outputs, "compute_on_cpu": compute_on_cpu},
matsumotosan marked this conversation as resolved.
Show resolved Hide resolved
weights=kwargs_update.get("weights", None),
)

def test_pearson_corrcoef_functional(self, preds, target):
@pytest.mark.parametrize(
"kwargs_update",
[
pytest.param({}, id="None weights"),
pytest.param({"weights": torch.rand(NUM_BATCHES, BATCH_SIZE)}, id="tensor weights"),
],
)
def test_pearson_corrcoef_functional(self, preds, target, kwargs_update):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=preds, target=target, metric_functional=pearson_corrcoef, reference_metric=_scipy_pearson
preds=preds,
target=target,
metric_functional=pearson_corrcoef,
reference_metric=_ref_metric,
weights=kwargs_update.get("weights", None),
)

def test_pearson_corrcoef_differentiability(self, preds, target):
Expand All @@ -100,7 +144,7 @@ def test_pearson_corrcoef_differentiability(self, preds, target):
metric_functional=pearson_corrcoef,
)

def test_pearson_corrcoef_half_cpu(self, preds, target):
def test_pearson_corrcoef_half_cpu(self, preds, target, metric_args):
"""Test dtype support of the metric on CPU."""
num_outputs = EXTRA_DIM if preds.ndim == 3 else 1
self.run_precision_test_cpu(preds, target, partial(PearsonCorrCoef, num_outputs=num_outputs), pearson_corrcoef)
Expand Down