From 5606696525d91f11d8093085934fa352b98ce97c Mon Sep 17 00:00:00 2001 From: Samuel Hinton Date: Thu, 19 Nov 2020 21:56:35 +1000 Subject: [PATCH] Updating extents to include some padding. Fixing watermark issues. --- README.md | 4 ++++ chainconsumer/analysis.py | 4 ++-- chainconsumer/chainconsumer.py | 2 +- chainconsumer/helpers.py | 14 +++++++++---- chainconsumer/plotter.py | 36 ++++++++++++++++++---------------- 5 files changed, 36 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index f82bf932..66a1b67c 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,10 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng ### Update History +##### 0.33.0 +* Adding extra padding to bin extents for KDE and smoothing +* Updating watermarking to work with matplotlib v3.0.0+ + ##### 0.32.0 * Fixing matplotlib axis formatter issue. diff --git a/chainconsumer/analysis.py b/chainconsumer/analysis.py index 46a2131d..d1f4d91b 100644 --- a/chainconsumer/analysis.py +++ b/chainconsumer/analysis.py @@ -314,14 +314,14 @@ def get_covariance_table(self, chain=0, parameters=None, caption="Parameter Cova parameters, cov = self.get_covariance(chain=chain, parameters=parameters) return self._get_2d_latex_table(parameters, cov, caption, label) - def _get_smoothed_histogram(self, chain, parameter): + def _get_smoothed_histogram(self, chain, parameter, pad=False): data = chain.get_data(parameter) smooth = chain.config["smooth"] if chain.grid: bins = get_grid_bins(data) else: bins = chain.config["bins"] - bins, smooth = get_smoothed_bins(smooth, bins, data, chain.weights) + bins, smooth = get_smoothed_bins(smooth, bins, data, chain.weights, pad=pad) hist, edges = np.histogram(data, bins=bins, density=True, weights=chain.weights) if chain.power is not None: diff --git a/chainconsumer/chainconsumer.py b/chainconsumer/chainconsumer.py index c987c39a..75e8b7be 100644 --- a/chainconsumer/chainconsumer.py +++ b/chainconsumer/chainconsumer.py @@ -20,7 +20,7 @@ class ChainConsumer(object): """ - __version__ = "0.32.0" + __version__ = "0.33.0" def __init__(self): logging.basicConfig(level=logging.INFO) diff --git a/chainconsumer/helpers.py b/chainconsumer/helpers.py index d0efdf00..8f515f01 100644 --- a/chainconsumer/helpers.py +++ b/chainconsumer/helpers.py @@ -2,7 +2,7 @@ import numpy as np -def get_extents(data, weight, plot=False, wide_extents=True, tiny=False): +def get_extents(data, weight, plot=False, wide_extents=True, tiny=False, pad=False): hist, be = np.histogram(data, weights=weight, bins=2000) bc = 0.5 * (be[1:] + be[:-1]) cdf = hist.cumsum() @@ -17,7 +17,13 @@ def get_extents(data, weight, plot=False, wide_extents=True, tiny=False): threshold = 0.3 i1 = np.where(cdf > threshold)[0][0] i2 = np.where(icdf > threshold)[0][0] - return bc[i1], bc[-i2] + lower = bc[i1] + upper = bc[-i2] + if pad: + width = upper - lower + lower -= 0.2 * width + upper += 0.2 * width + return lower, upper def get_bins(chains): @@ -25,8 +31,8 @@ def get_bins(chains): return proposal -def get_smoothed_bins(smooth, bins, data, weight, marginalised=True, plot=False): - minv, maxv = get_extents(data, weight, plot=plot) +def get_smoothed_bins(smooth, bins, data, weight, marginalised=True, plot=False, pad=False): + minv, maxv = get_extents(data, weight, plot=plot, pad=pad) if smooth is None or not smooth or smooth == 0: return np.linspace(minv, maxv, int(bins)), 0 else: diff --git a/chainconsumer/plotter.py b/chainconsumer/plotter.py index 9f413a57..5be11cdb 100644 --- a/chainconsumer/plotter.py +++ b/chainconsumer/plotter.py @@ -3,6 +3,7 @@ import numpy as np import matplotlib.pyplot as plt import matplotlib +from matplotlib.font_manager import FontProperties from matplotlib.ticker import MaxNLocator, ScalarFormatter, LogLocator from matplotlib.textpath import TextPath from numpy import meshgrid @@ -315,31 +316,33 @@ def _add_watermark(self, fig, axes, figsize, text, dpi=300, size_scale=1.0): # dx, dy = figsize dy, dx = dy * dpi, dx * dpi rotation = 180 / np.pi * np.arctan2(-dy, dx) - fontdict = self.parent.config["watermark_text_kwargs"] - if "usetex" in fontdict: - usetex = fontdict["usetex"] - else: - usetex = self.parent.config["usetex"] - fontdict["usetex"] = usetex - if fontdict["usetex"]: + property_dict = self.parent.config["watermark_text_kwargs"] + + keys_in_font_dict = ["family", "style", "variant", "weight", "stretch", "size"] + fontdict = {k: property_dict[k] for k in keys_in_font_dict if k in property_dict} + font_prop = FontProperties(**fontdict) + usetex = property_dict.get("usetex", self.parent.config["usetex"]) + if usetex: px, py, scale = 0.5, 0.5, 1.0 else: - px, py, scale = 0.45, 0.55, 0.8 - bb0 = TextPath((0, 0), text, size=50, prop=fontdict, usetex=usetex).get_extents() - bb1 = TextPath((0, 0), text, size=51, prop=fontdict, usetex=usetex).get_extents() + px, py, scale = 0.5, 0.5, 0.8 + + bb0 = TextPath((0, 0), text, size=50, prop=font_prop, usetex=usetex).get_extents() + bb1 = TextPath((0, 0), text, size=51, prop=font_prop, usetex=usetex).get_extents() dw = (bb1.width - bb0.width) * (dpi / 100) dh = (bb1.height - bb0.height) * (dpi / 100) size = np.sqrt(dy ** 2 + dx ** 2) / (dh * abs(dy / dx) + dw) * 0.6 * scale * size_scale if axes is not None: - if fontdict["usetex"]: + if usetex: size *= 0.7 else: - size *= 0.85 - fontdict["size"] = int(size) + size *= 0.8 + size = int(size) + print(f"Font size is {size}") if axes is None: - fig.text(px, py, text, fontdict=fontdict, rotation=rotation) + fig.text(px, py, text, fontdict=property_dict, rotation=rotation, fontsize=size) else: - axes.text(px, py, text, transform=axes.transAxes, fontdict=fontdict, rotation=rotation) + axes.text(px, py, text, transform=axes.transAxes, fontdict=property_dict, rotation=rotation, fontsize=size) def plot_walks( self, @@ -1215,11 +1218,10 @@ def _plot_bars(self, ax, parameter, chain, flip=False, summary=False): # pragma kde = chain.config["kde"] zorder = chain.config["zorder"] title_size = self.parent.config["label_font_size"] - chain_row = chain.get_data(parameter) weights = chain.weights if smooth or kde: - xs, ys, _ = self.parent.analysis._get_smoothed_histogram(chain, parameter) + xs, ys, _ = self.parent.analysis._get_smoothed_histogram(chain, parameter, pad=True) if flip: ax.plot(ys, xs, color=colour, ls=linestyle, lw=linewidth, zorder=zorder) else: