diff --git a/arviz/plots/autocorrplot.py b/arviz/plots/autocorrplot.py index 572981322d..ba4141857d 100644 --- a/arviz/plots/autocorrplot.py +++ b/arviz/plots/autocorrplot.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import numpy as np -from .plot_utils import _scale_text, default_grid, selection_to_string, xarray_var_iter +from .plot_utils import _scale_text, default_grid, make_label, xarray_var_iter from ..utils import convert_to_xarray @@ -66,7 +66,7 @@ def autocorrplot(posterior, var_names=None, max_lag=100, symmetric_plot=False, c ymax=y[midpoint + min_lag:midpoint + max_lag], lw=linewidth) ax.hlines(0, min_lag, max_lag, 'steelblue') - ax.set_title('{} ({})'.format(var_name, selection_to_string(selection)), fontsize=textsize) + ax.set_title(make_label(var_name, selection), fontsize=textsize) ax.tick_params(labelsize=textsize) y_min = min(y_min, y.min()) diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index 69c240ecd4..0a563c2ed8 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -3,11 +3,11 @@ from .kdeplot import fast_kde from ..stats import hpd -from ..utils import trace_to_dataframe, expand_variable_names -from .plot_utils import _scale_text +from ..utils import convert_to_xarray +from .plot_utils import _scale_text, make_label, xarray_var_iter -def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='mean', +def densityplot(data, data_labels=None, var_names=None, alpha=0.05, point_estimate='mean', colors='cycle', outline=True, hpd_markers='', shade=0., bw=4.5, figsize=None, textsize=None, skip_first=0): """ @@ -17,11 +17,11 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m Parameters ---------- - trace : Pandas DataFrame or PyMC3 trace or list of these objects - Posterior samples - models : list - List with names for the models in the list of traces. Useful when - plotting more that one trace. + data : xarray.Dataset, object that can be converted, or list of these + Posterior samples + data_labels : list[str] + List with names for the samples in the list of datasets. Useful when + plotting more than one trace. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). @@ -60,60 +60,60 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m ax : Matplotlib axes """ - if not isinstance(trace, (list, tuple)): - trace = [trace_to_dataframe(trace[skip_first:], combined=True)] + if not isinstance(data, (list, tuple)): + datasets = [convert_to_xarray(data)] else: - trace = [trace_to_dataframe(tr[skip_first:], combined=True) for tr in trace] + datasets = [convert_to_xarray(d) for d in data] + datasets = [data.where(data.draw >= skip_first).dropna('draw') for data in datasets] if point_estimate not in ('mean', 'median', None): - raise ValueError("Point estimate should be 'mean', 'median' or None") + raise ValueError(f"Point estimate should be 'mean', 'median' or None, not {point_estimate}") - length_trace = len(trace) + n_data = len(datasets) - if models is None: - if length_trace > 1: - models = ['m_{}'.format(i) for i in range(length_trace)] + if data_labels is None: + if n_data > 1: + data_labels = [f'{idx}' for idx in range(n_data)] else: - models = [''] - elif len(models) != length_trace: - raise ValueError( - "The number of names for the models does not match the number of models") - - length_models = len(models) + data_labels = [''] + elif len(data_labels) != n_data: + raise ValueError(f'The number of names for the models ({len(data_labels)}) ' + f'does not match the number of models ({n_data})') if colors == 'cycle': - colors = ['C{}'.format(i % 10) for i in range(length_models)] + colors = [f'C{idx % 10}' for idx in range(n_data)] elif isinstance(colors, str): - colors = [colors for i in range(length_models)] + colors = [colors for _ in range(n_data)] - if varnames is None: - varnames = set.union(*[set(tr.columns) for tr in trace]) - else: - varnames = set.union(*[set(expand_variable_names(tr, varnames)) for tr in trace]) + to_plot = [list(xarray_var_iter(data, var_names, combined=True)) for data in datasets] + all_labels = set() + for plotters in to_plot: + for var_name, selection, _ in plotters: + all_labels.add(make_label(var_name, selection)) if figsize is None: - figsize = (6, len(varnames) * 2) + figsize = (6, len(all_labels) * 2) textsize, linewidth, markersize = _scale_text(figsize, textsize=textsize) - fig, dplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize) - dplot = dplot.flatten() + fig, axes = plt.subplots(len(all_labels), 1, squeeze=False, figsize=figsize) + axis_map = {label: ax for label, ax in zip(all_labels, axes.flatten())} - for v_idx, vname in enumerate(varnames): - for t_idx, tr in enumerate(trace): - if vname in tr.columns: - vec = tr[vname].values - _d_helper(vec, vname, colors[t_idx], bw, textsize, linewidth, markersize, alpha, - point_estimate, hpd_markers, outline, shade, dplot[v_idx]) + for m_idx, plotters in enumerate(to_plot): + for var_name, selection, values in plotters: + label = make_label(var_name, selection) + _d_helper(values, label, colors[m_idx], bw, textsize, linewidth, markersize, + alpha, point_estimate, hpd_markers, outline, shade, axis_map[label]) - if length_trace > 1: - for m_idx, model in enumerate(models): - dplot[0].plot([], label=model, c=colors[m_idx], markersize=markersize) - dplot[0].legend(fontsize=textsize) + if n_data > 1: + ax = axes.flatten()[0] + for m_idx, label in enumerate(data_labels): + ax.plot([], label=label, c=colors[m_idx], markersize=markersize) + ax.legend(fontsize=textsize) fig.tight_layout() - return dplot + return axes def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha, @@ -172,8 +172,8 @@ def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha, ax.hist(vec, bins=bins, color=color, alpha=shade) if hpd_markers: - ax.plot(xmin, 0, 'v', color=color, markeredgecolor='k', markersize=markersize) - ax.plot(xmax, 0, 'v', color=color, markeredgecolor='k', markersize=markersize) + ax.plot(xmin, 0, hpd_markers, color=color, markeredgecolor='k', markersize=markersize) + ax.plot(xmax, 0, hpd_markers, color=color, markeredgecolor='k', markersize=markersize) if point_estimate is not None: if point_estimate == 'mean': diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 01372866b2..d6a8f51157 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -145,6 +145,24 @@ def selection_to_string(selection): return ', '.join(['{}: {}'.format(k, v) for k, v in selection.items()]) +def make_label(var_name, selection): + """Consistent labelling for plots + + Parameters + ---------- + var_name : str + Name of the variable + + selection : dict[Any] -> Any + Coordinates of the variable + + Returns + ------- + str + A text representation of the label + """ + return f'{var_name} ({selection_to_string(selection)})' + def xarray_var_iter(data, var_names=None, combined=False): """Converts xarray data to an iterator over vectors diff --git a/arviz/tests/test_plots.py b/arviz/tests/test_plots.py index 588013c06b..dc1c70e05c 100644 --- a/arviz/tests/test_plots.py +++ b/arviz/tests/test_plots.py @@ -21,8 +21,8 @@ def setup_class(cls): def test_density_plot(self): - assert densityplot(self.df_trace).shape == (1,) - assert densityplot(self.short_trace).shape == (18,) + for obj in (self.short_trace, self.fit): + assert densityplot(obj).shape == (18, 1) def test_traceplot(self): assert traceplot(self.df_trace).shape == (1, 2) diff --git a/doc/index.rst b/doc/index.rst index 49a28e4fc0..d93b7b816b 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -49,9 +49,9 @@ Contributions and issue reports are very welcome at `the github repository - +
- +
diff --git a/examples/densityplot.py b/examples/densityplot.py index b5f3ebe90c..834e98c3eb 100644 --- a/examples/densityplot.py +++ b/examples/densityplot.py @@ -8,5 +8,7 @@ az.style.use('arviz-darkgrid') -trace = az.utils.load_trace('data/centered_eight_trace.gzip') -az.densityplot(trace, varnames=('tau', 'theta__0')) +centered_data = az.load_data('data/centered_eight.nc') +non_centered_data = az.load_data('data/non_centered_eight.nc') +az.densityplot([centered_data, non_centered_data], ['Centered', 'Non Centered'], + var_names=['theta'], shade=0.1, alpha=0.01)