Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the legends interactive in bokeh #1024

Merged
merged 20 commits into from
Feb 13, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
* Add `skipna` argument to `hpd` and `summary` (#1035)
* Added `transform` argument to `plot_trace`, `plot_forest`, `plot_pair`, `plot_posterior`, `plot_rank`, `plot_parallel`, `plot_violin`,`plot_density`, `plot_joint` (#1036)
* Add `marker` functionality to `bokeh_plot_elpd` (#1040)
* Add the functionality for `interactive legends` for bokeh plots of `densityplot`, `energyplot` and `essplot` (#1024)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe link to bokeh docs on interactive pegends instead of formatting it like code?



### Maintenance and fixes
Expand Down
127 changes: 79 additions & 48 deletions arviz/plots/backends/bokeh/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import bokeh.plotting as bkp
import numpy as np
from bokeh.layouts import gridplot
from bokeh.models.annotations import Title
from bokeh.models.annotations import Title, Legend

from . import backend_kwarg_defaults, backend_show
from ...plot_utils import (
Expand Down Expand Up @@ -66,18 +66,17 @@ def plot_density(
if data_labels is None:
data_labels = {}

legend_items = {}
for m_idx, plotters in enumerate(to_plot):
for ax_idx, (var_name, selection, values) in enumerate(plotters):
for var_name, selection, values in plotters:
label = make_label(var_name, selection)

if data_labels:
data_label = data_labels[m_idx]
if ax_idx != 0 or data_label == "":
data_label = None
else:
data_label = None

_d_helper(
plotted = _d_helper(
values.flatten(),
label,
colors[m_idx],
Expand All @@ -90,8 +89,18 @@ def plot_density(
outline,
shade,
axis_map[label],
data_label=data_label,
)
if data_label is not None:
if axis_map[label] in legend_items.keys():
legend_items[axis_map[label]].append((data_label, plotted))
else:
legend_items[axis_map[label]] = []
legend_items[axis_map[label]].append((data_label, plotted))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could legend items be a defaultdict?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will make the code neat. Thanks, for the suggestion.


for ax1, legend in legend_items.items():
legend = Legend(items=legend, location="center_right", orientation="horizontal",)
ax1.add_layout(legend, "above")
ax1.legend.click_policy = "hide"

if backend_show(show):
grid = gridplot(ax.tolist(), toolbar_location="above")
Expand All @@ -113,11 +122,10 @@ def _d_helper(
outline,
shade,
ax,
data_label,
):

extra = dict()
if data_label is not None:
extra["legend_label"] = data_label
plotted = []

if vec.dtype.kind == "f":
if credible_interval != 1:
Expand All @@ -133,29 +141,41 @@ def _d_helper(
ymax = density[-1]

if outline:
ax.line(x, density, line_color=color, line_width=line_width, **extra)
ax.line(
[xmin, xmin],
[-ymin / 100, ymin],
line_color=color,
line_dash="solid",
line_width=line_width,
plotted.append(ax.line(x, density, line_color=color, line_width=line_width, **extra))
plotted.append(
ax.line(
[xmin, xmin],
[-ymin / 100, ymin],
line_color=color,
line_dash="solid",
line_width=line_width,
muted_color=color,
muted_alpha=0.2,
)
)
ax.line(
[xmax, xmax],
[-ymax / 100, ymax],
line_color=color,
line_dash="solid",
line_width=line_width,
plotted.append(
ax.line(
[xmax, xmax],
[-ymax / 100, ymax],
line_color=color,
line_dash="solid",
line_width=line_width,
muted_color=color,
muted_alpha=0.2,
)
)

if shade:
ax.patch(
np.r_[x[::-1], x, x[-1:]],
np.r_[np.zeros_like(x), density, [0]],
fill_color=color,
fill_alpha=shade,
**extra
plotted.append(
ax.patch(
np.r_[x[::-1], x, x[-1:]],
np.r_[np.zeros_like(x), density, [0]],
fill_color=color,
fill_alpha=shade,
muted_color=color,
muted_alpha=0.2,
**extra
)
)

else:
Expand All @@ -165,35 +185,46 @@ def _d_helper(
_, hist, edges = histogram(vec, bins=bins)

if outline:
ax.quad(
top=hist,
bottom=0,
left=edges[:-1],
right=edges[1:],
line_color=color,
fill_color=None,
**extra
plotted.append(
ax.quad(
top=hist,
bottom=0,
left=edges[:-1],
right=edges[1:],
line_color=color,
fill_color=None,
muted_color=color,
muted_alpha=0.2,
**extra
)
)
else:
ax.quad(
top=hist,
bottom=0,
left=edges[:-1],
right=edges[1:],
line_color=color,
fill_color=color,
fill_alpha=shade,
**extra
plotted.append(
ax.quad(
top=hist,
bottom=0,
left=edges[:-1],
right=edges[1:],
line_color=color,
fill_color=color,
fill_alpha=shade,
muted_color=color,
muted_alpha=0.2,
**extra
)
)

if hpd_markers:
ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize)
ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize)
plotted.append(ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize))
plotted.append(ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize))

if point_estimate is not None:
est = calculate_point_estimate(point_estimate, vec, bw)
ax.circle(est, 0, fill_color=color, line_color="black", size=markersize)
plotted.append(ax.circle(est, 0, fill_color=color, line_color="black", size=markersize))

_title = Title()
_title.text = vname
ax.title = _title
ax.title.text_font_size = "13pt"

return plotted
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 10 additions & 4 deletions arviz/plots/backends/bokeh/energyplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Bokeh energyplot."""
import bokeh.plotting as bkp
from bokeh.models import Label
from bokeh.models.annotations import Legend

from . import backend_kwarg_defaults, backend_show
from .distplot import _histplot_bokeh_op
Expand Down Expand Up @@ -39,14 +40,15 @@ def plot_energy(
if ax is None:
ax = bkp.figure(width=int(figsize[0] * dpi), height=int(figsize[1] * dpi), **backend_kwargs)

labels = []
if kind == "kde":
for alpha, color, label, value in series:
fill_kwargs["fill_alpha"] = alpha
fill_kwargs["fill_color"] = color
plot_kwargs["line_color"] = color
plot_kwargs["line_alpha"] = alpha
plot_kwargs.setdefault("line_width", line_width)
plot_kde(
_, glyph = plot_kde(
value,
bw=bw,
label=label,
Expand All @@ -57,7 +59,10 @@ def plot_energy(
backend="bokeh",
backend_kwargs={},
show=False,
return_glyph=True,
)
labels.append((label, glyph,))

elif kind in {"hist", "histogram"}:
hist_kwargs = plot_kwargs.copy()
hist_kwargs.update(**fill_kwargs)
Expand All @@ -78,7 +83,7 @@ def plot_energy(
for idx, val in enumerate(e_bfmi(energy)):
bfmi_info = Label(
x=int(figsize[0] * dpi * 0.58),
y=int(figsize[1] * dpi * 0.83) - 20 * idx,
y=int(figsize[1] * dpi * 0.73) - 20 * idx,
x_units="screen",
y_units="screen",
text="chain {:>2} BFMI = {:.2f}".format(idx, val),
Expand All @@ -91,8 +96,9 @@ def plot_energy(

ax.add_layout(bfmi_info)

if legend:
ax.legend.location = "top_left"
if legend and label is not None:
legend = Legend(items=labels, location="center_right", orientation="horizontal",)
ax.add_layout(legend, "above")
ax.legend.click_policy = "hide"

if backend_show(show):
Expand Down
19 changes: 14 additions & 5 deletions arviz/plots/backends/bokeh/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from bokeh.layouts import gridplot
from bokeh.models import Dash, Span, ColumnDataSource
from bokeh.models.annotations import Title
from bokeh.models.annotations import Title, Legend
from scipy.stats import rankdata

from . import backend_kwarg_defaults, backend_show
Expand Down Expand Up @@ -74,12 +74,12 @@ def plot_ess(
for (var_name, selection, x), ax_ in zip(
plotters, (item for item in ax.flatten() if item is not None)
):
ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
if kind == "evolution":
ax_.line(np.asarray(xdata), np.asarray(x), legend_label="bulk")
bulk_line = ax_.line(np.asarray(xdata), np.asarray(x))
ess_tail = ess_tail_dataset[var_name].sel(**selection)
ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange", legend_label="tail")
ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange")
tail_points = ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange")
tail_line = ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange")
elif rug:
if rug_kwargs is None:
rug_kwargs = {}
Expand Down Expand Up @@ -153,6 +153,15 @@ def plot_ess(

ax_.renderers.append(hline)

if kind == "evolution":
legend = Legend(
items=[("bulk", [bulk_points, bulk_line]), ("tail", [tail_line, tail_points])],
location="center_right",
orientation="horizontal",
)
ax_.add_layout(legend, "above")
ax_.legend.click_policy = "hide"

title = Title()
title.text = make_label(var_name, selection)
ax_.title = title
Expand Down
Loading