Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to colouring for missing plots #45

Merged
merged 6 commits into from
Jul 27, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 35 additions & 30 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,46 +127,47 @@ def plot(self, ax, paramname_x, paramname_y=None, *args, **kwargs):
kwargs['label'] = kwargs.get('label', self.label)

if do_1d_plot:
xmin, xmax = self._limits(paramname_x)
if plot_type == 'kde':
plot = plot_1d
elif plot_type == 'hist':
plot = hist_1d
elif plot_type is None:
ax.plot([], [])
else:
raise NotImplementedError("plot_type is '%s', but must be in "
"{'kde', 'hist'}." % plot_type)
if paramname_x in self and plot_type is not None:
x = self[paramname_x].compress()
xmin, xmax = self._limits(paramname_x)
if plot_type == 'kde':
plot = plot_1d
elif plot_type == 'hist':
plot = hist_1d
else:
raise NotImplementedError("plot_type is '%s', but must be"
" in {'kde', 'hist'}."
% plot_type)
return plot(ax, x, xmin=xmin, xmax=xmax, *args, **kwargs)

else:
xmin, xmax = self._limits(paramname_x)
ymin, ymax = self._limits(paramname_y)

if plot_type == 'kde':
nsamples = None
plot = contour_plot_2d
ax.scatter([], [])
elif plot_type == 'scatter':
nsamples = 500
plot = scatter_plot_2d
ax.plot([], [])
elif plot_type is None:
ax.plot([], [])
ax.scatter([], [])
else:
raise NotImplementedError("plot_type is '%s', but must be in "
"{'kde', 'scatter'}." % plot_type)
ax.plot([], [])

else:
if (paramname_x in self and paramname_y in self
and plot_type is not None):
if plot_type == 'kde':
nsamples = None
plot = contour_plot_2d
ax.scatter([], [])
elif plot_type == 'scatter':
nsamples = 500
plot = scatter_plot_2d
ax.plot([], [])
else:
raise NotImplementedError("plot_type is '%s', but must be"
"in {'kde', 'scatter'}."
% plot_type)
xmin, xmax = self._limits(paramname_x)
ymin, ymax = self._limits(paramname_y)
x = self[paramname_x].compress(nsamples)
y = self[paramname_y].compress(nsamples)
return plot(ax, x, y, xmin=xmin, xmax=xmax, ymin=ymin,
ymax=ymax, *args, **kwargs)

else:
ax.plot([], [])
ax.scatter([], [])

def plot_1d(self, axes, *args, **kwargs):
"""Create an array of 1D plots.

Expand All @@ -180,6 +181,9 @@ def plot_1d(self, axes, *args, **kwargs):
this is used for creating the plot. Otherwise a new set of axes are
created using the list or lists of strings.

types: str
What types of plots to produce. Options are {'kde', 'hist'}

Returns
-------
fig: matplotlib.figure.Figure
Expand All @@ -189,14 +193,15 @@ def plot_1d(self, axes, *args, **kwargs):
Pandas array of axes objects

"""
plot_type = kwargs.pop('types', 'kde')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far I was using plot_type (which is popped in samples.plot) to decide what kind of 1D plot I want. Now, providing the kwarg plot_type results in

TypeError: plot() got multiple values for keyword argument 'plot_type'

Code to reproduce:

pc = NestedSamples(root="./tests/example_data/pc")
pc.plot_1d(['x0', 'x1'], label="pc", plot_type='hist');

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I had noticed that MCMCSamples.plot_1d was missing a 'types' argument, in analogy to MCMCSamples.plot_2d. In this PR I've made them more consistent in preparation for the finer-grained control suggested in #41 .

The fix is to pass types='hist' rather than plot_type='hist'. I acknowledge that this is technically a breaking change, although only to undocumented (and, tbh unknown on my part) functionality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the things I'm considering is deprecating the cryptic types arguments, with the plan to replace them with subplot_types, which is more informative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, the error actually also happens in the 2D case:

