From 8ad1e35322f5d873cb88963eecad207409451453 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 24 May 2020 02:29:11 +0200 Subject: [PATCH 1/6] fixes to matplotlib subplot handling --- arviz/plots/backends/matplotlib/pairplot.py | 53 +++++++++++-------- arviz/plots/traceplot.py | 2 +- arviz/tests/base_tests/test_plots_bokeh.py | 2 +- .../tests/base_tests/test_plots_matplotlib.py | 12 ++++- 4 files changed, 43 insertions(+), 26 deletions(-) diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 4ce9f24cec..ec0513d7f3 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -199,49 +199,58 @@ def plot_pair( ax.tick_params(labelsize=xt_labelsize) else: + not_marginals = int(not marginals) + num_subplot_cols = numvars - not_marginals max_plots = ( - numvars ** 2 if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"] + num_subplot_cols ** 2 + if rcParams["plot.max_subplots"] is None + else rcParams["plot.max_subplots"] ) - vars_to_plot = np.sum(np.arange(numvars).cumsum() < max_plots) - if vars_to_plot < numvars: + cols_to_plot = np.sum(np.arange(num_subplot_cols).cumsum() <= max_plots) + print(marginals) + print(cols_to_plot) + if cols_to_plot < num_subplot_cols: warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " "of resulting pair plots with these variables, generating only a " - "{side}x{side} grid".format(max_plots=max_plots, side=vars_to_plot), + "{side}x{side} grid".format(max_plots=max_plots, side=cols_to_plot), UserWarning, ) - numvars = vars_to_plot + numvars = cols_to_plot - 1 + vars_to_plot = numvars - not_marginals (figsize, ax_labelsize, _, xt_labelsize, _, markersize) = _scale_fig_size( - figsize, textsize, numvars - 2, numvars - 2 + figsize, textsize, vars_to_plot, vars_to_plot ) point_estimate_marker_kwargs.setdefault("s", markersize + 50) if ax is None: - fig, ax = plt.subplots(numvars, numvars, figsize=figsize, **backend_kwargs) + fig, ax = plt.subplots( + vars_to_plot, + vars_to_plot, + figsize=figsize, + **backend_kwargs, + ) hexbin_values = [] - for i in range(0, numvars): + for i in range(0, vars_to_plot): var1 = infdata_group[i] - for j in range(0, numvars): - var2 = infdata_group[j] + for j in range(0, vars_to_plot): + var2 = infdata_group[j+not_marginals] if i > j: if ax[j, i].get_figure() is not None: ax[j, i].remove() continue - elif i == j: - if marginals: - loc = "right" - plot_dist(var1, ax=ax[i, j], **marginal_kwargs) - else: - loc = "left" - if ax[j, i].get_figure() is not None: - ax[j, i].remove() - continue + elif i == j and marginals: + loc = "right" + plot_dist(var1, ax=ax[i, j], **marginal_kwargs) else: + if i == j: + loc = "left" + if "scatter" in kind: ax[j, i].plot(var1, var2, **scatter_kwargs) @@ -285,7 +294,7 @@ def plot_pair( if reference_values: x_name = flat_var_names[i] - y_name = flat_var_names[j] + y_name = flat_var_names[j+not_marginals] if x_name and y_name not in difference: ax[j, i].plot( reference_values_copy[x_name], @@ -293,7 +302,7 @@ def plot_pair( **reference_values_kwargs, ) - if j != numvars - 1: + if j != vars_to_plot - 1: ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter()) else: ax[j, i].set_xlabel( @@ -303,7 +312,7 @@ def plot_pair( ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter()) else: ax[j, i].set_ylabel( - "{}".format(flat_var_names[j]), fontsize=ax_labelsize, wrap=True + "{}".format(flat_var_names[j+not_marginals]), fontsize=ax_labelsize, wrap=True ) ax[j, i].tick_params(labelsize=xt_labelsize) diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index bf19b416cb..9a8a66b479 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -230,7 +230,7 @@ def plot_trace( plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) max_plots = rcParams["plot.max_subplots"] - max_plots = len(plotters) if max_plots is None else max_plots + max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1) if len(plotters) > max_plots: warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index cb98bf67d1..7eb96ebe1c 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -149,7 +149,7 @@ def test_plot_trace_discrete(discrete_model): def test_plot_trace_max_subplots_warning(models): with pytest.warns(UserWarning): - with rc_context(rc={"plot.max_subplots": 1}): + with rc_context(rc={"plot.max_subplots": 2}): axes = plot_trace(models.model_1, backend="bokeh", show=False) assert axes.shape diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index c678064499..8db8de2cf7 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -188,9 +188,9 @@ def test_plot_trace_discrete(discrete_model): def test_plot_trace_max_subplots_warning(models): with pytest.warns(UserWarning): - with rc_context(rc={"plot.max_subplots": 1}): + with rc_context(rc={"plot.max_subplots": 6}): axes = plot_trace(models.model_1) - assert axes.shape + assert axes.shape == (3, 2) @pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}]) @@ -477,6 +477,14 @@ def test_plot_pair_overlaid(models, kwargs): assert ax is ax2 assert ax.shape +@pytest.mark.parametrize("marginals", [True, False]) +def test_plot_pair_shapes(marginals): + rng = np.random.default_rng() + idata = from_dict({"a": rng.standard_normal((4, 500, 3))}) + ax = plot_pair(idata, marginals=marginals) + side = 2 + marginals + assert ax.shape == (side, side) + @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"]) @pytest.mark.parametrize("alpha", [None, 0.2, 1]) From cccfa1732b4e5f4e084a9e1f29b61b88a36129ae Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 24 May 2020 02:45:41 +0200 Subject: [PATCH 2/6] lint --- arviz/plots/backends/matplotlib/pairplot.py | 15 ++++++--------- arviz/tests/base_tests/test_plots_matplotlib.py | 1 + 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index ec0513d7f3..4eb78e9410 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -226,18 +226,13 @@ def plot_pair( point_estimate_marker_kwargs.setdefault("s", markersize + 50) if ax is None: - fig, ax = plt.subplots( - vars_to_plot, - vars_to_plot, - figsize=figsize, - **backend_kwargs, - ) + fig, ax = plt.subplots(vars_to_plot, vars_to_plot, figsize=figsize, **backend_kwargs,) hexbin_values = [] for i in range(0, vars_to_plot): var1 = infdata_group[i] for j in range(0, vars_to_plot): - var2 = infdata_group[j+not_marginals] + var2 = infdata_group[j + not_marginals] if i > j: if ax[j, i].get_figure() is not None: ax[j, i].remove() @@ -294,7 +289,7 @@ def plot_pair( if reference_values: x_name = flat_var_names[i] - y_name = flat_var_names[j+not_marginals] + y_name = flat_var_names[j + not_marginals] if x_name and y_name not in difference: ax[j, i].plot( reference_values_copy[x_name], @@ -312,7 +307,9 @@ def plot_pair( ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter()) else: ax[j, i].set_ylabel( - "{}".format(flat_var_names[j+not_marginals]), fontsize=ax_labelsize, wrap=True + "{}".format(flat_var_names[j + not_marginals]), + fontsize=ax_labelsize, + wrap=True, ) ax[j, i].tick_params(labelsize=xt_labelsize) diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 8db8de2cf7..cddea66304 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -477,6 +477,7 @@ def test_plot_pair_overlaid(models, kwargs): assert ax is ax2 assert ax.shape + @pytest.mark.parametrize("marginals", [True, False]) def test_plot_pair_shapes(marginals): rng = np.random.default_rng() From dc34a3218c50d86a28a4410cc3de1618afbdf52a Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 24 May 2020 02:52:59 +0200 Subject: [PATCH 3/6] minor fix --- arviz/plots/backends/matplotlib/pairplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 4eb78e9410..bb52cb2141 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -216,7 +216,7 @@ def plot_pair( "{side}x{side} grid".format(max_plots=max_plots, side=cols_to_plot), UserWarning, ) - numvars = cols_to_plot - 1 + numvars = cols_to_plot - marginals vars_to_plot = numvars - not_marginals (figsize, ax_labelsize, _, xt_labelsize, _, markersize) = _scale_fig_size( From cd3cc93a216255274cdc98b2d3f7c19a8f6a94a6 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 24 May 2020 13:02:39 +0200 Subject: [PATCH 4/6] remove prints --- arviz/plots/backends/matplotlib/pairplot.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index bb52cb2141..c123ac8fbf 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -207,8 +207,6 @@ def plot_pair( else rcParams["plot.max_subplots"] ) cols_to_plot = np.sum(np.arange(num_subplot_cols).cumsum() <= max_plots) - print(marginals) - print(cols_to_plot) if cols_to_plot < num_subplot_cols: warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " From 5073f96b8dc9fe024b4778040c8e60c794cb24dd Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 24 May 2020 13:59:35 +0200 Subject: [PATCH 5/6] extend tests and more fixes --- arviz/plots/backends/matplotlib/pairplot.py | 9 +++++---- arviz/tests/base_tests/test_plots_matplotlib.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index c123ac8fbf..3c01308db5 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -206,16 +206,17 @@ def plot_pair( if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"] ) - cols_to_plot = np.sum(np.arange(num_subplot_cols).cumsum() <= max_plots) + cols_to_plot = np.sum(np.arange(1, num_subplot_cols+1).cumsum() <= max_plots) if cols_to_plot < num_subplot_cols: + vars_to_plot = cols_to_plot warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " "of resulting pair plots with these variables, generating only a " - "{side}x{side} grid".format(max_plots=max_plots, side=cols_to_plot), + "{side}x{side} grid".format(max_plots=max_plots, side=vars_to_plot), UserWarning, ) - numvars = cols_to_plot - marginals - vars_to_plot = numvars - not_marginals + else: + vars_to_plot = numvars - not_marginals (figsize, ax_labelsize, _, xt_labelsize, _, markersize) = _scale_fig_size( figsize, textsize, vars_to_plot, vars_to_plot diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index cddea66304..4e2a4666db 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -479,11 +479,17 @@ def test_plot_pair_overlaid(models, kwargs): @pytest.mark.parametrize("marginals", [True, False]) -def test_plot_pair_shapes(marginals): +@pytest.mark.parametrize("max_subplots", [True, False]) +def test_plot_pair_shapes(marginals, max_subplots): rng = np.random.default_rng() - idata = from_dict({"a": rng.standard_normal((4, 500, 3))}) - ax = plot_pair(idata, marginals=marginals) - side = 2 + marginals + idata = from_dict({"a": rng.standard_normal((4, 500, 5))}) + if max_subplots: + with rc_context({"plot.max_subplots": 6}): + with pytest.warns(UserWarning, match="3x3 grid"): + ax = plot_pair(idata, marginals=marginals) + else: + ax = plot_pair(idata, marginals=marginals) + side = 3 if max_subplots else (4 + marginals) assert ax.shape == (side, side) From 3b63b00a915b8be27a4fe00fdb9d66b502158a88 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 24 May 2020 14:01:16 +0200 Subject: [PATCH 6/6] lint --- arviz/plots/backends/matplotlib/pairplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 3c01308db5..5343945a4a 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -206,7 +206,7 @@ def plot_pair( if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"] ) - cols_to_plot = np.sum(np.arange(1, num_subplot_cols+1).cumsum() <= max_plots) + cols_to_plot = np.sum(np.arange(1, num_subplot_cols + 1).cumsum() <= max_plots) if cols_to_plot < num_subplot_cols: vars_to_plot = cols_to_plot warnings.warn(