diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index a582082671..2ec57d4784 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -3,6 +3,7 @@ import math import numbers +from typing import List, Tuple, Union import torch from torch.fft import irfft, rfft @@ -261,6 +262,66 @@ def quantile(input, probs, dim=0): return quantiles if probs.shape != torch.Size([]) else quantiles.squeeze(dim) +def weighed_quantile( + input: torch.Tensor, + probs: Union[List[float], Tuple[float, ...], torch.Tensor], + log_weights: torch.Tensor, + dim: int = 0, +) -> torch.Tensor: + """ + Computes quantiles of weighed ``input`` samples at ``probs``. + + :param torch.Tensor input: the input tensor. + :param list probs: quantile positions. + :param torch.Tensor log_weights: sample weights tensor. + :param int dim: dimension to take quantiles from ``input``. + :returns torch.Tensor: quantiles of ``input`` at ``probs``. + + Example: + >>> from pyro.ops.stats import weighed_quantile + >>> import torch + >>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]]) + >>> probs = torch.Tensor([0.2, 0.8]) + >>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log() + >>> result = weighed_quantile(input, probs, log_weights, -1) + >>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]])) + """ + dim = dim if dim >= 0 else (len(input.shape) + dim) + if isinstance(probs, (list, tuple)): + probs = torch.tensor(probs, dtype=input.dtype, device=input.device) + assert isinstance(probs, torch.Tensor) + # Calculate normalized weights + weights = (log_weights - torch.logsumexp(log_weights, 0)).exp() + # Sort input and weights + sorted_input, sorting_indices = input.sort(dim) + weights = weights[sorting_indices].cumsum(dim) + # Scale weights to be between zero and one + weights = weights - weights.min(dim, keepdim=True)[0] + weights = weights / weights.max(dim, keepdim=True)[0] + # Calculate indices + indices_above = ( + (weights[..., None] <= probs) + .sum(dim, keepdim=True) + .swapaxes(dim, -1) + .clamp(max=input.size(dim) - 1)[..., 0] + ) + indices_below = (indices_above - 1).clamp(min=0) + # Calculate below and above qunatiles + quantiles_below = sorted_input.gather(dim, indices_below) + quantiles_above = sorted_input.gather(dim, indices_above) + # Calculate weights for below and above quantiles + probs_shape = [None] * dim + [slice(None)] + [None] * (len(input.shape) - dim - 1) + expanded_probs_shape = list(input.shape) + expanded_probs_shape[dim] = len(probs) + probs = probs[probs_shape].expand(*expanded_probs_shape) + weights_below = weights.gather(dim, indices_below) + weights_above = weights.gather(dim, indices_above) + weights_below = (weights_above - probs) / (weights_above - weights_below) + weights_above = 1 - weights_below + # Return quantiles + return weights_below * quantiles_below + weights_above * quantiles_above + + def pi(input, prob, dim=0): """ Computes percentile interval which assigns equal probability mass diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index f77b464900..4346be5feb 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -20,6 +20,7 @@ resample, split_gelman_rubin, waic, + weighed_quantile, ) from tests.common import assert_close, assert_equal, xfail_if_not_implemented @@ -57,6 +58,25 @@ def test_quantile(): assert_equal(quantile(z, probs=0.8413), torch.tensor(1.0), prec=0.02) +@pytest.mark.init(rng_seed=3) +def test_weighed_quantile(): + # Fixed values test + input = torch.Tensor([[10, 50, 40], [20, 30, 0]]) + probs = [0.2, 0.8] + log_weights = torch.Tensor([0.4, 0.5, 0.1]).log() + result = weighed_quantile(input, probs, log_weights, -1) + assert_equal(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]])) + + # Random values test + dist = torch.distributions.normal.Normal(0, 1) + input = dist.sample((100000,)) + probs = [0.1, 0.7, 0.95] + log_weights = dist.log_prob(input) + result = weighed_quantile(input, probs, log_weights) + result_dist = torch.distributions.normal.Normal(0, torch.tensor(0.5).sqrt()) + assert_equal(result, result_dist.icdf(torch.Tensor(probs)), prec=0.01) + + def test_pi(): x = torch.randn(1000).exp() assert_equal(pi(x, prob=0.8), quantile(x, probs=[0.1, 0.9]))