Skip to content

Commit

Permalink
Updating extents to include some padding. Fixing watermark issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
Samreay committed Nov 19, 2020
1 parent 8850b3d commit 5606696
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 24 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng

### Update History

##### 0.33.0
* Adding extra padding to bin extents for KDE and smoothing
* Updating watermarking to work with matplotlib v3.0.0+

##### 0.32.0
* Fixing matplotlib axis formatter issue.

Expand Down
4 changes: 2 additions & 2 deletions chainconsumer/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ def get_covariance_table(self, chain=0, parameters=None, caption="Parameter Cova
parameters, cov = self.get_covariance(chain=chain, parameters=parameters)
return self._get_2d_latex_table(parameters, cov, caption, label)

def _get_smoothed_histogram(self, chain, parameter):
def _get_smoothed_histogram(self, chain, parameter, pad=False):
data = chain.get_data(parameter)
smooth = chain.config["smooth"]
if chain.grid:
bins = get_grid_bins(data)
else:
bins = chain.config["bins"]
bins, smooth = get_smoothed_bins(smooth, bins, data, chain.weights)
bins, smooth = get_smoothed_bins(smooth, bins, data, chain.weights, pad=pad)

hist, edges = np.histogram(data, bins=bins, density=True, weights=chain.weights)
if chain.power is not None:
Expand Down
2 changes: 1 addition & 1 deletion chainconsumer/chainconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ChainConsumer(object):
"""

__version__ = "0.32.0"
__version__ = "0.33.0"

def __init__(self):
logging.basicConfig(level=logging.INFO)
Expand Down
14 changes: 10 additions & 4 deletions chainconsumer/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


def get_extents(data, weight, plot=False, wide_extents=True, tiny=False):
def get_extents(data, weight, plot=False, wide_extents=True, tiny=False, pad=False):
hist, be = np.histogram(data, weights=weight, bins=2000)
bc = 0.5 * (be[1:] + be[:-1])
cdf = hist.cumsum()
Expand All @@ -17,16 +17,22 @@ def get_extents(data, weight, plot=False, wide_extents=True, tiny=False):
threshold = 0.3
i1 = np.where(cdf > threshold)[0][0]
i2 = np.where(icdf > threshold)[0][0]
return bc[i1], bc[-i2]
lower = bc[i1]
upper = bc[-i2]
if pad:
width = upper - lower
lower -= 0.2 * width
upper += 0.2 * width
return lower, upper


def get_bins(chains):
proposal = [max(35, np.floor(1.0 * np.power(chain.chain.shape[0] / chain.chain.shape[1], 0.25))) for chain in chains]
return proposal


def get_smoothed_bins(smooth, bins, data, weight, marginalised=True, plot=False):
minv, maxv = get_extents(data, weight, plot=plot)
def get_smoothed_bins(smooth, bins, data, weight, marginalised=True, plot=False, pad=False):
minv, maxv = get_extents(data, weight, plot=plot, pad=pad)
if smooth is None or not smooth or smooth == 0:
return np.linspace(minv, maxv, int(bins)), 0
else:
Expand Down
36 changes: 19 additions & 17 deletions chainconsumer/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties
from matplotlib.ticker import MaxNLocator, ScalarFormatter, LogLocator
from matplotlib.textpath import TextPath
from numpy import meshgrid
Expand Down Expand Up @@ -315,31 +316,33 @@ def _add_watermark(self, fig, axes, figsize, text, dpi=300, size_scale=1.0): #
dx, dy = figsize
dy, dx = dy * dpi, dx * dpi
rotation = 180 / np.pi * np.arctan2(-dy, dx)
fontdict = self.parent.config["watermark_text_kwargs"]
if "usetex" in fontdict:
usetex = fontdict["usetex"]
else:
usetex = self.parent.config["usetex"]
fontdict["usetex"] = usetex
if fontdict["usetex"]:
property_dict = self.parent.config["watermark_text_kwargs"]

keys_in_font_dict = ["family", "style", "variant", "weight", "stretch", "size"]
fontdict = {k: property_dict[k] for k in keys_in_font_dict if k in property_dict}
font_prop = FontProperties(**fontdict)
usetex = property_dict.get("usetex", self.parent.config["usetex"])
if usetex:
px, py, scale = 0.5, 0.5, 1.0
else:
px, py, scale = 0.45, 0.55, 0.8
bb0 = TextPath((0, 0), text, size=50, prop=fontdict, usetex=usetex).get_extents()
bb1 = TextPath((0, 0), text, size=51, prop=fontdict, usetex=usetex).get_extents()
px, py, scale = 0.5, 0.5, 0.8

bb0 = TextPath((0, 0), text, size=50, prop=font_prop, usetex=usetex).get_extents()
bb1 = TextPath((0, 0), text, size=51, prop=font_prop, usetex=usetex).get_extents()
dw = (bb1.width - bb0.width) * (dpi / 100)
dh = (bb1.height - bb0.height) * (dpi / 100)
size = np.sqrt(dy ** 2 + dx ** 2) / (dh * abs(dy / dx) + dw) * 0.6 * scale * size_scale
if axes is not None:
if fontdict["usetex"]:
if usetex:
size *= 0.7
else:
size *= 0.85
fontdict["size"] = int(size)
size *= 0.8
size = int(size)
print(f"Font size is {size}")
if axes is None:
fig.text(px, py, text, fontdict=fontdict, rotation=rotation)
fig.text(px, py, text, fontdict=property_dict, rotation=rotation, fontsize=size)
else:
axes.text(px, py, text, transform=axes.transAxes, fontdict=fontdict, rotation=rotation)
axes.text(px, py, text, transform=axes.transAxes, fontdict=property_dict, rotation=rotation, fontsize=size)

def plot_walks(
self,
Expand Down Expand Up @@ -1215,11 +1218,10 @@ def _plot_bars(self, ax, parameter, chain, flip=False, summary=False): # pragma
kde = chain.config["kde"]
zorder = chain.config["zorder"]
title_size = self.parent.config["label_font_size"]

chain_row = chain.get_data(parameter)
weights = chain.weights
if smooth or kde:
xs, ys, _ = self.parent.analysis._get_smoothed_histogram(chain, parameter)
xs, ys, _ = self.parent.analysis._get_smoothed_histogram(chain, parameter, pad=True)
if flip:
ax.plot(ys, xs, color=colour, ls=linestyle, lw=linewidth, zorder=zorder)
else:
Expand Down

0 comments on commit 5606696

Please sign in to comment.