pc = NestedSamples(root="./tests/example_data/pc")
fig, axes = pc.plot_2d(['x0', 'x1', 'x4', 'x5'], label="pc", plot_type='kde');

Might it be more consistent to allow plot_type as a kwarg for plot_2d that broadcasts the provided string to all subplots?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to fix this, but having plot_type and types in parallel like this seems wrong to me. If types is the one intended to be passed to plot_1d then plot_type should not be a valid kwarg.

I acknowledge that this is technically a breaking change, although only to undocumented (and, tbh unknown on my part) functionality.

Not undocumented with #24. Actually, plot_type was one of my major motivations for suggesting #24.

In this PR I've made them more consistent in preparation for the finer-grained control suggested in #41 .

The thing is, plot_type and types don't quite have the same role, so I'm not sure whether this is really more consistent. The need for types arises because in the 2D case we distinguish between 3 different families of plots (diagonal, lower and upper) which all need a plot_type assignment. This is a 2D specific feature and, thus, types is unnecessary in the 1D case.

One of the things I'm considering is deprecating the cryptic types arguments, with the plan to replace them with subplot_types, which is more informative.

Replacing types with a more general subplot_types that determines e.g. which plot type for which parameter would be an option. At the same type plot_type could be replaced with e.g. default_plot_type that broadcasts the plot type across all subplots and subplot_types can then be used to pick and alter specific parameters.

Copy link
Collaborator

@lukashergt lukashergt Jul 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, might it be better to revert this specific change and address all this in a separate PR that targets #41 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, reverted in 61e3f6b

if not isinstance(axes, pandas.Series):
fig, axes = make_1d_axes(axes, tex=self.tex)
else:
fig = axes.values[~axes.isna()][0].figure

for x, ax in axes.iteritems():
if ax is not None and x in self:
self.plot(ax, x, *args, **kwargs)
self.plot(ax, x, plot_type=plot_type, *args, **kwargs)

return fig, axes

Expand Down
90 changes: 90 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from matplotlib.collections import PathCollection
from anesthetic import MCMCSamples, NestedSamples, make_1d_axes, make_2d_axes
from numpy.testing import assert_array_equal
from matplotlib.colors import to_hex
try:
import montepython # noqa: F401
except ImportError:
Expand Down Expand Up @@ -64,6 +65,7 @@ def test_build_mcmc():


def test_read_getdist():
numpy.random.seed(3)
mcmc = MCMCSamples(root='./tests/example_data/gd')
mcmc.plot_2d(['x0', 'x1', 'x2', 'x3'])
mcmc.plot_1d(['x0', 'x1', 'x2', 'x3'])
Expand All @@ -78,13 +80,15 @@ def test_read_getdist():
raises=ImportError,
reason="requires montepython package")
def test_read_montepython():
numpy.random.seed(3)
mcmc = MCMCSamples(root='./tests/example_data/mp')
mcmc.plot_2d(['x0', 'x1', 'x2', 'x3'])
mcmc.plot_1d(['x0', 'x1', 'x2', 'x3'])
plt.close("all")


def test_read_multinest():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/mn')
ns.plot_2d(['x0', 'x1', 'x2', 'x3'])
ns.plot_1d(['x0', 'x1', 'x2', 'x3'])
Expand All @@ -96,13 +100,15 @@ def test_read_multinest():


def test_read_polychord():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
ns.plot_2d(['x0', 'x1', 'x2', 'x3'])
ns.plot_1d(['x0', 'x1', 'x2', 'x3'])
plt.close("all")


def test_different_parameters():
numpy.random.seed(3)
params_x = ['x0', 'x1', 'x2', 'x3', 'x4']
params_y = ['x0', 'x1', 'x2']
fig, axes = make_1d_axes(params_x)
Expand All @@ -118,6 +124,7 @@ def test_different_parameters():


