diff --git a/CHANGELOG.md b/CHANGELOG.md index 223cef715f..76476dcbeb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ## v0.x.x Unreleased ### New features * Added InferenceData dataset containing circular variables (#1265) - +* Added `is_circular` argument to `plot_dist` and `plot_kde` allowing for a circular histogram (Matplotlib, Bokeh) or 1D KDE plot (Matplotlib). (#1266) ### Maintenance and fixes * plot_posterior: fix overlap of hdi and rope (#1263) @@ -73,7 +73,7 @@ ### New features * Stats and plotting functions that provide `var_names` arg can now filter parameters based on partial naming (`filter="like"`) or regular expressions (`filter="regex"`) (see [#1154](https://github.com/arviz-devs/arviz/pull/1154)). -* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140 +* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables (#1140) * Allow xarray.Dataarray input for plots.(#1120) * Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117) * Skip test for optional/extra dependencies when not installed (#1113) diff --git a/arviz/plots/backends/bokeh/distplot.py b/arviz/plots/backends/bokeh/distplot.py index e2826a0638..8b6dc82482 100644 --- a/arviz/plots/backends/bokeh/distplot.py +++ b/arviz/plots/backends/bokeh/distplot.py @@ -6,6 +6,7 @@ from . import backend_kwarg_defaults from .. import show_layout from ...kdeplot import plot_kde +from ...plot_utils import set_bokeh_circular_ticks_labels from ....numeric_utils import get_bins @@ -29,6 +30,7 @@ def plot_dist( contourf_kwargs, pcolormesh_kwargs, hist_kwargs, + is_circular, ax, backend_kwargs, show, @@ -43,14 +45,22 @@ def plot_dist( **backend_kwargs, } if ax is None: - ax = bkp.figure(**backend_kwargs) + if is_circular: + ax = bkp.figure(x_axis_type=None, y_axis_type=None) + else: + ax = bkp.figure(**backend_kwargs) if kind == "auto": kind = "hist" if values.dtype.kind == "i" else "kde" if kind == "hist": _histplot_bokeh_op( - values=values, values2=values2, rotated=rotated, ax=ax, hist_kwargs=hist_kwargs + values=values, + values2=values2, + rotated=rotated, + ax=ax, + hist_kwargs=hist_kwargs, + is_circular=is_circular, ) elif kind == "kde": if plot_kwargs is None: @@ -91,7 +101,7 @@ def plot_dist( return ax -def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs): +def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs, is_circular): """Add a histogram for the data to the axes.""" if values2 is not None: raise NotImplementedError("Insert hexbin plot here") @@ -105,6 +115,8 @@ def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs): hist_kwargs["fill_color"] = color hist_kwargs["line_color"] = color + hist_kwargs.setdefault("line_alpha", 0) + # remove defaults for mpl hist_kwargs.pop("rwidth", None) hist_kwargs.pop("align", None) @@ -119,8 +131,45 @@ def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs): if hist_kwargs.pop("cumulative", False): hist = np.cumsum(hist) hist /= hist[-1] - if rotated: - ax.quad(top=edges[:-1], bottom=edges[1:], left=0, right=hist, **hist_kwargs) + + if is_circular: + + if is_circular == "degrees": + edges = np.deg2rad(edges) + labels = ["0°", "45°", "90°", "135°", "180°", "225°", "270°", "315°"] + else: + + labels = [ + r"0", + r"π/4", + r"π/2", + r"3π/4", + r"π", + r"5π/4", + r"3π/2", + r"7π/4", + ] + + delta = np.mean(np.diff(edges) / 2) + + ax.annular_wedge( + x=0, + y=0, + inner_radius=0, + outer_radius=hist, + start_angle=edges[1:] - delta, + end_angle=edges[:-1] - delta, + direction="clock", + **hist_kwargs, + ) + + ax = set_bokeh_circular_ticks_labels(ax, hist, labels) + else: - ax.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], **hist_kwargs) + + if rotated: + ax.quad(top=edges[:-1], bottom=edges[1:], left=0, right=hist, **hist_kwargs) + else: + ax.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], **hist_kwargs) + return ax diff --git a/arviz/plots/backends/bokeh/energyplot.py b/arviz/plots/backends/bokeh/energyplot.py index c7f936194e..6cef8465ff 100644 --- a/arviz/plots/backends/bokeh/energyplot.py +++ b/arviz/plots/backends/bokeh/energyplot.py @@ -72,7 +72,12 @@ def plot_energy( hist_kwargs["line_color"] = None hist_kwargs["line_alpha"] = alpha _histplot_bokeh_op( - value.flatten(), values2=None, rotated=False, ax=ax, hist_kwargs=hist_kwargs + value.flatten(), + values2=None, + rotated=False, + ax=ax, + hist_kwargs=hist_kwargs, + is_circular=False, ) else: diff --git a/arviz/plots/backends/matplotlib/distplot.py b/arviz/plots/backends/matplotlib/distplot.py index f71c0110ba..bf5ee25abf 100644 --- a/arviz/plots/backends/matplotlib/distplot.py +++ b/arviz/plots/backends/matplotlib/distplot.py @@ -29,6 +29,7 @@ def plot_dist( contourf_kwargs, pcolormesh_kwargs, hist_kwargs, + is_circular, ax, backend_kwargs, show, @@ -43,11 +44,16 @@ def plot_dist( ) backend_kwargs = None if ax is None: - ax = plt.gca() + ax = plt.gca(polar=is_circular) if kind == "hist": ax = _histplot_mpl_op( - values=values, values2=values2, rotated=rotated, ax=ax, hist_kwargs=hist_kwargs + values=values, + values2=values2, + rotated=rotated, + ax=ax, + hist_kwargs=hist_kwargs, + is_circular=is_circular, ) elif kind == "kde": @@ -77,6 +83,7 @@ def plot_dist( ax=ax, backend="matplotlib", backend_kwargs=backend_kwargs, + is_circular=is_circular, show=show, ) @@ -86,20 +93,48 @@ def plot_dist( return ax -def _histplot_mpl_op(values, values2, rotated, ax, hist_kwargs): +def _histplot_mpl_op(values, values2, rotated, ax, hist_kwargs, is_circular): """Add a histogram for the data to the axes.""" + bins = hist_kwargs.pop("bins", None) + + if is_circular == "degrees": + if bins is None: + bins = get_bins(values) + values = np.deg2rad(values) + bins = np.deg2rad(bins) + + elif is_circular: + labels = [ + r"0", + r"π/4", + r"π/2", + r"3π/4", + r"π", + r"5π/4", + r"3π/2", + r"7π/4", + ] + + ax.set_xticklabels(labels) + if values2 is not None: raise NotImplementedError("Insert hexbin plot here") - bins = hist_kwargs.pop("bins") if bins is None: bins = get_bins(values) - ax.hist(np.asarray(values).flatten(), bins=bins, **hist_kwargs) + + n, _, _ = ax.hist(np.asarray(values).flatten(), bins=bins, **hist_kwargs) if rotated: ax.set_yticks(bins[:-1]) - else: + elif not is_circular: ax.set_xticks(bins[:-1]) + + if is_circular: + ax.set_ylim(0, 1.5 * n.max()) + ax.set_yticklabels([]) + if hist_kwargs.get("label") is not None: ax.legend() + return ax diff --git a/arviz/plots/backends/matplotlib/kdeplot.py b/arviz/plots/backends/matplotlib/kdeplot.py index 9acd045c00..f8a9aa9a87 100644 --- a/arviz/plots/backends/matplotlib/kdeplot.py +++ b/arviz/plots/backends/matplotlib/kdeplot.py @@ -32,6 +32,7 @@ def plot_kde( contour_kwargs, contourf_kwargs, pcolormesh_kwargs, + is_circular, ax, legend, backend_kwargs, @@ -76,7 +77,27 @@ def plot_kde( rug_space = max(density) * rug_kwargs.pop("space") - x = np.linspace(lower, upper, len(density)) + if is_circular: + + if is_circular == "radians": + labels = [ + r"0", + r"π/4", + r"π/2", + r"3π/4", + r"π", + r"5π/4", + r"3π/2", + r"7π/4", + ] + + ax.set_xticklabels(labels) + + x = np.linspace(-np.pi, np.pi, len(density)) + ax.set_yticklabels([]) + + else: + x = np.linspace(lower, upper, len(density)) fill_func = ax.fill_between fill_x, fill_y = x, density diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index 024123acd1..fbe3a5e7a8 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -3,7 +3,6 @@ import xarray as xr from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser -from ..numeric_utils import get_bins from ..data import InferenceData from ..rcparams import rcParams @@ -29,6 +28,7 @@ def plot_dist( contourf_kwargs=None, pcolormesh_kwargs=None, hist_kwargs=None, + is_circular=False, ax=None, backend=None, backend_kwargs=None, @@ -90,6 +90,9 @@ def plot_dist( Keywords passed to ax.pcolormesh. Ignored for 1D KDE. hist_kwargs : dict Keywords passed to the histogram. + is_circular : {False, True, "radians", "degrees"}. Default False. + Select input type {"radians", "degrees"} for circular histogram or KDE plot. If True, + default input type is "radians". ax: axes, optional Matplotlib axes or bokeh figures. backend: str, optional @@ -160,7 +163,6 @@ def plot_dist( if kind == "hist": hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist") - hist_kwargs.setdefault("bins", get_bins(values)) hist_kwargs.setdefault("cumulative", cumulative) hist_kwargs.setdefault("color", color) hist_kwargs.setdefault("label", label) @@ -197,6 +199,7 @@ def plot_dist( hist_kwargs=hist_kwargs, ax=ax, backend_kwargs=backend_kwargs, + is_circular=is_circular, show=show, **kwargs, ) diff --git a/arviz/plots/kdeplot.py b/arviz/plots/kdeplot.py index 99f0b07950..5b34fdf876 100644 --- a/arviz/plots/kdeplot.py +++ b/arviz/plots/kdeplot.py @@ -26,6 +26,7 @@ def plot_kde( contour_kwargs=None, contourf_kwargs=None, pcolormesh_kwargs=None, + is_circular=False, ax=None, legend=True, backend=None, @@ -81,6 +82,9 @@ def plot_kde( Keywords passed to ax.contourf to draw filled contours. Ignored for 1D KDE. pcolormesh_kwargs : dict Keywords passed to ax.pcolormesh. Ignored for 1D KDE. + is_circular : {False, True, "radians", "degrees"}. Default False. + Select input type {"radians", "degrees"} for circular histogram or KDE plot. If True, + default input type is "radians". ax: axes, optional Matplotlib axes or bokeh figures. legend : bool @@ -230,6 +234,7 @@ def plot_kde( contour_kwargs=contour_kwargs, contourf_kwargs=contourf_kwargs, pcolormesh_kwargs=pcolormesh_kwargs, + is_circular=is_circular, ax=ax, legend=legend, backend_kwargs=backend_kwargs, @@ -246,6 +251,7 @@ def plot_kde( kde_plot_args.pop("textsize") kde_plot_args.pop("label") kde_plot_args.pop("legend") + kde_plot_args.pop("is_circular") else: kde_plot_args.pop("return_glyph") diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 98268df100..d2429f001a 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -717,3 +717,51 @@ def sample_reference_distribution(dist, shape): x_ss.append(x_s) densities.append(density) return np.array(x_ss).T, np.array(densities).T + + +def set_bokeh_circular_ticks_labels(ax, hist, labels): + """Place ticks and ticklabels on Bokeh's circular histogram.""" + ticks = np.linspace(-np.pi, np.pi, len(labels), endpoint=False) + ax.annular_wedge( + x=0, + y=0, + inner_radius=0, + outer_radius=np.max(hist) * 1.1, + start_angle=ticks, + end_angle=ticks, + line_color="grey", + ) + + radii_circles = np.linspace(0, np.max(hist) * 1.1, 4) + ax.circle(0, 0, radius=radii_circles, fill_color=None, line_color="grey") + + offset = np.max(hist * 1.05) * 0.15 + ticks_labels_pos_1 = np.max(hist * 1.05) + ticks_labels_pos_2 = ticks_labels_pos_1 * np.sqrt(2) / 2 + + ax.text( + [ + ticks_labels_pos_1 + offset, + ticks_labels_pos_2 + offset, + 0, + -ticks_labels_pos_2 - offset, + -ticks_labels_pos_1 - offset, + -ticks_labels_pos_2 - offset, + 0, + ticks_labels_pos_2 + offset, + ], + [ + 0, + ticks_labels_pos_2 + offset / 2, + ticks_labels_pos_1 + offset, + ticks_labels_pos_2 + offset / 2, + 0, + -ticks_labels_pos_2 - offset, + -ticks_labels_pos_1 - offset, + -ticks_labels_pos_2 - offset, + ], + text=labels, + text_align="center", + ) + + return ax diff --git a/arviz/tests/base_tests/test_plot_utils.py b/arviz/tests/base_tests/test_plot_utils.py index 3c3777001a..162b53de28 100644 --- a/arviz/tests/base_tests/test_plot_utils.py +++ b/arviz/tests/base_tests/test_plot_utils.py @@ -17,6 +17,7 @@ xarray_var_iter, vectorized_to_hex, _dealiase_sel_kwargs, + set_bokeh_circular_ticks_labels, ) from ...rcparams import rc_context from ...numeric_utils import get_bins @@ -264,3 +265,24 @@ def test_dealiase_sel_kwargs(): assert res["alpha"] == 0.4 assert "line_color" in res assert res["line_color"] == "red" + + +# Check if Bokeh is installed +bokeh_installed = importlib.util.find_spec("bokeh") is not None # pylint: disable=invalid-name + + +@pytest.mark.skipif( + not (bokeh_installed | running_on_ci()), reason="test requires bokeh which is not installed", +) +def test_set_bokeh_circular_ticks_labels(): + """Assert the axes returned after placing ticks and tick labels for circular plots.""" + import bokeh.plotting as bkp + + ax = bkp.figure(x_axis_type=None, y_axis_type=None) + hist = np.linspace(0, 1, 10) + labels = ["0°", "45°", "90°", "135°", "180°", "225°", "270°", "315°"] + ax = set_bokeh_circular_ticks_labels(ax, hist, labels) + renderers = ax.renderers + assert len(renderers) == 3 + assert renderers[2].data_source.data["text"] == labels + assert len(renderers[0].data_source.data["start_angle"]) == len(labels) diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index d2a2ab0732..b90838f4ec 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -191,7 +191,21 @@ def test_plot_kde_cumulative(continuous_model, kwargs): assert axes -@pytest.mark.parametrize("kwargs", [{"kind": "hist"}, {"kind": "kde"}]) +@pytest.mark.parametrize( + "kwargs", + [ + {"kind": "hist"}, + {"kind": "kde"}, + {"is_circular": False}, + {"is_circular": False, "kind": "hist"}, + {"is_circular": True}, + {"is_circular": True, "kind": "hist"}, + {"is_circular": "radians"}, + {"is_circular": "radians", "kind": "hist"}, + {"is_circular": "degrees"}, + {"is_circular": "degrees", "kind": "hist"}, + ], +) def test_plot_dist(continuous_model, kwargs): axes = plot_dist(continuous_model["x"], backend="bokeh", show=False, **kwargs) assert axes diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 62d7d58aac..9d7df33e97 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -339,6 +339,10 @@ def test_plot_joint_bad(models): }, {"contour": False}, {"contour": False, "pcolormesh_kwargs": {"cmap": "plasma"}}, + {"is_circular": False}, + {"is_circular": True}, + {"is_circular": "radians"}, + {"is_circular": "degrees"}, ], ) def test_plot_kde(continuous_model, kwargs): @@ -370,7 +374,21 @@ def test_plot_kde_cumulative(continuous_model, kwargs): assert axes -@pytest.mark.parametrize("kwargs", [{"kind": "hist"}, {"kind": "kde"}]) +@pytest.mark.parametrize( + "kwargs", + [ + {"kind": "hist"}, + {"kind": "kde"}, + {"is_circular": False}, + {"is_circular": False, "kind": "hist"}, + {"is_circular": True}, + {"is_circular": True, "kind": "hist"}, + {"is_circular": "radians"}, + {"is_circular": "radians", "kind": "hist"}, + {"is_circular": "degrees"}, + {"is_circular": "degrees", "kind": "hist"}, + ], +) def test_plot_dist(continuous_model, kwargs): axes = plot_dist(continuous_model["x"], **kwargs) assert axes