Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OOM in pearson metric #380

Merged
merged 35 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1f5b441
fix
SkafteNicki Jul 16, 2021
aade247
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2021
8969914
changelog
SkafteNicki Jul 16, 2021
74b2e26
typing
Borda Jul 16, 2021
8bf2e54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2021
ca5cdb5
Merge branch 'master' into pearson_update
mergify[bot] Jul 16, 2021
fb0ecdd
Merge branch 'master' into pearson_update
SkafteNicki Jul 17, 2021
94a40a2
doctest
SkafteNicki Jul 17, 2021
b1deed4
Merge branch 'pearson_update' of https://github.com/PyTorchLightning/…
SkafteNicki Jul 17, 2021
75c799c
variable renaming
SkafteNicki Jul 17, 2021
2593d90
Merge branch 'master' into pearson_update
mergify[bot] Jul 19, 2021
bdcf234
Merge branch 'master' into pearson_update
mergify[bot] Jul 24, 2021
174731c
rename
Borda Jul 24, 2021
d61f43c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2021
b7881b0
fmt
Borda Jul 24, 2021
9bf9a3a
.
Borda Jul 24, 2021
453b040
Merge branch 'pearson_update' of https://github.com/PyTorchLightning/…
Borda Jul 24, 2021
452af45
Merge branch 'master' into pearson_update
mergify[bot] Jul 24, 2021
2f91bcb
Merge branch 'master' into pearson_update
mergify[bot] Jul 24, 2021
290ce1d
Merge branch 'master' into pearson_update
mergify[bot] Jul 24, 2021
77ac987
Merge branch 'master' into pearson_update
mergify[bot] Jul 24, 2021
3712f49
Merge branch 'master' into pearson_update
mergify[bot] Jul 26, 2021
e71eef8
CI: update mergify
Borda Jul 26, 2021
ad0bf5d
Merge branch 'master' into pearson_update
mergify[bot] Jul 26, 2021
d7e33d2
Merge branch 'master' into pearson_update
Borda Jul 26, 2021
d6e8fdf
Merge branch 'master' into pearson_update
mergify[bot] Jul 26, 2021
597b362
Merge branch 'master' into pearson_update
mergify[bot] Jul 26, 2021
2c86c4f
Merge branch 'master' into pearson_update
mergify[bot] Jul 26, 2021
013eca8
Merge branch 'master' into pearson_update
mergify[bot] Jul 28, 2021
2d3d694
Merge branch 'master' into pearson_update
mergify[bot] Jul 28, 2021
3b6c474
fix tests
SkafteNicki Jul 28, 2021
ba7ff13
Merge branch 'master' into pearson_update
SkafteNicki Jul 28, 2021
084361b
fix merge
SkafteNicki Jul 28, 2021
0b6aaf9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2021
da43c3a
fix flake8
SkafteNicki Jul 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Moved `R2Score` from `regression.r2score` to `regression.r2` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))


- Pearson metrics now only store 6 statistics instead of all predictions and targets ([#380](https://github.com/PyTorchLightning/metrics/pull/380))


### Deprecated

- Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))
Expand Down
5 changes: 2 additions & 3 deletions tests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@ class TestPearsonCorrcoef(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step):
def test_pearson_corrcoef(self, preds, target, ddp):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=PearsonCorrcoef,
sk_metric=_sk_pearsonr,
dist_sync_on_step=dist_sync_on_step,
dist_sync_on_step=False,
)

def test_pearson_corrcoef_functional(self, preds, target):
Expand Down
81 changes: 47 additions & 34 deletions torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,56 +22,66 @@
def _pearson_corrcoef_update(
preds: Tensor,
target: Tensor,
) -> Tuple[Tensor, Tensor]:
mean_x: Tensor,
mean_y: Tensor,
var_x: Tensor,
var_y: Tensor,
corr_xy: Tensor,
n_prior: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Updates and returns variables required to compute Pearson Correlation Coefficient.
Checks for same shape of input tensors.

Args:
preds: Predicted tensor
target: Ground truth tensor
mean_x: current mean estimate of x tensor
mean_y: current mean estimate of y tensor
var_x: current variance estimate of x tensor
var_y: current variance estimate of y tensor
corr_xy: current covariance estimate between x and y tensor
n_prior: current number of observed observations
"""

# Data checking
_check_same_shape(preds, target)
preds = preds.squeeze()
target = target.squeeze()
if preds.ndim > 1 or target.ndim > 1:
raise ValueError('Expected both predictions and target to be 1 dimensional tensors.')

return preds, target


def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor:
n_obs = preds.numel()
mx_new = (n_prior * mean_x + preds.mean() * n_obs) / (n_prior + n_obs)
my_new = (n_prior * mean_y + target.mean() * n_obs) / (n_prior + n_obs)
n_prior += n_obs
var_x += ((preds - mx_new) * (preds - mean_x)).sum()
var_y += ((target - my_new) * (target - mean_y)).sum()
corr_xy += ((preds - mx_new) * (target - mean_y)).sum()
mean_x = mx_new
mean_y = my_new

return mean_x, mean_y, var_x, var_y, corr_xy, n_prior


def _pearson_corrcoef_compute(
var_x: Tensor,
var_y: Tensor,
corr_xy: Tensor,
nb: Tensor,
) -> Tensor:
"""
Computes Pearson Correlation Coefficient.
Computes the final pearson correlation based on accumulated statistics

