Skip to content

Commit

Permalink
Fix circular traceplot warnings and title (#1517)
Browse files Browse the repository at this point in the history
* fix circular traceplot warnings and title

* update changelog and small fix
  • Loading branch information
agustinaarroyuelo authored Jan 23, 2021
1 parent 914f079 commit 4d8a65d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
### New features

### Maintenance and fixes

* Fixed ovelapping titles and repeating warnings on circular traceplot ([1517](https://github.com/arviz-devs/arviz/pull/1517))
### Deprecation

### Documentation
Expand Down
3 changes: 3 additions & 0 deletions arviz/plots/backends/matplotlib/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import _pylab_helpers
import matplotlib.ticker as mticker


from ...plot_utils import _scale_fig_size
Expand Down Expand Up @@ -101,6 +102,8 @@ def plot_kde(
f"{-np.pi/4:.2f}",
]

ticks_loc = ax.get_xticks()
ax.xaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_xticklabels(labels)

x = np.linspace(-np.pi, np.pi, len(density))
Expand Down
9 changes: 8 additions & 1 deletion arviz/plots/backends/matplotlib/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
import matplotlib.ticker as mticker

from ....stats.density_utils import get_bins
from ...distplot import plot_dist
Expand Down Expand Up @@ -318,9 +319,14 @@ def plot_trace(
if value[0].dtype.kind == "i" and idy == 0:
xticks = get_bins(value)
ax.set_xticks(xticks[:-1])
y = 1 / textsize
if not idy:
ax.set_yticks([])
ax.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True, y=1)
if circular:
y = 0.13 if selection else 0.12
ax.set_title(
make_label(var_name, selection), fontsize=titlesize, wrap=True, y=textsize * y
)
ax.tick_params(labelsize=xt_labelsize)

xlims = ax.get_xlim()
Expand Down Expand Up @@ -471,6 +477,7 @@ def _plot_chains_mpl(
if circ_units_trace == "degrees":
y_tick_locs = axes.get_yticks()
y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)]
axes.yaxis.set_major_locator(mticker.FixedLocator(y_tick_locs))
axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels])

if not combined:
Expand Down

0 comments on commit 4d8a65d

Please sign in to comment.