diff --git a/demes/demes.py b/demes/demes.py index cc0b3909..deb852a0 100644 --- a/demes/demes.py +++ b/demes/demes.py @@ -1143,6 +1143,36 @@ def time_span(self): """ return self.start_time - self.end_time + def size_at(self, time: float) -> float: + """ + Get the size of the deme at a given time. + + :param float time: The time at which the size should be calculated. + :return: The deme size. + :rtype: float + """ + for epoch in self.epochs: + if epoch.start_time >= time >= epoch.end_time: + break + else: + raise ValueError( + f"time {time} is outside deme {self.name}'s existence interval: " + f"({self.start_time}, {self.end_time}]" + ) + + if math.isclose(time, epoch.end_time) or epoch.size_function == "constant": + N = epoch.end_size + elif epoch.size_function == "exponential": + dt = (epoch.start_time - time) / epoch.time_span + r = math.log(epoch.end_size / epoch.start_size) + N = epoch.start_size * math.exp(r * dt) + elif epoch.size_function == "linear": + dt = (epoch.start_time - time) / epoch.time_span + N = epoch.start_size + (epoch.end_size - epoch.start_size) * dt + else: + raise NotImplementedError(f"unknown size_function '{epoch.size_function}'") + return N + @attr.s(auto_attribs=True, kw_only=True, slots=True) class Graph: diff --git a/demes/hypothesis_strategies.py b/demes/hypothesis_strategies.py index 1e781683..4b89aa60 100644 --- a/demes/hypothesis_strategies.py +++ b/demes/hypothesis_strategies.py @@ -89,6 +89,7 @@ def epochs_lists( max_epochs=5, min_deme_size=FLOAT_EPS, max_deme_size=FLOAT_MAX, + size_functions=None, ): """ A hypothesis strategy for creating lists of Epochs for a deme. @@ -96,6 +97,8 @@ def epochs_lists( :param float start_time: The start time of the deme. :param int max_epochs: The maximum number of epochs in the list. """ + if size_functions is None: + size_functions = ["constant", "exponential", "linear"] assert max_epochs >= 2 times = draw( st.lists( @@ -117,8 +120,15 @@ def epochs_lists( start_size = draw(st.floats(min_value=min_deme_size, max_value=max_deme_size)) if i == 0 and math.isinf(start_time): end_size = start_size + size_function = "constant" else: - end_size = draw(st.floats(min_value=min_deme_size, max_value=max_deme_size)) + size_function = draw(st.sampled_from(size_functions)) + if size_function == "constant": + end_size = start_size + else: + end_size = draw( + st.floats(min_value=min_deme_size, max_value=max_deme_size) + ) cloning_rate = draw(st.floats(min_value=0, max_value=1)) selfing_rate = draw(st.floats(min_value=0, max_value=prec32(1 - cloning_rate))) @@ -127,6 +137,7 @@ def epochs_lists( end_time=end_time, start_size=start_size, end_size=end_size, + size_function=size_function, cloning_rate=cloning_rate, selfing_rate=selfing_rate, ) @@ -319,6 +330,7 @@ def graphs( max_pulses=10, min_deme_size=FLOAT_EPS, max_deme_size=FLOAT_MAX, + size_functions=None, ): """ A hypothesis strategy for creating a Graph. @@ -336,6 +348,7 @@ def test_something(self, graph: demes.Graph): :param int max_pulses: The maximum number of pulses in the graph. :param float min_deme_size: The minimum size of a deme in any epoch. :param float max_deme_size: The maximum size of a deme in any epoch. + :param list size_functions: Allowable values for an epoch's size_function. """ generation_time = draw( st.none() | st.floats(min_value=FLOAT_EPS, max_value=FLOAT_MAX) @@ -418,6 +431,7 @@ def test_something(self, graph: demes.Graph): max_epochs=max_epochs, min_deme_size=min_deme_size, max_deme_size=max_deme_size, + size_functions=size_functions, ) ), start_time=start_time, diff --git a/setup.cfg b/setup.cfg index f7d85777..406a773e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,3 +54,6 @@ ignore_missing_imports = True [tool.black] target_version = py36 +[tool:pytest] +addopts = -n auto +testpaths = tests diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..1491a7ea --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,24 @@ +import pathlib +import functools + +import demes + + +@functools.lru_cache(maxsize=None) +def example_files(): + cwd = pathlib.Path(__file__).parent.resolve() + example_dir = cwd / "../examples" + files = list(example_dir.glob("**/*.yaml")) + files += list(example_dir.glob("**/*.yml")) + assert len(files) > 1 + return files + + +@functools.lru_cache(maxsize=None) +def example_graphs(): + return [demes.load(fn) for fn in example_files()] + + +@functools.lru_cache(maxsize=None) +def example_demes(): + return [deme for graph in example_graphs() for deme in graph.demes] diff --git a/tests/test_demes.py b/tests/test_demes.py index 486f8a84..fd7b6010 100644 --- a/tests/test_demes.py +++ b/tests/test_demes.py @@ -23,6 +23,7 @@ ) import demes import demes.hypothesis_strategies +import tests class TestEpoch(unittest.TestCase): @@ -1750,6 +1751,67 @@ def test_isclose(self): ) +class TestDemeSizeAt: + @pytest.mark.parametrize("deme", tests.example_demes()) + def test_deme_start_and_end_times(self, deme): + N = deme.size_at(deme.start_time) + assert N == deme.epochs[0].start_size + N = deme.size_at(deme.end_time) + assert N == deme.epochs[-1].end_size + + @pytest.mark.parametrize("deme", tests.example_demes()) + def test_times_within_each_epoch(self, deme): + for epoch in deme.epochs: + if math.isinf(epoch.start_time): + # The deme has the same size from end_time back to infinity. + for t in [epoch.end_time, epoch.end_time + 100, math.inf]: + N = deme.size_at(t) + assert N == epoch.start_size + else: + # Recalling that an epoch spans over the open-closed interval + # (start_time, end_time], we test several times in this range. + dt = epoch.start_time - epoch.end_time + r = math.log(epoch.end_size / epoch.start_size) + for p in [0, 1e-6, 1 / 3, 0.1, 1 - 1e-6]: + t = epoch.end_time + p * dt + N = deme.size_at(t) + if epoch.size_function == "constant": + assert N == epoch.start_size + elif epoch.size_function == "exponential": + expected_N = epoch.start_size * math.exp(r * (1 - p)) + assert math.isclose(N, expected_N) + elif epoch.size_function == "linear": + expected_N = epoch.start_size + ( + epoch.end_size - epoch.start_size + ) * (1 - p) + assert math.isclose(N, expected_N) + else: + raise AssertionError( + f"No tests for size_function '{epoch.size_function}'" + ) + + def test_bad_time(self): + b = demes.Builder() + b.add_deme("A", epochs=[dict(start_size=1, end_time=100)]) + b.add_deme("B", ancestors=["A"], epochs=[dict(start_size=1)]) + graph = b.resolve() + with pytest.raises(ValueError, match="existence interval"): + graph["A"].size_at(10) + with pytest.raises(ValueError, match="existence interval"): + graph["B"].size_at(200) + + def test_unknown_size_function(self): + b = demes.Builder() + b.add_deme( + "A", + epochs=[dict(start_size=1, end_time=100), dict(start_size=1, end_size=50)], + ) + graph = b.resolve() + graph["A"].epochs[-1].size_function = "foo" + with pytest.raises(NotImplementedError, match="size_function"): + graph["A"].size_at(10) + + class TestGraph(unittest.TestCase): def test_bad_generation_time(self): for generation_time in ([], {}, "42", "inf", math.nan):