Skip to content

Commit

Permalink
add option to display rank plots instead of trace (#1134)
Browse files Browse the repository at this point in the history
* add option to display rank plots instead of trace

* update changelog and add galery examples
  • Loading branch information
aloctavodia committed Apr 6, 2020
1 parent afaec70 commit 6383820
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 38 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
## v0.x.x Unreleased

### New features
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079
* Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* Allow xarray.Dataarray input for plots.(#1120)
* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117)
* Skip test for optional/extra dependencies when not installed (#1113)
* Add option to display rank plots instead of trace (#1134)
### Maintenance and fixes
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
* Fixed hist kind of `plot_dist` with multidimensional input (#1115)
Expand Down
48 changes: 34 additions & 14 deletions arviz/plots/backends/bokeh/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from . import backend_kwarg_defaults
from .. import show_layout
from ...distplot import plot_dist
from ...rankplot import plot_rank
from ...plot_utils import xarray_var_iter, make_label, _scale_fig_size
from ....rcparams import rcParams

Expand All @@ -19,6 +20,7 @@ def plot_trace(
data,
var_names,
divergences,
kind,
figsize,
rug,
lines,
Expand All @@ -30,6 +32,7 @@ def plot_trace(
rug_kwargs: [Dict],
hist_kwargs: [Dict],
trace_kwargs: [Dict],
rank_kwargs: [Dict],
plotters,
divergence_data,
axes,
Expand Down Expand Up @@ -68,6 +71,9 @@ def plot_trace(
trace_kwargs.setdefault("line_width", linewidth)
plot_kwargs.setdefault("line_width", linewidth)

if rank_kwargs is None:
rank_kwargs = {}

if axes is None:
axes = []
for i in range(len(plotters)):
Expand Down Expand Up @@ -154,12 +160,14 @@ def plot_trace(
chain_prop=chain_prop,
combined=combined,
rug=rug,
kind=kind,
legend=legend,
trace_kwargs=trace_kwargs,
hist_kwargs=hist_kwargs,
plot_kwargs=plot_kwargs,
fill_kwargs=fill_kwargs,
rug_kwargs=rug_kwargs,
rank_kwargs=rank_kwargs,
)
else:
for y_name in cds_var_groups[var_name]:
Expand All @@ -174,12 +182,14 @@ def plot_trace(
chain_prop=chain_prop,
combined=combined,
rug=rug,
kind=kind,
legend=legend,
trace_kwargs=trace_kwargs,
hist_kwargs=hist_kwargs,
plot_kwargs=plot_kwargs,
fill_kwargs=fill_kwargs,
rug_kwargs=rug_kwargs,
rank_kwargs=rank_kwargs,
)

for col in (0, 1):
Expand Down Expand Up @@ -272,33 +282,36 @@ def _plot_chains_bokeh(
chain_prop,
combined,
rug,
kind,
legend,
trace_kwargs,
hist_kwargs,
plot_kwargs,
fill_kwargs,
rug_kwargs,
rank_kwargs,
):
marker = trace_kwargs.pop("marker", True)
for chain_idx, cds in data.items():
if legend:
trace_kwargs["legend_label"] = "chain {}".format(chain_idx)
ax_trace.line(
x=x_name,
y=y_name,
source=cds,
**{chain_prop[0]: chain_prop[1][chain_idx]},
**trace_kwargs,
)
if marker:
ax_trace.circle(
if kind == "trace":
if legend:
trace_kwargs["legend_label"] = "chain {}".format(chain_idx)
ax_trace.line(
x=x_name,
y=y_name,
source=cds,
radius=0.30,
alpha=0.5,
**{chain_prop[0]: chain_prop[1][chain_idx],},
**{chain_prop[0]: chain_prop[1][chain_idx]},
**trace_kwargs,
)
if marker:
ax_trace.circle(
x=x_name,
y=y_name,
source=cds,
radius=0.30,
alpha=0.5,
**{chain_prop[0]: chain_prop[1][chain_idx],},
)
if not combined:
rug_kwargs["cds"] = cds
if legend:
Expand All @@ -318,6 +331,13 @@ def _plot_chains_bokeh(
)
plot_kwargs.pop(chain_prop[0])

if kind == "rank_bars":
value = np.array([item.data[y_name] for item in data.values()])
plot_rank(value, kind="bars", axes=ax_trace, backend="bokeh", show=False, **rank_kwargs)
elif kind == "rank_vlines":
value = np.array([item.data[y_name] for item in data.values()])
plot_rank(value, kind="vlines", axes=ax_trace, backend="bokeh", show=False, **rank_kwargs)

if combined:
rug_kwargs["cds"] = data
if legend:
Expand Down
68 changes: 47 additions & 21 deletions arviz/plots/backends/matplotlib/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

from . import backend_kwarg_defaults, backend_show
from ...distplot import plot_dist
from ...rankplot import plot_rank
from ...plot_utils import _scale_fig_size, get_bins, make_label, format_coords_as_labels


def plot_trace(
data,
var_names, # pylint: disable=unused-argument
divergences,
kind,
figsize,
rug,
lines,
Expand All @@ -27,6 +29,7 @@ def plot_trace(
rug_kwargs,
hist_kwargs,
trace_kwargs,
rank_kwargs,
plotters,
divergence_data,
axes,
Expand All @@ -47,6 +50,8 @@ def plot_trace(
One or more variables to be plotted.
divergences : {"bottom", "top", None, False}
Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y.
kind : {"trace", "rank_bar", "rank_vlines"}, optional
Choose between plotting sampled values per iteration and rank plots.
figsize : figure size tuple
If None, size is (12, variables * 2)
rug : bool
Expand All @@ -70,6 +75,8 @@ def plot_trace(
Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables.
trace_kwargs : dict
Extra keyword arguments passed to `plt.plot`
rank_kwargs : dict
Extra keyword arguments passed to `arviz.plot_rank`
Returns
-------
axes : matplotlib axes
Expand Down Expand Up @@ -160,11 +167,13 @@ def plot_trace(
combined,
xt_labelsize,
rug,
kind,
trace_kwargs,
hist_kwargs,
plot_kwargs,
fill_kwargs,
rug_kwargs,
rank_kwargs,
)
if compact_prop:
plot_kwargs.pop(compact_prop[0])
Expand Down Expand Up @@ -197,11 +206,13 @@ def plot_trace(
combined,
xt_labelsize,
rug,
kind,
trace_kwargs,
hist_kwargs,
plot_kwargs,
fill_kwargs,
rug_kwargs,
rank_kwargs,
)
if legend:
handles.append(
Expand Down Expand Up @@ -241,18 +252,19 @@ def plot_trace(
else:
ylocs = [ylim[0] for ylim in ylims]
values = value[chain, div_idxs]
axes[idx, 1].plot(
div_draws,
np.zeros_like(div_idxs) + ylocs[1],
marker="|",
color="black",
markeredgewidth=1.5,
markersize=30,
linestyle="None",
alpha=hist_kwargs["alpha"],
zorder=-5,
)
axes[idx, 1].set_ylim(*ylims[1])
if kind == "trace":
axes[idx, 1].plot(
div_draws,
np.zeros_like(div_idxs) + ylocs[1],
marker="|",
color="black",
markeredgewidth=1.5,
markersize=30,
linestyle="None",
alpha=hist_kwargs["alpha"],
zorder=-5,
)
axes[idx, 1].set_ylim(*ylims[1])
axes[idx, 0].plot(
values,
np.zeros_like(values) + ylocs[0],
Expand All @@ -276,12 +288,18 @@ def plot_trace(
"line-positions should be numeric, found {}".format(line_values)
)
axes[idx, 0].vlines(line_values, *ylims[0], colors="black", linewidth=1.5, alpha=0.75)
axes[idx, 1].hlines(
line_values, *xlims[1], colors="black", linewidth=1.5, alpha=trace_kwargs["alpha"]
)
if kind == "trace":
axes[idx, 1].hlines(
line_values,
*xlims[1],
colors="black",
linewidth=1.5,
alpha=trace_kwargs["alpha"]
)
axes[idx, 0].set_ylim(bottom=0, top=ylims[0][1])
axes[idx, 1].set_xlim(left=data.draw.min(), right=data.draw.max())
axes[idx, 1].set_ylim(*ylims[1])
if kind == "trace":
axes[idx, 1].set_xlim(left=data.draw.min(), right=data.draw.max())
axes[idx, 1].set_ylim(*ylims[1])
if legend:
legend_kwargs = trace_kwargs if combined else plot_kwargs
handles = [
Expand All @@ -295,7 +313,7 @@ def plot_trace(
[], [], label="combined", **{chain_prop[0]: chain_prop[1][-1]}, **plot_kwargs
),
)
axes[0, 1].legend(handles=handles, title="chain")
axes[0, 0].legend(handles=handles, title="chain")

if backend_show(show):
plt.show()
Expand All @@ -312,16 +330,19 @@ def _plot_chains_mpl(
combined,
xt_labelsize,
rug,
kind,
trace_kwargs,
hist_kwargs,
plot_kwargs,
fill_kwargs,
rug_kwargs,
rank_kwargs,
):
for chain_idx, row in enumerate(value):
axes[idx, 1].plot(
data.draw.values, row, **{chain_prop[0]: chain_prop[1][chain_idx]}, **trace_kwargs
)
if kind == "trace":
axes[idx, 1].plot(
data.draw.values, row, **{chain_prop[0]: chain_prop[1][chain_idx]}, **trace_kwargs
)

if not combined:
plot_kwargs[chain_prop[0]] = chain_prop[1][chain_idx]
Expand All @@ -339,6 +360,11 @@ def _plot_chains_mpl(
)
plot_kwargs.pop(chain_prop[0])

if kind == "rank_bars":
plot_rank(data=value, kind="bars", axes=axes[idx, 1], **rank_kwargs)
elif kind == "rank_vlines":
plot_rank(data=value, kind="vlines", axes=axes[idx, 1], **rank_kwargs)

if combined:
plot_kwargs[chain_prop[0]] = chain_prop[1][-1]
plot_dist(
Expand Down
20 changes: 19 additions & 1 deletion arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def plot_trace(
transform: Optional[Callable] = None,
coords: Optional[CoordSpec] = None,
divergences: Optional[str] = "bottom",
kind: Optional[str] = "trace",
figsize: Optional[Tuple[float, float]] = None,
rug: bool = False,
lines: Optional[List[Tuple[str, CoordSpec, Any]]] = None,
Expand All @@ -36,13 +37,14 @@ def plot_trace(
rug_kwargs: Optional[KwargSpec] = None,
hist_kwargs: Optional[KwargSpec] = None,
trace_kwargs: Optional[KwargSpec] = None,
rank_kwargs: Optional[KwargSpec] = None,
ax=None,
backend: Optional[str] = None,
backend_config: Optional[KwargSpec] = None,
backend_kwargs: Optional[KwargSpec] = None,
show: Optional[bool] = None,
):
"""Plot distribution (histogram or kernel density estimates) and sampled values.
"""Plot distribution (histogram or kernel density estimates) and sampled values or rank plot.
If `divergences` data is available in `sample_stats`, will plot the location of divergences as
dashed vertical lines.
Expand All @@ -58,6 +60,8 @@ def plot_trace(
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
divergences : {"bottom", "top", None}, optional
Plot location of divergences on the traceplots.
kind : {"trace", "rank_bar", "rank_vlines"}, optional
Choose between plotting sampled values per iteration and rank plots.
transform : callable, optional
Function to transform data (defaults to None i.e.the identity function)
figsize : tuple of (float, float), optional
Expand Down Expand Up @@ -118,6 +122,13 @@ def plot_trace(
>>> az.plot_trace(data, compact=True)
Display a rank plot instead of trace
.. plot::
:context: close-figs
>>> az.plot_trace(data, var_names=["mu", "tau"], kind="rank_bars")
Combine all chains into one distribution
.. plot::
Expand All @@ -135,6 +146,9 @@ def plot_trace(
>>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)
"""
if kind not in {"trace", "rank_vlines", "rank_bars"}:
raise ValueError("The value of kind must be either trace, rank_vlines or rank_bars.")

if divergences:
try:
divergence_data = convert_to_dataset(data, group="sample_stats").diverging
Expand Down Expand Up @@ -235,6 +249,8 @@ def plot_trace(
fill_kwargs = {}
if rug_kwargs is None:
rug_kwargs = {}
if rank_kwargs is None:
rank_kwargs = {}

# TODO: Check if this can be further simplified
trace_plot_args = dict(
Expand All @@ -243,6 +259,7 @@ def plot_trace(
var_names=var_names,
# coords = coords,
divergences=divergences,
kind=kind,
figsize=figsize,
rug=rug,
lines=lines,
Expand All @@ -251,6 +268,7 @@ def plot_trace(
rug_kwargs=rug_kwargs,
hist_kwargs=hist_kwargs,
trace_kwargs=trace_kwargs,
rank_kwargs=rank_kwargs,
compact_prop=compact_prop,
combined=combined,
chain_prop=chain_prop,
Expand Down
Loading

0 comments on commit 6383820

Please sign in to comment.