Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairplot of two variables with bokeh #1179

Merged
merged 8 commits into from
May 11, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* New grayscale style. This also add two new cmaps `cet_grey_r` and `cet_grey_r`. These are perceptually uniform gray scale cmaps from colorcet (linear_grey_10_95_c0) (#1164)
* Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being picky here. Looks like theres a blank line that can be removed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this changelog item correct? It feels like this PR is internally facing readability changes (which is great) but the feature described here ie something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a duplicated line. Thank you for your helpful comment

* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
### Maintenance and fixes
* Changed `diagonal` argument for `marginals` and fixed `point_estimate_marker_kwargs` in `plot_pair` (#1167)
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
Expand Down
45 changes: 25 additions & 20 deletions arviz/plots/backends/bokeh/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ def plot_pair(
)
numvars = vars_to_plot

(figsize, _, _, _, _, markersize) = _scale_fig_size(figsize, textsize, numvars - 2, numvars - 2)
if numvars == 2:
offset = 1
else:
offset = 2
(figsize, _, _, _, _, markersize) = _scale_fig_size(
figsize, textsize, numvars - offset, numvars - offset
)

point_estimate_marker_kwargs.setdefault("line_width", markersize)
point_estimate_kwargs.setdefault("line_color", "orange")
Expand Down Expand Up @@ -147,22 +153,22 @@ def get_width_and_height(jointplot, rotate):
return width, height

if marginals:
var = 0
marginals_offset = 0
else:
var = 1
marginals_offset = 1

if ax is None:
ax = []
backend_kwargs.setdefault("width", int(figsize[0] / (numvars - 1) * dpi))
backend_kwargs.setdefault("height", int(figsize[1] / (numvars - 1) * dpi))
for row in range(numvars - var):
for row in range(numvars - marginals_offset):
row_ax = []
var1 = (
flat_var_names[row + var]
flat_var_names[row + marginals_offset]
if tmp_flat_var_names is None
else tmp_flat_var_names[row + var]
else tmp_flat_var_names[row + marginals_offset]
)
for n, col in enumerate(range(numvars - var)):
for col in range(numvars - marginals_offset):
var2 = (
flat_var_names[col] if tmp_flat_var_names is None else tmp_flat_var_names[col]
)
Expand All @@ -179,7 +185,7 @@ def get_width_and_height(jointplot, rotate):
row_ax.append(None)
else:
jointplot = row == col and numvars == 2 and marginals
rotate = n == 1
rotate = col == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, much easier to read now

width, height = get_width_and_height(jointplot, rotate)
if jointplot:
ax_ = bkp.figure(width=width, height=height, tooltips=tooltips)
Expand All @@ -189,19 +195,19 @@ def get_width_and_height(jointplot, rotate):
ax.append(row_ax)
ax = np.array(ax)
else:
assert ax.shape == (numvars - var, numvars - var)
assert ax.shape == (numvars - marginals_offset, numvars - marginals_offset)

# pylint: disable=too-many-nested-blocks
for i in range(0, numvars - var):
for i in range(0, numvars - marginals_offset):

var1 = flat_var_names[i] if tmp_flat_var_names is None else tmp_flat_var_names[i]

for j in range(0, numvars - var):
for j in range(0, numvars - marginals_offset):

var2 = (
flat_var_names[j + var]
flat_var_names[j + marginals_offset]
if tmp_flat_var_names is None
else tmp_flat_var_names[j + var]
else tmp_flat_var_names[j + marginals_offset]
)

if j == i and marginals:
Expand All @@ -217,9 +223,9 @@ def get_width_and_height(jointplot, rotate):
)

ax[j, i].xaxis.axis_label = flat_var_names[i]
ax[j, i].yaxis.axis_label = flat_var_names[j + var]
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]

elif j + var > i:
elif j + marginals_offset > i:

if "scatter" in kind:
if divergences:
Expand All @@ -229,7 +235,7 @@ def get_width_and_height(jointplot, rotate):

if "kde" in kind:
var1_kde = infdata_group[i]
var2_kde = infdata_group[j + var]
var2_kde = infdata_group[j + marginals_offset]
plot_kde(
var1_kde,
var2_kde,
Expand All @@ -242,7 +248,7 @@ def get_width_and_height(jointplot, rotate):

if "hexbin" in kind:
var1_hexbin = infdata_group[i]
var2_hexbin = infdata_group[j + var]
var2_hexbin = infdata_group[j + marginals_offset]
ax[j, i].grid.visible = False
ax[j, i].hexbin(
var1_hexbin, var2_hexbin, **hexbin_kwargs,
Expand Down Expand Up @@ -289,13 +295,12 @@ def get_width_and_height(jointplot, rotate):
ax[-1, -1].add_layout(ax_pe_hline)

if reference_values:
x = reference_values_copy[flat_var_names[j + var]]
x = reference_values_copy[flat_var_names[j + marginals_offset]]
y = reference_values_copy[flat_var_names[i]]
if x and y:
ax[j, i].circle(y, x, **reference_values_kwargs)

ax[j, i].xaxis.axis_label = flat_var_names[i]
ax[j, i].yaxis.axis_label = flat_var_names[j + var]
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]

show_layout(ax, show)

Expand Down