Skip to content

Commit

Permalink
Merge pull request #362 from grahamgower/pulse-misc
Browse files Browse the repository at this point in the history
fix some pulse edge cases
  • Loading branch information
grahamgower authored Jul 26, 2021
2 parents 0c24b41 + c462416 commit 586f073
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 46 deletions.
31 changes: 10 additions & 21 deletions demes/demes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ class Graph:
# because we're using slotted classes and can't add attributes after
# object creation (e.g. in __attrs_post_init__()).
_deme_map: Dict[Name, Deme] = attr.ib(
factory=dict, init=False, repr=False, cmp=False
factory=dict, init=False, repr=False, eq=False, order=False
)

def __attrs_post_init__(self):
Expand Down Expand Up @@ -1289,7 +1289,6 @@ def assert_close(
- The graphs' ``description`` and ``doi`` attributes.
- The order in which ``migrations`` were specified.
- The order in which admixture ``pulses`` were specified.
- The order in which ``demes`` were specified.
- The order in which a deme's ``ancestors`` were specified.
Expand Down Expand Up @@ -1329,13 +1328,12 @@ def assert_sorted_eq(aa, bb, *, rel_tol, abs_tol, name):
abs_tol=abs_tol,
name="migrations",
)
assert_sorted_eq(
self.pulses,
other.pulses,
rel_tol=rel_tol,
abs_tol=abs_tol,
name="pulses",
)
assert len(self.pulses) == len(other.pulses)
for i, (self_pulse, other_pulse) in enumerate(zip(self.pulses, other.pulses)):
try:
self_pulse.assert_close(other_pulse, rel_tol=rel_tol, abs_tol=abs_tol)
except AssertionError as e:
raise AssertionError(f"Failed for pulses (number {i})") from e

def isclose(
self,
Expand All @@ -1353,7 +1351,6 @@ def isclose(
- The graphs' ``description`` and ``doi`` attributes.
- The order in which ``migrations`` were specified.
- The order in which admixture ``pulses`` were specified.
- The order in which ``demes`` were specified.
- The order in which a deme's ``ancestors`` were specified.
Expand Down Expand Up @@ -1617,17 +1614,6 @@ def _add_pulse(self, *, source, dest, proportion, time) -> Pulse:
"the desired ancestry proportions."
)

# Check for multiple pulses into dest at the same time that
# give a sum of proportions > 1.
proportion_sum = proportion
for pulse in self.pulses:
if dest == pulse.dest and pulse.time == time:
proportion_sum += pulse.proportion
if proportion_sum > 1:
raise ValueError(
f"sum of pulse proportions > 1 for dest={dest} at time={time}"
)

self.pulses.append(new_pulse)
return new_pulse

Expand Down Expand Up @@ -2093,6 +2079,9 @@ def fromdict(cls, data: MutableMapping[str, Any]) -> "Graph":
except (TypeError, ValueError) as e:
raise e.__class__(f"pulse[{i}]: invalid pulse") from e

# Sort pulses from oldest to youngest.
graph.pulses.sort(key=lambda pulse: pulse.time, reverse=True)

return graph

def asdict(self) -> MutableMapping[str, Any]:
Expand Down
113 changes: 89 additions & 24 deletions tests/test_demes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,6 +2041,24 @@ def test_isclose(self):
g2 = b2.resolve()
assert not g1.isclose(g2)

@pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time")
def test_isclose_pulse_ordering(self):
b1 = Builder(defaults=dict(epoch=dict(start_size=1)))
b1.add_deme("a")
b1.add_deme("b")
b1.add_deme("c")

# Order of pulses matters for simultaneous pulses.
b2 = copy.deepcopy(b1)
b2.add_pulse(source="a", dest="b", time=100, proportion=0.1)
b2.add_pulse(source="b", dest="c", time=100, proportion=0.1)
g2 = b2.resolve()
b3 = copy.deepcopy(b1)
b3.add_pulse(source="b", dest="c", time=100, proportion=0.1)
b3.add_pulse(source="a", dest="b", time=100, proportion=0.1)
g3 = b3.resolve()
assert not g2.isclose(g3)

def test_successors_predecessors(self):
# single population
b = Builder()
Expand Down Expand Up @@ -3002,7 +3020,7 @@ def test_bad_pulse_time(self):
with pytest.raises(ValueError):
b2.resolve()

def test_pulse_same_time(self):
def test_simultaneous_pulses_warning(self):
b1 = Builder()
for j in range(4):
b1.add_deme(f"d{j}", epochs=[dict(start_size=1000)])
Expand All @@ -3013,65 +3031,112 @@ def test_pulse_same_time(self):
b2 = copy.deepcopy(b1)
b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1)
b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1)
with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="Multiple pulses.*same.*time"):
b2.resolve()

# Warn for: d0 -> d1; d1 -> d2.
b2 = copy.deepcopy(b1)
b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1)
b2.add_pulse(source="d1", dest="d2", time=T, proportion=0.1)
with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="Multiple pulses.*same.*time"):
b2.resolve()

# Warn for: d0 -> d2; d1 -> d2.
b2 = copy.deepcopy(b1)
b2.add_pulse(source="d0", dest="d2", time=T, proportion=0.1)
b2.add_pulse(source="d1", dest="d2", time=T, proportion=0.1)
with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="Multiple pulses.*same.*time"):
b2.resolve()

@pytest.mark.filterwarnings("error:Multiple pulses.*same.*time")
def test_unrelated_pulses_no_warning(self):
b1 = Builder()
for j in range(4):
b1.add_deme(f"d{j}", epochs=[dict(start_size=1000)])

T = 100 # time of pulses

