Skip to content

Commit

Permalink
Merge pull request #111 from ColCarroll/densityplot
Browse files Browse the repository at this point in the history
Update densityplot to use xarray
  • Loading branch information
ColCarroll authored Jun 27, 2018
2 parents 4ba5e57 + dea4912 commit 91b0063
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 52 deletions.
4 changes: 2 additions & 2 deletions arviz/plots/autocorrplot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import numpy as np

from .plot_utils import _scale_text, default_grid, selection_to_string, xarray_var_iter
from .plot_utils import _scale_text, default_grid, make_label, xarray_var_iter
from ..utils import convert_to_xarray


Expand Down Expand Up @@ -66,7 +66,7 @@ def autocorrplot(posterior, var_names=None, max_lag=100, symmetric_plot=False, c
ymax=y[midpoint + min_lag:midpoint + max_lag],
lw=linewidth)
ax.hlines(0, min_lag, max_lag, 'steelblue')
ax.set_title('{} ({})'.format(var_name, selection_to_string(selection)), fontsize=textsize)
ax.set_title(make_label(var_name, selection), fontsize=textsize)
ax.tick_params(labelsize=textsize)
y_min = min(y_min, y.min())

Expand Down
88 changes: 44 additions & 44 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from .kdeplot import fast_kde
from ..stats import hpd
from ..utils import trace_to_dataframe, expand_variable_names
from .plot_utils import _scale_text
from ..utils import convert_to_xarray
from .plot_utils import _scale_text, make_label, xarray_var_iter


def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='mean',
def densityplot(data, data_labels=None, var_names=None, alpha=0.05, point_estimate='mean',
colors='cycle', outline=True, hpd_markers='', shade=0., bw=4.5, figsize=None,
textsize=None, skip_first=0):
"""
Expand All @@ -17,11 +17,11 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
Parameters
----------
trace : Pandas DataFrame or PyMC3 trace or list of these objects
Posterior samples
models : list
List with names for the models in the list of traces. Useful when
plotting more that one trace.
data : xarray.Dataset, object that can be converted, or list of these
Posterior samples
data_labels : list[str]
List with names for the samples in the list of datasets. Useful when
plotting more than one trace.
varnames: list
List of variables to plot (defaults to None, which results in all
variables plotted).
Expand Down Expand Up @@ -60,60 +60,60 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
ax : Matplotlib axes
"""
if not isinstance(trace, (list, tuple)):
trace = [trace_to_dataframe(trace[skip_first:], combined=True)]
if not isinstance(data, (list, tuple)):
datasets = [convert_to_xarray(data)]
else:
trace = [trace_to_dataframe(tr[skip_first:], combined=True) for tr in trace]
datasets = [convert_to_xarray(d) for d in data]
datasets = [data.where(data.draw >= skip_first).dropna('draw') for data in datasets]

if point_estimate not in ('mean', 'median', None):
raise ValueError("Point estimate should be 'mean', 'median' or None")
raise ValueError(f"Point estimate should be 'mean', 'median' or None, not {point_estimate}")

length_trace = len(trace)
n_data = len(datasets)

if models is None:
if length_trace > 1:
models = ['m_{}'.format(i) for i in range(length_trace)]
if data_labels is None:
if n_data > 1:
data_labels = [f'{idx}' for idx in range(n_data)]
else:
models = ['']
elif len(models) != length_trace:
raise ValueError(
"The number of names for the models does not match the number of models")

length_models = len(models)
data_labels = ['']
elif len(data_labels) != n_data:
raise ValueError(f'The number of names for the models ({len(data_labels)}) '
f'does not match the number of models ({n_data})')

if colors == 'cycle':
colors = ['C{}'.format(i % 10) for i in range(length_models)]
colors = [f'C{idx % 10}' for idx in range(n_data)]
elif isinstance(colors, str):
colors = [colors for i in range(length_models)]
colors = [colors for _ in range(n_data)]

if varnames is None:
varnames = set.union(*[set(tr.columns) for tr in trace])
else:
varnames = set.union(*[set(expand_variable_names(tr, varnames)) for tr in trace])
to_plot = [list(xarray_var_iter(data, var_names, combined=True)) for data in datasets]
all_labels = set()
for plotters in to_plot:
for var_name, selection, _ in plotters:
all_labels.add(make_label(var_name, selection))

