Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to colouring for missing plots #45

Merged
merged 6 commits into from
Jul 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,46 +127,47 @@ def plot(self, ax, paramname_x, paramname_y=None, *args, **kwargs):
kwargs['label'] = kwargs.get('label', self.label)

if do_1d_plot:
xmin, xmax = self._limits(paramname_x)
if plot_type == 'kde':
plot = plot_1d
elif plot_type == 'hist':
plot = hist_1d
elif plot_type is None:
ax.plot([], [])
else:
raise NotImplementedError("plot_type is '%s', but must be in "
"{'kde', 'hist'}." % plot_type)
if paramname_x in self and plot_type is not None:
x = self[paramname_x].compress()
xmin, xmax = self._limits(paramname_x)
if plot_type == 'kde':
plot = plot_1d
elif plot_type == 'hist':
plot = hist_1d
else:
raise NotImplementedError("plot_type is '%s', but must be"
" in {'kde', 'hist'}."
% plot_type)
return plot(ax, x, xmin=xmin, xmax=xmax, *args, **kwargs)

else:
xmin, xmax = self._limits(paramname_x)
ymin, ymax = self._limits(paramname_y)

if plot_type == 'kde':
nsamples = None
plot = contour_plot_2d
ax.scatter([], [])
elif plot_type == 'scatter':
nsamples = 500
plot = scatter_plot_2d
ax.plot([], [])
elif plot_type is None:
ax.plot([], [])
ax.scatter([], [])
else:
raise NotImplementedError("plot_type is '%s', but must be in "
"{'kde', 'scatter'}." % plot_type)
ax.plot([], [])

else:
if (paramname_x in self and paramname_y in self
and plot_type is not None):
if plot_type == 'kde':
nsamples = None
plot = contour_plot_2d
ax.scatter([], [])
elif plot_type == 'scatter':
nsamples = 500
plot = scatter_plot_2d
ax.plot([], [])
else:
raise NotImplementedError("plot_type is '%s', but must be"
"in {'kde', 'scatter'}."
% plot_type)
xmin, xmax = self._limits(paramname_x)
ymin, ymax = self._limits(paramname_y)
x = self[paramname_x].compress(nsamples)
y = self[paramname_y].compress(nsamples)
return plot(ax, x, y, xmin=xmin, xmax=xmax, ymin=ymin,
ymax=ymax, *args, **kwargs)

else:
ax.plot([], [])
ax.scatter([], [])

