diff --git a/CHANGELOG.md b/CHANGELOG.md
index 004b823e83..114500620e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -12,6 +12,8 @@
 * Add `skipna` argument to `hpd` and `summary` (#1035)
 * Added `transform` argument to `plot_trace`, `plot_forest`, `plot_pair`, `plot_posterior`, `plot_rank`, `plot_parallel`,  `plot_violin`,`plot_density`, `plot_joint` (#1036)
 * Add `marker` functionality to `bokeh_plot_elpd` (#1040)
+* Added the functionality [interactive legends](https://docs.bokeh.org/en/1.4.0/docs/user_guide/interaction/legends.html) for bokeh plots of `densityplot`, `energyplot` 
+  and `essplot` (#1024)
 
 
 ### Maintenance and fixes
diff --git a/arviz/plots/backends/bokeh/densityplot.py b/arviz/plots/backends/bokeh/densityplot.py
index b5874370f4..6495db154a 100644
--- a/arviz/plots/backends/bokeh/densityplot.py
+++ b/arviz/plots/backends/bokeh/densityplot.py
@@ -1,8 +1,9 @@
 """Bokeh Densityplot."""
-import bokeh.plotting as bkp
+from collections import defaultdict
 import numpy as np
+import bokeh.plotting as bkp
 from bokeh.layouts import gridplot
-from bokeh.models.annotations import Title
+from bokeh.models.annotations import Title, Legend
 
 from . import backend_kwarg_defaults, backend_show
 from ...plot_utils import (
@@ -66,18 +67,17 @@ def plot_density(
     if data_labels is None:
         data_labels = {}
 
+    legend_items = defaultdict(list)
     for m_idx, plotters in enumerate(to_plot):
-        for ax_idx, (var_name, selection, values) in enumerate(plotters):
+        for var_name, selection, values in plotters:
             label = make_label(var_name, selection)
 
             if data_labels:
                 data_label = data_labels[m_idx]
-                if ax_idx != 0 or data_label == "":
-                    data_label = None
             else:
                 data_label = None
 
-            _d_helper(
+            plotted = _d_helper(
                 values.flatten(),
                 label,
                 colors[m_idx],
@@ -90,8 +90,14 @@ def plot_density(
                 outline,
                 shade,
                 axis_map[label],
-                data_label=data_label,
             )
+            if data_label is not None:
+                legend_items[axis_map[label]].append((data_label, plotted))
+
+    for ax1, legend in legend_items.items():
+        legend = Legend(items=legend, location="center_right", orientation="horizontal",)
+        ax1.add_layout(legend, "above")
+        ax1.legend.click_policy = "hide"
 
     if backend_show(show):
         grid = gridplot(ax.tolist(), toolbar_location="above")
@@ -113,11 +119,10 @@ def _d_helper(
     outline,
     shade,
     ax,
-    data_label,
 ):
+
     extra = dict()
-    if data_label is not None:
-        extra["legend_label"] = data_label
+    plotted = []
 
     if vec.dtype.kind == "f":
         if credible_interval != 1:
@@ -133,29 +138,41 @@ def _d_helper(
         ymax = density[-1]
 
         if outline:
-            ax.line(x, density, line_color=color, line_width=line_width, **extra)
-            ax.line(
-                [xmin, xmin],
-                [-ymin / 100, ymin],
-                line_color=color,
-                line_dash="solid",
-                line_width=line_width,
+            plotted.append(ax.line(x, density, line_color=color, line_width=line_width, **extra))
+            plotted.append(
+                ax.line(
+                    [xmin, xmin],
+                    [-ymin / 100, ymin],
+                    line_color=color,
+                    line_dash="solid",
+                    line_width=line_width,
+                    muted_color=color,
+                    muted_alpha=0.2,
+                )
             )
-            ax.line(
-                [xmax, xmax],
-                [-ymax / 100, ymax],
-                line_color=color,
-                line_dash="solid",
-                line_width=line_width,
+            plotted.append(
+                ax.line(
+                    [xmax, xmax],
+                    [-ymax / 100, ymax],
+                    line_color=color,
+                    line_dash="solid",
+                    line_width=line_width,
+                    muted_color=color,
+                    muted_alpha=0.2,
+                )
             )
 
         if shade:
-            ax.patch(
-                np.r_[x[::-1], x, x[-1:]],
-                np.r_[np.zeros_like(x), density, [0]],
-                fill_color=color,
-                fill_alpha=shade,
-                **extra
+            plotted.append(
+                ax.patch(
+                    np.r_[x[::-1], x, x[-1:]],
+                    np.r_[np.zeros_like(x), density, [0]],
+                    fill_color=color,
+                    fill_alpha=shade,
+                    muted_color=color,
+                    muted_alpha=0.2,
+                    **extra
+                )
             )
 
     else:
@@ -165,35 +182,46 @@ def _d_helper(
         _, hist, edges = histogram(vec, bins=bins)
 
         if outline:
-            ax.quad(
-                top=hist,
-                bottom=0,
-                left=edges[:-1],
-                right=edges[1:],
-                line_color=color,
-                fill_color=None,
-                **extra
+            plotted.append(
+                ax.quad(
+                    top=hist,
+                    bottom=0,
+                    left=edges[:-1],
+                    right=edges[1:],
+                    line_color=color,
+                    fill_color=None,
+                    muted_color=color,
+                    muted_alpha=0.2,
+                    **extra
+                )
             )
         else:
-            ax.quad(
-                top=hist,
-                bottom=0,
-                left=edges[:-1],
-                right=edges[1:],
-                line_color=color,
-                fill_color=color,
-                fill_alpha=shade,
-                **extra
+            plotted.append(
+                ax.quad(
+                    top=hist,
+                    bottom=0,
+                    left=edges[:-1],
+                    right=edges[1:],
+                    line_color=color,
+                    fill_color=color,
+                    fill_alpha=shade,
+                    muted_color=color,
+                    muted_alpha=0.2,
+                    **extra
+                )
             )
 
     if hpd_markers:
-        ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize)
-        ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize)
+        plotted.append(ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize))
+        plotted.append(ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize))
 
     if point_estimate is not None:
         est = calculate_point_estimate(point_estimate, vec, bw)
-        ax.circle(est, 0, fill_color=color, line_color="black", size=markersize)
+        plotted.append(ax.circle(est, 0, fill_color=color, line_color="black", size=markersize))
 
     _title = Title()
     _title.text = vname
     ax.title = _title
+    ax.title.text_font_size = "13pt"
+
+    return plotted
diff --git a/arviz/plots/backends/bokeh/energyplot.py b/arviz/plots/backends/bokeh/energyplot.py
index 6f8a33dda2..e2dd5433df 100644
--- a/arviz/plots/backends/bokeh/energyplot.py
+++ b/arviz/plots/backends/bokeh/energyplot.py
@@ -1,6 +1,7 @@
 """Bokeh energyplot."""
 import bokeh.plotting as bkp
 from bokeh.models import Label
+from bokeh.models.annotations import Legend
 
 from . import backend_kwarg_defaults, backend_show
 from .distplot import _histplot_bokeh_op
@@ -39,6 +40,7 @@ def plot_energy(
     if ax is None:
         ax = bkp.figure(width=int(figsize[0] * dpi), height=int(figsize[1] * dpi), **backend_kwargs)
 
+    labels = []
     if kind == "kde":
         for alpha, color, label, value in series:
             fill_kwargs["fill_alpha"] = alpha
@@ -46,7 +48,7 @@ def plot_energy(
             plot_kwargs["line_color"] = color
             plot_kwargs["line_alpha"] = alpha
             plot_kwargs.setdefault("line_width", line_width)
-            plot_kde(
+            _, glyph = plot_kde(
                 value,
                 bw=bw,
                 label=label,
@@ -57,7 +59,10 @@ def plot_energy(
                 backend="bokeh",
                 backend_kwargs={},
                 show=False,
+                return_glyph=True,
             )
+            labels.append((label, glyph,))
+
     elif kind in {"hist", "histogram"}:
         hist_kwargs = plot_kwargs.copy()
         hist_kwargs.update(**fill_kwargs)
@@ -78,7 +83,7 @@ def plot_energy(
         for idx, val in enumerate(e_bfmi(energy)):
             bfmi_info = Label(
                 x=int(figsize[0] * dpi * 0.58),
-                y=int(figsize[1] * dpi * 0.83) - 20 * idx,
+                y=int(figsize[1] * dpi * 0.73) - 20 * idx,
                 x_units="screen",
                 y_units="screen",
                 text="chain {:>2} BFMI = {:.2f}".format(idx, val),
@@ -91,8 +96,9 @@ def plot_energy(
 
             ax.add_layout(bfmi_info)
 
-    if legend:
-        ax.legend.location = "top_left"
+    if legend and label is not None:
+        legend = Legend(items=labels, location="center_right", orientation="horizontal",)
+        ax.add_layout(legend, "above")
         ax.legend.click_policy = "hide"
 
     if backend_show(show):
diff --git a/arviz/plots/backends/bokeh/essplot.py b/arviz/plots/backends/bokeh/essplot.py
index 69a3b66f04..4c00dc97a9 100644
--- a/arviz/plots/backends/bokeh/essplot.py
+++ b/arviz/plots/backends/bokeh/essplot.py
@@ -4,7 +4,7 @@
 import numpy as np
 from bokeh.layouts import gridplot
 from bokeh.models import Dash, Span, ColumnDataSource
-from bokeh.models.annotations import Title
+from bokeh.models.annotations import Title, Legend
 from scipy.stats import rankdata
 
 from . import backend_kwarg_defaults, backend_show
@@ -74,12 +74,12 @@ def plot_ess(
     for (var_name, selection, x), ax_ in zip(
         plotters, (item for item in ax.flatten() if item is not None)
     ):
-        ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
+        bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
         if kind == "evolution":
-            ax_.line(np.asarray(xdata), np.asarray(x), legend_label="bulk")
+            bulk_line = ax_.line(np.asarray(xdata), np.asarray(x))
             ess_tail = ess_tail_dataset[var_name].sel(**selection)
-            ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange", legend_label="tail")
-            ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange")
+            tail_points = ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange")
+            tail_line = ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange")
         elif rug:
             if rug_kwargs is None:
                 rug_kwargs = {}
@@ -153,6 +153,15 @@ def plot_ess(
 
         ax_.renderers.append(hline)
 
+        if kind == "evolution":
+            legend = Legend(
+                items=[("bulk", [bulk_points, bulk_line]), ("tail", [tail_line, tail_points])],
+                location="center_right",
+                orientation="horizontal",
+            )
+            ax_.add_layout(legend, "above")
+            ax_.legend.click_policy = "hide"
+
         title = Title()
         title.text = make_label(var_name, selection)
         ax_.title = title
diff --git a/arviz/plots/backends/bokeh/kdeplot.py b/arviz/plots/backends/bokeh/kdeplot.py
index 851fd92d22..142168d2a5 100644
--- a/arviz/plots/backends/bokeh/kdeplot.py
+++ b/arviz/plots/backends/bokeh/kdeplot.py
@@ -27,7 +27,6 @@ def plot_kde(
     values,
     values2,
     rug,
-    label,
     quantiles,
     rotated,
     contour,
@@ -39,9 +38,9 @@ def plot_kde(
     contourf_kwargs,
     pcolormesh_kwargs,
     ax,
-    legend,
     backend_kwargs,
     show,
+    return_glyph,
 ):
     """Bokeh kde plot."""
     if backend_kwargs is None:
@@ -59,9 +58,7 @@ def plot_kde(
     if ax is None:
         ax = bkp.figure(**backend_kwargs)
 
-    if legend and label is not None:
-        plot_kwargs["legend_label"] = label
-
+    glyphs = []
     if values2 is None:
         if plot_kwargs is None:
             plot_kwargs = {}
@@ -103,6 +100,7 @@ def plot_kde(
                 else:
                     glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs)
                 ax.add_glyph(cds_rug, glyph)
+            glyphs.append(glyph)
 
         x = np.linspace(lower, upper, len(density))
 
@@ -124,9 +122,10 @@ def plot_kde(
                         (np.zeros_like(density[idx]), [density[idx][-1]], density[idx][::-1], [0])
                     )
                     if not rotated:
-                        ax.patch(patch_x, patch_y, **fill_kwargs)
+                        patch = ax.patch(patch_x, patch_y, **fill_kwargs)
                     else:
-                        ax.patch(patch_y, patch_x, **fill_kwargs)
+                        patch = ax.patch(patch_y, patch_x, **fill_kwargs)
+                    glyphs.append(patch)
         else:
             if fill_kwargs.get("fill_alpha", False):
                 patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]]))
@@ -134,14 +133,16 @@ def plot_kde(
                     (np.zeros_like(density), [density[-1]], density[::-1], [0])
                 )
                 if not rotated:
-                    ax.patch(patch_x, patch_y, **fill_kwargs)
+                    patch = ax.patch(patch_x, patch_y, **fill_kwargs)
                 else:
-                    ax.patch(patch_y, patch_x, **fill_kwargs)
+                    patch = ax.patch(patch_y, patch_x, **fill_kwargs)
+                glyphs.append(patch)
 
             if not rotated:
-                ax.line(x, density, **plot_kwargs)
+                line = ax.line(x, density, **plot_kwargs)
             else:
-                ax.line(density, x, **plot_kwargs)
+                line = ax.line(density, x, **plot_kwargs)
+            glyphs.append(line)
 
     else:
         if contour_kwargs is None:
@@ -196,7 +197,8 @@ def plot_kde(
                     continue
                 vertices, _ = contour_generator.create_filled_contour(level, level_upper)
                 for seg in vertices:
-                    ax.patch(*seg.T, fill_color=color, **contour_kwargs)
+                    patch = ax.patch(*seg.T, fill_color=color, **contour_kwargs)
+                    glyphs.append(patch)
 
             if fill_last:
                 ax.background_fill_color = colors[0]
@@ -217,7 +219,7 @@ def plot_kde(
             else:
                 colors = cmap
 
-            ax.image(
+            image = ax.image(
                 image=[density.T],
                 x=xmin,
                 y=ymin,
@@ -226,10 +228,15 @@ def plot_kde(
                 palette=colors,
                 **pcolormesh_kwargs
             )
+            glyphs.append(image)
             ax.x_range.range_padding = ax.y_range.range_padding = 0
 
     if backend_show(show):
         bkp.show(ax, toolbar_location="above")
+
+    if return_glyph:
+        return ax, glyphs
+
     return ax
 
 
diff --git a/arviz/plots/kdeplot.py b/arviz/plots/kdeplot.py
index f1e62ce058..eec8d56b2b 100644
--- a/arviz/plots/kdeplot.py
+++ b/arviz/plots/kdeplot.py
@@ -29,6 +29,7 @@ def plot_kde(
     backend=None,
     backend_kwargs=None,
     show=None,
+    return_glyph=False,
     **kwargs
 ):
     """1D or 2D KDE plot taking into account boundary conditions.
@@ -88,10 +89,15 @@ def plot_kde(
         check the plotting method of the backend.
     show : bool, optional
         Call backend show function.
+    return_glyph : bool, optional
+        Internal argument to return glyphs for bokeh
 
     Returns
     -------
-    axes : matplotlib axes or bokeh figures
+    axes : matplotlib.Axes or bokeh.plotting.Figure
+        Object containing the kde plot
+    glyphs : list, optional
+        Bokeh glyphs present in plot.  Only provided if ``return_glyph`` is True.
 
     Examples
     --------
@@ -209,11 +215,16 @@ def plot_kde(
         legend=legend,
         backend_kwargs=backend_kwargs,
         show=show,
+        return_glyph=return_glyph,
         **kwargs,
     )
 
     if backend == "bokeh":
         kde_plot_args.pop("textsize")
+        kde_plot_args.pop("label")
+        kde_plot_args.pop("legend")
+    else:
+        kde_plot_args.pop("return_glyph")
 
     # TODO: Add backend kwargs
     plot = get_plotting_function("plot_kde", "kdeplot", backend)