if figsize is None:
figsize = (6, len(varnames) * 2)
figsize = (6, len(all_labels) * 2)

textsize, linewidth, markersize = _scale_text(figsize, textsize=textsize)

fig, dplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize)
dplot = dplot.flatten()
fig, axes = plt.subplots(len(all_labels), 1, squeeze=False, figsize=figsize)
axis_map = {label: ax for label, ax in zip(all_labels, axes.flatten())}

for v_idx, vname in enumerate(varnames):
for t_idx, tr in enumerate(trace):
if vname in tr.columns:
vec = tr[vname].values
_d_helper(vec, vname, colors[t_idx], bw, textsize, linewidth, markersize, alpha,
point_estimate, hpd_markers, outline, shade, dplot[v_idx])
for m_idx, plotters in enumerate(to_plot):
for var_name, selection, values in plotters:
label = make_label(var_name, selection)
_d_helper(values, label, colors[m_idx], bw, textsize, linewidth, markersize,
alpha, point_estimate, hpd_markers, outline, shade, axis_map[label])

if length_trace > 1:
for m_idx, model in enumerate(models):
dplot[0].plot([], label=model, c=colors[m_idx], markersize=markersize)
dplot[0].legend(fontsize=textsize)
if n_data > 1:
ax = axes.flatten()[0]
for m_idx, label in enumerate(data_labels):
ax.plot([], label=label, c=colors[m_idx], markersize=markersize)
ax.legend(fontsize=textsize)

fig.tight_layout()

return dplot
return axes


def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha,
Expand Down Expand Up @@ -172,8 +172,8 @@ def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha,
ax.hist(vec, bins=bins, color=color, alpha=shade)

if hpd_markers:
ax.plot(xmin, 0, 'v', color=color, markeredgecolor='k', markersize=markersize)
ax.plot(xmax, 0, 'v', color=color, markeredgecolor='k', markersize=markersize)
ax.plot(xmin, 0, hpd_markers, color=color, markeredgecolor='k', markersize=markersize)
ax.plot(xmax, 0, hpd_markers, color=color, markeredgecolor='k', markersize=markersize)

if point_estimate is not None:
if point_estimate == 'mean':
Expand Down
18 changes: 18 additions & 0 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ def selection_to_string(selection):
return ', '.join(['{}: {}'.format(k, v) for k, v in selection.items()])


def make_label(var_name, selection):
"""Consistent labelling for plots
Parameters
----------
var_name : str
Name of the variable
selection : dict[Any] -> Any
Coordinates of the variable
Returns
-------
str
A text representation of the label
"""
return f'{var_name} ({selection_to_string(selection)})'


def xarray_var_iter(data, var_names=None, combined=False):
"""Converts xarray data to an iterator over vectors
Expand Down
4 changes: 2 additions & 2 deletions arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def setup_class(cls):


def test_density_plot(self):
assert densityplot(self.df_trace).shape == (1,)
assert densityplot(self.short_trace).shape == (18,)
for obj in (self.short_trace, self.fit):
assert densityplot(obj).shape == (18, 1)

def test_traceplot(self):
assert traceplot(self.df_trace).shape == (1, 2)
Expand Down
4 changes: 2 additions & 2 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ Contributions and issue reports are very welcome at `the github repository
<img src="_static/traceplot_thumb.png">
</div>
</a>
<a href="examples/energyplot.html">
<a href="examples/densityplot.html">
<div class="col-md-4 thumbnail">
<img src="_static/energyplot_thumb.png">
<img src="_static/densityplot_thumb.png">
</div>
<a href="examples/jointplot.html">
<div class="col-md-4 thumbnail">
Expand Down
6 changes: 4 additions & 2 deletions examples/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@

az.style.use('arviz-darkgrid')

trace = az.utils.load_trace('data/centered_eight_trace.gzip')
az.densityplot(trace, varnames=('tau', 'theta__0'))
centered_data = az.load_data('data/centered_eight.nc')
non_centered_data = az.load_data('data/non_centered_eight.nc')
az.densityplot([centered_data, non_centered_data], ['Centered', 'Non Centered'],
var_names=['theta'], shade=0.1, alpha=0.01)

0 comments on commit 91b0063

Please sign in to comment.