diff --git a/CHANGELOG.md b/CHANGELOG.md index 78452b5cf7..38ab4593af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ * Use method="average" in `scipy.stats.rankdata` ([#1380](https://github.com/arviz-devs/arviz/pull/1380)) * Add more `plot_parallel` examples ([#1380](https://github.com/arviz-devs/arviz/pull/1380)) * Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389) +* Fix multi rope for `plot_forest` ([#1390](https://github.com/arviz-devs/arviz/pull/1390)) ### Deprecation diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 3c674477de..1924ec6fea 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -266,22 +266,25 @@ def label_idxs(): return label_idxs() - def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var): + def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection): """Display ROPE when more than one interval is provided.""" - vals = dict(rope[rope_var][0])["rope"] - ax.line( - vals, - (y + 0.05, y + 0.05), - line_width=linewidth * 2, - color=[ - color - for _, color in zip( - range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) + for sel in rope.get(var_name, []): + # pylint: disable=line-too-long + if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"): + vals = sel["rope"] + ax.line( + vals, + (y + 0.05, y + 0.05), + line_width=linewidth * 2, + color=[ + color + for _, color in zip( + range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) + ) + ][2], + line_alpha=0.7, ) - ][2], - line_alpha=0.7, - ) - return ax + return ax def ridgeplot( self, @@ -452,9 +455,9 @@ def forestplot(self, hdi_prob, quartiles, linewidth, markersize, ax, rope): qlist = [endpoint, 50, 100 - endpoint] for plotter in self.plotters.values(): - for y, rope_var, values, color in plotter.treeplot(qlist, hdi_prob): + for y, selection, values, color in plotter.treeplot(qlist, hdi_prob): if isinstance(rope, dict): - self.display_multiple_ropes(rope, ax, y, linewidth, rope_var) + self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection) mid = len(values) // 2 param_iter = zip( @@ -560,6 +563,7 @@ def iterator(self): skip_dims = set() label_dict = OrderedDict() + selection_list = [] for name, grouped_datum in zip(self.model_names, grouped_data): for _, sub_data in grouped_datum: datum_iter = xarray_var_iter( @@ -569,6 +573,7 @@ def iterator(self): reverse_selections=True, ) for _, selection, values in datum_iter: + selection_list.append(selection) label = make_label(self.var_name, selection, position="beside") if label not in label_dict: label_dict[label] = OrderedDict() @@ -577,14 +582,16 @@ def iterator(self): label_dict[label][name].append(values) y = self.y_start - for label, model_data in label_dict.items(): + for idx, (label, model_data) in enumerate(label_dict.items()): for model_name, value_list in model_data.items(): if model_name: row_label = "{}: {}".format(model_name, label) else: row_label = label for values in value_list: - yield y, row_label, label, values, self.model_color[model_name] + yield y, row_label, label, selection_list[idx], values, self.model_color[ + model_name + ] y += self.chain_offset y += self.var_offset y += self.group_offset @@ -592,7 +599,7 @@ def iterator(self): def labels_ticks_and_vals(self): """Get labels, ticks, values, and colors for the variable.""" y_ticks = defaultdict(list) - for y, label, _, vals, color in self.iterator(): + for y, label, _, _, vals, color in self.iterator(): y_ticks[label].append((y, vals, color)) labels, ticks, vals, colors = [], [], [], [] for label, data in y_ticks.items(): @@ -604,10 +611,10 @@ def labels_ticks_and_vals(self): def treeplot(self, qlist, hdi_prob): """Get data for each treeplot for the variable.""" - for y, _, label, values, color in self.iterator(): + for y, _, _, selection, values, color in self.iterator(): ntiles = np.percentile(values.flatten(), qlist) ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False) - yield y, label, ntiles, color + yield y, selection, ntiles, color def ridgeplot(self, hdi_prob, mult, ridgeplot_kind): """Get data for each ridgeplot for the variable.""" diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index 569a9cf23a..d77c4c4bc8 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -230,19 +230,22 @@ def label_idxs(): return label_idxs() - def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var): + def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection): """Display ROPE when more than one interval is provided.""" - vals = dict(rope[rope_var][0])["rope"] - ax.plot( - vals, - (y + 0.05, y + 0.05), - lw=linewidth * 2, - color="C2", - solid_capstyle="round", - zorder=0, - alpha=0.7, - ) - return ax + for sel in rope.get(var_name, []): + # pylint: disable=line-too-long + if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"): + vals = sel["rope"] + ax.plot( + vals, + (y + 0.05, y + 0.05), + lw=linewidth * 2, + color="C2", + solid_capstyle="round", + zorder=0, + alpha=0.7, + ) + return ax def ridgeplot( self, @@ -376,9 +379,9 @@ def forestplot( qlist = [endpoint, 50, 100 - endpoint] for plotter in self.plotters.values(): - for y, rope_var, values, color in plotter.treeplot(qlist, hdi_prob): + for y, selection, values, color in plotter.treeplot(qlist, hdi_prob): if isinstance(rope, dict): - self.display_multiple_ropes(rope, ax, y, linewidth, rope_var) + self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection) mid = len(values) // 2 param_iter = zip( @@ -502,6 +505,7 @@ def iterator(self): skip_dims = set() label_dict = OrderedDict() + selection_list = [] for name, grouped_datum in zip(self.model_names, grouped_data): for _, sub_data in grouped_datum: datum_iter = xarray_var_iter( @@ -511,6 +515,7 @@ def iterator(self): reverse_selections=True, ) for _, selection, values in datum_iter: + selection_list.append(selection) label = make_label(self.var_name, selection, position="beside") if label not in label_dict: label_dict[label] = OrderedDict() @@ -519,14 +524,16 @@ def iterator(self): label_dict[label][name].append(values) y = self.y_start - for label, model_data in label_dict.items(): + for idx, (label, model_data) in enumerate(label_dict.items()): for model_name, value_list in model_data.items(): if model_name: row_label = "{}: {}".format(model_name, label) else: row_label = label for values in value_list: - yield y, row_label, label, values, self.model_color[model_name] + yield y, row_label, label, selection_list[idx], values, self.model_color[ + model_name + ] y += self.chain_offset y += self.var_offset y += self.group_offset @@ -534,7 +541,7 @@ def iterator(self): def labels_ticks_and_vals(self): """Get labels, ticks, values, and colors for the variable.""" y_ticks = defaultdict(list) - for y, label, _, vals, color in self.iterator(): + for y, label, _, _, vals, color in self.iterator(): y_ticks[label].append((y, vals, color)) labels, ticks, vals, colors = [], [], [], [] for label, data in y_ticks.items(): @@ -546,10 +553,10 @@ def labels_ticks_and_vals(self): def treeplot(self, qlist, hdi_prob): """Get data for each treeplot for the variable.""" - for y, _, label, values, color in self.iterator(): + for y, _, _, selection, values, color in self.iterator(): ntiles = np.percentile(values.flatten(), qlist) ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False) - yield y, label, ntiles, color + yield y, selection, ntiles, color def ridgeplot(self, hdi_prob, mult, ridgeplot_kind): """Get data for each ridgeplot for the variable.""" diff --git a/arviz/plots/forestplot.py b/arviz/plots/forestplot.py index 9e29b8f9f5..7a906ee7d8 100644 --- a/arviz/plots/forestplot.py +++ b/arviz/plots/forestplot.py @@ -136,10 +136,23 @@ def plot_forest( >>> var_names=["^the"], >>> filter_vars="regex", >>> combined=True, - >>> ridgeplot_overlap=3, >>> figsize=(9, 7)) >>> axes[0].set_title('Estimated theta for 8 schools model') + Forestpĺot with ropes + + .. plot:: + :context: close-figs + + >>> rope = {'theta': [{'school': 'Choate', 'rope': (2, 4)}], 'mu': [{'rope': (-2, 2)}]} + >>> axes = az.plot_forest(non_centered_data, + >>> rope=rope, + >>> var_names='~tau', + >>> combined=True, + >>> figsize=(9, 7)) + >>> axes[0].set_title('Estimated theta for 8 schools model') + + Ridgeplot .. plot:: diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 1cada4a1a8..94d07a8884 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -490,7 +490,10 @@ def test_plot_ess_no_divergences(models): ( { "var_names": ["mu", "tau"], - "rope": {"mu": [{"rope": (-0.1, 0.1)}], "tau": [{"rope": (0.2, 0.5)}]}, + "rope": { + "mu": [{"rope": (-0.1, 0.1)}], + "theta": [{"school": "Choate", "rope": (0.2, 0.5)}], + }, }, 1, ), diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 6e9c1f4fcb..7189980592 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -279,8 +279,11 @@ def test_plot_trace_futurewarning(models, prop): ({"kind": "ridgeplot", "r_hat": True, "ess": True, "ridgeplot_alpha": 0}, 3), ( { - "var_names": ["mu", "tau"], - "rope": {"mu": [{"rope": (-0.1, 0.1)}], "tau": [{"rope": (0.2, 0.5)}]}, + "var_names": ["mu", "theta"], + "rope": { + "mu": [{"rope": (-0.1, 0.1)}], + "theta": [{"school": "Choate", "rope": (0.2, 0.5)}], + }, }, 1, ),