def test_plot_2d_types():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
params_x = ['x0', 'x1', 'x2', 'x3']
params_y = ['x0', 'x1', 'x2']
Expand Down Expand Up @@ -152,6 +159,7 @@ def test_plot_2d_types():


def test_plot_2d_types_multiple_calls():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
params = ['x0', 'x1', 'x2', 'x3']

Expand All @@ -168,6 +176,7 @@ def test_plot_2d_types_multiple_calls():


def test_root_and_label():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
assert(ns.root == './tests/example_data/pc')
assert(ns.label == 'pc')
Expand All @@ -186,6 +195,7 @@ def test_root_and_label():


def test_plot_2d_legend():
numpy.random.seed(3)
ns = NestedSamples(root='./tests/example_data/pc')
mc = MCMCSamples(root='./tests/example_data/gd')
params = ['x0', 'x1', 'x2', 'x3']
Expand Down Expand Up @@ -250,3 +260,83 @@ def test_plot_2d_legend():
handles, labels = ax.get_legend_handles_labels()
assert(labels == ['l1', 'l2'])
plt.close('all')


def test_plot_2d_colours():
numpy.random.seed(3)
gd = MCMCSamples(root="./tests/example_data/gd")
gd.drop(columns='x3', inplace=True)
pc = NestedSamples(root="./tests/example_data/pc")
pc.drop(columns='x4', inplace=True)
mn = NestedSamples(root="./tests/example_data/mn")
mn.drop(columns='x2', inplace=True)

fig = plt.figure()
fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3', 'x4'], fig=fig)
gd.plot_2d(axes, label="gd")
pc.plot_2d(axes, label="pc")
mn.plot_2d(axes, label="mn")
gd_colors = []
pc_colors = []
mn_colors = []
for y, rows in axes.iterrows():
for x, ax in rows.iteritems():
handles, labels = ax.get_legend_handles_labels()
for handle, label in zip(handles, labels):
if isinstance(handle, Rectangle):
color = to_hex(handle.get_facecolor())
elif isinstance(handle, PathCollection):
color = to_hex(handle.get_facecolor()[0])
else:
color = handle.get_color()

if label == 'gd':
gd_colors.append(color)
elif label == 'pc':
pc_colors.append(color)
elif label == 'mn':
mn_colors.append(color)

assert(len(set(gd_colors)) == 1)
assert(len(set(mn_colors)) == 1)
assert(len(set(pc_colors)) == 1)


def test_plot_1d_colours():
numpy.random.seed(3)
gd = MCMCSamples(root="./tests/example_data/gd")
gd.drop(columns='x3', inplace=True)
pc = NestedSamples(root="./tests/example_data/pc")
pc.drop(columns='x4', inplace=True)
mn = NestedSamples(root="./tests/example_data/mn")
mn.drop(columns='x2', inplace=True)

for types in ['kde', 'hist']:
fig = plt.figure()
fig, axes = make_1d_axes(['x0', 'x1', 'x2', 'x3', 'x4'], fig=fig)
gd.plot_1d(axes, types=types, label="gd")
pc.plot_1d(axes, types=types, label="pc")
mn.plot_1d(axes, types=types, label="mn")
gd_colors = []
pc_colors = []
mn_colors = []
for x, ax in axes.iteritems():
handles, labels = ax.get_legend_handles_labels()
for handle, label in zip(handles, labels):
if isinstance(handle, Rectangle):
color = to_hex(handle.get_facecolor())
elif isinstance(handle, PathCollection):
color = to_hex(handle.get_facecolor()[0])
else:
color = handle.get_color()

if label == 'gd':
gd_colors.append(color)
elif label == 'pc':
pc_colors.append(color)
elif label == 'mn':
mn_colors.append(color)

assert(len(set(gd_colors)) == 1)
assert(len(set(mn_colors)) == 1)
assert(len(set(pc_colors)) == 1)