diff --git a/CHANGELOG.md b/CHANGELOG.md index 62b636b9e4..0069a12513 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ### New features * Added `labeller` argument to enable label customization in plots and summary ([1201](https://github.com/arviz-devs/arviz/pull/1201)) * Added `arviz.labels` module with classes and utilities ([1201](https://github.com/arviz-devs/arviz/pull/1201)) +* Added probability estimate within ROPE in `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570)) +* Added `rope_color` and `ref_val_color` arguments to `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570)) ### Maintenance and fixes * Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201)) diff --git a/arviz/plots/backends/bokeh/posteriorplot.py b/arviz/plots/backends/bokeh/posteriorplot.py index 324e3d6e61..60eac3ebbb 100644 --- a/arviz/plots/backends/bokeh/posteriorplot.py +++ b/arviz/plots/backends/bokeh/posteriorplot.py @@ -13,6 +13,7 @@ calculate_point_estimate, format_sig_figs, round_num, + vectorized_to_hex, ) from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -37,6 +38,8 @@ def plot_posterior( textsize, ref_val, rope, + ref_val_color, + rope_color, labeller, kwargs, backend_kwargs, @@ -87,6 +90,8 @@ def plot_posterior( linewidth=linewidth, ref_val=ref_val, rope=rope, + ref_val_color=ref_val_color, + rope_color=rope_color, ax_labelsize=ax_labelsize, **kwargs, ) @@ -117,6 +122,8 @@ def _plot_posterior_op( skipna, ref_val, rope, + ref_val_color, + rope_color, ax_labelsize, round_to: Optional[int] = None, **kwargs, @@ -155,9 +162,20 @@ def display_ref_val(max_data): val, format_as_percent(greater_than_ref_probability, 1), ) - ax.line([val, val], [0, 0.8 * max_data], line_color="blue", line_alpha=0.65) + ax.line( + [val, val], + [0, 0.8 * max_data], + line_color=vectorized_to_hex(ref_val_color), + line_alpha=0.65, + ) - ax.text(x=[values.mean()], y=[max_data * 0.6], text=[ref_in_posterior], text_align="center") + ax.text( + x=[values.mean()], + y=[max_data * 0.6], + text=[ref_in_posterior], + text_color=vectorized_to_hex(ref_val_color), + text_align="center", + ) def display_rope(max_data): if rope is None: @@ -185,15 +203,28 @@ def display_rope(max_data): vals, (max_data * 0.02, max_data * 0.02), line_width=linewidth * 5, - line_color="red", + line_color=vectorized_to_hex(rope_color), line_alpha=0.7, ) - + probability_within_rope = ((values > vals[0]) & (values <= vals[1])).mean() text_props = dict( - text_font_size="{}pt".format(ax_labelsize), text_color="black", text_align="center" + text_color=vectorized_to_hex(rope_color), + text_align="center", + ) + ax.text( + x=values.mean(), + y=[max_data * 0.45], + text=[f"{format_as_percent(probability_within_rope, 1)} in ROPE"], + **text_props, ) - ax.text(x=vals, y=[max_data * 0.2, max_data * 0.2], text=rope_text, **text_props) + ax.text( + x=vals, + y=[max_data * 0.2, max_data * 0.2], + text_font_size="{}pt".format(ax_labelsize), + text=rope_text, + **text_props, + ) def display_point_estimate(max_data): if not point_estimate: diff --git a/arviz/plots/backends/matplotlib/posteriorplot.py b/arviz/plots/backends/matplotlib/posteriorplot.py index 2b0c286880..2c6a38731b 100644 --- a/arviz/plots/backends/matplotlib/posteriorplot.py +++ b/arviz/plots/backends/matplotlib/posteriorplot.py @@ -12,6 +12,7 @@ calculate_point_estimate, format_sig_figs, round_num, + vectorized_to_hex, ) from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser @@ -35,6 +36,8 @@ def plot_posterior( textsize, ref_val, rope, + ref_val_color, + rope_color, labeller, kwargs, backend_kwargs, @@ -87,6 +90,8 @@ def plot_posterior( skipna=skipna, ref_val=ref_val, rope=rope, + ref_val_color=ref_val_color, + rope_color=rope_color, ax_labelsize=ax_labelsize, xt_labelsize=xt_labelsize, **kwargs, @@ -119,6 +124,8 @@ def _plot_posterior_op( skipna, ref_val, rope, + ref_val_color, + rope_color, ax_labelsize, xt_labelsize, round_to=None, @@ -158,15 +165,22 @@ def display_ref_val(): val, format_as_percent(greater_than_ref_probability, 1), ) - ax.axvline(val, ymin=0.05, ymax=0.75, color="C1", lw=linewidth, alpha=0.65) + ax.axvline( + val, + ymin=0.05, + ymax=0.75, + lw=linewidth, + alpha=0.65, + color=vectorized_to_hex(ref_val_color), + ) ax.text( values.mean(), plot_height * 0.6, ref_in_posterior, size=ax_labelsize, - color="C1", weight="semibold", horizontalalignment="center", + color=vectorized_to_hex(ref_val_color), ) def display_rope(): @@ -190,24 +204,33 @@ def display_rope(): "iterable of length 2" ) rope_text = [f"{val:.{format_sig_figs(val, round_to)}g}" for val in vals] - ax.plot( vals, (plot_height * 0.02, plot_height * 0.02), lw=linewidth * 5, - color="C2", solid_capstyle="butt", zorder=0, alpha=0.7, + color=vectorized_to_hex(rope_color), + ) + probability_within_rope = ((values > vals[0]) & (values <= vals[1])).mean() + ax.text( + values.mean(), + plot_height * 0.45, + f"{format_as_percent(probability_within_rope, 1)} in ROPE", + weight="semibold", + horizontalalignment="center", + size=ax_labelsize, + color=vectorized_to_hex(rope_color), ) - text_props = {"size": ax_labelsize, "color": "C2"} ax.text( vals[0], plot_height * 0.2, rope_text[0], weight="semibold", horizontalalignment="right", - **text_props, + size=ax_labelsize, + color=vectorized_to_hex(rope_color), ) ax.text( vals[1], @@ -215,7 +238,8 @@ def display_rope(): rope_text[1], weight="semibold", horizontalalignment="left", - **text_props, + size=ax_labelsize, + color=vectorized_to_hex(rope_color), ) def display_point_estimate(): diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index 47f2f4be97..982aa87d88 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -24,6 +24,8 @@ def plot_posterior( group="posterior", rope=None, ref_val=None, + rope_color="C2", + ref_val_color="C1", kind="kde", bw="default", circular=False, @@ -84,6 +86,10 @@ def plot_posterior( display the percentage below and above the values in ref_val. Must be None (default), a constant, a list or a dictionary like see an example below. If a list is provided, its length should match the number of variables. + rope_color: str, optional + Specifies the color of ROPE and displayed percentage within ROPE + ref_val_color: str, optional + Specifies the color of the displayed percentage kind: str Type of plot to display (kde or hist) For discrete variables this argument is ignored and a histogram is always used. @@ -255,6 +261,8 @@ def plot_posterior( textsize=textsize, ref_val=ref_val, rope=rope, + ref_val_color=ref_val_color, + rope_color=rope_color, labeller=labeller, kwargs=kwargs, backend_kwargs=backend_kwargs,