Skip to content

Commit

Permalink
fix #1039: legend handling at top level, add uniform region.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 6, 2024
1 parent 724ea4f commit ba6c1e6
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,7 @@ def _sbc_rank_plot(
show_ylabel: bool = False,
sharey: bool = False,
fig: Optional[FigureBase] = None,
legend_kwargs: Optional[Dict] = None,
ax=None, # no type hint to avoid hassle with pyright. Should be `array(Axes).`
figsize: Optional[tuple] = None,
) -> Tuple[Figure, Axes]:
Expand Down Expand Up @@ -1569,6 +1570,9 @@ def _sbc_rank_plot(
plot_type in plot_types
), "plot type {plot_type} not implemented, use one in {plot_types}."

if legend_kwargs is None:
legend_kwargs = dict(loc=1, handlelength=0.8)

num_sbc_runs, num_parameters = ranks_list[0].shape
num_ranks = len(ranks_list)

Expand Down Expand Up @@ -1627,7 +1631,6 @@ def _sbc_rank_plot(
xlabel=f"posterior ranks {parameter_labels[jj]}",
# Show legend and ylabel only in first subplot.
show_ylabel=jj == 0,
show_legend=jj == 0,
alpha=line_alpha,
)
if ii == 0 and show_uniform_region:
Expand All @@ -1647,7 +1650,6 @@ def _sbc_rank_plot(
xlabel=f"posterior rank {parameter_labels[jj]}",
# Show legend and ylabel only in first subplot.
show_ylabel=show_ylabel,
show_legend=jj == 0,
alpha=line_alpha,
xlim_offset_factor=xlim_offset_factor,
)
Expand All @@ -1658,6 +1660,10 @@ def _sbc_rank_plot(
num_posterior_samples,
alpha=uniform_region_alpha,
)
# show legend only in first subplot.
if jj == 0 and ranks_labels[ii] is not None:
plt.legend(**legend_kwargs)

else:
raise ValueError(
f"plot_type {plot_type} not defined, use one in {plot_types}"
Expand Down Expand Up @@ -1685,7 +1691,6 @@ def _sbc_rank_plot(
xlabel="posterior rank",
# Plot ylabel and legend at last.
show_ylabel=jj == (num_parameters - 1),
show_legend=jj == (num_parameters - 1),
alpha=line_alpha,
)
if show_uniform_region:
Expand All @@ -1695,6 +1700,8 @@ def _sbc_rank_plot(
num_repeats,
alpha=uniform_region_alpha,
)
# show legend on the last subplot.
plt.legend(**legend_kwargs)

return fig, ax # pyright: ignore[reportReturnType]

Expand All @@ -1708,10 +1715,8 @@ def _plot_ranks_as_hist(
color: str = "firebrick",
alpha: float = 0.8,
show_ylabel: bool = False,
show_legend: bool = False,
num_ticks: int = 3,
xlim_offset_factor: float = 0.1,
legend_kwargs: Optional[Dict] = None,
) -> None:
"""Plot ranks as histograms on the current axis.
Expand Down Expand Up @@ -1743,8 +1748,6 @@ def _plot_ranks_as_hist(
plt.ylabel("counts")
else:
plt.yticks([])
if show_legend and ranks_label:
plt.legend(loc=1, handlelength=0.8, **legend_kwargs or {})

plt.xlim(-xlim_offset, num_posterior_samples + xlim_offset)
plt.xticks(np.linspace(0, num_posterior_samples, num_ticks))
Expand All @@ -1760,9 +1763,7 @@ def _plot_ranks_as_cdf(
color: Optional[str] = None,
alpha: float = 0.8,
show_ylabel: bool = True,
show_legend: bool = False,
num_ticks: int = 3,
legend_kwargs: Optional[Dict] = None,
) -> None:
"""Plot ranks as empirical CDFs on the current axis.
Expand Down Expand Up @@ -1800,8 +1801,6 @@ def _plot_ranks_as_cdf(
else:
# Plot ticks only
plt.yticks(np.linspace(0, 1, 3), [])
if show_legend and ranks_label:
plt.legend(loc=2, handlelength=0.8, **legend_kwargs or {})

plt.ylim(0, 1)
plt.xlim(0, num_bins)
Expand Down Expand Up @@ -1835,6 +1834,7 @@ def _plot_cdf_region_expected_under_uniformity(
y2=np.repeat(upper / np.max(upper), num_repeats), # pyright: ignore[reportArgumentType]
color=color,
alpha=alpha,
label="expected under uniformity",
)


Expand All @@ -1857,6 +1857,7 @@ def _plot_hist_region_expected_under_uniformity(
y2=np.repeat(upper, num_bins), # pyright: ignore[reportArgumentType]
color=color,
alpha=alpha,
label="expected under uniformity",
)


Expand Down Expand Up @@ -1996,7 +1997,7 @@ def marginal_plot_with_probs_intensity(
weights = np.array([np.mean(w) for w in weights])
# remove empty bins
id = list(set(range(n_bins)) - set(probs_per_marginal["bins"]))
patches = np.delete(patches, id)
patches = np.delete(np.array(patches), id)
bins = np.delete(bins, id)

# normalize color intensity
Expand Down

0 comments on commit ba6c1e6

Please sign in to comment.