Skip to content

Commit

Permalink
Fix labels and use circular KDE in plot_trace (#1428)
Browse files Browse the repository at this point in the history
* fix circular traceplot labels

* update changelog

* fix xticklabels

* update test
  • Loading branch information
agustinaarroyuelo authored Oct 23, 2020
1 parent 9997735 commit 475e34c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Fix typo in `loo_pit` extraction of log likelihood ([1418](https://github.com/arviz-devs/arviz/pull/1418))
* Have `from_pystan` store attrs as strings to allow netCDF storage ([1417](https://github.com/arviz-devs/arviz/pull/1417))
* Remove ticks and spines in `plot_violin` ([1426 ](https://github.com/arviz-devs/arviz/pull/1426))
* Use circular KDE function and fix tick labels in circular `plot_trace` ([1428](https://github.com/arviz-devs/arviz/pull/1428))

### Deprecation

Expand Down
18 changes: 9 additions & 9 deletions arviz/plots/backends/matplotlib/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def plot_kde(

if is_circular == "radians":
labels = [
r"0",
r"π/4",
r"π/2",
r"3π/4",
r"π",
r"5π/4",
r"3π/2",
r"7π/4",
"0",
f"{np.pi/4:.2f}",
f"{np.pi/2:.2f}",
f"{3*np.pi/4:.2f}",
f"{np.pi:.2f}",
f"{-3*np.pi/4:.2f}",
f"{-np.pi/2:.2f}",
f"{-np.pi/4:.2f}",
]

ax.set_xticklabels(labels)
Expand Down Expand Up @@ -130,7 +130,7 @@ def plot_kde(
fill_x,
fill_y,
where=np.isin(fill_x, fill_x[idx], invert=True, assume_unique=True),
**fill_kwargs
**fill_kwargs,
)
else:
fill_kwargs.setdefault("alpha", 0)
Expand Down
30 changes: 22 additions & 8 deletions arviz/plots/backends/matplotlib/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,14 @@ def plot_trace(
for idy in range(2):
value = np.atleast_2d(value)

is_circular = var_name in circ_var_names and not idy
circular = var_name in circ_var_names and not idy
if var_name in circ_var_names and idy:
circ_units_trace = circ_var_units
else:
circ_units_trace = False

if axes is None:
ax = fig.add_subplot(spec[idx, idy], polar=is_circular)
ax = fig.add_subplot(spec[idx, idy], polar=circular)
else:
ax = axes[idx, idy]

Expand Down Expand Up @@ -255,8 +259,9 @@ def plot_trace(
fill_kwargs,
rug_kwargs,
rank_kwargs,
is_circular,
circular,
circ_var_units,
circ_units_trace,
)

else:
Expand Down Expand Up @@ -294,16 +299,17 @@ def plot_trace(
fill_kwargs,
rug_kwargs,
rank_kwargs,
is_circular,
circular,
circ_var_units,
circ_units_trace,
)
if legend:
handles.append(
Line2D(
[],
[],
label=label,
**dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0)
**dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0),
)
)
if legend and idy == 0:
Expand Down Expand Up @@ -337,7 +343,7 @@ def plot_trace(
ylocs = ylims[0]
values = value[chain, div_idxs]

if is_circular:
if circular:
tick = [ax.get_rmin() + ax.get_rmax() * 0.60, ax.get_rmax()]
for val in values:
ax.plot(
Expand Down Expand Up @@ -449,17 +455,23 @@ def _plot_chains_mpl(
fill_kwargs,
rug_kwargs,
rank_kwargs,
is_circular,
circular,
circ_var_units,
circ_units_trace,
):
if not is_circular:

if not circular:
circ_var_units = False

for chain_idx, row in enumerate(value):
if kind == "trace":
aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx)
if idy:
axes.plot(data.draw.values, row, **aux_kwargs)
if circ_units_trace == "degrees":
y_tick_locs = axes.get_yticks()
y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)]
axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels])

if not combined:
aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx)
Expand All @@ -476,6 +488,7 @@ def _plot_chains_mpl(
backend="matplotlib",
show=False,
is_circular=circ_var_units,
circular=circular,
)

if kind == "rank_bars" and idy:
Expand All @@ -498,5 +511,6 @@ def _plot_chains_mpl(
backend="matplotlib",
show=False,
is_circular=circ_var_units,
circular=circular,
)
return axes
2 changes: 1 addition & 1 deletion arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_plot_separation(kwargs):
{"lines": [("mu", {}, [1, 2])]},
{"lines": [("mu", {}, 8)]},
{"circ_var_names": ["mu"]},
{"circ_var_units": "degrees"},
{"circ_var_names": ["mu"], "circ_var_units": "degrees"},
],
)
def test_plot_trace(models, kwargs):
Expand Down

0 comments on commit 475e34c

Please sign in to comment.