Skip to content

Commit

Permalink
density grid (#379)
Browse files Browse the repository at this point in the history
* density grid

* remove unused import

* fix tests
  • Loading branch information
aloctavodia authored and ahartikainen committed Oct 27, 2018
1 parent 8e92503 commit 77fc227
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
34 changes: 23 additions & 11 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""KDE and histogram plots for multiple variables."""
import numpy as np
import matplotlib.pyplot as plt

from ..data import convert_to_dataset
from ..stats import hpd
from .kdeplot import _fast_kde
from .plot_utils import _scale_fig_size, make_label, xarray_var_iter
from .plot_utils import (
_scale_fig_size,
make_label,
xarray_var_iter,
default_grid,
_create_axes_grid,
)
from ..utils import _var_names


Expand Down Expand Up @@ -103,17 +108,25 @@ def plot_density(
colors = [colors for _ in range(n_data)]

to_plot = [list(xarray_var_iter(data, var_names, combined=True)) for data in datasets]
all_labels = set()
all_labels = []
length_plotters = []
for plotters in to_plot:
length_plotters.append(len(plotters))
for var_name, selection, _ in plotters:
all_labels.add(make_label(var_name, selection))
label = make_label(var_name, selection)
if label not in all_labels:
all_labels.append(label)

length_plotters = max(length_plotters)
rows, cols = default_grid(length_plotters, max_cols=3)

(figsize, _, titlesize, xt_labelsize, linewidth, markersize) = _scale_fig_size(
figsize, textsize, len(all_labels), 1
figsize, textsize, rows, cols
)

fig, axes = plt.subplots(len(all_labels), 1, squeeze=False, figsize=figsize)
axis_map = {label: ax for label, ax in zip(all_labels, axes.flatten())}
fig, ax = _create_axes_grid(length_plotters, rows, cols, figsize=figsize, squeeze=False)

axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())}

for m_idx, plotters in enumerate(to_plot):
for var_name, selection, values in plotters:
Expand All @@ -136,14 +149,13 @@ def plot_density(
)

if n_data > 1:
ax = axes.flatten()[0]
for m_idx, label in enumerate(data_labels):
ax.plot([], label=label, c=colors[m_idx], markersize=markersize)
ax.legend(fontsize=xt_labelsize)
ax[0].plot([], label=label, c=colors[m_idx], markersize=markersize)
ax[0].legend(fontsize=xt_labelsize)

fig.tight_layout()

return axes
return ax


def _d_helper(
Expand Down
3 changes: 1 addition & 2 deletions arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,11 @@ def test_plot_density_float(models, kwargs):
obj = [getattr(models, model_fit) for model_fit in ["pymc3_fit", "stan_fit", "pyro_fit"]]
axes = plot_density(obj, **kwargs)
assert axes.shape[0] >= 18
assert axes.shape[1] == 1


def test_plot_density_discrete(discrete_model):
axes = plot_density(discrete_model, shade=0.9)
assert axes.shape[1] == 1
assert axes.shape[0] == 2


@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
Expand Down

0 comments on commit 77fc227

Please sign in to comment.