From 787d29f1abad45fa26aac85a5d88ab53e12c36c0 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 26 Oct 2018 15:49:04 -0300 Subject: [PATCH 1/3] density grid --- arviz/plots/densityplot.py | 33 +++++++++++++++++++++++---------- arviz/tests/test_plots.py | 3 +-- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index ce82a95f8f..f81d58c2b1 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -5,7 +5,13 @@ from ..data import convert_to_dataset from ..stats import hpd from .kdeplot import _fast_kde -from .plot_utils import _scale_fig_size, make_label, xarray_var_iter +from .plot_utils import ( + _scale_fig_size, + make_label, + xarray_var_iter, + default_grid, + _create_axes_grid, +) from ..utils import _var_names @@ -103,17 +109,25 @@ def plot_density( colors = [colors for _ in range(n_data)] to_plot = [list(xarray_var_iter(data, var_names, combined=True)) for data in datasets] - all_labels = set() + all_labels = [] + length_plotters = [] for plotters in to_plot: + length_plotters.append(len(plotters)) for var_name, selection, _ in plotters: - all_labels.add(make_label(var_name, selection)) + all_labels.append(make_label(var_name, selection)) + + length_plotters = max(length_plotters) + rows, cols = default_grid(length_plotters, max_cols=3) (figsize, _, titlesize, xt_labelsize, linewidth, markersize) = _scale_fig_size( - figsize, textsize, len(all_labels), 1 + figsize, textsize, rows, cols + ) + + fig, ax = _create_axes_grid( + length_plotters, rows, cols, figsize=figsize, squeeze=False, constrained_layout=True ) - fig, axes = plt.subplots(len(all_labels), 1, squeeze=False, figsize=figsize) - axis_map = {label: ax for label, ax in zip(all_labels, axes.flatten())} + axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())} for m_idx, plotters in enumerate(to_plot): for var_name, selection, values in plotters: @@ -136,14 +150,13 @@ def plot_density( ) if n_data > 1: - ax = axes.flatten()[0] for m_idx, label in enumerate(data_labels): - ax.plot([], label=label, c=colors[m_idx], markersize=markersize) - ax.legend(fontsize=xt_labelsize) + ax[0].plot([], label=label, c=colors[m_idx], markersize=markersize) + ax[0].legend(fontsize=xt_labelsize) fig.tight_layout() - return axes + return ax def _d_helper( diff --git a/arviz/tests/test_plots.py b/arviz/tests/test_plots.py index d396acd50e..6356d894d3 100644 --- a/arviz/tests/test_plots.py +++ b/arviz/tests/test_plots.py @@ -116,12 +116,11 @@ def test_plot_density_float(models, kwargs): obj = [getattr(models, model_fit) for model_fit in ["pymc3_fit", "stan_fit", "pyro_fit"]] axes = plot_density(obj, **kwargs) assert axes.shape[0] >= 18 - assert axes.shape[1] == 1 def test_plot_density_discrete(discrete_model): axes = plot_density(discrete_model, shade=0.9) - assert axes.shape[1] == 1 + assert axes.shape[0] == 2 @pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) From e9a0a5f46e9b632d6d19782a74f691d7b14d78cc Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 26 Oct 2018 16:56:28 -0300 Subject: [PATCH 2/3] remove unused import --- arviz/plots/densityplot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index f81d58c2b1..3a88ed1ddf 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -1,6 +1,5 @@ """KDE and histogram plots for multiple variables.""" import numpy as np -import matplotlib.pyplot as plt from ..data import convert_to_dataset from ..stats import hpd From 92aeeca71e7f8c7431773e4e13004d87cec3439a Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 27 Oct 2018 08:47:29 -0300 Subject: [PATCH 3/3] fix tests --- arviz/plots/densityplot.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index 3a88ed1ddf..39f18a6ae4 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -113,7 +113,9 @@ def plot_density( for plotters in to_plot: length_plotters.append(len(plotters)) for var_name, selection, _ in plotters: - all_labels.append(make_label(var_name, selection)) + label = make_label(var_name, selection) + if label not in all_labels: + all_labels.append(label) length_plotters = max(length_plotters) rows, cols = default_grid(length_plotters, max_cols=3) @@ -122,9 +124,7 @@ def plot_density( figsize, textsize, rows, cols ) - fig, ax = _create_axes_grid( - length_plotters, rows, cols, figsize=figsize, squeeze=False, constrained_layout=True - ) + fig, ax = _create_axes_grid(length_plotters, rows, cols, figsize=figsize, squeeze=False) axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())}