# Shouldn't warn for: d0 -> d1; d0 -> d2.
b2 = copy.deepcopy(b1)
b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1)
b2.add_pulse(source="d0", dest="d2", time=T, proportion=0.1)
with pytest.warns(None) as record:
b2.resolve()
assert len(record) == 0
b2.resolve()

# Shouldn't warn for: d0 -> d1; d2 -> d3.
b2 = copy.deepcopy(b1)
b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1)
b2.add_pulse(source="d2", dest="d3", time=T, proportion=0.1)
with pytest.warns(None) as record:
b2.resolve()
assert len(record) == 0
b2.resolve()

# Different pulse times shouldn't warn for: d0 -> d1; d1 -> d2.
b2 = copy.deepcopy(b1)
b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1)
b2.add_pulse(source="d1", dest="d2", time=2 * T, proportion=0.1)
with pytest.warns(None) as record:
b2.resolve()
assert len(record) == 0
b2.resolve()

# Different pulse times shouldn't warn for: d0 -> d2; d1 -> d2.
b2 = copy.deepcopy(b1)
b2.add_pulse(source="d0", dest="d2", time=T, proportion=0.1)
b2.add_pulse(source="d1", dest="d2", time=2 * T, proportion=0.1)
with pytest.warns(None) as record:
b2.resolve()
assert len(record) == 0
b2.resolve()

def test_pulse_proportions_sum(self):
@pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time")
def test_pulse_proportions_sum_greater_than_one(self):
b = Builder(defaults=dict(epoch=dict(start_size=1)))
b.add_deme("a")
b.add_deme("b")
b.add_deme("c")
b.add_pulse(source="b", dest="a", time=100, proportion=0.6)
b.add_pulse(source="c", dest="a", time=100, proportion=0.6)
with pytest.warns(UserWarning):
with pytest.raises(ValueError):
b.resolve()
b.resolve()

@pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time")
def test_pulse_order(self):
b = Builder(defaults=dict(epoch=dict(start_size=1)))
b.add_deme("a")
b.add_deme("b")
b.add_deme("c")

# Pulses defined in oldest-to-youngest order have order maintained.
b.data["pulses"] = []
b.add_pulse(source="a", dest="b", time=200, proportion=0.1)
b.add_pulse(source="b", dest="c", time=100, proportion=0.1)
g = b.resolve()
assert g.pulses[0].source == "a"
assert g.pulses[0].dest == "b"
assert g.pulses[1].source == "b"
assert g.pulses[1].dest == "c"

# Pulses defined out of order will be sorted oldest-to-youngest.
b.data["pulses"] = []
b.add_pulse(source="b", dest="c", time=100, proportion=0.1)
b.add_pulse(source="a", dest="b", time=200, proportion=0.1)
g = b.resolve()
assert g.pulses[0].source == "a"
assert g.pulses[0].dest == "b"
assert g.pulses[1].source == "b"
assert g.pulses[1].dest == "c"

# Simultaneous pulses should be ordered as they were defined.
b.data["pulses"] = []
b.add_pulse(source="a", dest="b", time=100, proportion=0.1)
b.add_pulse(source="b", dest="c", time=100, proportion=0.1)
g = b.resolve()
assert g.pulses[0].source == "a"
assert g.pulses[0].dest == "b"
assert g.pulses[1].source == "b"
assert g.pulses[1].dest == "c"

# Reverse the order of simultaneous pulses, to check it's no accident.
# The order should still match the definitions.
b.data["pulses"] = []
b.add_pulse(source="b", dest="c", time=100, proportion=0.1)
b.add_pulse(source="a", dest="b", time=100, proportion=0.1)
g = b.resolve()
assert g.pulses[0].source == "b"
assert g.pulses[0].dest == "c"
assert g.pulses[1].source == "a"
assert g.pulses[1].dest == "b"

def test_toplevel_defaults_deme(self):
# description
Expand Down Expand Up @@ -3247,14 +3312,15 @@ def test_toplevel_defaults_migration(self):
assert g.migrations[8].source == "x"
assert g.migrations[8].dest == "y"

@pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time")
def test_toplevel_defaults_pulse(self):
# source
b = Builder(defaults=dict(pulse=dict(source="a")))
for name in "abcd":
b.add_deme(name, epochs=[dict(start_size=1)])
for name in "bcd":
b.add_pulse(dest=name, proportion=0.1, time=100)
b.add_pulse(source="d", dest="a", proportion=0.2, time=200)
b.add_pulse(source="d", dest="a", proportion=0.2, time=50)
g = b.resolve()
assert g.pulses[0].source == g.pulses[1].source == g.pulses[2].source == "a"
assert g.pulses[3].source == "d"
Expand All @@ -3265,9 +3331,8 @@ def test_toplevel_defaults_pulse(self):
b.add_deme(name, epochs=[dict(start_size=1)])
for name in "bcd":
b.add_pulse(source=name, proportion=0.1, time=100)
b.add_pulse(dest="d", source="a", proportion=0.2, time=200)
with pytest.warns(UserWarning):
g = b.resolve()
b.add_pulse(dest="d", source="a", proportion=0.2, time=50)
g = b.resolve()
assert g.pulses[0].dest == g.pulses[1].dest == g.pulses[2].dest == "a"
assert g.pulses[3].dest == "d"

Expand Down
1 change: 0 additions & 1 deletion tests/test_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def N_ref_macs(*, theta, mu):
class TestFromMs:
def test_ignored_options_have_no_effect(self):
def check(command, N0=1):
# with pytest.warns(UserWarning, match="Ignoring unknown args"):
graph = demes.from_ms(command, N0=N0)
assert len(graph.pulses) == 0
assert len(graph.migrations) == 0
Expand Down

0 comments on commit 586f073

Please sign in to comment.