Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traceplot legend #1070

Merged
merged 13 commits into from
Mar 1, 2020
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
and `essplot` (#1024)
* New defaults for cross validation: `loo` (old: waic) and `log` -scale (old: `deviance` -scale) (#1067)
* **Experimental Feature**: Added `arviz.wrappers` module to allow ArviZ to
refit the models if necessary
* **Experimental Feature**: Added `reloo` function to ArviZ
refit the models if necessary (#771)
* **Experimental Feature**: Added `reloo` function to ArviZ (#771)
* ArviZ version to InferenceData attributes. (#1086)
* Add `log_likelihood` argument to `from_pymc3`
* Add `log_likelihood` argument to `from_pymc3` (#1082)
* Integrated rcParams for `plot.bokeh.layout` and `plot.backend`. (#1089)
* Add automatic legends in `plot_trace` with compact=True (matplotlib only) (#1070)


### Maintenance and fixes
Expand Down
46 changes: 26 additions & 20 deletions arviz/plots/backends/bokeh/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def plot_trace(
rug,
lines,
combined,
chain_prop,
legend,
plot_kwargs: [Dict],
fill_kwargs: [Dict],
Expand All @@ -31,7 +32,7 @@ def plot_trace(
trace_kwargs: [Dict],
plotters,
divergence_data,
colors,
axes,
backend_config,
backend_kwargs: [Dict],
show,
Expand Down Expand Up @@ -67,16 +68,17 @@ def plot_trace(
trace_kwargs.setdefault("line_width", linewidth)
plot_kwargs.setdefault("line_width", linewidth)

axes = []
for i in range(len(plotters)):
if i != 0:
_axes = [
bkp.figure(**backend_kwargs),
bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs),
]
else:
_axes = [bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs)]
axes.append(_axes)
if axes is None:
axes = []
for i in range(len(plotters)):
if i != 0:
_axes = [
bkp.figure(**backend_kwargs),
bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs),
]
else:
_axes = [bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs)]
axes.append(_axes)

axes = np.array(axes)

Expand Down Expand Up @@ -149,7 +151,7 @@ def plot_trace(
data=cds_data,
x_name=draw_name,
y_name=y_name,
colors=colors,
chain_prop=chain_prop,
combined=combined,
rug=rug,
legend=legend,
Expand All @@ -169,7 +171,7 @@ def plot_trace(
data=cds_data,
x_name=draw_name,
y_name=y_name,
colors=colors,
chain_prop=chain_prop,
combined=combined,
rug=rug,
legend=legend,
Expand Down Expand Up @@ -267,7 +269,7 @@ def _plot_chains_bokeh(
data,
x_name,
y_name,
colors,
chain_prop,
combined,
rug,
legend,
Expand All @@ -282,27 +284,29 @@ def _plot_chains_bokeh(
if legend:
trace_kwargs["legend_label"] = "chain {}".format(chain_idx)
ax_trace.line(
x=x_name, y=y_name, source=cds, line_color=colors[chain_idx], **trace_kwargs,
x=x_name,
y=y_name,
source=cds,
**{chain_prop[0]: chain_prop[1][chain_idx]},
**trace_kwargs,
)
if marker:
ax_trace.circle(
x=x_name,
y=y_name,
source=cds,
radius=0.30,
line_color=colors[chain_idx],
fill_color=colors[chain_idx],
alpha=0.5,
**{chain_prop[0]: chain_prop[1][chain_idx],},
)
if not combined:
rug_kwargs["cds"] = cds
if legend:
plot_kwargs["legend_label"] = "chain {}".format(chain_idx)
plot_kwargs["line_color"] = colors[chain_idx]
plot_kwargs[chain_prop[0]] = chain_prop[1][chain_idx]
plot_dist(
cds.data[y_name],
ax=ax_density,
color=colors[chain_idx],
rug=rug,
hist_kwargs=hist_kwargs,
plot_kwargs=plot_kwargs,
Expand All @@ -312,15 +316,16 @@ def _plot_chains_bokeh(
backend_kwargs={},
show=False,
)
plot_kwargs.pop(chain_prop[0])

if combined:
rug_kwargs["cds"] = data
if legend:
plot_kwargs["legend_label"] = "combined chains"
plot_kwargs[chain_prop[0]] = chain_prop[1][-1]
plot_dist(
np.concatenate([item.data[y_name] for item in data.values()]).flatten(),
ax=ax_density,
color=colors[-1],
rug=rug,
hist_kwargs=hist_kwargs,
plot_kwargs=plot_kwargs,
Expand All @@ -330,3 +335,4 @@ def _plot_chains_bokeh(
backend_kwargs={},
show=False,
)
plot_kwargs.pop(chain_prop[0])
72 changes: 59 additions & 13 deletions arviz/plots/backends/matplotlib/traceplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Matplotlib traceplot."""
from itertools import cycle

import warnings
import matplotlib.pyplot as plt
Expand All @@ -7,7 +8,7 @@

from . import backend_kwarg_defaults, backend_show
from ...distplot import plot_dist
from ...plot_utils import _scale_fig_size, get_bins, make_label
from ...plot_utils import _scale_fig_size, get_bins, make_label, format_coords_as_labels


def plot_trace(
Expand All @@ -17,7 +18,9 @@ def plot_trace(
figsize,
rug,
lines,
compact_prop,
combined,
chain_prop,
legend,
plot_kwargs,
fill_kwargs,
Expand All @@ -26,7 +29,7 @@ def plot_trace(
trace_kwargs,
plotters,
divergence_data,
colors,
axes,
backend_kwargs,
show,
):
Expand Down Expand Up @@ -123,7 +126,8 @@ def plot_trace(
trace_kwargs.setdefault("linewidth", linewidth)
plot_kwargs.setdefault("linewidth", linewidth)

_, axes = plt.subplots(len(plotters), 2, squeeze=False, figsize=figsize, **backend_kwargs)
if axes is None:
_, axes = plt.subplots(len(plotters), 2, squeeze=False, figsize=figsize, **backend_kwargs)

# Check the input for lines
if lines is not None:
Expand All @@ -144,12 +148,15 @@ def plot_trace(
value = np.atleast_2d(value)

if len(value.shape) == 2:
if compact_prop:
plot_kwargs[compact_prop[0]] = compact_prop[1][0]
trace_kwargs[compact_prop[0]] = compact_prop[1][0]
_plot_chains_mpl(
axes,
idx,
value,
data,
colors,
chain_prop,
combined,
xt_labelsize,
rug,
Expand All @@ -159,15 +166,34 @@ def plot_trace(
fill_kwargs,
rug_kwargs,
)
if compact_prop:
plot_kwargs.pop(compact_prop[0])
trace_kwargs.pop(compact_prop[0])
else:
sub_data = data[var_name].sel(**selection)
legend_labels = format_coords_as_labels(sub_data, skip_dims=("chain", "draw"))
legend_title = ", ".join(
[
"{}".format(coord_name)
for coord_name in sub_data.coords
if coord_name not in {"chain", "draw"}
]
)
value = value.reshape((value.shape[0], value.shape[1], -1))
for sub_idx in range(value.shape[2]):
compact_prop_cycle = cycle(compact_prop[1])
handles = []
for sub_idx, label, prop in zip(
range(value.shape[2]), legend_labels, compact_prop_cycle
):
if compact_prop:
plot_kwargs[compact_prop[0]] = prop
trace_kwargs[compact_prop[0]] = prop
_plot_chains_mpl(
axes,
idx,
value[..., sub_idx],
data,
colors,
chain_prop,
combined,
xt_labelsize,
rug,
Expand All @@ -177,6 +203,16 @@ def plot_trace(
fill_kwargs,
rug_kwargs,
)
if legend:
handles.append(
Line2D(
[], [], label=label, **{chain_prop[0]: chain_prop[1][0]}, **plot_kwargs
)
)
if legend:
axes[idx, 0].legend(handles=handles, title=legend_title)
plot_kwargs.pop(compact_prop[0], None)
trace_kwargs.pop(compact_prop[0], None)

if value[0].dtype.kind == "i":
xticks = get_bins(value)
Expand Down Expand Up @@ -247,12 +283,18 @@ def plot_trace(
axes[idx, 1].set_xlim(left=data.draw.min(), right=data.draw.max())
axes[idx, 1].set_ylim(*ylims[1])
if legend:
legend_kwargs = trace_kwargs if combined else plot_kwargs
handles = [
Line2D([], [], color=color, label=chain_id)
for chain_id, color in zip(data.chain.values, colors)
Line2D([], [], label=chain_id, **{chain_prop[0]: prop}, **legend_kwargs)
for chain_id, prop in zip(data.chain.values, chain_prop[1])
]
if combined:
handles.insert(0, Line2D([], [], color=colors[-1], label="combined"))
handles.insert(
0,
Line2D(
[], [], label="combined", **{chain_prop[0]: chain_prop[1][-1]}, **plot_kwargs
),
)
axes[0, 1].legend(handles=handles, title="chain")

if backend_show(show):
Expand All @@ -266,7 +308,7 @@ def _plot_chains_mpl(
idx,
value,
data,
colors,
chain_prop,
combined,
xt_labelsize,
rug,
Expand All @@ -277,10 +319,12 @@ def _plot_chains_mpl(
rug_kwargs,
):
for chain_idx, row in enumerate(value):
axes[idx, 1].plot(data.draw.values, row, color=colors[chain_idx], **trace_kwargs)
axes[idx, 1].plot(
data.draw.values, row, **{chain_prop[0]: chain_prop[1][chain_idx]}, **trace_kwargs
)

if not combined:
plot_kwargs["color"] = colors[chain_idx]
plot_kwargs[chain_prop[0]] = chain_prop[1][chain_idx]
plot_dist(
values=row,
textsize=xt_labelsize,
Expand All @@ -293,9 +337,10 @@ def _plot_chains_mpl(
backend="matplotlib",
show=False,
)
plot_kwargs.pop(chain_prop[0])

if combined:
plot_kwargs["color"] = colors[-1]
plot_kwargs[chain_prop[0]] = chain_prop[1][-1]
plot_dist(
values=value.flatten(),
textsize=xt_labelsize,
Expand All @@ -308,3 +353,4 @@ def _plot_chains_mpl(
backend="matplotlib",
show=False,
)
plot_kwargs.pop(chain_prop[0])
21 changes: 18 additions & 3 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities for plotting."""
import warnings
from typing import Dict, Any
from itertools import product, tee
import importlib
from scipy.signal import convolve, convolve2d
Expand All @@ -17,6 +18,8 @@
from ..utils import conditional_jit, _stack
from ..rcparams import rcParams

KwargSpec = Dict[str, Any]


def make_2d(ary):
"""Convert any array into a 2d numpy array.
Expand Down Expand Up @@ -599,9 +602,21 @@ def color_from_dim(dataarray, dim_name):
return colors, color_mapping


def format_coords_as_labels(dataarray):
"""Format 1d or multi-d dataarray coords as strings."""
coord_labels = dataarray.coords.to_index().values
def format_coords_as_labels(dataarray, skip_dims=None):
"""Format 1d or multi-d dataarray coords as strings.

Parameters
----------
dataarray : xarray.DataArray
DataArray whose coordinates will be converted to labels.
skip_dims : str of list_like, optional
Dimensions whose values should not be included in the labels
"""
if skip_dims is None:
coord_labels = dataarray.coords.to_index()
else:
coord_labels = dataarray.coords.to_index().droplevel(skip_dims).drop_duplicates()
coord_labels = coord_labels.values
if isinstance(coord_labels[0], tuple):
fmt = ", ".join(["{}" for _ in coord_labels[0]])
coord_labels[:] = [fmt.format(*x) for x in coord_labels]
Expand Down
Loading