diff --git a/chainconsumer/helpers.py b/chainconsumer/helpers.py index 889fc3e8..b280efe3 100644 --- a/chainconsumer/helpers.py +++ b/chainconsumer/helpers.py @@ -9,7 +9,7 @@ def get_extents(data, weight, plot=False, wide_extents=True): icdf = (1 - cdf)[::-1] icdf = icdf / icdf.max() cdf = 1 - icdf[::-1] - threshold = 1e-3 if plot else 1e-5 + threshold = 1e-4 if plot else 1e-5 if plot and not wide_extents: threshold = 0.05 i1 = np.where(cdf > threshold)[0][0] diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 99b92c1a..f6ac5355 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -91,6 +91,20 @@ def test_summary_specific(self): diff = np.abs(expected - actual) assert np.all(diff < tolerance) + def test_summary_disjoint(self): + tolerance = 5e-2 + consumer = ChainConsumer() + consumer.add_chain(self.data, parameters="A") + consumer.add_chain(self.data, parameters="B") + consumer.configure(bins=0.8) + summary = consumer.analysis.get_summary(parameters="A") + assert len(summary) == 2 # Two chains + assert summary[1] == {} # Second chain doesnt have param A + actual = summary[0]["A"] + expected = np.array([3.5, 5.0, 6.5]) + diff = np.abs(expected - actual) + assert np.all(diff < tolerance) + def test_output_text(self): consumer = ChainConsumer() consumer.add_chain(self.data, parameters=["a"]) diff --git a/tests/test_plotter.py b/tests/test_plotter.py index 0160c95a..f26b51cf 100644 --- a/tests/test_plotter.py +++ b/tests/test_plotter.py @@ -14,8 +14,8 @@ def test_plotter_extents1(self): c.add_chain(self.data, parameters=["x"]) c.configure() minv, maxv = c.plotter._get_parameter_extents("x", c.chains) - assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1) - assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1) + assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2) + assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2) def test_plotter_extents2(self): c = ChainConsumer() @@ -23,8 +23,8 @@ def test_plotter_extents2(self): c.add_chain(self.data + 5, parameters=["y"]) c.configure() minv, maxv = c.plotter._get_parameter_extents("x", c.chains) - assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1) - assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1) + assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2) + assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2) def test_plotter_extents3(self): c = ChainConsumer() @@ -32,8 +32,8 @@ def test_plotter_extents3(self): c.add_chain(self.data + 5, parameters=["x"]) c.configure() minv, maxv = c.plotter._get_parameter_extents("x", c.chains) - assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1) - assert np.isclose(maxv, (10.0 + 1.5 * 3.1), atol=0.1) + assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2) + assert np.isclose(maxv, (10.0 + 1.5 * 3.7), atol=0.2) def test_plotter_extents4(self): c = ChainConsumer() @@ -41,8 +41,8 @@ def test_plotter_extents4(self): c.add_chain(self.data + 5, parameters=["y"]) c.configure() minv, maxv = c.plotter._get_parameter_extents("x", c.chains[:1]) - assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1) - assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1) + assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2) + assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2) def test_plotter_extents5(self): x, y = np.linspace(-3, 3, 200), np.linspace(-5, 5, 200) @@ -55,4 +55,4 @@ def test_plotter_extents5(self): c.configure() minv, maxv = c.plotter._get_parameter_extents("x", c.chains) assert np.isclose(minv, -3, atol=0.001) - assert np.isclose(maxv, 3, atol=0.001) \ No newline at end of file + assert np.isclose(maxv, 3, atol=0.001)