Skip to content

Commit

Permalink
Added possibility for multiple ropes
Browse files Browse the repository at this point in the history
Initial commit arviz-devs#448 completely rewritten.
Rope can be a list of 2 or a dict like in posteriorplot() passed to the rope argument
Argument rope_values writes ROPE values in plot when multiple ropes are given.
  • Loading branch information
GWeindel authored and ahartikainen committed Jan 14, 2019
1 parent 7b81d2e commit c2fb25d
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def plot_forest(
var_names=None,
combined=False,
credible_interval=0.94,
quartiles=True,
rope=None,
rope_values=True,
quartiles=True,
eff_n=False,
r_hat=False,
colors="cycle",
Expand All @@ -37,8 +38,7 @@ def plot_forest(
markersize=None,
ridgeplot_alpha=None,
ridgeplot_overlap=2,
rope_alpha=.5,
figsize=None
figsize=None,
):
"""Forest plot to compare credible intervals from a number of distributions.
Expand All @@ -63,11 +63,14 @@ def plot_forest(
chains will be plotted separately.
credible_interval : float, optional
Credible interval to plot. Defaults to 0.94.
rope: tuple or dictionary of tuples
Lower and upper values of the Region Of Practical Equivalence. If a list with one interval only is provided, the ROPE will be displayed across the y-axis. If more than one interval is provided the
length of the list should match the number of variables.
rope_values : bool, optional
Flag for plotting lower and upper values of the Region Of Practical Equivalence as text
quartiles : bool, optional
Flag for plotting the interquartile range, in addition to the credible_interval intervals.
Defaults to True
rope: tuple
Lower and upper values of the Region Of Practical Equivalence defined for all displayed variables
r_hat : bool, optional
Flag for plotting Split R-hat statistics. Requires 2 or more chains. Defaults to False
eff_n : bool, optional
Expand All @@ -89,8 +92,6 @@ def plot_forest(
a black outline is used.
ridgeplot_overlap : float
Overlap height for ridgeplots.
rope_alpha : float
Transparency for rope interval.
figsize : tuple
Figure size. If None it will be defined automatically.
Expand Down Expand Up @@ -140,8 +141,7 @@ def plot_forest(
axes = np.atleast_1d(axes)
if kind == "forestplot":
plot_handler.forestplot(
credible_interval, quartiles, xt_labelsize, titlesize, linewidth,
markersize, axes[0], rope, rope_alpha
credible_interval, quartiles, xt_labelsize, titlesize, linewidth, markersize, axes[0], rope, rope_values
)
elif kind == "ridgeplot":
plot_handler.ridgeplot(ridgeplot_overlap, xt_labelsize, linewidth, ridgeplot_alpha, axes[0])
Expand Down Expand Up @@ -246,6 +246,23 @@ def labels_and_ticks(self):
idxs.append(sub_idxs)
return np.concatenate(labels), np.concatenate(idxs)

def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var,markersize, rope_values,xt_labelsize):
vals = dict(rope[rope_var][0])["rope"]
ax.plot(
vals,
(y+.05, y+.05),
lw=linewidth * 2,
color="C2",
solid_capstyle="round",
zorder=0,
alpha=0.7,
)
text_props = {"size": xt_labelsize, "horizontalalignment": "center", "color": "C2"}
if rope_values:
ax.text(vals[0], y+.08, vals[0])#, **text_props)
ax.text(vals[1], y+.08, vals[1])#, **text_props)


def ridgeplot(self, mult, xt_labelsize, linewidth, alpha, ax):
"""Draw ridgeplot for each plotter.
Expand Down Expand Up @@ -281,7 +298,7 @@ def ridgeplot(self, mult, xt_labelsize, linewidth, alpha, ax):
return ax

def forestplot(
self, credible_interval, quartiles, xt_labelsize, titlesize, linewidth, markersize, ax, rope, rope_alpha
self, credible_interval, quartiles, xt_labelsize, titlesize, linewidth, markersize, ax, rope, rope_values
):
"""Draw forestplot for each plotter.
Expand Down Expand Up @@ -309,6 +326,7 @@ def forestplot(
else:
qlist = [endpoint, 50, 100 - endpoint]

label, ticks = self.labels_and_ticks()
for plotter in self.plotters.values():
for y, values, color in plotter.treeplot(qlist, credible_interval):
mid = len(values) // 2
Expand All @@ -325,11 +343,18 @@ def forestplot(
markersize=markersize * 0.75,
color=color,
)
if rope is not None and len(rope) == 2:
ax.axvspan(rope[0], rope[1], 0, len(values), color='C2', alpha=rope_alpha)
elif rope is not None:
raise ValueError('Argument `rope` must be None or an '
'iterable of length 2')
if isinstance(rope, dict):
self.display_multiple_ropes(rope, ax, y, linewidth, label[ticks==y][0],markersize, rope_values,xt_labelsize)
if rope is None or isinstance(rope, dict):
return
elif len(rope) == 2:
ax.axvspan(rope[0], rope[1], 0, self.y_max(), color='C2', alpha=.5)
else:
raise ValueError(
"Argument `rope` must be None, a dictionary like"
'{"var_name": {"rope": (lo, hi)}}, or an '
"iterable of length 2"
)
ax.tick_params(labelsize=xt_labelsize)
ax.set_title(
"{:.1%} Credible Interval".format(credible_interval), fontsize=titlesize, wrap=True
Expand Down

0 comments on commit c2fb25d

Please sign in to comment.