Skip to content

Commit

Permalink
fix multi rope for forest plot (#1390)
Browse files Browse the repository at this point in the history
* fix multi rope

* update tests and chagelog

Co-authored-by: Ari Hartikainen <ahartikainen@users.noreply.github.com>
  • Loading branch information
aloctavodia and ahartikainen authored Sep 21, 2020
1 parent 1b9a194 commit 652302f
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
* Use method="average" in `scipy.stats.rankdata` ([#1380](https://github.com/arviz-devs/arviz/pull/1380))
* Add more `plot_parallel` examples ([#1380](https://github.com/arviz-devs/arviz/pull/1380))
* Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389)
* Fix multi rope for `plot_forest` ([#1390](https://github.com/arviz-devs/arviz/pull/1390))

### Deprecation

Expand Down
49 changes: 28 additions & 21 deletions arviz/plots/backends/bokeh/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,25 @@ def label_idxs():

return label_idxs()

def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var):
def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection):
"""Display ROPE when more than one interval is provided."""
vals = dict(rope[rope_var][0])["rope"]
ax.line(
vals,
(y + 0.05, y + 0.05),
line_width=linewidth * 2,
color=[
color
for _, color in zip(
range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
for sel in rope.get(var_name, []):
# pylint: disable=line-too-long
if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
vals = sel["rope"]
ax.line(
vals,
(y + 0.05, y + 0.05),
line_width=linewidth * 2,
color=[
color
for _, color in zip(
range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
)
][2],
line_alpha=0.7,
)
][2],
line_alpha=0.7,
)
return ax
return ax

def ridgeplot(
self,
Expand Down Expand Up @@ -452,9 +455,9 @@ def forestplot(self, hdi_prob, quartiles, linewidth, markersize, ax, rope):
qlist = [endpoint, 50, 100 - endpoint]

for plotter in self.plotters.values():
for y, rope_var, values, color in plotter.treeplot(qlist, hdi_prob):
for y, selection, values, color in plotter.treeplot(qlist, hdi_prob):
if isinstance(rope, dict):
self.display_multiple_ropes(rope, ax, y, linewidth, rope_var)
self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection)

mid = len(values) // 2
param_iter = zip(
Expand Down Expand Up @@ -560,6 +563,7 @@ def iterator(self):
skip_dims = set()

label_dict = OrderedDict()
selection_list = []
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
Expand All @@ -569,6 +573,7 @@ def iterator(self):
reverse_selections=True,
)
for _, selection, values in datum_iter:
selection_list.append(selection)
label = make_label(self.var_name, selection, position="beside")
if label not in label_dict:
label_dict[label] = OrderedDict()
Expand All @@ -577,22 +582,24 @@ def iterator(self):
label_dict[label][name].append(values)

y = self.y_start
for label, model_data in label_dict.items():
for idx, (label, model_data) in enumerate(label_dict.items()):
for model_name, value_list in model_data.items():
if model_name:
row_label = "{}: {}".format(model_name, label)
else:
row_label = label
for values in value_list:
yield y, row_label, label, values, self.model_color[model_name]
yield y, row_label, label, selection_list[idx], values, self.model_color[
model_name
]
y += self.chain_offset
y += self.var_offset
y += self.group_offset

def labels_ticks_and_vals(self):
"""Get labels, ticks, values, and colors for the variable."""
y_ticks = defaultdict(list)
for y, label, _, vals, color in self.iterator():
for y, label, _, _, vals, color in self.iterator():
y_ticks[label].append((y, vals, color))
labels, ticks, vals, colors = [], [], [], []
for label, data in y_ticks.items():
Expand All @@ -604,10 +611,10 @@ def labels_ticks_and_vals(self):

def treeplot(self, qlist, hdi_prob):
"""Get data for each treeplot for the variable."""
for y, _, label, values, color in self.iterator():
for y, _, _, selection, values, color in self.iterator():
ntiles = np.percentile(values.flatten(), qlist)
ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
yield y, label, ntiles, color
yield y, selection, ntiles, color

def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
"""Get data for each ridgeplot for the variable."""
Expand Down
45 changes: 26 additions & 19 deletions arviz/plots/backends/matplotlib/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,22 @@ def label_idxs():

return label_idxs()

def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var):
def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection):
"""Display ROPE when more than one interval is provided."""
vals = dict(rope[rope_var][0])["rope"]
ax.plot(
vals,
(y + 0.05, y + 0.05),
lw=linewidth * 2,
color="C2",
solid_capstyle="round",
zorder=0,
alpha=0.7,
)
return ax
for sel in rope.get(var_name, []):
# pylint: disable=line-too-long
if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
vals = sel["rope"]
ax.plot(
vals,
(y + 0.05, y + 0.05),
lw=linewidth * 2,
color="C2",
solid_capstyle="round",
zorder=0,
alpha=0.7,
)
return ax

def ridgeplot(
self,
Expand Down Expand Up @@ -376,9 +379,9 @@ def forestplot(
qlist = [endpoint, 50, 100 - endpoint]

for plotter in self.plotters.values():
for y, rope_var, values, color in plotter.treeplot(qlist, hdi_prob):
for y, selection, values, color in plotter.treeplot(qlist, hdi_prob):
if isinstance(rope, dict):
self.display_multiple_ropes(rope, ax, y, linewidth, rope_var)
self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection)

mid = len(values) // 2
param_iter = zip(
Expand Down Expand Up @@ -502,6 +505,7 @@ def iterator(self):
skip_dims = set()

label_dict = OrderedDict()
selection_list = []
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
Expand All @@ -511,6 +515,7 @@ def iterator(self):
reverse_selections=True,
)
for _, selection, values in datum_iter:
selection_list.append(selection)
label = make_label(self.var_name, selection, position="beside")
if label not in label_dict:
label_dict[label] = OrderedDict()
Expand All @@ -519,22 +524,24 @@ def iterator(self):
label_dict[label][name].append(values)

y = self.y_start
for label, model_data in label_dict.items():
for idx, (label, model_data) in enumerate(label_dict.items()):
for model_name, value_list in model_data.items():
if model_name:
row_label = "{}: {}".format(model_name, label)
else:
row_label = label
for values in value_list:
yield y, row_label, label, values, self.model_color[model_name]
yield y, row_label, label, selection_list[idx], values, self.model_color[
model_name
]
y += self.chain_offset
y += self.var_offset
y += self.group_offset

def labels_ticks_and_vals(self):
"""Get labels, ticks, values, and colors for the variable."""
y_ticks = defaultdict(list)
for y, label, _, vals, color in self.iterator():
for y, label, _, _, vals, color in self.iterator():
y_ticks[label].append((y, vals, color))
labels, ticks, vals, colors = [], [], [], []
for label, data in y_ticks.items():
Expand All @@ -546,10 +553,10 @@ def labels_ticks_and_vals(self):

def treeplot(self, qlist, hdi_prob):
"""Get data for each treeplot for the variable."""
for y, _, label, values, color in self.iterator():
for y, _, _, selection, values, color in self.iterator():
ntiles = np.percentile(values.flatten(), qlist)
ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
yield y, label, ntiles, color
yield y, selection, ntiles, color

def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
"""Get data for each ridgeplot for the variable."""
Expand Down
15 changes: 14 additions & 1 deletion arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,23 @@ def plot_forest(
>>> var_names=["^the"],
>>> filter_vars="regex",
>>> combined=True,
>>> ridgeplot_overlap=3,
>>> figsize=(9, 7))
>>> axes[0].set_title('Estimated theta for 8 schools model')
Forestpĺot with ropes
.. plot::
:context: close-figs
>>> rope = {'theta': [{'school': 'Choate', 'rope': (2, 4)}], 'mu': [{'rope': (-2, 2)}]}
>>> axes = az.plot_forest(non_centered_data,
>>> rope=rope,
>>> var_names='~tau',
>>> combined=True,
>>> figsize=(9, 7))
>>> axes[0].set_title('Estimated theta for 8 schools model')
Ridgeplot
.. plot::
Expand Down
5 changes: 4 additions & 1 deletion arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,10 @@ def test_plot_ess_no_divergences(models):
(
{
"var_names": ["mu", "tau"],
"rope": {"mu": [{"rope": (-0.1, 0.1)}], "tau": [{"rope": (0.2, 0.5)}]},
"rope": {
"mu": [{"rope": (-0.1, 0.1)}],
"theta": [{"school": "Choate", "rope": (0.2, 0.5)}],
},
},
1,
),
Expand Down
7 changes: 5 additions & 2 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,11 @@ def test_plot_trace_futurewarning(models, prop):
({"kind": "ridgeplot", "r_hat": True, "ess": True, "ridgeplot_alpha": 0}, 3),
(
{
"var_names": ["mu", "tau"],
"rope": {"mu": [{"rope": (-0.1, 0.1)}], "tau": [{"rope": (0.2, 0.5)}]},
"var_names": ["mu", "theta"],
"rope": {
"mu": [{"rope": (-0.1, 0.1)}],
"theta": [{"school": "Choate", "rope": (0.2, 0.5)}],
},
},
1,
),
Expand Down

0 comments on commit 652302f

Please sign in to comment.