diff --git a/CHANGELOG.md b/CHANGELOG.md index 004b823e83..114500620e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ * Add `skipna` argument to `hpd` and `summary` (#1035) * Added `transform` argument to `plot_trace`, `plot_forest`, `plot_pair`, `plot_posterior`, `plot_rank`, `plot_parallel`, `plot_violin`,`plot_density`, `plot_joint` (#1036) * Add `marker` functionality to `bokeh_plot_elpd` (#1040) +* Added the functionality [interactive legends](https://docs.bokeh.org/en/1.4.0/docs/user_guide/interaction/legends.html) for bokeh plots of `densityplot`, `energyplot` + and `essplot` (#1024) ### Maintenance and fixes diff --git a/arviz/plots/backends/bokeh/densityplot.py b/arviz/plots/backends/bokeh/densityplot.py index b5874370f4..6495db154a 100644 --- a/arviz/plots/backends/bokeh/densityplot.py +++ b/arviz/plots/backends/bokeh/densityplot.py @@ -1,8 +1,9 @@ """Bokeh Densityplot.""" -import bokeh.plotting as bkp +from collections import defaultdict import numpy as np +import bokeh.plotting as bkp from bokeh.layouts import gridplot -from bokeh.models.annotations import Title +from bokeh.models.annotations import Title, Legend from . import backend_kwarg_defaults, backend_show from ...plot_utils import ( @@ -66,18 +67,17 @@ def plot_density( if data_labels is None: data_labels = {} + legend_items = defaultdict(list) for m_idx, plotters in enumerate(to_plot): - for ax_idx, (var_name, selection, values) in enumerate(plotters): + for var_name, selection, values in plotters: label = make_label(var_name, selection) if data_labels: data_label = data_labels[m_idx] - if ax_idx != 0 or data_label == "": - data_label = None else: data_label = None - _d_helper( + plotted = _d_helper( values.flatten(), label, colors[m_idx], @@ -90,8 +90,14 @@ def plot_density( outline, shade, axis_map[label], - data_label=data_label, ) + if data_label is not None: + legend_items[axis_map[label]].append((data_label, plotted)) + + for ax1, legend in legend_items.items(): + legend = Legend(items=legend, location="center_right", orientation="horizontal",) + ax1.add_layout(legend, "above") + ax1.legend.click_policy = "hide" if backend_show(show): grid = gridplot(ax.tolist(), toolbar_location="above") @@ -113,11 +119,10 @@ def _d_helper( outline, shade, ax, - data_label, ): + extra = dict() - if data_label is not None: - extra["legend_label"] = data_label + plotted = [] if vec.dtype.kind == "f": if credible_interval != 1: @@ -133,29 +138,41 @@ def _d_helper( ymax = density[-1] if outline: - ax.line(x, density, line_color=color, line_width=line_width, **extra) - ax.line( - [xmin, xmin], - [-ymin / 100, ymin], - line_color=color, - line_dash="solid", - line_width=line_width, + plotted.append(ax.line(x, density, line_color=color, line_width=line_width, **extra)) + plotted.append( + ax.line( + [xmin, xmin], + [-ymin / 100, ymin], + line_color=color, + line_dash="solid", + line_width=line_width, + muted_color=color, + muted_alpha=0.2, + ) ) - ax.line( - [xmax, xmax], - [-ymax / 100, ymax], - line_color=color, - line_dash="solid", - line_width=line_width, + plotted.append( + ax.line( + [xmax, xmax], + [-ymax / 100, ymax], + line_color=color, + line_dash="solid", + line_width=line_width, + muted_color=color, + muted_alpha=0.2, + ) ) if shade: - ax.patch( - np.r_[x[::-1], x, x[-1:]], - np.r_[np.zeros_like(x), density, [0]], - fill_color=color, - fill_alpha=shade, - **extra + plotted.append( + ax.patch( + np.r_[x[::-1], x, x[-1:]], + np.r_[np.zeros_like(x), density, [0]], + fill_color=color, + fill_alpha=shade, + muted_color=color, + muted_alpha=0.2, + **extra + ) ) else: @@ -165,35 +182,46 @@ def _d_helper( _, hist, edges = histogram(vec, bins=bins) if outline: - ax.quad( - top=hist, - bottom=0, - left=edges[:-1], - right=edges[1:], - line_color=color, - fill_color=None, - **extra + plotted.append( + ax.quad( + top=hist, + bottom=0, + left=edges[:-1], + right=edges[1:], + line_color=color, + fill_color=None, + muted_color=color, + muted_alpha=0.2, + **extra + ) ) else: - ax.quad( - top=hist, - bottom=0, - left=edges[:-1], - right=edges[1:], - line_color=color, - fill_color=color, - fill_alpha=shade, - **extra + plotted.append( + ax.quad( + top=hist, + bottom=0, + left=edges[:-1], + right=edges[1:], + line_color=color, + fill_color=color, + fill_alpha=shade, + muted_color=color, + muted_alpha=0.2, + **extra + ) ) if hpd_markers: - ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize) - ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize) + plotted.append(ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize)) + plotted.append(ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize)) if point_estimate is not None: est = calculate_point_estimate(point_estimate, vec, bw) - ax.circle(est, 0, fill_color=color, line_color="black", size=markersize) + plotted.append(ax.circle(est, 0, fill_color=color, line_color="black", size=markersize)) _title = Title() _title.text = vname ax.title = _title + ax.title.text_font_size = "13pt" + + return plotted diff --git a/arviz/plots/backends/bokeh/energyplot.py b/arviz/plots/backends/bokeh/energyplot.py index 6f8a33dda2..e2dd5433df 100644 --- a/arviz/plots/backends/bokeh/energyplot.py +++ b/arviz/plots/backends/bokeh/energyplot.py @@ -1,6 +1,7 @@ """Bokeh energyplot.""" import bokeh.plotting as bkp from bokeh.models import Label +from bokeh.models.annotations import Legend from . import backend_kwarg_defaults, backend_show from .distplot import _histplot_bokeh_op @@ -39,6 +40,7 @@ def plot_energy( if ax is None: ax = bkp.figure(width=int(figsize[0] * dpi), height=int(figsize[1] * dpi), **backend_kwargs) + labels = [] if kind == "kde": for alpha, color, label, value in series: fill_kwargs["fill_alpha"] = alpha @@ -46,7 +48,7 @@ def plot_energy( plot_kwargs["line_color"] = color plot_kwargs["line_alpha"] = alpha plot_kwargs.setdefault("line_width", line_width) - plot_kde( + _, glyph = plot_kde( value, bw=bw, label=label, @@ -57,7 +59,10 @@ def plot_energy( backend="bokeh", backend_kwargs={}, show=False, + return_glyph=True, ) + labels.append((label, glyph,)) + elif kind in {"hist", "histogram"}: hist_kwargs = plot_kwargs.copy() hist_kwargs.update(**fill_kwargs) @@ -78,7 +83,7 @@ def plot_energy( for idx, val in enumerate(e_bfmi(energy)): bfmi_info = Label( x=int(figsize[0] * dpi * 0.58), - y=int(figsize[1] * dpi * 0.83) - 20 * idx, + y=int(figsize[1] * dpi * 0.73) - 20 * idx, x_units="screen", y_units="screen", text="chain {:>2} BFMI = {:.2f}".format(idx, val), @@ -91,8 +96,9 @@ def plot_energy( ax.add_layout(bfmi_info) - if legend: - ax.legend.location = "top_left" + if legend and label is not None: + legend = Legend(items=labels, location="center_right", orientation="horizontal",) + ax.add_layout(legend, "above") ax.legend.click_policy = "hide" if backend_show(show): diff --git a/arviz/plots/backends/bokeh/essplot.py b/arviz/plots/backends/bokeh/essplot.py index 69a3b66f04..4c00dc97a9 100644 --- a/arviz/plots/backends/bokeh/essplot.py +++ b/arviz/plots/backends/bokeh/essplot.py @@ -4,7 +4,7 @@ import numpy as np from bokeh.layouts import gridplot from bokeh.models import Dash, Span, ColumnDataSource -from bokeh.models.annotations import Title +from bokeh.models.annotations import Title, Legend from scipy.stats import rankdata from . import backend_kwarg_defaults, backend_show @@ -74,12 +74,12 @@ def plot_ess( for (var_name, selection, x), ax_ in zip( plotters, (item for item in ax.flatten() if item is not None) ): - ax_.circle(np.asarray(xdata), np.asarray(x), size=6) + bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6) if kind == "evolution": - ax_.line(np.asarray(xdata), np.asarray(x), legend_label="bulk") + bulk_line = ax_.line(np.asarray(xdata), np.asarray(x)) ess_tail = ess_tail_dataset[var_name].sel(**selection) - ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange", legend_label="tail") - ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange") + tail_points = ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange") + tail_line = ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange") elif rug: if rug_kwargs is None: rug_kwargs = {} @@ -153,6 +153,15 @@ def plot_ess( ax_.renderers.append(hline) + if kind == "evolution": + legend = Legend( + items=[("bulk", [bulk_points, bulk_line]), ("tail", [tail_line, tail_points])], + location="center_right", + orientation="horizontal", + ) + ax_.add_layout(legend, "above") + ax_.legend.click_policy = "hide" + title = Title() title.text = make_label(var_name, selection) ax_.title = title diff --git a/arviz/plots/backends/bokeh/kdeplot.py b/arviz/plots/backends/bokeh/kdeplot.py index 851fd92d22..142168d2a5 100644 --- a/arviz/plots/backends/bokeh/kdeplot.py +++ b/arviz/plots/backends/bokeh/kdeplot.py @@ -27,7 +27,6 @@ def plot_kde( values, values2, rug, - label, quantiles, rotated, contour, @@ -39,9 +38,9 @@ def plot_kde( contourf_kwargs, pcolormesh_kwargs, ax, - legend, backend_kwargs, show, + return_glyph, ): """Bokeh kde plot.""" if backend_kwargs is None: @@ -59,9 +58,7 @@ def plot_kde( if ax is None: ax = bkp.figure(**backend_kwargs) - if legend and label is not None: - plot_kwargs["legend_label"] = label - + glyphs = [] if values2 is None: if plot_kwargs is None: plot_kwargs = {} @@ -103,6 +100,7 @@ def plot_kde( else: glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs) ax.add_glyph(cds_rug, glyph) + glyphs.append(glyph) x = np.linspace(lower, upper, len(density)) @@ -124,9 +122,10 @@ def plot_kde( (np.zeros_like(density[idx]), [density[idx][-1]], density[idx][::-1], [0]) ) if not rotated: - ax.patch(patch_x, patch_y, **fill_kwargs) + patch = ax.patch(patch_x, patch_y, **fill_kwargs) else: - ax.patch(patch_y, patch_x, **fill_kwargs) + patch = ax.patch(patch_y, patch_x, **fill_kwargs) + glyphs.append(patch) else: if fill_kwargs.get("fill_alpha", False): patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]])) @@ -134,14 +133,16 @@ def plot_kde( (np.zeros_like(density), [density[-1]], density[::-1], [0]) ) if not rotated: - ax.patch(patch_x, patch_y, **fill_kwargs) + patch = ax.patch(patch_x, patch_y, **fill_kwargs) else: - ax.patch(patch_y, patch_x, **fill_kwargs) + patch = ax.patch(patch_y, patch_x, **fill_kwargs) + glyphs.append(patch) if not rotated: - ax.line(x, density, **plot_kwargs) + line = ax.line(x, density, **plot_kwargs) else: - ax.line(density, x, **plot_kwargs) + line = ax.line(density, x, **plot_kwargs) + glyphs.append(line) else: if contour_kwargs is None: @@ -196,7 +197,8 @@ def plot_kde( continue vertices, _ = contour_generator.create_filled_contour(level, level_upper) for seg in vertices: - ax.patch(*seg.T, fill_color=color, **contour_kwargs) + patch = ax.patch(*seg.T, fill_color=color, **contour_kwargs) + glyphs.append(patch) if fill_last: ax.background_fill_color = colors[0] @@ -217,7 +219,7 @@ def plot_kde( else: colors = cmap - ax.image( + image = ax.image( image=[density.T], x=xmin, y=ymin, @@ -226,10 +228,15 @@ def plot_kde( palette=colors, **pcolormesh_kwargs ) + glyphs.append(image) ax.x_range.range_padding = ax.y_range.range_padding = 0 if backend_show(show): bkp.show(ax, toolbar_location="above") + + if return_glyph: + return ax, glyphs + return ax diff --git a/arviz/plots/kdeplot.py b/arviz/plots/kdeplot.py index f1e62ce058..eec8d56b2b 100644 --- a/arviz/plots/kdeplot.py +++ b/arviz/plots/kdeplot.py @@ -29,6 +29,7 @@ def plot_kde( backend=None, backend_kwargs=None, show=None, + return_glyph=False, **kwargs ): """1D or 2D KDE plot taking into account boundary conditions. @@ -88,10 +89,15 @@ def plot_kde( check the plotting method of the backend. show : bool, optional Call backend show function. + return_glyph : bool, optional + Internal argument to return glyphs for bokeh Returns ------- - axes : matplotlib axes or bokeh figures + axes : matplotlib.Axes or bokeh.plotting.Figure + Object containing the kde plot + glyphs : list, optional + Bokeh glyphs present in plot. Only provided if ``return_glyph`` is True. Examples -------- @@ -209,11 +215,16 @@ def plot_kde( legend=legend, backend_kwargs=backend_kwargs, show=show, + return_glyph=return_glyph, **kwargs, ) if backend == "bokeh": kde_plot_args.pop("textsize") + kde_plot_args.pop("label") + kde_plot_args.pop("legend") + else: + kde_plot_args.pop("return_glyph") # TODO: Add backend kwargs plot = get_plotting_function("plot_kde", "kdeplot", backend)