Skip to content

Commit

Permalink
test: add sbc and tarp plotting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 7, 2024
1 parent 3a999de commit 2892c5a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,7 @@ def pp_plot_lc2st(
)


def plot_tarp(ecp: Tensor, alpha: Tensor, title="") -> Tuple[Figure, Axes]:
def plot_tarp(ecp: Tensor, alpha: Tensor, title: Optional[str]) -> Tuple[Figure, Axes]:
"""
Plots the expected coverage probability (ECP) against the credibility
level,alpha, for a given alpha grid.
Expand All @@ -2117,6 +2117,8 @@ def plot_tarp(ecp: Tensor, alpha: Tensor, title="") -> Tuple[Figure, Axes]:

fig = plt.figure(figsize=(6, 6))
ax: Axes = plt.gca()
if title is None:
title = ""

ax.plot(alpha, ecp, color="blue", label="TARP")
ax.plot(alpha, alpha, color="black", linestyle="--", label="ideal")
Expand Down
32 changes: 32 additions & 0 deletions tests/sbc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

from __future__ import annotations

from typing import Union

import pytest
import torch
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal, Uniform

from sbi.analysis import sbc_rank_plot
from sbi.diagnostics import check_sbc, get_nltp, run_sbc
from sbi.inference import SNLE, SNPE, simulate_for_sbi
from sbi.simulators import linear_gaussian
Expand Down Expand Up @@ -208,3 +211,32 @@ def test_sbc_checks():
assert (checks["ks_pvals"] > 0.05).all()
assert (checks["c2st_ranks"] < 0.55).all()
assert (checks["c2st_dap"] < 0.55).all()


# add test for sbc plotting
@pytest.mark.parametrize("num_bins", (None, 30))
@pytest.mark.parametrize("plot_type", ("cdf", "hist"))
@pytest.mark.parametrize("legend_kwargs", (None, {"loc": "upper left"}))
@pytest.mark.parametrize("num_rank_sets", (1, 2))
def test_sbc_plotting(
num_bins: int, plot_type: str, legend_kwargs: Union[None, dict], num_rank_sets: int
):
"""Test the uniformity checks for SBC."""

num_dim = 2
num_posterior_samples = 1000

# Ranks should be distributed uniformly in [0, num_posterior_samples]
ranks = [
torch.distributions.Uniform(
zeros(num_dim), num_posterior_samples * ones(num_dim)
).sample((num_posterior_samples,))
] * num_rank_sets

sbc_rank_plot(
ranks,
num_posterior_samples,
num_bins=num_bins,
plot_type=plot_type,
legend_kwargs=legend_kwargs,
)
12 changes: 12 additions & 0 deletions tests/tarp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.distributions import Normal, Uniform
from torch.nn import L1Loss

from sbi.analysis.plot import plot_tarp
from sbi.diagnostics.tarp import _run_tarp, check_tarp, get_tarp_references, run_tarp
from sbi.inference import SNPE
from sbi.simulators import linear_gaussian
Expand Down Expand Up @@ -286,3 +287,14 @@ def simulator(theta):
atc, kspvals = check_tarp(ecp, alpha)
assert -0.5 < atc < 0.5
assert kspvals > 0.05


# Test tarp plotting
@pytest.mark.parametrize("title", ["Correct", None])
def test_tarp_plotting(title: str, accurate_samples):
theta, samples = accurate_samples
references = get_tarp_references(theta)

ecp, alpha = _run_tarp(samples, theta, references)

plot_tarp(ecp, alpha, title=title)

0 comments on commit 2892c5a

Please sign in to comment.