Args:
preds: Predicted tensor
target: Ground truth tensor
eps: Avoids ZeroDivisionError. default: 1e-6

Example:
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> preds, target = _pearson_corrcoef_update(preds, target)
>>> _pearson_corrcoef_compute(preds, target)
tensor(0.9849)
"""
var_x: variance estimate of x tensor
var_y: variance estimate of y tensor
corr_xy: covariance estimate between x and y tensor
nb: number of observations

preds_diff = preds - preds.mean()
target_diff = target - target.mean()

cov = (preds_diff * target_diff).mean()
preds_std = torch.sqrt((preds_diff * preds_diff).mean())
target_std = torch.sqrt((target_diff * target_diff).mean())

denom = preds_std * target_std
# prevent division by zero
if denom == 0:
denom += eps

corrcoef = cov / denom
"""
var_x /= (nb - 1)
var_y /= (nb - 1)
corr_xy /= (nb - 1)
corrcoef = (corr_xy / (var_x * var_y).sqrt()).squeeze()
return torch.clamp(corrcoef, -1.0, 1.0)


Expand All @@ -90,5 +100,8 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
>>> pearson_corrcoef(preds, target)
tensor(0.9849)
"""
preds, target = _pearson_corrcoef_update(preds, target)
return _pearson_corrcoef_compute(preds, target)
_temp = torch.zeros(1, dtype=preds.dtype, device=preds.device)
mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone()
var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone()
_, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update(preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb)
return _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb)
75 changes: 59 additions & 16 deletions torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat


def _final_aggregation(
means_x: Tensor,
means_y: Tensor,
vars_x: Tensor,
vars_y: Tensor,
corrs_xy: Tensor,
nbs: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Aggregate the statistics from multiple devices. Formula taken from here:
https://stackoverflow.com/questions/68395368/estimate-running-correlation-on-multiple-nodes
"""
# assert len(means_x) > 1 and len(means_y) > 1 and len(vars_x) > 1 and len(vars_y) > 1 and len(corrs_xy) > 1
mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
for i in range(1, len(means_x)):
mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]

nb = n1 + n2
mean_x = (n1 * mx1 + n2 * mx2) / nb
mean_y = (n1 * my1 + n2 * my2) / nb
var_x = (1 / (n1 + n2 - 1) * ((n1 - 1) * vx1 + (n2 - 1) * vx2 + ((n1 * n2) / (n1 + n2)) * (mx1 - mx2)**2))
var_y = (1 / (n1 + n2 - 1) * ((n1 - 1) * vy1 + (n2 - 1) * vy2 + ((n1 * n2) / (n1 + n2)) * (my1 - my2)**2))

corr1 = n1 * cxy1 + n1 * (mx1 - mean_x) * (my1 - mean_y)
corr2 = n2 * cxy2 + n2 * (mx2 - mean_x) * (my2 - mean_y)
corr_xy = (corr1 + corr2) / (n1 + n2)

mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb

return var_x, var_y, corr_xy, nb


class PearsonCorrcoef(Metric):
Expand Down Expand Up @@ -58,6 +88,12 @@ class PearsonCorrcoef(Metric):
"""
preds: List[Tensor]
target: List[Tensor]
mean_x: Tensor
mean_y: Tensor
var_x: Tensor
var_y: Tensor
corr_xy: Tensor
n_total: Tensor

def __init__(
self,
Expand All @@ -71,13 +107,12 @@ def __init__(
process_group=process_group,
)

rank_zero_warn(
'Metric `PearsonCorrcoef` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
self.add_state("mean_x", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("mean_y", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("var_x", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("var_y", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("corr_xy", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("n_total", default=torch.zeros(1), dist_reduce_fx=None)

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Expand All @@ -87,17 +122,25 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model
target: Ground truth values
"""
preds, target = _pearson_corrcoef_update(preds, target)
self.preds.append(preds)
self.target.append(target)
self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total = _pearson_corrcoef_update(
preds, target, self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total
)

def compute(self) -> Tensor:
"""
Computes pearson correlation coefficient over state.
"""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _pearson_corrcoef_compute(preds, target)
if self.mean_x.numel() > 1: # multiple devices, need further reduction
var_x, var_y, corr_xy, n_total = _final_aggregation(
self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total
)
else:
var_x = self.var_x
var_y = self.var_y
corr_xy = self.corr_xy
n_total = self.n_total

return _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total)

@property
def is_differentiable(self) -> bool:
Expand Down