diff --git a/anesthetic/samples.py b/anesthetic/samples.py index 2b877395..e6e13400 100644 --- a/anesthetic/samples.py +++ b/anesthetic/samples.py @@ -127,46 +127,47 @@ def plot(self, ax, paramname_x, paramname_y=None, *args, **kwargs): kwargs['label'] = kwargs.get('label', self.label) if do_1d_plot: - xmin, xmax = self._limits(paramname_x) - if plot_type == 'kde': - plot = plot_1d - elif plot_type == 'hist': - plot = hist_1d - elif plot_type is None: - ax.plot([], []) - else: - raise NotImplementedError("plot_type is '%s', but must be in " - "{'kde', 'hist'}." % plot_type) if paramname_x in self and plot_type is not None: x = self[paramname_x].compress() + xmin, xmax = self._limits(paramname_x) + if plot_type == 'kde': + plot = plot_1d + elif plot_type == 'hist': + plot = hist_1d + else: + raise NotImplementedError("plot_type is '%s', but must be" + " in {'kde', 'hist'}." + % plot_type) return plot(ax, x, xmin=xmin, xmax=xmax, *args, **kwargs) - - else: - xmin, xmax = self._limits(paramname_x) - ymin, ymax = self._limits(paramname_y) - - if plot_type == 'kde': - nsamples = None - plot = contour_plot_2d - ax.scatter([], []) - elif plot_type == 'scatter': - nsamples = 500 - plot = scatter_plot_2d - ax.plot([], []) - elif plot_type is None: - ax.plot([], []) - ax.scatter([], []) else: - raise NotImplementedError("plot_type is '%s', but must be in " - "{'kde', 'scatter'}." % plot_type) + ax.plot([], []) + else: if (paramname_x in self and paramname_y in self and plot_type is not None): + if plot_type == 'kde': + nsamples = None + plot = contour_plot_2d + ax.scatter([], []) + elif plot_type == 'scatter': + nsamples = 500 + plot = scatter_plot_2d + ax.plot([], []) + else: + raise NotImplementedError("plot_type is '%s', but must be" + "in {'kde', 'scatter'}." + % plot_type) + xmin, xmax = self._limits(paramname_x) + ymin, ymax = self._limits(paramname_y) x = self[paramname_x].compress(nsamples) y = self[paramname_y].compress(nsamples) return plot(ax, x, y, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, *args, **kwargs) + else: + ax.plot([], []) + ax.scatter([], []) + def plot_1d(self, axes, *args, **kwargs): """Create an array of 1D plots. @@ -195,8 +196,7 @@ def plot_1d(self, axes, *args, **kwargs): fig = axes.values[~axes.isna()][0].figure for x, ax in axes.iteritems(): - if ax is not None and x in self: - self.plot(ax, x, *args, **kwargs) + self.plot(ax, x, *args, **kwargs) return fig, axes diff --git a/tests/test_samples.py b/tests/test_samples.py index cc3c4f5a..3945c500 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -8,6 +8,7 @@ from matplotlib.collections import PathCollection from anesthetic import MCMCSamples, NestedSamples, make_1d_axes, make_2d_axes from numpy.testing import assert_array_equal +from matplotlib.colors import to_hex try: import montepython # noqa: F401 except ImportError: @@ -64,6 +65,7 @@ def test_build_mcmc(): def test_read_getdist(): + numpy.random.seed(3) mcmc = MCMCSamples(root='./tests/example_data/gd') mcmc.plot_2d(['x0', 'x1', 'x2', 'x3']) mcmc.plot_1d(['x0', 'x1', 'x2', 'x3']) @@ -78,6 +80,7 @@ def test_read_getdist(): raises=ImportError, reason="requires montepython package") def test_read_montepython(): + numpy.random.seed(3) mcmc = MCMCSamples(root='./tests/example_data/mp') mcmc.plot_2d(['x0', 'x1', 'x2', 'x3']) mcmc.plot_1d(['x0', 'x1', 'x2', 'x3']) @@ -85,6 +88,7 @@ def test_read_montepython(): def test_read_multinest(): + numpy.random.seed(3) ns = NestedSamples(root='./tests/example_data/mn') ns.plot_2d(['x0', 'x1', 'x2', 'x3']) ns.plot_1d(['x0', 'x1', 'x2', 'x3']) @@ -96,6 +100,7 @@ def test_read_multinest(): def test_read_polychord(): + numpy.random.seed(3) ns = NestedSamples(root='./tests/example_data/pc') ns.plot_2d(['x0', 'x1', 'x2', 'x3']) ns.plot_1d(['x0', 'x1', 'x2', 'x3']) @@ -103,6 +108,7 @@ def test_read_polychord(): def test_different_parameters(): + numpy.random.seed(3) params_x = ['x0', 'x1', 'x2', 'x3', 'x4'] params_y = ['x0', 'x1', 'x2'] fig, axes = make_1d_axes(params_x) @@ -118,6 +124,7 @@ def test_different_parameters(): def test_plot_2d_types(): + numpy.random.seed(3) ns = NestedSamples(root='./tests/example_data/pc') params_x = ['x0', 'x1', 'x2', 'x3'] params_y = ['x0', 'x1', 'x2'] @@ -152,6 +159,7 @@ def test_plot_2d_types(): def test_plot_2d_types_multiple_calls(): + numpy.random.seed(3) ns = NestedSamples(root='./tests/example_data/pc') params = ['x0', 'x1', 'x2', 'x3'] @@ -168,6 +176,7 @@ def test_plot_2d_types_multiple_calls(): def test_root_and_label(): + numpy.random.seed(3) ns = NestedSamples(root='./tests/example_data/pc') assert(ns.root == './tests/example_data/pc') assert(ns.label == 'pc') @@ -186,6 +195,7 @@ def test_root_and_label(): def test_plot_2d_legend(): + numpy.random.seed(3) ns = NestedSamples(root='./tests/example_data/pc') mc = MCMCSamples(root='./tests/example_data/gd') params = ['x0', 'x1', 'x2', 'x3'] @@ -250,3 +260,83 @@ def test_plot_2d_legend(): handles, labels = ax.get_legend_handles_labels() assert(labels == ['l1', 'l2']) plt.close('all') + + +def test_plot_2d_colours(): + numpy.random.seed(3) + gd = MCMCSamples(root="./tests/example_data/gd") + gd.drop(columns='x3', inplace=True) + pc = NestedSamples(root="./tests/example_data/pc") + pc.drop(columns='x4', inplace=True) + mn = NestedSamples(root="./tests/example_data/mn") + mn.drop(columns='x2', inplace=True) + + fig = plt.figure() + fig, axes = make_2d_axes(['x0', 'x1', 'x2', 'x3', 'x4'], fig=fig) + gd.plot_2d(axes, label="gd") + pc.plot_2d(axes, label="pc") + mn.plot_2d(axes, label="mn") + gd_colors = [] + pc_colors = [] + mn_colors = [] + for y, rows in axes.iterrows(): + for x, ax in rows.iteritems(): + handles, labels = ax.get_legend_handles_labels() + for handle, label in zip(handles, labels): + if isinstance(handle, Rectangle): + color = to_hex(handle.get_facecolor()) + elif isinstance(handle, PathCollection): + color = to_hex(handle.get_facecolor()[0]) + else: + color = handle.get_color() + + if label == 'gd': + gd_colors.append(color) + elif label == 'pc': + pc_colors.append(color) + elif label == 'mn': + mn_colors.append(color) + + assert(len(set(gd_colors)) == 1) + assert(len(set(mn_colors)) == 1) + assert(len(set(pc_colors)) == 1) + + +def test_plot_1d_colours(): + numpy.random.seed(3) + gd = MCMCSamples(root="./tests/example_data/gd") + gd.drop(columns='x3', inplace=True) + pc = NestedSamples(root="./tests/example_data/pc") + pc.drop(columns='x4', inplace=True) + mn = NestedSamples(root="./tests/example_data/mn") + mn.drop(columns='x2', inplace=True) + + for plot_type in ['kde', 'hist']: + fig = plt.figure() + fig, axes = make_1d_axes(['x0', 'x1', 'x2', 'x3', 'x4'], fig=fig) + gd.plot_1d(axes, plot_type=plot_type, label="gd") + pc.plot_1d(axes, plot_type=plot_type, label="pc") + mn.plot_1d(axes, plot_type=plot_type, label="mn") + gd_colors = [] + pc_colors = [] + mn_colors = [] + for x, ax in axes.iteritems(): + handles, labels = ax.get_legend_handles_labels() + for handle, label in zip(handles, labels): + if isinstance(handle, Rectangle): + color = to_hex(handle.get_facecolor()) + elif isinstance(handle, PathCollection): + color = to_hex(handle.get_facecolor()[0]) + else: + color = handle.get_color() + + if label == 'gd': + gd_colors.append(color) + elif label == 'pc': + pc_colors.append(color) + elif label == 'mn': + mn_colors.append(color) + + assert(len(set(gd_colors)) == 1) + assert(len(set(mn_colors)) == 1) + assert(len(set(pc_colors)) == 1)