Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for calculating quantiles of weighed samples. #3340

Merged
merged 6 commits into from
Mar 18, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,61 @@ def quantile(input, probs, dim=0):
return quantiles if probs.shape != torch.Size([]) else quantiles.squeeze(dim)


def weighed_quantile(input, probs, log_weights, dim=0):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
"""
Computes quantiles of weighed ``input`` samples at ``probs``.

:param torch.Tensor input: the input tensor.
:param list probs: quantile positions.
:param torch.Tensor log_weights: sample weights tensor.
:param int dim: dimension to take quantiles from ``input``.
:returns torch.Tensor: quantiles of ``input`` at ``probs``.

Example:
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
"""
dim = dim if dim >= 0 else (len(input.shape) + dim)
if isinstance(probs, (list, tuple)):
probs = torch.tensor(probs, dtype=input.dtype, device=input.device)
# Calculate normalized weights
weights = (log_weights - torch.logsumexp(log_weights, 0)).exp()
# Sort input and weights
sorted_input, sorting_indices = input.sort(dim)
weights = weights[sorting_indices].cumsum(dim)
# Scale weights to be between zero and one
weights = weights - weights.min(dim, keepdim=True)[0]
weights = weights / weights.max(dim, keepdim=True)[0]
# Calculate indices
indices_above = (
(weights[..., None] <= probs)
.sum(dim, keepdim=True)
.swapaxes(dim, -1)
.clamp(max=input.size(dim) - 1)[..., 0]
)
indices_below = (indices_above - 1).clamp(min=0)
# Calculate below and above qunatiles
quantiles_below = sorted_input.gather(dim, indices_below)
quantiles_above = sorted_input.gather(dim, indices_above)
# Calculate weights for below and above quantiles
probs_shape = [None] * len(input.shape)
probs_shape[dim] = slice(None)
expanded_probs_shape = list(input.shape)
expanded_probs_shape[dim] = len(probs)
probs = probs[probs_shape].expand(*expanded_probs_shape)
weights_below = weights.gather(dim, indices_below)
weights_above = weights.gather(dim, indices_above)
weights_below = (weights_above - probs) / (weights_above - weights_below)
weights_above = 1 - weights_below
# Return quantiles
return weights_below * quantiles_below + weights_above * quantiles_above


def pi(input, prob, dim=0):
"""
Computes percentile interval which assigns equal probability mass
Expand Down
Loading