Skip to content

Commit

Permalink
Add Deme.size_at() to compute the deme size at a given time.
Browse files Browse the repository at this point in the history
Closes #312.
  • Loading branch information
grahamgower committed Jun 4, 2021
1 parent 083dcea commit a9db443
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 1 deletion.
30 changes: 30 additions & 0 deletions demes/demes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,36 @@ def time_span(self):
"""
return self.start_time - self.end_time

def size_at(self, time: float) -> float:
"""
Get the size of the deme at a given time.
:param float time: The time at which the size should be calculated.
:return: The deme size.
:rtype: float
"""
for epoch in self.epochs:
if epoch.start_time >= time >= epoch.end_time:
break
else:
raise ValueError(
f"time {time} is outside deme {self.name}'s existence interval: "
f"({self.start_time}, {self.end_time}]"
)

if math.isclose(time, epoch.end_time) or epoch.size_function == "constant":
N = epoch.end_size
elif epoch.size_function == "exponential":
dt = (epoch.start_time - time) / epoch.time_span
r = math.log(epoch.end_size / epoch.start_size)
N = epoch.start_size * math.exp(r * dt)
elif epoch.size_function == "linear":
dt = (epoch.start_time - time) / epoch.time_span
N = epoch.start_size + (epoch.end_size - epoch.start_size) * dt
else:
raise NotImplementedError(f"unknown size_function '{epoch.size_function}'")
return N


@attr.s(auto_attribs=True, kw_only=True, slots=True)
class Graph:
Expand Down
16 changes: 15 additions & 1 deletion demes/hypothesis_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,16 @@ def epochs_lists(
max_epochs=5,
min_deme_size=FLOAT_EPS,
max_deme_size=FLOAT_MAX,
size_functions=None,
):
"""
A hypothesis strategy for creating lists of Epochs for a deme.
:param float start_time: The start time of the deme.
:param int max_epochs: The maximum number of epochs in the list.
"""
if size_functions is None:
size_functions = ["constant", "exponential", "linear"]
assert max_epochs >= 2
times = draw(
st.lists(
Expand All @@ -117,8 +120,15 @@ def epochs_lists(
start_size = draw(st.floats(min_value=min_deme_size, max_value=max_deme_size))
if i == 0 and math.isinf(start_time):
end_size = start_size
size_function = "constant"
else:
end_size = draw(st.floats(min_value=min_deme_size, max_value=max_deme_size))
size_function = draw(st.sampled_from(size_functions))
if size_function == "constant":
end_size = start_size
else:
end_size = draw(
st.floats(min_value=min_deme_size, max_value=max_deme_size)
)
cloning_rate = draw(st.floats(min_value=0, max_value=1))
selfing_rate = draw(st.floats(min_value=0, max_value=prec32(1 - cloning_rate)))

Expand All @@ -127,6 +137,7 @@ def epochs_lists(
end_time=end_time,
start_size=start_size,
end_size=end_size,
size_function=size_function,
cloning_rate=cloning_rate,
selfing_rate=selfing_rate,
)
Expand Down Expand Up @@ -319,6 +330,7 @@ def graphs(
max_pulses=10,
min_deme_size=FLOAT_EPS,
max_deme_size=FLOAT_MAX,
size_functions=None,
):
"""
A hypothesis strategy for creating a Graph.
Expand All @@ -336,6 +348,7 @@ def test_something(self, graph: demes.Graph):
:param int max_pulses: The maximum number of pulses in the graph.
:param float min_deme_size: The minimum size of a deme in any epoch.
:param float max_deme_size: The maximum size of a deme in any epoch.
:param list size_functions: Allowable values for an epoch's size_function.
"""
generation_time = draw(
st.none() | st.floats(min_value=FLOAT_EPS, max_value=FLOAT_MAX)
Expand Down Expand Up @@ -418,6 +431,7 @@ def test_something(self, graph: demes.Graph):
max_epochs=max_epochs,
min_deme_size=min_deme_size,
max_deme_size=max_deme_size,
size_functions=size_functions,
)
),
start_time=start_time,
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ ignore_missing_imports = True
[tool.black]
target_version = py36

[tool:pytest]
addopts = -n auto
testpaths = tests
24 changes: 24 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pathlib
import functools

import demes


@functools.lru_cache(maxsize=None)
def example_files():
cwd = pathlib.Path(__file__).parent.resolve()
example_dir = cwd / "../examples"
files = list(example_dir.glob("**/*.yaml"))
files += list(example_dir.glob("**/*.yml"))
assert len(files) > 1
return files


@functools.lru_cache(maxsize=None)
def example_graphs():
return [demes.load(fn) for fn in example_files()]


@functools.lru_cache(maxsize=None)
def example_demes():
return [deme for graph in example_graphs() for deme in graph.demes]
62 changes: 62 additions & 0 deletions tests/test_demes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
import demes
import demes.hypothesis_strategies
import tests


class TestEpoch(unittest.TestCase):
Expand Down Expand Up @@ -1750,6 +1751,67 @@ def test_isclose(self):
)


class TestDemeSizeAt:
@pytest.mark.parametrize("deme", tests.example_demes())
def test_deme_start_and_end_times(self, deme):
N = deme.size_at(deme.start_time)
assert N == deme.epochs[0].start_size
N = deme.size_at(deme.end_time)
assert N == deme.epochs[-1].end_size

@pytest.mark.parametrize("deme", tests.example_demes())
def test_times_within_each_epoch(self, deme):
for epoch in deme.epochs:
if math.isinf(epoch.start_time):
# The deme has the same size from end_time back to infinity.
for t in [epoch.end_time, epoch.end_time + 100, math.inf]:
N = deme.size_at(t)
assert N == epoch.start_size
else:
# Recalling that an epoch spans over the open-closed interval
# (start_time, end_time], we test several times in this range.
dt = epoch.start_time - epoch.end_time
r = math.log(epoch.end_size / epoch.start_size)
for p in [0, 1e-6, 1 / 3, 0.1, 1 - 1e-6]:
t = epoch.end_time + p * dt
N = deme.size_at(t)
if epoch.size_function == "constant":
assert N == epoch.start_size
elif epoch.size_function == "exponential":
expected_N = epoch.start_size * math.exp(r * (1 - p))
assert math.isclose(N, expected_N)
elif epoch.size_function == "linear":
expected_N = epoch.start_size + (
epoch.end_size - epoch.start_size
) * (1 - p)
assert math.isclose(N, expected_N)
else:
raise AssertionError(
f"No tests for size_function '{epoch.size_function}'"
)

def test_bad_time(self):
b = demes.Builder()
b.add_deme("A", epochs=[dict(start_size=1, end_time=100)])
b.add_deme("B", ancestors=["A"], epochs=[dict(start_size=1)])
graph = b.resolve()
with pytest.raises(ValueError, match="existence interval"):
graph["A"].size_at(10)
with pytest.raises(ValueError, match="existence interval"):
graph["B"].size_at(200)

def test_unknown_size_function(self):
b = demes.Builder()
b.add_deme(
"A",
epochs=[dict(start_size=1, end_time=100), dict(start_size=1, end_size=50)],
)
graph = b.resolve()
graph["A"].epochs[-1].size_function = "foo"
with pytest.raises(NotImplementedError, match="size_function"):
graph["A"].size_at(10)


class TestGraph(unittest.TestCase):
def test_bad_generation_time(self):
for generation_time in ([], {}, "42", "inf", math.nan):
Expand Down

0 comments on commit a9db443

Please sign in to comment.