Skip to content

Commit

Permalink
Fixes #83
Browse files Browse the repository at this point in the history
Creating new formatters per x and y axis instead of a global formatter
  • Loading branch information
Samreay committed Oct 30, 2020
1 parent ae32883 commit 2faed7b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng

### Update History

##### 0.31.2, 0.31.3
##### 0.32.0
* Fixing matplotlib axis formatter issue.

##### 0.31.2, 0.31.3
* Conda-forge updates

##### 0.31.1
Expand Down
2 changes: 1 addition & 1 deletion chainconsumer/chainconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ChainConsumer(object):
"""

__version__ = "0.31.3"
__version__ = "0.32.0"

def __init__(self):
logging.basicConfig(level=logging.INFO)
Expand Down
14 changes: 8 additions & 6 deletions chainconsumer/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def plot_walks(
if i == 0 and plot_posterior:
for chain in chains:
if chain.posterior is not None:
self._plot_walk(ax, "$\log(P)$", chain.posterior - chain.posterior.max(), convolve=convolve, color=chain.config["color"])
self._plot_walk(ax, r"$\log(P)$", chain.posterior - chain.posterior.max(), convolve=convolve, color=chain.config["color"])
else:
if log_weight is None:
log_weight = np.any([chain.weights.mean() < 0.1 for chain in chains])
Expand Down Expand Up @@ -939,9 +939,6 @@ def _get_figure(self, all_parameters, flip, figsize=(5, 5), external_extents=Non
fig, axes = plt.subplots(n, n, figsize=figsize, squeeze=False, gridspec_kw=gridspec_kw)
fig.subplots_adjust(left=0.1, right=0.95, top=0.95, bottom=0.1, wspace=0.05 * spacing, hspace=0.05 * spacing)

formatter = ScalarFormatter(useOffset=False)
formatter.set_powerlimits((-3, 4))

extents = self._get_custom_extents(all_parameters, chains, external_extents)

if plot_hists:
Expand All @@ -953,6 +950,11 @@ def _get_figure(self, all_parameters, flip, figsize=(5, 5), external_extents=Non
for i, p1 in enumerate(params1):
for j, p2 in enumerate(params2):
ax = axes[i, j]
formatter_x = ScalarFormatter(useOffset=True)
formatter_x.set_powerlimits((-3, 4))
formatter_y = ScalarFormatter(useOffset=True)
formatter_y.set_powerlimits((-3, 4))

display_x_ticks = False
display_y_ticks = False
if i < j:
Expand Down Expand Up @@ -1001,7 +1003,7 @@ def _get_figure(self, all_parameters, flip, figsize=(5, 5), external_extents=Non
_ = [l.set_fontsize(tick_font_size) for l in ax.get_xticklabels()]
if not logx:
ax.xaxis.set_major_locator(MaxNLocator(max_ticks, prune="lower"))
ax.xaxis.set_major_formatter(formatter)
ax.xaxis.set_major_formatter(formatter_x)
else:
ax.xaxis.set_major_locator(LogLocator(numticks=max_ticks))
else:
Expand All @@ -1012,7 +1014,7 @@ def _get_figure(self, all_parameters, flip, figsize=(5, 5), external_extents=Non
_ = [l.set_fontsize(tick_font_size) for l in ax.get_yticklabels()]
if not logy:
ax.yaxis.set_major_locator(MaxNLocator(max_ticks, prune="lower"))
ax.yaxis.set_major_formatter(formatter)
ax.yaxis.set_major_formatter(formatter_y)
else:
ax.yaxis.set_major_locator(LogLocator(numticks=max_ticks))
else:
Expand Down

0 comments on commit 2faed7b

Please sign in to comment.