Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for calculating quantiles of weighed samples. #3340

Merged
merged 6 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
import numbers
from typing import List, Tuple, Union

import torch
from torch.fft import irfft, rfft
Expand Down Expand Up @@ -261,6 +262,66 @@ def quantile(input, probs, dim=0):
return quantiles if probs.shape != torch.Size([]) else quantiles.squeeze(dim)


def weighed_quantile(
input: torch.Tensor,
probs: Union[List[float], Tuple[float, ...], torch.Tensor],
log_weights: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""
Computes quantiles of weighed ``input`` samples at ``probs``.

:param torch.Tensor input: the input tensor.
:param list probs: quantile positions.
:param torch.Tensor log_weights: sample weights tensor.
:param int dim: dimension to take quantiles from ``input``.
:returns torch.Tensor: quantiles of ``input`` at ``probs``.

Example:
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
"""
dim = dim if dim >= 0 else (len(input.shape) + dim)
if isinstance(probs, (list, tuple)):
probs = torch.tensor(probs, dtype=input.dtype, device=input.device)
assert isinstance(probs, torch.Tensor)
# Calculate normalized weights
weights = (log_weights - torch.logsumexp(log_weights, 0)).exp()
# Sort input and weights
sorted_input, sorting_indices = input.sort(dim)
weights = weights[sorting_indices].cumsum(dim)
# Scale weights to be between zero and one
weights = weights - weights.min(dim, keepdim=True)[0]
weights = weights / weights.max(dim, keepdim=True)[0]
# Calculate indices
indices_above = (
(weights[..., None] <= probs)
.sum(dim, keepdim=True)
.swapaxes(dim, -1)
.clamp(max=input.size(dim) - 1)[..., 0]
)
indices_below = (indices_above - 1).clamp(min=0)
# Calculate below and above qunatiles
quantiles_below = sorted_input.gather(dim, indices_below)
quantiles_above = sorted_input.gather(dim, indices_above)
# Calculate weights for below and above quantiles
probs_shape = [None] * dim + [slice(None)] + [None] * (len(input.shape) - dim - 1)
expanded_probs_shape = list(input.shape)
expanded_probs_shape[dim] = len(probs)
probs = probs[probs_shape].expand(*expanded_probs_shape)
weights_below = weights.gather(dim, indices_below)
weights_above = weights.gather(dim, indices_above)
weights_below = (weights_above - probs) / (weights_above - weights_below)
weights_above = 1 - weights_below
# Return quantiles
return weights_below * quantiles_below + weights_above * quantiles_above


def pi(input, prob, dim=0):
"""
Computes percentile interval which assigns equal probability mass
Expand Down
20 changes: 20 additions & 0 deletions tests/ops/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
resample,
split_gelman_rubin,
waic,
weighed_quantile,
)
from tests.common import assert_close, assert_equal, xfail_if_not_implemented

Expand Down Expand Up @@ -57,6 +58,25 @@ def test_quantile():
assert_equal(quantile(z, probs=0.8413), torch.tensor(1.0), prec=0.02)


@pytest.mark.init(rng_seed=3)
def test_weighed_quantile():
# Fixed values test
input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
probs = [0.2, 0.8]
log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
result = weighed_quantile(input, probs, log_weights, -1)
assert_equal(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))

# Random values test
dist = torch.distributions.normal.Normal(0, 1)
input = dist.sample((100000,))
probs = [0.1, 0.7, 0.95]
log_weights = dist.log_prob(input)
result = weighed_quantile(input, probs, log_weights)
result_dist = torch.distributions.normal.Normal(0, torch.tensor(0.5).sqrt())
assert_equal(result, result_dist.icdf(torch.Tensor(probs)), prec=0.01)


def test_pi():
x = torch.randn(1000).exp()
assert_equal(pi(x, prob=0.8), quantile(x, probs=[0.1, 0.9]))
Expand Down
Loading