diff --git a/CHANGELOG.md b/CHANGELOG.md index b2bf074d13..3467cbd59a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,12 +3,12 @@ ## v0.x.x Unreleased ### New features -* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 * Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090 * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * Allow xarray.Dataarray input for plots.(#1120) * Skip test for optional/extra dependencies when not installed (#1113) +* Add option to display rank plots instead of trace (#1134) ### Maintenance and fixes * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) diff --git a/arviz/plots/backends/bokeh/traceplot.py b/arviz/plots/backends/bokeh/traceplot.py index d14856fc94..1e4f889a49 100644 --- a/arviz/plots/backends/bokeh/traceplot.py +++ b/arviz/plots/backends/bokeh/traceplot.py @@ -11,6 +11,7 @@ from . import backend_kwarg_defaults from .. import show_layout from ...distplot import plot_dist +from ...rankplot import plot_rank from ...plot_utils import xarray_var_iter, make_label, _scale_fig_size from ....rcparams import rcParams @@ -19,6 +20,7 @@ def plot_trace( data, var_names, divergences, + kind, figsize, rug, lines, @@ -30,6 +32,7 @@ def plot_trace( rug_kwargs: [Dict], hist_kwargs: [Dict], trace_kwargs: [Dict], + rank_kwargs: [Dict], plotters, divergence_data, axes, @@ -68,6 +71,9 @@ def plot_trace( trace_kwargs.setdefault("line_width", linewidth) plot_kwargs.setdefault("line_width", linewidth) + if rank_kwargs is None: + rank_kwargs = {} + if axes is None: axes = [] for i in range(len(plotters)): @@ -154,12 +160,14 @@ def plot_trace( chain_prop=chain_prop, combined=combined, rug=rug, + kind=kind, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, + rank_kwargs=rank_kwargs, ) else: for y_name in cds_var_groups[var_name]: @@ -174,12 +182,14 @@ def plot_trace( chain_prop=chain_prop, combined=combined, rug=rug, + kind=kind, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, + rank_kwargs=rank_kwargs, ) for col in (0, 1): @@ -272,33 +282,36 @@ def _plot_chains_bokeh( chain_prop, combined, rug, + kind, legend, trace_kwargs, hist_kwargs, plot_kwargs, fill_kwargs, rug_kwargs, + rank_kwargs, ): marker = trace_kwargs.pop("marker", True) for chain_idx, cds in data.items(): - if legend: - trace_kwargs["legend_label"] = "chain {}".format(chain_idx) - ax_trace.line( - x=x_name, - y=y_name, - source=cds, - **{chain_prop[0]: chain_prop[1][chain_idx]}, - **trace_kwargs, - ) - if marker: - ax_trace.circle( + if kind == "trace": + if legend: + trace_kwargs["legend_label"] = "chain {}".format(chain_idx) + ax_trace.line( x=x_name, y=y_name, source=cds, - radius=0.30, - alpha=0.5, - **{chain_prop[0]: chain_prop[1][chain_idx],}, + **{chain_prop[0]: chain_prop[1][chain_idx]}, + **trace_kwargs, ) + if marker: + ax_trace.circle( + x=x_name, + y=y_name, + source=cds, + radius=0.30, + alpha=0.5, + **{chain_prop[0]: chain_prop[1][chain_idx],}, + ) if not combined: rug_kwargs["cds"] = cds if legend: @@ -318,6 +331,13 @@ def _plot_chains_bokeh( ) plot_kwargs.pop(chain_prop[0]) + if kind == "rank_bars": + value = np.array([item.data[y_name] for item in data.values()]) + plot_rank(value, kind="bars", axes=ax_trace, backend="bokeh", show=False, **rank_kwargs) + elif kind == "rank_vlines": + value = np.array([item.data[y_name] for item in data.values()]) + plot_rank(value, kind="vlines", axes=ax_trace, backend="bokeh", show=False, **rank_kwargs) + if combined: rug_kwargs["cds"] = data if legend: diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index c8de44b2ae..fbd51316e0 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -8,6 +8,7 @@ from . import backend_kwarg_defaults, backend_show from ...distplot import plot_dist +from ...rankplot import plot_rank from ...plot_utils import _scale_fig_size, get_bins, make_label, format_coords_as_labels @@ -15,6 +16,7 @@ def plot_trace( data, var_names, # pylint: disable=unused-argument divergences, + kind, figsize, rug, lines, @@ -27,6 +29,7 @@ def plot_trace( rug_kwargs, hist_kwargs, trace_kwargs, + rank_kwargs, plotters, divergence_data, axes, @@ -47,6 +50,8 @@ def plot_trace( One or more variables to be plotted. divergences : {"bottom", "top", None, False} Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y. + kind : {"trace", "rank_bar", "rank_vlines"}, optional + Choose between plotting sampled values per iteration and rank plots. figsize : figure size tuple If None, size is (12, variables * 2) rug : bool @@ -70,6 +75,8 @@ def plot_trace( Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables. trace_kwargs : dict Extra keyword arguments passed to `plt.plot` + rank_kwargs : dict + Extra keyword arguments passed to `arviz.plot_rank` Returns ------- axes : matplotlib axes @@ -160,11 +167,13 @@ def plot_trace( combined, xt_labelsize, rug, + kind, trace_kwargs, hist_kwargs, plot_kwargs, fill_kwargs, rug_kwargs, + rank_kwargs, ) if compact_prop: plot_kwargs.pop(compact_prop[0]) @@ -197,11 +206,13 @@ def plot_trace( combined, xt_labelsize, rug, + kind, trace_kwargs, hist_kwargs, plot_kwargs, fill_kwargs, rug_kwargs, + rank_kwargs, ) if legend: handles.append( @@ -241,18 +252,19 @@ def plot_trace( else: ylocs = [ylim[0] for ylim in ylims] values = value[chain, div_idxs] - axes[idx, 1].plot( - div_draws, - np.zeros_like(div_idxs) + ylocs[1], - marker="|", - color="black", - markeredgewidth=1.5, - markersize=30, - linestyle="None", - alpha=hist_kwargs["alpha"], - zorder=-5, - ) - axes[idx, 1].set_ylim(*ylims[1]) + if kind == "trace": + axes[idx, 1].plot( + div_draws, + np.zeros_like(div_idxs) + ylocs[1], + marker="|", + color="black", + markeredgewidth=1.5, + markersize=30, + linestyle="None", + alpha=hist_kwargs["alpha"], + zorder=-5, + ) + axes[idx, 1].set_ylim(*ylims[1]) axes[idx, 0].plot( values, np.zeros_like(values) + ylocs[0], @@ -276,12 +288,18 @@ def plot_trace( "line-positions should be numeric, found {}".format(line_values) ) axes[idx, 0].vlines(line_values, *ylims[0], colors="black", linewidth=1.5, alpha=0.75) - axes[idx, 1].hlines( - line_values, *xlims[1], colors="black", linewidth=1.5, alpha=trace_kwargs["alpha"] - ) + if kind == "trace": + axes[idx, 1].hlines( + line_values, + *xlims[1], + colors="black", + linewidth=1.5, + alpha=trace_kwargs["alpha"] + ) axes[idx, 0].set_ylim(bottom=0, top=ylims[0][1]) - axes[idx, 1].set_xlim(left=data.draw.min(), right=data.draw.max()) - axes[idx, 1].set_ylim(*ylims[1]) + if kind == "trace": + axes[idx, 1].set_xlim(left=data.draw.min(), right=data.draw.max()) + axes[idx, 1].set_ylim(*ylims[1]) if legend: legend_kwargs = trace_kwargs if combined else plot_kwargs handles = [ @@ -295,7 +313,7 @@ def plot_trace( [], [], label="combined", **{chain_prop[0]: chain_prop[1][-1]}, **plot_kwargs ), ) - axes[0, 1].legend(handles=handles, title="chain") + axes[0, 0].legend(handles=handles, title="chain") if backend_show(show): plt.show() @@ -312,16 +330,19 @@ def _plot_chains_mpl( combined, xt_labelsize, rug, + kind, trace_kwargs, hist_kwargs, plot_kwargs, fill_kwargs, rug_kwargs, + rank_kwargs, ): for chain_idx, row in enumerate(value): - axes[idx, 1].plot( - data.draw.values, row, **{chain_prop[0]: chain_prop[1][chain_idx]}, **trace_kwargs - ) + if kind == "trace": + axes[idx, 1].plot( + data.draw.values, row, **{chain_prop[0]: chain_prop[1][chain_idx]}, **trace_kwargs + ) if not combined: plot_kwargs[chain_prop[0]] = chain_prop[1][chain_idx] @@ -339,6 +360,11 @@ def _plot_chains_mpl( ) plot_kwargs.pop(chain_prop[0]) + if kind == "rank_bars": + plot_rank(data=value, kind="bars", axes=axes[idx, 1], **rank_kwargs) + elif kind == "rank_vlines": + plot_rank(data=value, kind="vlines", axes=axes[idx, 1], **rank_kwargs) + if combined: plot_kwargs[chain_prop[0]] = chain_prop[1][-1] plot_dist( diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index e0628ece87..7618364fe8 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -23,6 +23,7 @@ def plot_trace( transform: Optional[Callable] = None, coords: Optional[CoordSpec] = None, divergences: Optional[str] = "bottom", + kind: Optional[str] = "trace", figsize: Optional[Tuple[float, float]] = None, rug: bool = False, lines: Optional[List[Tuple[str, CoordSpec, Any]]] = None, @@ -36,13 +37,14 @@ def plot_trace( rug_kwargs: Optional[KwargSpec] = None, hist_kwargs: Optional[KwargSpec] = None, trace_kwargs: Optional[KwargSpec] = None, + rank_kwargs: Optional[KwargSpec] = None, ax=None, backend: Optional[str] = None, backend_config: Optional[KwargSpec] = None, backend_kwargs: Optional[KwargSpec] = None, show: Optional[bool] = None, ): - """Plot distribution (histogram or kernel density estimates) and sampled values. + """Plot distribution (histogram or kernel density estimates) and sampled values or rank plot. If `divergences` data is available in `sample_stats`, will plot the location of divergences as dashed vertical lines. @@ -58,6 +60,8 @@ def plot_trace( Coordinates of var_names to be plotted. Passed to `Dataset.sel` divergences : {"bottom", "top", None}, optional Plot location of divergences on the traceplots. + kind : {"trace", "rank_bar", "rank_vlines"}, optional + Choose between plotting sampled values per iteration and rank plots. transform : callable, optional Function to transform data (defaults to None i.e.the identity function) figsize : tuple of (float, float), optional @@ -118,6 +122,13 @@ def plot_trace( >>> az.plot_trace(data, compact=True) + Display a rank plot instead of trace + + .. plot:: + :context: close-figs + + >>> az.plot_trace(data, var_names=["mu", "tau"], kind="rank_bars") + Combine all chains into one distribution .. plot:: @@ -135,6 +146,9 @@ def plot_trace( >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) """ + if kind not in {"trace", "rank_vlines", "rank_bars"}: + raise ValueError("The value of kind must be either trace, rank_vlines or rank_bars.") + if divergences: try: divergence_data = convert_to_dataset(data, group="sample_stats").diverging @@ -235,6 +249,8 @@ def plot_trace( fill_kwargs = {} if rug_kwargs is None: rug_kwargs = {} + if rank_kwargs is None: + rank_kwargs = {} # TODO: Check if this can be further simplified trace_plot_args = dict( @@ -243,6 +259,7 @@ def plot_trace( var_names=var_names, # coords = coords, divergences=divergences, + kind=kind, figsize=figsize, rug=rug, lines=lines, @@ -251,6 +268,7 @@ def plot_trace( rug_kwargs=rug_kwargs, hist_kwargs=hist_kwargs, trace_kwargs=trace_kwargs, + rank_kwargs=rank_kwargs, compact_prop=compact_prop, combined=combined, chain_prop=chain_prop, diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index b1452cba6e..9103d5510a 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -131,6 +131,8 @@ def test_plot_density_bad_kwargs(models): {"combined": True, "compact": True, "legend": True}, {"divergences": "top"}, {"divergences": False}, + {"kind": "rank_vlines"}, + {"kind": "rank_bars"}, {"lines": [("mu", {}, [1, 2])]}, {"lines": [("mu", {}, 8)]}, ], diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index d6143b6786..4760f8ff1a 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -144,6 +144,8 @@ def test_plot_density_bad_kwargs(models): {"combined": True, "compact": True, "legend": True}, {"divergences": "top", "legend": True}, {"divergences": False}, + {"kind": "rank_vlines"}, + {"kind": "rank_bars"}, {"lines": [("mu", {}, [1, 2])]}, {"lines": [("mu", {}, 8)]}, ], @@ -160,7 +162,7 @@ def test_plot_trace_legend(compact, combined): axes = plot_trace( idata, var_names=["home", "atts_star"], compact=compact, combined=combined, legend=True ) - assert axes[0, 1].get_legend() + assert axes[0, 0].get_legend() compact_legend = axes[1, 0].get_legend() if compact: assert axes.shape == (2, 2) diff --git a/examples/bokeh/bokeh_plot_trace_bars.py b/examples/bokeh/bokeh_plot_trace_bars.py new file mode 100644 index 0000000000..917a3b95e3 --- /dev/null +++ b/examples/bokeh/bokeh_plot_trace_bars.py @@ -0,0 +1,10 @@ +""" +Traceplot rank_bars Bokeh +=============== + +_thumb: .1, .8 +""" +import arviz as az + +data = az.load_arviz_data("non_centered_eight") +ax = az.plot_trace(data, var_names=("tau", "mu"), kind="rank_bars", backend="bokeh") diff --git a/examples/bokeh/bokeh_plot_trace_vlines.py b/examples/bokeh/bokeh_plot_trace_vlines.py new file mode 100644 index 0000000000..3ad32e03a3 --- /dev/null +++ b/examples/bokeh/bokeh_plot_trace_vlines.py @@ -0,0 +1,10 @@ +""" +Traceplot rank_vlines Bokeh +=============== + +_thumb: .1, .8 +""" +import arviz as az + +data = az.load_arviz_data("non_centered_eight") +ax = az.plot_trace(data, var_names=("tau", "mu"), kind="rank_vlines", backend="bokeh") diff --git a/examples/matplotlib/mpl_plot_trace_bars.py b/examples/matplotlib/mpl_plot_trace_bars.py new file mode 100644 index 0000000000..be9b11fc15 --- /dev/null +++ b/examples/matplotlib/mpl_plot_trace_bars.py @@ -0,0 +1,15 @@ +""" +Traceplot rank_bars +========= + +_thumb: .1, .8 +""" +import matplotlib.pyplot as plt +import arviz as az + +az.style.use("arviz-darkgrid") + +data = az.load_arviz_data("non_centered_eight") +az.plot_trace(data, var_names=("tau", "mu"), kind="rank_bars") + +plt.show() diff --git a/examples/matplotlib/mpl_plot_trace_vlines.py b/examples/matplotlib/mpl_plot_trace_vlines.py new file mode 100644 index 0000000000..850afbf1c7 --- /dev/null +++ b/examples/matplotlib/mpl_plot_trace_vlines.py @@ -0,0 +1,15 @@ +""" +Traceplot rank_vlines +========= + +_thumb: .1, .8 +""" +import matplotlib.pyplot as plt +import arviz as az + +az.style.use("arviz-darkgrid") + +data = az.load_arviz_data("non_centered_eight") +az.plot_trace(data, var_names=("tau", "mu"), kind="rank_vlines") + +plt.show()