Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sbc reduce funs, refactor plotting. #793

Merged
merged 2 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,9 @@ def sbc_rank_plot(
parameter_labels: Optional[List[str]] = None,
ranks_labels: Optional[List[str]] = None,
colors: Optional[List[str]] = None,
fig: Optional[Figure] = None,
ax: Optional[Axes] = None,
figsize: Optional[tuple] = None,
kwargs: Dict = {},
) -> Tuple[Figure, Axes]:
"""Plot simulation-based calibration ranks as empirical CDFs or histograms.
Expand Down Expand Up @@ -1016,6 +1019,9 @@ def sbc_rank_plot(
parameter_labels,
ranks_labels,
colors,
fig=fig,
ax=ax,
figsize=figsize,
**kwargs,
)

Expand All @@ -1032,6 +1038,7 @@ def _sbc_rank_plot(
line_alpha: float = 0.8,
show_uniform_region: bool = True,
uniform_region_alpha: float = 0.3,
uniform_region_color: str = "gray",
xlim_offset_factor: float = 0.1,
num_cols: int = 4,
params_in_subplots: bool = False,
Expand Down Expand Up @@ -1145,15 +1152,18 @@ def _sbc_rank_plot(
num_repeats,
ranks_label=ranks_labels[ii],
color=f"C{ii}" if colors is None else colors[ii],
xlabel=f"posterior rank {parameter_labels[jj]}",
xlabel=f"posterior ranks {parameter_labels[jj]}",
# Show legend and ylabel only in first subplot.
show_ylabel=jj == 0,
show_legend=jj == 0,
alpha=line_alpha,
)
if ii == 0 and show_uniform_region:
_plot_cdf_region_expected_under_uniformity(
num_sbc_runs, num_bins, num_repeats, alpha=0.1
num_sbc_runs,
num_bins,
num_repeats,
alpha=uniform_region_alpha,
)
elif plot_type == "hist":
_plot_ranks_as_hist(
Expand Down Expand Up @@ -1208,7 +1218,10 @@ def _sbc_rank_plot(
)
if show_uniform_region:
_plot_cdf_region_expected_under_uniformity(
num_sbc_runs, num_bins, num_repeats, alpha=uniform_region_alpha
num_sbc_runs,
num_bins,
num_repeats,
alpha=uniform_region_alpha,
)

return fig, ax
Expand Down Expand Up @@ -1329,7 +1342,7 @@ def _plot_cdf_region_expected_under_uniformity(
num_bins: int,
num_repeats: int,
alpha: float = 0.2,
color: str = "grey",
color: str = "gray",
) -> None:
"""Plot region of empirical cdfs expected under uniformity on the current axis."""

Expand Down Expand Up @@ -1358,7 +1371,7 @@ def _plot_hist_region_expected_under_uniformity(
num_bins: int,
num_posterior_samples: int,
alpha: float = 0.2,
color: str = "grey",
color: str = "gray",
) -> None:
"""Plot region of empirical cdfs expected under uniformity."""

Expand Down
8 changes: 6 additions & 2 deletions sbi/analysis/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def sbc_on_batch(
"`reduce_fn` must either be the string `marginals` or a Callable or a List "
"of Callables."
)
reduce_fns = [(lambda theta, x: theta[i]) for i in range(thetas.shape[1])]
reduce_fns = [
eval(f"lambda theta, x: theta[:, {i}]") for i in range(thetas.shape[1])
]

if isinstance(reduce_fns, Callable):
reduce_fns = [reduce_fns]
Expand All @@ -175,7 +177,9 @@ def sbc_on_batch(

# rank for each posterior dimension as in Talts et al. section 4.1.
for i, reduce_fn in enumerate(reduce_fns):
ranks[idx, i] = (reduce_fn(ths, xo) < reduce_fn(tho, xo)).sum().item()
ranks[idx, i] = (
(reduce_fn(ths, xo) < reduce_fn(tho.unsqueeze(0), xo)).sum().item()
)

return ranks, dap_samples

Expand Down
4 changes: 3 additions & 1 deletion tests/plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def test_sbc_rank_plot(num_parameters, num_cols, custom_figure, plot_type):
ranks,
num_posterior_samples,
plot_type=plot_type,
kwargs=dict(fig=fig, ax=ax, num_cols=num_cols, params_in_subplots=True),
fig=fig,
ax=ax,
kwargs=dict(num_cols=num_cols, params_in_subplots=True),
)
if not custom_figure:
if num_parameters > num_cols:
Expand Down
33 changes: 33 additions & 0 deletions tests/sbc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sbi.inference import SNLE, SNPE, simulate_for_sbi
from sbi.simulators import linear_gaussian
from sbi.utils import BoxUniform, MultipleIndependent
from tests.test_utils import PosteriorPotential, TractablePosterior


@pytest.mark.parametrize("reduce_fn_str", ("marginals", "posterior_log_prob"))
Expand Down Expand Up @@ -64,6 +65,38 @@ def simulator(theta):
get_nltp(thetas, xs, posterior)


def test_sbc_accuracy():

num_dim = 2
# Gaussian toy problem, set posterior = prior
simulator = lambda theta: torch.randn_like(theta) + theta
prior = BoxUniform(-ones(num_dim), ones(num_dim))
posterior_dist = prior

potential = PosteriorPotential(posterior=posterior_dist, prior=prior)

posterior = TractablePosterior(potential_fn=potential)

N = L = 1000
thetas = prior.sample((N,))
xs = simulator(thetas)

ranks, daps = run_sbc(
thetas,
xs,
posterior,
num_workers=1,
num_posterior_samples=L,
reduce_fns="marginals",
)

pvals, c2st_ranks, _ = check_sbc(
ranks, prior.sample((N,)), daps, num_posterior_samples=L
).values()
assert (c2st_ranks <= 0.6).all(), "posterior ranks must be close to uniform."
assert (pvals > 0.05).all(), "posterior ranks uniformity test p-values too small."


@pytest.mark.slow
def test_sbc_checks():
"""Test the uniformity checks for SBC."""
Expand Down
143 changes: 142 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@

from __future__ import annotations

from typing import Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.simulators.linear_gaussian import true_posterior_linear_gaussian_mvn_prior
from sbi.types import Shape, TorchTransform
from sbi.utils import BoxUniform, within_support
from sbi.utils.metrics import c2st
from sbi.utils.torchutils import ensure_theta_batched


def kl_d_via_monte_carlo(
Expand Down Expand Up @@ -149,3 +152,141 @@ def check_c2st(x: Tensor, y: Tensor, alg: str, tol: float = 0.1) -> None:
assert (
(0.5 - tol) <= score <= (0.5 + tol)
), f"{alg}'s c2st={score:.2f} is too far from the desired near-chance performance."


class PosteriorPotential(BasePotential):
allow_iid_x = False # type: ignore

def __init__(
self,
posterior: Distribution,
prior: Distribution,
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
r"""Returns the potential for a closed-form posterior.