def plot_1d(self, axes, *args, **kwargs):
"""Create an array of 1D plots.

Expand Down Expand Up @@ -195,8 +196,7 @@ def plot_1d(self, axes, *args, **kwargs):
fig = axes.values[~axes.isna()][0].figure

for x, ax in axes.iteritems():
if ax is not None and x in self:
self.plot(ax, x, *args, **kwargs)
self.plot(ax, x, *args, **kwargs)

return fig, axes

Expand Down
90 changes: 90 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from matplotlib.collections import PathCollection
from anesthetic import MCMCSamples, NestedSamples, make_1d_axes, make_2d_axes
from numpy.testing import assert_array_equal
from matplotlib.colors import to_hex
try:
import montepython # noqa: F401
except ImportError:
Expand Down Expand Up @@ -64,6 +65,7 @@ def test_build_mcmc():


def test_read_getdist():
numpy.random.seed(3)
mcmc = MCMCSamples(root='./tests/example_data/gd')
mcmc.plot_2d(['x0', 'x1', 'x2', 'x3'])
mcmc.plot_1d(['x0', 'x1', 'x2', 'x3'])
Expand All @@ -78,13 +80,15 @@ def test_read_getdist():
raises=ImportError,
reason="requires montepython package")
def test_read_montepython():
numpy.random.seed(3)
mcmc = MCMCSamples(root='./tests/example_data/mp')
mcmc.plot_2d(['x0', 'x1', 'x2', 'x3'])
mcmc.plot_1d(['x0', 'x1', 'x2', 'x3'])
plt.close("all")


def test_read_multinest():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/mn')
ns.plot_2d(['x0', 'x1', 'x2', 'x3'])
ns.plot_1d(['x0', 'x1', 'x2', 'x3'])
Expand All @@ -96,13 +100,15 @@ def test_read_multinest():


def test_read_polychord():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
ns.plot_2d(['x0', 'x1', 'x2', 'x3'])
ns.plot_1d(['x0', 'x1', 'x2', 'x3'])
plt.close("all")


def test_different_parameters():
numpy.random.seed(3)
params_x = ['x0', 'x1', 'x2', 'x3', 'x4']
params_y = ['x0', 'x1', 'x2']
fig, axes = make_1d_axes(params_x)
Expand All @@ -118,6 +124,7 @@ def test_different_parameters():


def test_plot_2d_types():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
params_x = ['x0', 'x1', 'x2', 'x3']
params_y = ['x0', 'x1', 'x2']
Expand Down Expand Up @@ -152,6 +159,7 @@ def test_plot_2d_types():


def test_plot_2d_types_multiple_calls():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
params = ['x0', 'x1', 'x2', 'x3']

Expand All @@ -168,6 +176,7 @@ def test_plot_2d_types_multiple_calls():


def test_root_and_label():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
assert(ns.root == './tests/example_data/pc')
assert(ns.label == 'pc')
Expand All @@ -186,6 +195,7 @@ def test_root_and_label():


def test_plot_2d_legend():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
mc = MCMCSamples(root='./tests/example_data/gd')
params = ['x0', 'x1', 'x2', 'x3']
Expand Down Expand Up @@ -250,3 +260,83 @@ def test_plot_2d_legend():
handles, labels = ax.get_legend_handles_labels()
assert(labels == ['l1', 'l2'])
plt.close('all')


def test_plot_2d_colours():
numpy.random.seed(3)
gd = MCMCSamples(root="./tests/example_data/gd")
gd.drop(columns='x3', inplace=True)
pc = NestedSamples(root="./tests/example_data/pc")
pc.drop(columns='x4', inplace=True)
mn = NestedSamples(root="./tests/example_data/mn")
mn.drop(columns='x2', inplace=True)

fig = plt.figure()
fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3', 'x4'], fig=fig)
gd.plot_2d(axes, label="gd")
pc.plot_2d(axes, label="pc")
mn.plot_2d(axes, label="mn")
gd_colors = []
pc_colors = []
mn_colors = []
for y, rows in axes.iterrows():
for x, ax in rows.iteritems():
handles, labels = ax.get_legend_handles_labels()
for handle, label in zip(handles, labels):
if isinstance(handle, Rectangle):
color = to_hex(handle.get_facecolor())
elif isinstance(handle, PathCollection):
color = to_hex(handle.get_facecolor()[0])
else:
color = handle.get_color()

if label == 'gd':
gd_colors.append(color)
elif label == 'pc':
pc_colors.append(color)
elif label == 'mn':
mn_colors.append(color)

assert(len(set(gd_colors)) == 1)
assert(len(set(mn_colors)) == 1)
assert(len(set(pc_colors)) == 1)


def test_plot_1d_colours():
numpy.random.seed(3)
gd = MCMCSamples(root="./tests/example_data/gd")
gd.drop(columns='x3', inplace=True)
pc = NestedSamples(root="./tests/example_data/pc")
pc.drop(columns='x4', inplace=True)
mn = NestedSamples(root="./tests/example_data/mn")
mn.drop(columns='x2', inplace=True)

for plot_type in ['kde', 'hist']:
fig = plt.figure()
fig, axes = make_1d_axes(['x0', 'x1', 'x2', 'x3', 'x4'], fig=fig)
gd.plot_1d(axes, plot_type=plot_type, label="gd")
pc.plot_1d(axes, plot_type=plot_type, label="pc")
mn.plot_1d(axes, plot_type=plot_type, label="mn")
gd_colors = []
pc_colors = []
mn_colors = []
for x, ax in axes.iteritems():
handles, labels = ax.get_legend_handles_labels()
for handle, label in zip(handles, labels):
if isinstance(handle, Rectangle):
color = to_hex(handle.get_facecolor())
elif isinstance(handle, PathCollection):
color = to_hex(handle.get_facecolor()[0])
else:
color = handle.get_color()

if label == 'gd':
gd_colors.append(color)
elif label == 'pc':
pc_colors.append(color)
elif label == 'mn':
mn_colors.append(color)

assert(len(set(gd_colors)) == 1)
assert(len(set(mn_colors)) == 1)
assert(len(set(pc_colors)) == 1)