From 475e34c1bc79e5eef58ef8a12819bab18342bad8 Mon Sep 17 00:00:00 2001
From: Agustina Arroyuelo <agustinaarroyuelo@gmail.com>
Date: Fri, 23 Oct 2020 11:35:35 -0300
Subject: [PATCH] Fix labels and use circular KDE in plot_trace (#1428)

* fix circular traceplot labels

* update changelog

* fix xticklabels

* update test
---
 CHANGELOG.md                                  |  1 +
 arviz/plots/backends/matplotlib/kdeplot.py    | 18 +++++------
 arviz/plots/backends/matplotlib/traceplot.py  | 30 ++++++++++++++-----
 .../tests/base_tests/test_plots_matplotlib.py |  2 +-
 4 files changed, 33 insertions(+), 18 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index c82f6d9229..c0fbc0134b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,7 @@
 * Fix typo in `loo_pit` extraction of log likelihood ([1418](https://github.com/arviz-devs/arviz/pull/1418))
 * Have `from_pystan` store attrs as strings to allow netCDF storage ([1417](https://github.com/arviz-devs/arviz/pull/1417))
 * Remove ticks and spines in `plot_violin`  ([1426 ](https://github.com/arviz-devs/arviz/pull/1426))
+* Use circular KDE function and fix tick labels in circular `plot_trace` ([1428](https://github.com/arviz-devs/arviz/pull/1428))
 
 ### Deprecation
 
diff --git a/arviz/plots/backends/matplotlib/kdeplot.py b/arviz/plots/backends/matplotlib/kdeplot.py
index 173421473b..feca641605 100644
--- a/arviz/plots/backends/matplotlib/kdeplot.py
+++ b/arviz/plots/backends/matplotlib/kdeplot.py
@@ -85,14 +85,14 @@ def plot_kde(
 
             if is_circular == "radians":
                 labels = [
-                    r"0",
-                    r"π/4",
-                    r"π/2",
-                    r"3π/4",
-                    r"π",
-                    r"5π/4",
-                    r"3π/2",
-                    r"7π/4",
+                    "0",
+                    f"{np.pi/4:.2f}",
+                    f"{np.pi/2:.2f}",
+                    f"{3*np.pi/4:.2f}",
+                    f"{np.pi:.2f}",
+                    f"{-3*np.pi/4:.2f}",
+                    f"{-np.pi/2:.2f}",
+                    f"{-np.pi/4:.2f}",
                 ]
 
                 ax.set_xticklabels(labels)
@@ -130,7 +130,7 @@ def plot_kde(
                 fill_x,
                 fill_y,
                 where=np.isin(fill_x, fill_x[idx], invert=True, assume_unique=True),
-                **fill_kwargs
+                **fill_kwargs,
             )
         else:
             fill_kwargs.setdefault("alpha", 0)
diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py
index b0591663ab..2855c2d97b 100644
--- a/arviz/plots/backends/matplotlib/traceplot.py
+++ b/arviz/plots/backends/matplotlib/traceplot.py
@@ -224,10 +224,14 @@ def plot_trace(
         for idy in range(2):
             value = np.atleast_2d(value)
 
-            is_circular = var_name in circ_var_names and not idy
+            circular = var_name in circ_var_names and not idy
+            if var_name in circ_var_names and idy:
+                circ_units_trace = circ_var_units
+            else:
+                circ_units_trace = False
 
             if axes is None:
-                ax = fig.add_subplot(spec[idx, idy], polar=is_circular)
+                ax = fig.add_subplot(spec[idx, idy], polar=circular)
             else:
                 ax = axes[idx, idy]
 
@@ -255,8 +259,9 @@ def plot_trace(
                     fill_kwargs,
                     rug_kwargs,
                     rank_kwargs,
-                    is_circular,
+                    circular,
                     circ_var_units,
+                    circ_units_trace,
                 )
 
             else:
@@ -294,8 +299,9 @@ def plot_trace(
                         fill_kwargs,
                         rug_kwargs,
                         rank_kwargs,
-                        is_circular,
+                        circular,
                         circ_var_units,
+                        circ_units_trace,
                     )
                     if legend:
                         handles.append(
@@ -303,7 +309,7 @@ def plot_trace(
                                 [],
                                 [],
                                 label=label,
-                                **dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0)
+                                **dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0),
                             )
                         )
                 if legend and idy == 0:
@@ -337,7 +343,7 @@ def plot_trace(
                             ylocs = ylims[0]
                         values = value[chain, div_idxs]
 
-                        if is_circular:
+                        if circular:
                             tick = [ax.get_rmin() + ax.get_rmax() * 0.60, ax.get_rmax()]
                             for val in values:
                                 ax.plot(
@@ -449,10 +455,12 @@ def _plot_chains_mpl(
     fill_kwargs,
     rug_kwargs,
     rank_kwargs,
-    is_circular,
+    circular,
     circ_var_units,
+    circ_units_trace,
 ):
-    if not is_circular:
+
+    if not circular:
         circ_var_units = False
 
     for chain_idx, row in enumerate(value):
@@ -460,6 +468,10 @@ def _plot_chains_mpl(
             aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx)
             if idy:
                 axes.plot(data.draw.values, row, **aux_kwargs)
+                if circ_units_trace == "degrees":
+                    y_tick_locs = axes.get_yticks()
+                    y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)]
+                    axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels])
 
         if not combined:
             aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx)
@@ -476,6 +488,7 @@ def _plot_chains_mpl(
                     backend="matplotlib",
                     show=False,
                     is_circular=circ_var_units,
+                    circular=circular,
                 )
 
     if kind == "rank_bars" and idy:
@@ -498,5 +511,6 @@ def _plot_chains_mpl(
                 backend="matplotlib",
                 show=False,
                 is_circular=circ_var_units,
+                circular=circular,
             )
     return axes
diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py
index 41274098f6..8757957722 100644
--- a/arviz/tests/base_tests/test_plots_matplotlib.py
+++ b/arviz/tests/base_tests/test_plots_matplotlib.py
@@ -192,7 +192,7 @@ def test_plot_separation(kwargs):
         {"lines": [("mu", {}, [1, 2])]},
         {"lines": [("mu", {}, 8)]},
         {"circ_var_names": ["mu"]},
-        {"circ_var_units": "degrees"},
+        {"circ_var_names": ["mu"], "circ_var_units": "degrees"},
     ],
 )
 def test_plot_trace(models, kwargs):