Skip to content

Commit

Permalink
Expanding plot extents slightly due to change in sigma2d default. A…
Browse files Browse the repository at this point in the history
…dding test for disjoint parameter summaries.
  • Loading branch information
Samreay committed Mar 6, 2018
1 parent 4b7b92d commit 20e48aa
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion chainconsumer/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_extents(data, weight, plot=False, wide_extents=True):
icdf = (1 - cdf)[::-1]
icdf = icdf / icdf.max()
cdf = 1 - icdf[::-1]
threshold = 1e-3 if plot else 1e-5
threshold = 1e-4 if plot else 1e-5
if plot and not wide_extents:
threshold = 0.05
i1 = np.where(cdf > threshold)[0][0]
Expand Down
14 changes: 14 additions & 0 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ def test_summary_specific(self):
diff = np.abs(expected - actual)
assert np.all(diff < tolerance)

def test_summary_disjoint(self):
tolerance = 5e-2
consumer = ChainConsumer()
consumer.add_chain(self.data, parameters="A")
consumer.add_chain(self.data, parameters="B")
consumer.configure(bins=0.8)
summary = consumer.analysis.get_summary(parameters="A")
assert len(summary) == 2 # Two chains
assert summary[1] == {} # Second chain doesnt have param A
actual = summary[0]["A"]
expected = np.array([3.5, 5.0, 6.5])
diff = np.abs(expected - actual)
assert np.all(diff < tolerance)

def test_output_text(self):
consumer = ChainConsumer()
consumer.add_chain(self.data, parameters=["a"])
Expand Down
18 changes: 9 additions & 9 deletions tests/test_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,35 @@ def test_plotter_extents1(self):
c.add_chain(self.data, parameters=["x"])
c.configure()
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1)
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2)

def test_plotter_extents2(self):
c = ChainConsumer()
c.add_chain(self.data, parameters=["x"])
c.add_chain(self.data + 5, parameters=["y"])
c.configure()
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1)
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2)

def test_plotter_extents3(self):
c = ChainConsumer()
c.add_chain(self.data, parameters=["x"])
c.add_chain(self.data + 5, parameters=["x"])
c.configure()
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
assert np.isclose(maxv, (10.0 + 1.5 * 3.1), atol=0.1)
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
assert np.isclose(maxv, (10.0 + 1.5 * 3.7), atol=0.2)

def test_plotter_extents4(self):
c = ChainConsumer()
c.add_chain(self.data, parameters=["x"])
c.add_chain(self.data + 5, parameters=["y"])
c.configure()
minv, maxv = c.plotter._get_parameter_extents("x", c.chains[:1])
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1)
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2)

def test_plotter_extents5(self):
x, y = np.linspace(-3, 3, 200), np.linspace(-5, 5, 200)
Expand All @@ -55,4 +55,4 @@ def test_plotter_extents5(self):
c.configure()
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
assert np.isclose(minv, -3, atol=0.001)
assert np.isclose(maxv, 3, atol=0.001)
assert np.isclose(maxv, 3, atol=0.001)

0 comments on commit 20e48aa

Please sign in to comment.