Skip to content

Commit

Permalink
Fixes #61
Browse files Browse the repository at this point in the history
Also adding `global_point` option to configure for more flexibility.

v0.26.1 RC1
  • Loading branch information
Samreay committed Jul 18, 2018
1 parent 6cbd912 commit 13a6907
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 37 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng

### Update History

##### 0.26.1
* Adding ability to plot maximum points on 2D contour, not just global posterior maximum.
* Fixing truth dictionary mutation on `plot_walks`

##### 0.26.0
* Adding ability to pass in a power to raise the surface to for each chain.
* Adding methods to retrieve the maximum posterior point: `Analysis.get_max_posteriors`
Expand Down
7 changes: 4 additions & 3 deletions chainconsumer/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_summary(self, squeeze=True, parameters=None, chains=None):
else:
if isinstance(chains, (int, str)):
chains = [chains]
chains = [self.parent.chains[self.parent._get_chain(c)] for c in chains]
chains = [self.parent.chains[i] for c in chains for i in self.parent._get_chain(c)]

for chain in chains:
res = {}
Expand Down Expand Up @@ -174,7 +174,7 @@ def get_max_posteriors(self, parameters=None, squeeze=True, chains=None):
else:
if isinstance(chains, (int, str)):
chains = [chains]
chains = [self.parent.chains[self.parent._get_chain(c)] for c in chains]
chains = [self.parent.chains[i] for c in chains for i in self.parent._get_chain(c)]

if isinstance(parameters, str):
parameters = [parameters]
Expand Down Expand Up @@ -244,7 +244,8 @@ def get_covariance(self, chain=0, parameters=None):
2D covariance matrix.
"""
index = self.parent._get_chain(chain)
chain = self.parent.chains[index]
assert len(index) == 1, "Please specify only one chain, have %d chains" % len(index)
chain = self.parent.chains[index[0]]
if parameters is None:
parameters = chain.parameters

Expand Down
37 changes: 22 additions & 15 deletions chainconsumer/chainconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ChainConsumer(object):
figures, tables, diagnostics, you name it.
"""
__version__ = "0.26.0"
__version__ = "0.26.1"

def __init__(self):
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -251,7 +251,7 @@ def remove_chain(self, chain=-1):
if isinstance(chain, str) or isinstance(chain, int):
chain = [chain]

chain = sorted([self._get_chain(c) for c in chain])[::-1]
chain = sorted([i for c in chain for i in self._get_chain(c)])[::-1]
assert len(chain) == len(list(set(chain))), "Error, you are trying to remove a chain more than once."

