Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix multi rope for forest plot #1390

Merged
merged 3 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
* Use method="average" in `scipy.stats.rankdata` ([#1380](https://github.com/arviz-devs/arviz/pull/1380))
* Add more `plot_parallel` examples ([#1380](https://github.com/arviz-devs/arviz/pull/1380))
* Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389)
* Fix multi rope for `plot_forest` ([#1390](https://github.com/arviz-devs/arviz/pull/1390))

### Deprecation

Expand Down
49 changes: 28 additions & 21 deletions arviz/plots/backends/bokeh/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,25 @@ def label_idxs():

return label_idxs()

def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var):
def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection):
"""Display ROPE when more than one interval is provided."""
vals = dict(rope[rope_var][0])["rope"]
ax.line(
vals,
(y + 0.05, y + 0.05),
line_width=linewidth * 2,
color=[
color
for _, color in zip(
range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
for sel in rope.get(var_name, []):
# pylint: disable=line-too-long
if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
vals = sel["rope"]
ax.line(
vals,
(y + 0.05, y + 0.05),
line_width=linewidth * 2,
color=[
color
for _, color in zip(
range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
)
][2],
line_alpha=0.7,
)
][2],
line_alpha=0.7,
)
return ax
return ax

def ridgeplot(
self,
Expand Down Expand Up @@ -452,9 +455,9 @@ def forestplot(self, hdi_prob, quartiles, linewidth, markersize, ax, rope):
qlist = [endpoint, 50, 100 - endpoint]

for plotter in self.plotters.values():
for y, rope_var, values, color in plotter.treeplot(qlist, hdi_prob):
for y, selection, values, color in plotter.treeplot(qlist, hdi_prob):
if isinstance(rope, dict):
self.display_multiple_ropes(rope, ax, y, linewidth, rope_var)
self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection)

mid = len(values) // 2
param_iter = zip(
Expand Down Expand Up @@ -560,6 +563,7 @@ def iterator(self):
skip_dims = set()

label_dict = OrderedDict()
selection_list = []
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
Expand All @@ -569,6 +573,7 @@ def iterator(self):
reverse_selections=True,
)
for _, selection, values in datum_iter:
selection_list.append(selection)
label = make_label(self.var_name, selection, position="beside")
if label not in label_dict:
label_dict[label] = OrderedDict()
Expand All @@ -577,22 +582,24 @@ def iterator(self):
label_dict[label][name].append(values)

y = self.y_start
for label, model_data in label_dict.items():
for idx, (label, model_data) in enumerate(label_dict.items()):
for model_name, value_list in model_data.items():
if model_name:
row_label = "{}: {}".format(model_name, label)
else:
row_label = label
for values in value_list:
yield y, row_label, label, values, self.model_color[model_name]
yield y, row_label, label, selection_list[idx], values, self.model_color[
model_name
]
y += self.chain_offset
y += self.var_offset
y += self.group_offset

def labels_ticks_and_vals(self):
"""Get labels, ticks, values, and colors for the variable."""
y_ticks = defaultdict(list)
for y, label, _, vals, color in self.iterator():
for y, label, _, _, vals, color in self.iterator():
y_ticks[label].append((y, vals, color))
labels, ticks, vals, colors = [], [], [], []
for label, data in y_ticks.items():
Expand All @@ -604,10 +611,10 @@ def labels_ticks_and_vals(self):

def treeplot(self, qlist, hdi_prob):
"""Get data for each treeplot for the variable."""
for y, _, label, values, color in self.iterator():
for y, _, _, selection, values, color in self.iterator():
ntiles = np.percentile(values.flatten(), qlist)
ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
yield y, label, ntiles, color
yield y, selection, ntiles, color

def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
"""Get data for each ridgeplot for the variable."""
Expand Down
45 changes: 26 additions & 19 deletions arviz/plots/backends/matplotlib/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,22 @@ def label_idxs():

return label_idxs()

def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var):
def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection):
"""Display ROPE when more than one interval is provided."""
vals = dict(rope[rope_var][0])["rope"]
ax.plot(
vals,
(y + 0.05, y + 0.05),
lw=linewidth * 2,
color="C2",
solid_capstyle="round",
zorder=0,
alpha=0.7,
)
return ax
for sel in rope.get(var_name, []):
# pylint: disable=line-too-long
if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
vals = sel["rope"]
ax.plot(
vals,
(y + 0.05, y + 0.05),
lw=linewidth * 2,
color="C2",
solid_capstyle="round",
zorder=0,
alpha=0.7,
)
return ax

def ridgeplot(
self,
Expand Down Expand Up @@ -376,9 +379,9 @@ def forestplot(
qlist = [endpoint, 50, 100 - endpoint]

for plotter in self.plotters.values():
for y, rope_var, values, color in plotter.treeplot(qlist, hdi_prob):
for y, selection, values, color in plotter.treeplot(qlist, hdi_prob):
if isinstance(rope, dict):
self.display_multiple_ropes(rope, ax, y, linewidth, rope_var)
self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection)

mid = len(values) // 2
param_iter = zip(
Expand Down Expand Up @@ -502,6 +505,7 @@ def iterator(self):
skip_dims = set()

label_dict = OrderedDict()
selection_list = []
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
Expand All @@ -511,6 +515,7 @@ def iterator(self):
reverse_selections=True,
)
for _, selection, values in datum_iter:
selection_list.append(selection)
label = make_label(self.var_name, selection, position="beside")
if label not in label_dict:
label_dict[label] = OrderedDict()
Expand All @@ -519,22 +524,24 @@ def iterator(self):
label_dict[label][name].append(values)

y = self.y_start
for label, model_data in label_dict.items():
for idx, (label, model_data) in enumerate(label_dict.items()):
for model_name, value_list in model_data.items():
if model_name:
row_label = "{}: {}".format(model_name, label)
else:
row_label = label
for values in value_list:
yield y, row_label, label, values, self.model_color[model_name]
yield y, row_label, label, selection_list[idx], values, self.model_color[
model_name
]
y += self.chain_offset
y += self.var_offset
y += self.group_offset

def labels_ticks_and_vals(self):
"""Get labels, ticks, values, and colors for the variable."""
y_ticks = defaultdict(list)
for y, label, _, vals, color in self.iterator():
for y, label, _, _, vals, color in self.iterator():
y_ticks[label].append((y, vals, color))
labels, ticks, vals, colors = [], [], [], []
for label, data in y_ticks.items():
Expand All @@ -546,10 +553,10 @@ def labels_ticks_and_vals(self):

def treeplot(self, qlist, hdi_prob):
"""Get data for each treeplot for the variable."""
for y, _, label, values, color in self.iterator():
for y, _, _, selection, values, color in self.iterator():
ntiles = np.percentile(values.flatten(), qlist)
ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
yield y, label, ntiles, color
yield y, selection, ntiles, color

def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
"""Get data for each ridgeplot for the variable."""
Expand Down
15 changes: 14 additions & 1 deletion arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,23 @@ def plot_forest(
>>> var_names=["^the"],
>>> filter_vars="regex",
>>> combined=True,
>>> ridgeplot_overlap=3,
>>> figsize=(9, 7))
>>> axes[0].set_title('Estimated theta for 8 schools model')

Forestpĺot with ropes

.. plot::
:context: close-figs

>>> rope = {'theta': [{'school': 'Choate', 'rope': (2, 4)}], 'mu': [{'rope': (-2, 2)}]}
>>> axes = az.plot_forest(non_centered_data,
>>> rope=rope,
>>> var_names='~tau',
>>> combined=True,
>>> figsize=(9, 7))
>>> axes[0].set_title('Estimated theta for 8 schools model')


Ridgeplot

.. plot::
Expand Down
5 changes: 4 additions & 1 deletion arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,10 @@ def test_plot_ess_no_divergences(models):
(
{
"var_names": ["mu", "tau"],
"rope": {"mu": [{"rope": (-0.1, 0.1)}], "tau": [{"rope": (0.2, 0.5)}]},
"rope": {
"mu": [{"rope": (-0.1, 0.1)}],
"theta": [{"school": "Choate", "rope": (0.2, 0.5)}],
},
},
1,
),
Expand Down
7 changes: 5 additions & 2 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,11 @@ def test_plot_trace_futurewarning(models, prop):
({"kind": "ridgeplot", "r_hat": True, "ess": True, "ridgeplot_alpha": 0}, 3),
(
{
"var_names": ["mu", "tau"],
"rope": {"mu": [{"rope": (-0.1, 0.1)}], "tau": [{"rope": (0.2, 0.5)}]},
"var_names": ["mu", "theta"],
"rope": {
"mu": [{"rope": (-0.1, 0.1)}],
"theta": [{"school": "Choate", "rope": (0.2, 0.5)}],
},
},
1,
),
Expand Down