Skip to content

Commit

Permalink
adapt plotting to take ax and fig.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Dec 7, 2021
1 parent 2a10717 commit e24c377
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
4 changes: 4 additions & 0 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def pairplot(
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = [],
points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"],
fig=None,
axes=None,
**kwargs
):
"""
Expand Down Expand Up @@ -72,6 +74,8 @@ def pairplot(
ticks=ticks,
points_colors=points_colors,
warn_about_deprecation=False,
fig=fig,
axes=axes,
**kwargs,
)

Expand Down
30 changes: 26 additions & 4 deletions sbi/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def pairplot(
ticks: Union[List, torch.Tensor] = [],
points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"],
warn_about_deprecation: bool = True,
fig=None,
axes=None,
**kwargs
):
"""
Expand All @@ -178,6 +180,8 @@ def pairplot(
warn_about_deprecation: With sbi v0.15.0, we depracated the import of this
function from `sbi.utils`. Instead, it should be imported from
`sbi.analysis`.
fig: matplotlib figure to plot on.
axes: matplotlib axes corresponding to fig.
**kwargs: Additional arguments to adjust the plot, see the source code in
`_get_default_opts()` in `sbi.utils.plot` for more details.
Expand Down Expand Up @@ -378,7 +382,9 @@ def upper_func(row, col, limits, **kwargs):
else:
pass

return _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts)
return _pairplot_scaffold(
diag_func, upper_func, dim, limits, points, opts, fig=fig, axes=axes
)


def conditional_pairplot(
Expand All @@ -395,6 +401,8 @@ def conditional_pairplot(
ticks: Union[List, torch.Tensor] = [],
points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"],
warn_about_deprecation: bool = True,
fig=None,
axes=None,
**kwargs
):
r"""
Expand Down Expand Up @@ -430,6 +438,8 @@ def conditional_pairplot(
warn_about_deprecation: With sbi v0.15.0, we depracated the import of this
function from `sbi.utils`. Instead, it should be imported from
`sbi.analysis`.
fig: matplotlib figure to plot on.
axes: matplotlib axes corresponding to fig.
**kwargs: Additional arguments to adjust the plot, see the source code in
`_get_default_opts()` in `sbi.utils.plot` for more details.
Expand Down Expand Up @@ -524,10 +534,14 @@ def upper_func(row, col, **kwargs):
aspect="auto",
)

return _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts)
return _pairplot_scaffold(
diag_func, upper_func, dim, limits, points, opts, fig=fig, axes=axes
)


def _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts):
def _pairplot_scaffold(
diag_func, upper_func, dim, limits, points, opts, fig=None, axes=None
):
"""
Builds the scaffold for any function that plots parameters in a pairplot setting.
Expand All @@ -544,6 +558,8 @@ def _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts):
opts: Dictionary built by the functions that call `pairplot_scaffold`. Must
contain at least `labels`, `subset`, `figsize`, `subplots`,
`fig_subplots_adjust`, `title`, `title_format`, ..
fig: matplotlib figure to plot on.
axes: matplotlib axes corresponding to fig.
Returns: figure and axis
"""
Expand Down Expand Up @@ -588,7 +604,13 @@ def _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts):
raise NotImplementedError
rows = cols = len(subset)

fig, axes = plt.subplots(rows, cols, figsize=opts["figsize"], **opts["subplots"])
# Create fig and axes if they were not passed.
if fig is None or axes is None:
fig, axes = plt.subplots(
rows, cols, figsize=opts["figsize"], **opts["subplots"]
)
else:
assert axes.shape == (rows, cols), f"Passed axes must match subplot shape: {rows, cols}."
# Cast to ndarray in case of 1D subplots.
axes = np.array(axes).reshape(rows, cols)

Expand Down

0 comments on commit e24c377

Please sign in to comment.