for index in chain:
Expand All @@ -270,7 +270,7 @@ def configure(self, statistics="max", max_ticks=5, plot_hists=True, flip=True,
colors=None, linestyles=None, linewidths=None, kde=False, smooth=None,
cloud=None, shade=None, shade_alpha=None, shade_gradient=None, bar_shade=None,
num_cloud=None, color_params=None, plot_color_params=False, cmaps=None,
plot_contour=None, plot_point=None, marker_style=None, marker_size=None, marker_alpha=None,
plot_contour=None, plot_point=None, global_point=True, marker_style=None, marker_size=None, marker_alpha=None,
usetex=True, diagonal_tick_labels=True, label_font_size=12, tick_font_size=10,
spacing=None, contour_labels=None, contour_label_font_size=10,
legend_kwargs=None, legend_location=None, legend_artists=None,
Expand Down Expand Up @@ -376,6 +376,10 @@ def configure(self, statistics="max", max_ticks=5, plot_hists=True, flip=True,
25 concurrent chains.
plot_point : bool|list[bool], optional
Whether to plot a maximum likelihood point. Defaults to true for more then 24 chains.
global_point : bool, optional
Whether the point which gets plotted is the global posterior maximum, or the marginalised 2D
posterior maximum. Note that when you use marginalised 2D maximums for the points, you do not
get the 1D histograms. Defaults to `True`, for a global maximum value.
marker_style : str|list[str], optional
The marker style to use when plotting points. Defaults to `'.'`
marker_size : numeric|list[numeric], optional
Expand Down Expand Up @@ -678,6 +682,7 @@ def configure(self, statistics="max", max_ticks=5, plot_hists=True, flip=True,
assert isinstance(summary_area, float), "summary_area needs to be a float, not %s!" % type(summary_area)
assert summary_area > 0, "summary_area should be a positive number, instead is %s!" % summary_area
assert summary_area < 1, "summary_area must be less than unity, instead is %s!" % summary_area
assert isinstance(global_point, bool), "global_point should be a bool"

# List options
for i, c in enumerate(self.chains):
Expand Down Expand Up @@ -731,6 +736,7 @@ def configure(self, statistics="max", max_ticks=5, plot_hists=True, flip=True,
self.config["legend_artists"] = legend_artists
self.config["legend_color_text"] = legend_color_text
self.config["watermark_text_kwargs"] = watermark_text_kwargs_default
self.config["global_point"] = global_point

self._configured = True
return self
Expand Down Expand Up @@ -788,28 +794,29 @@ def divide_chain(self, chain=0):
A new ChainConsumer instance with the same settings as the parent instance, containing
``num_walker`` chains.
"""
index = self._get_chain(chain)
chain = self.chains[index]

assert chain.walkers is not None, "The chain you have selected was not added with any walkers!"
num_walkers = chain.walkers
data = np.split(chain.chain, num_walkers)
ws = np.split(chain.weights, num_walkers)
indexes = self._get_chain(chain)
con = ChainConsumer()
for j, (c, w) in enumerate(zip(data, ws)):
con.add_chain(c, weights=w, name="Chain %d" % j, parameters=chain.parameters)

for index in indexes:
chain = self.chains[index]
assert chain.walkers is not None, "The chain you have selected was not added with any walkers!"
num_walkers = chain.walkers
data = np.split(chain.chain, num_walkers)
ws = np.split(chain.weights, num_walkers)
for j, (c, w) in enumerate(zip(data, ws)):
con.add_chain(c, weights=w, name="Chain %d" % j, parameters=chain.parameters)
return con

def _get_chain(self, chain):
if isinstance(chain, Chain):
return self.chains.index(chain)
return [self.chains.index(chain)]
if isinstance(chain, str):
names = [c.name for c in self.chains]
assert chain in names, "Chain %s not found!" % chain
index = names.index(chain)
index = [i for i, n in enumerate(names) if chain == n]
elif isinstance(chain, int):
assert chain < len(self.chains), "Chain index %d not found!" % chain
index = chain
index = [chain]
else:
raise ValueError("Type %s not recognised for chain" % type(chain))
return index
Expand Down
6 changes: 4 additions & 2 deletions chainconsumer/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def gelman_rubin(self, chain=None, threshold=0.05):
return np.all([self.gelman_rubin(k, threshold=threshold) for k in range(len(self.parent.chains))])

index = self.parent._get_chain(chain)
chain = self.parent.chains[index]
assert len(index) == 1, "Please specify only one chain, have %d chains" % len(index)
chain = self.parent.chains[index[0]]

num_walkers = chain.walkers
parameters = chain.parameters
Expand Down Expand Up @@ -101,7 +102,8 @@ def geweke(self, chain=None, first=0.1, last=0.5, threshold=0.05):
return np.all([self.geweke(k, threshold=threshold) for k in range(len(self.parent.chains))])

index = self.parent._get_chain(chain)
chain = self.parent.chains[index]
assert len(index) == 1, "Please specify only one chain, have %d chains" % len(index)
chain = self.parent.chains[index[0]]

num_walkers = chain.walkers
assert num_walkers is not None and num_walkers > 0, \
Expand Down
4 changes: 3 additions & 1 deletion chainconsumer/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np


def get_extents(data, weight, plot=False, wide_extents=True):
def get_extents(data, weight, plot=False, wide_extents=True, tiny=False):
hist, be = np.histogram(data, weights=weight, bins=2000)
bc = 0.5 * (be[1:] + be[:-1])
cdf = hist.cumsum()
Expand All @@ -12,6 +12,8 @@ def get_extents(data, weight, plot=False, wide_extents=True):
threshold = 1e-4 if plot else 1e-5
if plot and not wide_extents:
threshold = 0.05
if tiny:
threshold = 0.3
i1 = np.where(cdf > threshold)[0][0]
i2 = np.where(icdf > threshold)[0][0]
return bc[i1], bc[-i2]
Expand Down
41 changes: 29 additions & 12 deletions chainconsumer/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def plot(self, figsize="GROW", parameters=None, chains=None, extents=None, filen
if max_val is None or m > max_val:
max_val = m

if num_chain_points:
if num_chain_points and self.parent.config["global_point"]:
m = self._plot_point_histogram(ax, subgroups, p1, flip=do_flip)
if max_val is None or m > max_val:
max_val = m
Expand Down Expand Up @@ -395,9 +395,8 @@ def plot_walks(self, parameters=None, truth=None, extents=None, display=False,
for chain in chains:
if p in chain.parameters:
chain_row = chain.get_data(p)
self._plot_walk(ax, p, chain_row, truth=truth.get(p),
extents=extents.get(p), convolve=convolve, color=chain.config["color"])
truth[p] = None
self._plot_walk(ax, p, chain_row, extents=extents.get(p), convolve=convolve, color=chain.config["color"])
self._plot_walk_truth(ax, truth.get(p))
else:
if i == 0 and plot_posterior:
for chain in chains:
Expand Down Expand Up @@ -728,7 +727,7 @@ def _sanitise(self, chains, parameters, truth, extents, color_p=False, blind=Non
else:
if isinstance(chains, str) or isinstance(chains, int):
chains = [chains]
chains = [self.parent._get_chain(c) for c in chains]
chains = [i for c in chains for i in self.parent._get_chain(c)]

chains = [self.parent.chains[i] for i in chains]

Expand Down Expand Up @@ -893,8 +892,12 @@ def _get_parameter_extents(self, parameter, chains, wide_extents=True):
if parameter not in chain.parameters:
continue # pragma: no cover
if not chain.config["plot_contour"]:
min_prop = chain.posterior_max_params[parameter]
max_prop = min_prop
if self.parent.config["global_point"]:
min_prop = chain.posterior_max_params.get(parameter)
max_prop = min_prop
else:
data = chain.get_data(parameter)
min_prop, max_prop = get_extents(data, chain.weights, tiny=True)
else:
data = chain.get_data(parameter)
if chain.grid:
Expand All @@ -917,10 +920,23 @@ def _get_levels(self):
return levels

def _plot_points(self, ax, chains_groups, markers, sizes, alphas, py, px): # pragma: no cover
global_point = self.parent.config["global_point"]
for marker, chains, size, alpha in zip(markers, chains_groups, sizes, alphas):
res = self.parent.analysis.get_max_posteriors(parameters=[px, py], chains=chains, squeeze=False)
xs = [r[px] for r in res if r is not None]
ys = [r[py] for r in res if r is not None]
if global_point:
res = self.parent.analysis.get_max_posteriors(parameters=[px, py], chains=chains, squeeze=False)
xs = [r[px] for r in res if r is not None]
ys = [r[py] for r in res if r is not None]
else:
xs, ys, res = [], [], []
for chain in chains:
if px in chain.parameters and py in chain.parameters:
hist, x_centers, y_centers = self._get_smoothed_histogram2d(chain, py, px)
index = np.unravel_index(hist.argmax(), hist.shape)
ys.append(x_centers[index[0]])
xs.append(y_centers[index[1]])
res.append({"px": xs[-1], "py": ys[-1]})
else:
res.append(None)
cs = [c.config["color"] for c, r in zip(chains, res) if r is not None]
h = ax.scatter(xs, ys, marker=marker, c=cs, s=size, linewidth=0.7, alpha=alpha)
return h
Expand Down Expand Up @@ -1104,8 +1120,9 @@ def _plot_walk(self, ax, parameter, data, truth=None, extents=None,
filt = np.ones(convolve) / convolve
filtered = np.convolve(data, filt, mode="same")
ax.plot(x[:-1], filtered[:-1], ls=':', color=color2, alpha=1)
if truth is not None:
ax.axhline(truth, **self.parent.config_truth)

def _plot_walk_truth(self, ax, truth):
ax.axhline(truth, **self.parent.config_truth)

def _convert_to_stdev(self, sigma): # pragma: no cover
# From astroML
Expand Down
28 changes: 28 additions & 0 deletions examples/Basics/plot_hundreds_of_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,31 @@

fig = c.plotter.plot()
fig.set_size_inches(2.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.

###############################################################################
# If you've loaded a whole host of chains in, but only want to focus on one
# set, you can also pick out all chains with the same name when plotting.

fig = c.plotter.plot(chains="Sim1")
fig.set_size_inches(2.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.

###############################################################################
# Finally, we should clarify what exactly the points mean! If you don't specify
# anything, by defaults the points represent the coordinates of the
# maximum posterior value. However, in high dimensional surfaces, this maximum
# value across all dimensions can be different to the maximum posterior value
# of a 2D slice. If we want to plot, instead of the global maximum as defined
# by the posterior values, the maximum point of each 2D slice, we can specify
# to `configure` that `global_point=False`.

c.configure(legend_artists=True, global_point=False)
fig = c.plotter.plot(chains="Sim1")
fig.set_size_inches(2.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.

###############################################################################
# Note here that the histograms have disappeared. This is because the maximal
# point changes for each pair of parameters, and so none of the points can
# be used in a histogram. Whilst one could use the maximum point, marginalising
# across all parameters, this can be misleading if only two parameters
# are requested to be plotted. As such, we do not report histograms for
# the maximal 2D posterior points.
10 changes: 6 additions & 4 deletions tests/test_chainconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from chainconsumer import ChainConsumer


class TestChain(object):
class TestChainConsumer(object):
np.random.seed(1)
n = 2000000
data = np.random.normal(loc=5.0, scale=1.5, size=n)
Expand All @@ -28,8 +28,10 @@ def test_get_chain_via_object(self):
c = ChainConsumer()
c.add_chain(self.data, name="A")
c.add_chain(self.data, name="B")
assert c._get_chain(c.chains[0]) == 0
assert c._get_chain(c.chains[1]) == 1
assert c._get_chain(c.chains[0])[0] == 0
assert c._get_chain(c.chains[1])[0] == 1
assert len(c._get_chain(c.chains[0])) == 1
assert len(c._get_chain(c.chains[1])) == 1

def test_summary_bad_input1(self):
with pytest.raises(AssertionError):
Expand Down Expand Up @@ -165,7 +167,7 @@ def test_remove_multiple_chains3(self):

def test_remove_multiple_chains_fails(self):
with pytest.raises(AssertionError):
ChainConsumer().add_chain(self.data).remove_chain(chain=[0, 0])
ChainConsumer().add_chain(self.data).remove_chain(chain=[0,0])

def test_shade_alpha_algorithm1(self):
consumer = ChainConsumer()
Expand Down

0 comments on commit 13a6907

Please sign in to comment.