The potential is the same as the log-probability of the posterior,
but it is set to $-\inf$ outside of the prior bounds.

Args:
posterior: The posterior distribution
prior: The prior distribution.
x_o: The observed data at which to evaluate the posterior.

Returns:
The potential function.
"""
super().__init__(prior, x_o, device)

assert (
x_o is None
), "No need to pass x_o, passed Posterior must be fixed to x_o."
self.posterior = posterior

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
r"""Returns the potential for posterior-based methods.

Args:
theta: The parameter set at which to evaluate the potential function.
track_gradients: Whether to track the gradients.

Returns:
The potential.
"""

theta = ensure_theta_batched(torch.as_tensor(theta))

with torch.set_grad_enabled(track_gradients):
posterior_log_prob = self.posterior.log_prob(theta)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)

posterior_log_prob = torch.where(
in_prior_support,
posterior_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32, device=self.device),
)
return posterior_log_prob


class TractablePosterior(NeuralPosterior):
r"""Posterior $p(\theta|x_o)$ with `log_prob()` and `sample()` methods, built from a
potential function with tractable posterior distribution.<br/><br/>"""

def __init__(
self,
potential_fn: Callable,
theta_transform: Optional[TorchTransform] = None,
device: Optional[str] = "cpu",
x_shape: Optional[torch.Size] = None,
):
"""
Args:
potential_fn: The potential function from which to draw samples.
theta_transform: Transformation that will be applied during sampling.
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
"""
assert isinstance(potential_fn, PosteriorPotential)
super().__init__(potential_fn, theta_transform, device, x_shape)

def sample(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
show_progress_bars: bool = True,
mcmc_method: Optional[str] = None,
mcmc_parameters: Optional[Dict[str, Any]] = None,
) -> Tensor:
"""See child classes for docstring."""

return self.potential_fn.posterior.sample(sample_shape)

def log_prob(
self,
theta: Tensor,
) -> Tensor:
r"""Returns the log-probability of the posterior $p(\theta|x)$.

Args:
theta: Parameters $\theta$.

Returns:
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
support of the prior, -∞ (corresponding to 0 probability) outside.
"""
theta = ensure_theta_batched(torch.as_tensor(theta))
return self.potential_fn.posterior.log_prob(theta)

def map(
self,
x: Optional[Tensor] = None,
num_iter: int = 1_000,
num_to_optimize: int = 100,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1_000,
save_best_every: int = 10,
show_progress_bars: bool = False,
force_update: bool = False,
) -> Tensor:
"""Returns stored maximum-a-posterior estimate (MAP), otherwise calculates it.

See child classes for docstring.
"""

return super().map(
x=x,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
force_update=force_update,
)