From d966911f145899ef2865183f4ed832eddc5ac6e2 Mon Sep 17 00:00:00 2001 From: Ari Hartikainen <ahartikainen@users.noreply.github.com> Date: Thu, 5 Mar 2020 12:28:36 +0200 Subject: [PATCH 1/7] [WIP] Modify pairplot to include jointplot and cornerplot like features (#1079) * add jointplot features into pairplot * add scatter_kde kind for pairplot * add point_estimate arguments * bokeh backend * fix None argument for color in kdeplot bokeh backend * run black, pylint and pytest * remove scatter_kde kind among several other changes * minor changes * run pytest * add plot width and height to backend_kwargs fix pylint issues fix hover feature fix hover feature minor fixes * update docstring run pylint * update changelog --- CHANGELOG.md | 1 + arviz/plots/backends/bokeh/pairplot.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb08719949..222f1a2cdc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ * New grayscale style. This also add two new cmaps `cet_grey_r` and `cet_grey_r`. These are perceptually uniform gray scale cmaps from colorcet (linear_grey_10_95_c0) (#1164) * Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171) +* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 ### Maintenance and fixes * Changed `diagonal` argument for `marginals` and fixed `point_estimate_marker_kwargs` in `plot_pair` (#1167) * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index 3daa81f7d8..b936015464 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -293,7 +293,6 @@ def get_width_and_height(jointplot, rotate): y = reference_values_copy[flat_var_names[i]] if x and y: ax[j, i].circle(y, x, **reference_values_kwargs) - ax[j, i].xaxis.axis_label = flat_var_names[i] ax[j, i].yaxis.axis_label = flat_var_names[j + var] From 090e7cf2e444bcbde898a75d652c945eecbeeca8 Mon Sep 17 00:00:00 2001 From: agustinaarroyuelo <agustinaarroyuelo@gmail.com> Date: Sun, 10 May 2020 15:40:15 -0300 Subject: [PATCH 2/7] fix jointplot with bokeh --- arviz/plots/backends/bokeh/pairplot.py | 44 +++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index b936015464..517d2a57d7 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -105,7 +105,13 @@ def plot_pair( ) numvars = vars_to_plot - (figsize, _, _, _, _, markersize) = _scale_fig_size(figsize, textsize, numvars - 2, numvars - 2) + if numvars == 2: + offset = 1 + else: + offset = 2 + (figsize, _, _, _, _, markersize) = _scale_fig_size( + figsize, textsize, numvars - offset, numvars - offset + ) point_estimate_marker_kwargs.setdefault("line_width", markersize) point_estimate_kwargs.setdefault("line_color", "orange") @@ -147,22 +153,22 @@ def get_width_and_height(jointplot, rotate): return width, height if marginals: - var = 0 + marginals_offset = 0 else: - var = 1 + marginals_offset = 1 if ax is None: ax = [] backend_kwargs.setdefault("width", int(figsize[0] / (numvars - 1) * dpi)) backend_kwargs.setdefault("height", int(figsize[1] / (numvars - 1) * dpi)) - for row in range(numvars - var): + for row in range(numvars - marginals_offset): row_ax = [] var1 = ( - flat_var_names[row + var] + flat_var_names[row + marginals_offset] if tmp_flat_var_names is None - else tmp_flat_var_names[row + var] + else tmp_flat_var_names[row + marginals_offset] ) - for n, col in enumerate(range(numvars - var)): + for col in range(numvars - marginals_offset): var2 = ( flat_var_names[col] if tmp_flat_var_names is None else tmp_flat_var_names[col] ) @@ -179,7 +185,7 @@ def get_width_and_height(jointplot, rotate): row_ax.append(None) else: jointplot = row == col and numvars == 2 and marginals - rotate = n == 1 + rotate = col == 1 width, height = get_width_and_height(jointplot, rotate) if jointplot: ax_ = bkp.figure(width=width, height=height, tooltips=tooltips) @@ -189,19 +195,19 @@ def get_width_and_height(jointplot, rotate): ax.append(row_ax) ax = np.array(ax) else: - assert ax.shape == (numvars - var, numvars - var) + assert ax.shape == (numvars - marginals_offset, numvars - marginals_offset) # pylint: disable=too-many-nested-blocks - for i in range(0, numvars - var): + for i in range(0, numvars - marginals_offset): var1 = flat_var_names[i] if tmp_flat_var_names is None else tmp_flat_var_names[i] - for j in range(0, numvars - var): + for j in range(0, numvars - marginals_offset): var2 = ( - flat_var_names[j + var] + flat_var_names[j + marginals_offset] if tmp_flat_var_names is None - else tmp_flat_var_names[j + var] + else tmp_flat_var_names[j + marginals_offset] ) if j == i and marginals: @@ -217,9 +223,9 @@ def get_width_and_height(jointplot, rotate): ) ax[j, i].xaxis.axis_label = flat_var_names[i] - ax[j, i].yaxis.axis_label = flat_var_names[j + var] + ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset] - elif j + var > i: + elif j + marginals_offset > i: if "scatter" in kind: if divergences: @@ -229,7 +235,7 @@ def get_width_and_height(jointplot, rotate): if "kde" in kind: var1_kde = infdata_group[i] - var2_kde = infdata_group[j + var] + var2_kde = infdata_group[j + marginals_offset] plot_kde( var1_kde, var2_kde, @@ -242,7 +248,7 @@ def get_width_and_height(jointplot, rotate): if "hexbin" in kind: var1_hexbin = infdata_group[i] - var2_hexbin = infdata_group[j + var] + var2_hexbin = infdata_group[j + marginals_offset] ax[j, i].grid.visible = False ax[j, i].hexbin( var1_hexbin, var2_hexbin, **hexbin_kwargs, @@ -289,12 +295,12 @@ def get_width_and_height(jointplot, rotate): ax[-1, -1].add_layout(ax_pe_hline) if reference_values: - x = reference_values_copy[flat_var_names[j + var]] + x = reference_values_copy[flat_var_names[j + marginals_offset]] y = reference_values_copy[flat_var_names[i]] if x and y: ax[j, i].circle(y, x, **reference_values_kwargs) ax[j, i].xaxis.axis_label = flat_var_names[i] - ax[j, i].yaxis.axis_label = flat_var_names[j + var] + ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset] show_layout(ax, show) From 86d09bba75dc7bb65f305de4ea5b416594f8b187 Mon Sep 17 00:00:00 2001 From: amukh18 <45681148+amukh18@users.noreply.github.com> Date: Mon, 11 May 2020 04:15:12 +0900 Subject: [PATCH 3/7] To transform colors in plotting functions to hex (#1084) * Transform colors to hex in plot_khat * Reformatted with black * Added hex conversion to khatplot.py and vectorized_to_hex to plot_utils.py * Black changes * Added keep_alpha parameter and list comprehension * Pydocstyle changes * More pydocstyle changes * generalised vectorized_to_hex * Black changes * Black changes * Modified khatplot hline_kwargs, vectorized_to_hex and test for vectorized_to_hex * Black changes * Rewrote tests and modified vectorized_to_hex * lint and minor fixes Co-authored-by: Oriol (ZBook) <oriol.abril.pla@gmail.com> --- arviz/plots/khatplot.py | 3 +++ arviz/plots/plot_utils.py | 25 ++++++++++++++++++++++- arviz/tests/base_tests/test_plot_utils.py | 15 ++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index 1437a55ac2..2df5c48ba6 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -11,6 +11,7 @@ format_coords_as_labels, get_plotting_function, matplotlib_kwarg_dealiaser, + vectorized_to_hex, ) from ..stats import ELPDData from ..rcparams import rcParams @@ -138,6 +139,7 @@ def plot_khat( hlines_kwargs.setdefault("alpha", 0.7) hlines_kwargs.setdefault("zorder", -1) hlines_kwargs.setdefault("color", "C1") + hlines_kwargs["color"] = vectorized_to_hex(hlines_kwargs["color"]) if coords is None: coords = {} @@ -200,6 +202,7 @@ def plot_khat( khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten() alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1) rgba_c[:, 3] = alphas + rgba_c = vectorized_to_hex(rgba_c) plot_khat_kwargs = dict( hover_label=hover_label, diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 874dc6a30e..14bc8057f3 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -4,6 +4,7 @@ from itertools import product, tee import importlib from scipy.stats import mode +from matplotlib.colors import to_hex import packaging import numpy as np @@ -468,7 +469,7 @@ def color_from_dim(dataarray, dim_name): ---------- dataarray : xarray.DataArray dim_name : str - dimension whose coordinates will be used as color code. + dimension whose coordinates will be used as color code. Returns ------- @@ -490,6 +491,28 @@ def color_from_dim(dataarray, dim_name): return colors, color_mapping +def vectorized_to_hex(c_values, keep_alpha=False): + """Convert a color (including vector of colors) to hex. + + Parameters + ---------- + c: Matplotlib color + + keep_alpha: boolean + to select if alpha values should be kept in the final hex values. + + Returns + ------- + rgba_hex : vector of hex values + """ + try: + hex_color = to_hex(c_values, keep_alpha) + + except ValueError: + hex_color = [to_hex(color, keep_alpha) for color in c_values] + return hex_color + + def format_coords_as_labels(dataarray, skip_dims=None): """Format 1d or multi-d dataarray coords as strings. diff --git a/arviz/tests/base_tests/test_plot_utils.py b/arviz/tests/base_tests/test_plot_utils.py index 1f81610f1c..e2cf8ec0c6 100644 --- a/arviz/tests/base_tests/test_plot_utils.py +++ b/arviz/tests/base_tests/test_plot_utils.py @@ -15,6 +15,7 @@ matplotlib_kwarg_dealiaser, xarray_to_ndarray, xarray_var_iter, + vectorized_to_hex, ) from ...rcparams import rc_context from ...numeric_utils import get_bins @@ -229,3 +230,17 @@ def test_matplotlib_kwarg_dealiaser(params): dealiased = matplotlib_kwarg_dealiaser(params["input"][0], kind=params["input"][1]) for returned in dealiased: assert returned in params["output"] + + +@pytest.mark.parametrize("c_values", ["#0000ff", "blue", [0, 0, 1]]) +def test_vectorized_to_hex_scalar(c_values): + output = vectorized_to_hex(c_values) + assert output == "#0000ff" + + +@pytest.mark.parametrize( + "c_values", [["blue", "blue"], ["blue", "#0000ff"], np.array([[0, 0, 1], [0, 0, 1]])] +) +def test_vectorized_to_hex_array(c_values): + output = vectorized_to_hex(c_values) + assert np.all([item == "#0000ff" for item in output]) From 0e9e24fe8bcffb9df228583769aa0039a60fdf72 Mon Sep 17 00:00:00 2001 From: Ari Hartikainen <ahartikainen@users.noreply.github.com> Date: Thu, 5 Mar 2020 12:28:36 +0200 Subject: [PATCH 4/7] [WIP] Modify pairplot to include jointplot and cornerplot like features (#1079) * add jointplot features into pairplot * add scatter_kde kind for pairplot * add point_estimate arguments * bokeh backend * fix None argument for color in kdeplot bokeh backend * run black, pylint and pytest * remove scatter_kde kind among several other changes * minor changes * run pytest * add plot width and height to backend_kwargs fix pylint issues fix hover feature fix hover feature minor fixes * update docstring run pylint * update changelog --- CHANGELOG.md | 1 + arviz/plots/backends/bokeh/pairplot.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb08719949..222f1a2cdc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ * New grayscale style. This also add two new cmaps `cet_grey_r` and `cet_grey_r`. These are perceptually uniform gray scale cmaps from colorcet (linear_grey_10_95_c0) (#1164) * Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171) +* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 ### Maintenance and fixes * Changed `diagonal` argument for `marginals` and fixed `point_estimate_marker_kwargs` in `plot_pair` (#1167) * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index 3daa81f7d8..b936015464 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -293,7 +293,6 @@ def get_width_and_height(jointplot, rotate): y = reference_values_copy[flat_var_names[i]] if x and y: ax[j, i].circle(y, x, **reference_values_kwargs) - ax[j, i].xaxis.axis_label = flat_var_names[i] ax[j, i].yaxis.axis_label = flat_var_names[j + var] From 535e165070bf4c10b1e995766ec65b4a25e1e871 Mon Sep 17 00:00:00 2001 From: agustinaarroyuelo <agustinaarroyuelo@gmail.com> Date: Sun, 10 May 2020 15:40:15 -0300 Subject: [PATCH 5/7] fix jointplot with bokeh --- arviz/plots/backends/bokeh/pairplot.py | 44 +++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index b936015464..517d2a57d7 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -105,7 +105,13 @@ def plot_pair( ) numvars = vars_to_plot - (figsize, _, _, _, _, markersize) = _scale_fig_size(figsize, textsize, numvars - 2, numvars - 2) + if numvars == 2: + offset = 1 + else: + offset = 2 + (figsize, _, _, _, _, markersize) = _scale_fig_size( + figsize, textsize, numvars - offset, numvars - offset + ) point_estimate_marker_kwargs.setdefault("line_width", markersize) point_estimate_kwargs.setdefault("line_color", "orange") @@ -147,22 +153,22 @@ def get_width_and_height(jointplot, rotate): return width, height if marginals: - var = 0 + marginals_offset = 0 else: - var = 1 + marginals_offset = 1 if ax is None: ax = [] backend_kwargs.setdefault("width", int(figsize[0] / (numvars - 1) * dpi)) backend_kwargs.setdefault("height", int(figsize[1] / (numvars - 1) * dpi)) - for row in range(numvars - var): + for row in range(numvars - marginals_offset): row_ax = [] var1 = ( - flat_var_names[row + var] + flat_var_names[row + marginals_offset] if tmp_flat_var_names is None - else tmp_flat_var_names[row + var] + else tmp_flat_var_names[row + marginals_offset] ) - for n, col in enumerate(range(numvars - var)): + for col in range(numvars - marginals_offset): var2 = ( flat_var_names[col] if tmp_flat_var_names is None else tmp_flat_var_names[col] ) @@ -179,7 +185,7 @@ def get_width_and_height(jointplot, rotate): row_ax.append(None) else: jointplot = row == col and numvars == 2 and marginals - rotate = n == 1 + rotate = col == 1 width, height = get_width_and_height(jointplot, rotate) if jointplot: ax_ = bkp.figure(width=width, height=height, tooltips=tooltips) @@ -189,19 +195,19 @@ def get_width_and_height(jointplot, rotate): ax.append(row_ax) ax = np.array(ax) else: - assert ax.shape == (numvars - var, numvars - var) + assert ax.shape == (numvars - marginals_offset, numvars - marginals_offset) # pylint: disable=too-many-nested-blocks - for i in range(0, numvars - var): + for i in range(0, numvars - marginals_offset): var1 = flat_var_names[i] if tmp_flat_var_names is None else tmp_flat_var_names[i] - for j in range(0, numvars - var): + for j in range(0, numvars - marginals_offset): var2 = ( - flat_var_names[j + var] + flat_var_names[j + marginals_offset] if tmp_flat_var_names is None - else tmp_flat_var_names[j + var] + else tmp_flat_var_names[j + marginals_offset] ) if j == i and marginals: @@ -217,9 +223,9 @@ def get_width_and_height(jointplot, rotate): ) ax[j, i].xaxis.axis_label = flat_var_names[i] - ax[j, i].yaxis.axis_label = flat_var_names[j + var] + ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset] - elif j + var > i: + elif j + marginals_offset > i: if "scatter" in kind: if divergences: @@ -229,7 +235,7 @@ def get_width_and_height(jointplot, rotate): if "kde" in kind: var1_kde = infdata_group[i] - var2_kde = infdata_group[j + var] + var2_kde = infdata_group[j + marginals_offset] plot_kde( var1_kde, var2_kde, @@ -242,7 +248,7 @@ def get_width_and_height(jointplot, rotate): if "hexbin" in kind: var1_hexbin = infdata_group[i] - var2_hexbin = infdata_group[j + var] + var2_hexbin = infdata_group[j + marginals_offset] ax[j, i].grid.visible = False ax[j, i].hexbin( var1_hexbin, var2_hexbin, **hexbin_kwargs, @@ -289,12 +295,12 @@ def get_width_and_height(jointplot, rotate): ax[-1, -1].add_layout(ax_pe_hline) if reference_values: - x = reference_values_copy[flat_var_names[j + var]] + x = reference_values_copy[flat_var_names[j + marginals_offset]] y = reference_values_copy[flat_var_names[i]] if x and y: ax[j, i].circle(y, x, **reference_values_kwargs) ax[j, i].xaxis.axis_label = flat_var_names[i] - ax[j, i].yaxis.axis_label = flat_var_names[j + var] + ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset] show_layout(ax, show) From 9280aa2cd7545eeb64b828fa0b629e47f5c8a08b Mon Sep 17 00:00:00 2001 From: agustinaarroyuelo <agustinaarroyuelo@gmail.com> Date: Mon, 11 May 2020 09:04:25 -0300 Subject: [PATCH 6/7] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 222f1a2cdc..3da39c02d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,8 @@ * New grayscale style. This also add two new cmaps `cet_grey_r` and `cet_grey_r`. These are perceptually uniform gray scale cmaps from colorcet (linear_grey_10_95_c0) (#1164) * Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171) -* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 ### Maintenance and fixes +* Fixed `plot_pair` functionality for two variables with bokeh backend (#1179) * Changed `diagonal` argument for `marginals` and fixed `point_estimate_marker_kwargs` in `plot_pair` (#1167) * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) From 2ac10ef9615fc99897f76ffdcbc86d8d1f8a188f Mon Sep 17 00:00:00 2001 From: agustinaarroyuelo <agustinaarroyuelo@gmail.com> Date: Mon, 11 May 2020 09:16:47 -0300 Subject: [PATCH 7/7] fixed changelog --- CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 667dd80d23..4d4c280471 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,6 @@ ### New features * Stats and plotting functions that provide `var_names` arg can now filter parameters based on partial naming (`filter="like"`) or regular expressions (`filter="regex"`) (see [#1154](https://github.com/arviz-devs/arviz/pull/1154)). * Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140 -* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 * Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090 * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Allow xarray.Dataarray input for plots.(#1120) @@ -19,7 +18,6 @@ * New grayscale style. This also add two new cmaps `cet_grey_r` and `cet_grey_r`. These are perceptually uniform gray scale cmaps from colorcet (linear_grey_10_95_c0) (#1164) * Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171) -* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 ### Maintenance and fixes * Fixed `plot_pair` functionality for two variables with bokeh backend (#1179) * Changed `diagonal` argument for `marginals` and fixed `point_estimate_marker_kwargs` in `plot_pair` (#1167)