Skip to content

Commit

Permalink
fix sbc reduce funs, refactor plotting.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Dec 17, 2022
1 parent 214ee4e commit 49a2d4f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
23 changes: 18 additions & 5 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,9 @@ def sbc_rank_plot(
parameter_labels: Optional[List[str]] = None,
ranks_labels: Optional[List[str]] = None,
colors: Optional[List[str]] = None,
fig: Optional[Figure] = None,
ax: Optional[Axes] = None,
figsize: Optional[tuple] = None,
kwargs: Dict = {},
) -> Tuple[Figure, Axes]:
"""Plot simulation-based calibration ranks as empirical CDFs or histograms.
Expand Down Expand Up @@ -1016,6 +1019,9 @@ def sbc_rank_plot(
parameter_labels,
ranks_labels,
colors,
fig=fig,
ax=ax,
figsize=figsize,
**kwargs,
)

Expand All @@ -1032,6 +1038,7 @@ def _sbc_rank_plot(
line_alpha: float = 0.8,
show_uniform_region: bool = True,
uniform_region_alpha: float = 0.3,
uniform_region_color: str = "gray",
xlim_offset_factor: float = 0.1,
num_cols: int = 4,
params_in_subplots: bool = False,
Expand Down Expand Up @@ -1145,15 +1152,18 @@ def _sbc_rank_plot(
num_repeats,
ranks_label=ranks_labels[ii],
color=f"C{ii}" if colors is None else colors[ii],
xlabel=f"posterior rank {parameter_labels[jj]}",
xlabel=f"posterior ranks {parameter_labels[jj]}",
# Show legend and ylabel only in first subplot.
show_ylabel=jj == 0,
show_legend=jj == 0,
alpha=line_alpha,
)
if ii == 0 and show_uniform_region:
_plot_cdf_region_expected_under_uniformity(
num_sbc_runs, num_bins, num_repeats, alpha=0.1
num_sbc_runs,
num_bins,
num_repeats,
alpha=uniform_region_alpha,
)
elif plot_type == "hist":
_plot_ranks_as_hist(
Expand Down Expand Up @@ -1208,7 +1218,10 @@ def _sbc_rank_plot(
)
if show_uniform_region:
_plot_cdf_region_expected_under_uniformity(
num_sbc_runs, num_bins, num_repeats, alpha=uniform_region_alpha
num_sbc_runs,
num_bins,
num_repeats,
alpha=uniform_region_alpha,
)

return fig, ax
Expand Down Expand Up @@ -1329,7 +1342,7 @@ def _plot_cdf_region_expected_under_uniformity(
num_bins: int,
num_repeats: int,
alpha: float = 0.2,
color: str = "grey",
color: str = "gray",
) -> None:
"""Plot region of empirical cdfs expected under uniformity on the current axis."""

Expand Down Expand Up @@ -1358,7 +1371,7 @@ def _plot_hist_region_expected_under_uniformity(
num_bins: int,
num_posterior_samples: int,
alpha: float = 0.2,
color: str = "grey",
color: str = "gray",
) -> None:
"""Plot region of empirical cdfs expected under uniformity."""

Expand Down
8 changes: 6 additions & 2 deletions sbi/analysis/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def sbc_on_batch(
"`reduce_fn` must either be the string `marginals` or a Callable or a List "
"of Callables."
)
reduce_fns = [(lambda theta, x: theta[i]) for i in range(thetas.shape[1])]
reduce_fns = [
eval(f"lambda theta, x: theta[:, {i}]") for i in range(thetas.shape[1])
]

if isinstance(reduce_fns, Callable):
reduce_fns = [reduce_fns]
Expand All @@ -175,7 +177,9 @@ def sbc_on_batch(

# rank for each posterior dimension as in Talts et al. section 4.1.
for i, reduce_fn in enumerate(reduce_fns):
ranks[idx, i] = (reduce_fn(ths, xo) < reduce_fn(tho, xo)).sum().item()
ranks[idx, i] = (
(reduce_fn(ths, xo) < reduce_fn(tho.reshape(1, -1), xo)).sum().item()
)

return ranks, dap_samples

Expand Down

0 comments on commit 49a2d4f

Please sign in to comment.