diff --git a/CHANGELOG.md b/CHANGELOG.md index c82f6d9229..c0fbc0134b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ * Fix typo in `loo_pit` extraction of log likelihood ([1418](https://github.com/arviz-devs/arviz/pull/1418)) * Have `from_pystan` store attrs as strings to allow netCDF storage ([1417](https://github.com/arviz-devs/arviz/pull/1417)) * Remove ticks and spines in `plot_violin` ([1426 ](https://github.com/arviz-devs/arviz/pull/1426)) +* Use circular KDE function and fix tick labels in circular `plot_trace` ([1428](https://github.com/arviz-devs/arviz/pull/1428)) ### Deprecation diff --git a/arviz/plots/backends/matplotlib/kdeplot.py b/arviz/plots/backends/matplotlib/kdeplot.py index 173421473b..feca641605 100644 --- a/arviz/plots/backends/matplotlib/kdeplot.py +++ b/arviz/plots/backends/matplotlib/kdeplot.py @@ -85,14 +85,14 @@ def plot_kde( if is_circular == "radians": labels = [ - r"0", - r"π/4", - r"π/2", - r"3π/4", - r"π", - r"5π/4", - r"3π/2", - r"7π/4", + "0", + f"{np.pi/4:.2f}", + f"{np.pi/2:.2f}", + f"{3*np.pi/4:.2f}", + f"{np.pi:.2f}", + f"{-3*np.pi/4:.2f}", + f"{-np.pi/2:.2f}", + f"{-np.pi/4:.2f}", ] ax.set_xticklabels(labels) @@ -130,7 +130,7 @@ def plot_kde( fill_x, fill_y, where=np.isin(fill_x, fill_x[idx], invert=True, assume_unique=True), - **fill_kwargs + **fill_kwargs, ) else: fill_kwargs.setdefault("alpha", 0) diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index b0591663ab..2855c2d97b 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -224,10 +224,14 @@ def plot_trace( for idy in range(2): value = np.atleast_2d(value) - is_circular = var_name in circ_var_names and not idy + circular = var_name in circ_var_names and not idy + if var_name in circ_var_names and idy: + circ_units_trace = circ_var_units + else: + circ_units_trace = False if axes is None: - ax = fig.add_subplot(spec[idx, idy], polar=is_circular) + ax = fig.add_subplot(spec[idx, idy], polar=circular) else: ax = axes[idx, idy] @@ -255,8 +259,9 @@ def plot_trace( fill_kwargs, rug_kwargs, rank_kwargs, - is_circular, + circular, circ_var_units, + circ_units_trace, ) else: @@ -294,8 +299,9 @@ def plot_trace( fill_kwargs, rug_kwargs, rank_kwargs, - is_circular, + circular, circ_var_units, + circ_units_trace, ) if legend: handles.append( @@ -303,7 +309,7 @@ def plot_trace( [], [], label=label, - **dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0) + **dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0), ) ) if legend and idy == 0: @@ -337,7 +343,7 @@ def plot_trace( ylocs = ylims[0] values = value[chain, div_idxs] - if is_circular: + if circular: tick = [ax.get_rmin() + ax.get_rmax() * 0.60, ax.get_rmax()] for val in values: ax.plot( @@ -449,10 +455,12 @@ def _plot_chains_mpl( fill_kwargs, rug_kwargs, rank_kwargs, - is_circular, + circular, circ_var_units, + circ_units_trace, ): - if not is_circular: + + if not circular: circ_var_units = False for chain_idx, row in enumerate(value): @@ -460,6 +468,10 @@ def _plot_chains_mpl( aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx) if idy: axes.plot(data.draw.values, row, **aux_kwargs) + if circ_units_trace == "degrees": + y_tick_locs = axes.get_yticks() + y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)] + axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels]) if not combined: aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx) @@ -476,6 +488,7 @@ def _plot_chains_mpl( backend="matplotlib", show=False, is_circular=circ_var_units, + circular=circular, ) if kind == "rank_bars" and idy: @@ -498,5 +511,6 @@ def _plot_chains_mpl( backend="matplotlib", show=False, is_circular=circ_var_units, + circular=circular, ) return axes diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 41274098f6..8757957722 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -192,7 +192,7 @@ def test_plot_separation(kwargs): {"lines": [("mu", {}, [1, 2])]}, {"lines": [("mu", {}, 8)]}, {"circ_var_names": ["mu"]}, - {"circ_var_units": "degrees"}, + {"circ_var_names": ["mu"], "circ_var_units": "degrees"}, ], ) def test_plot_trace(models, kwargs):