From 379856cb3309a5811375a4f191ee98fbe755e087 Mon Sep 17 00:00:00 2001 From: janfb Date: Fri, 5 Nov 2021 18:12:32 +0100 Subject: [PATCH] adapt plotting to take ax and fig. --- sbi/analysis/plot.py | 12 ++++++++++++ sbi/utils/plot.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sbi/analysis/plot.py b/sbi/analysis/plot.py index 2616ad5fa..d4b366564 100644 --- a/sbi/analysis/plot.py +++ b/sbi/analysis/plot.py @@ -32,6 +32,8 @@ def pairplot( labels: Optional[List[str]] = None, ticks: Union[List, torch.Tensor] = [], points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"], + fig=None, + axes=None, **kwargs ): """ @@ -55,6 +57,8 @@ def pairplot( labels: List of strings specifying the names of the parameters. ticks: Position of the ticks. points_colors: Colors of the `points`. + fig: matplotlib figure to plot on. + axes: matplotlib axes corresponding to fig. **kwargs: Additional arguments to adjust the plot, see the source code in `_get_default_opts()` in `sbi.utils.plot` for more details. @@ -72,6 +76,8 @@ def pairplot( ticks=ticks, points_colors=points_colors, warn_about_deprecation=False, + fig=fig, + axes=axes, **kwargs, ) @@ -89,6 +95,8 @@ def conditional_pairplot( labels: Optional[List[str]] = None, ticks: Union[List, torch.Tensor] = [], points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"], + fig=None, + axes=None, **kwargs ): r""" @@ -121,6 +129,8 @@ def conditional_pairplot( labels: List of strings specifying the names of the parameters. ticks: Position of the ticks. points_colors: Colors of the `points`. + fig: matplotlib figure to plot on. + axes: matplotlib axes corresponding to fig. **kwargs: Additional arguments to adjust the plot, see the source code in `_get_default_opts()` in `sbi.utils.plot` for more details. @@ -138,5 +148,7 @@ def conditional_pairplot( ticks=ticks, points_colors=points_colors, warn_about_deprecation=False, + fig=fig, + axes=axes, **kwargs, ) diff --git a/sbi/utils/plot.py b/sbi/utils/plot.py index 6f56ef680..80d0c51cd 100644 --- a/sbi/utils/plot.py +++ b/sbi/utils/plot.py @@ -152,6 +152,8 @@ def pairplot( ticks: Union[List, torch.Tensor] = [], points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"], warn_about_deprecation: bool = True, + fig=None, + axes=None, **kwargs ): """ @@ -178,6 +180,8 @@ def pairplot( warn_about_deprecation: With sbi v0.15.0, we depracated the import of this function from `sbi.utils`. Instead, it should be imported from `sbi.analysis`. + fig: matplotlib figure to plot on. + axes: matplotlib axes corresponding to fig. **kwargs: Additional arguments to adjust the plot, see the source code in `_get_default_opts()` in `sbi.utils.plot` for more details. @@ -378,7 +382,9 @@ def upper_func(row, col, limits, **kwargs): else: pass - return _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts) + return _pairplot_scaffold( + diag_func, upper_func, dim, limits, points, opts, fig=fig, axes=axes + ) def conditional_pairplot( @@ -395,6 +401,8 @@ def conditional_pairplot( ticks: Union[List, torch.Tensor] = [], points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"], warn_about_deprecation: bool = True, + fig=None, + axes=None, **kwargs ): r""" @@ -430,6 +438,8 @@ def conditional_pairplot( warn_about_deprecation: With sbi v0.15.0, we depracated the import of this function from `sbi.utils`. Instead, it should be imported from `sbi.analysis`. + fig: matplotlib figure to plot on. + axes: matplotlib axes corresponding to fig. **kwargs: Additional arguments to adjust the plot, see the source code in `_get_default_opts()` in `sbi.utils.plot` for more details. @@ -524,10 +534,14 @@ def upper_func(row, col, **kwargs): aspect="auto", ) - return _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts) + return _pairplot_scaffold( + diag_func, upper_func, dim, limits, points, opts, fig=fig, axes=axes + ) -def _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts): +def _pairplot_scaffold( + diag_func, upper_func, dim, limits, points, opts, fig=None, axes=None +): """ Builds the scaffold for any function that plots parameters in a pairplot setting. @@ -544,6 +558,8 @@ def _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts): opts: Dictionary built by the functions that call `pairplot_scaffold`. Must contain at least `labels`, `subset`, `figsize`, `subplots`, `fig_subplots_adjust`, `title`, `title_format`, .. + fig: matplotlib figure to plot on. + axes: matplotlib axes corresponding to fig. Returns: figure and axis """ @@ -588,7 +604,13 @@ def _pairplot_scaffold(diag_func, upper_func, dim, limits, points, opts): raise NotImplementedError rows = cols = len(subset) - fig, axes = plt.subplots(rows, cols, figsize=opts["figsize"], **opts["subplots"]) + # Create fig and axes if they were not passed. + if fig is None or axes is None: + fig, axes = plt.subplots( + rows, cols, figsize=opts["figsize"], **opts["subplots"] + ) + else: + assert axes.shape == (rows, cols), f"Passed axes must match subplot shape: {rows, cols}." # Cast to ndarray in case of 1D subplots. axes = np.array(axes).reshape(rows, cols)