From 6b49ef9bea2acfdf10758930abe2b3ffe9e612f0 Mon Sep 17 00:00:00 2001 From: Aaron Ragsdale Date: Thu, 8 Jul 2021 14:45:49 -0500 Subject: [PATCH] Allow simultaneous pulses using multiple sources --- .mergify.yml | 2 +- demes/demes.py | 136 +++++++++++------ demes/hypothesis_strategies.py | 4 +- demes/ms.py | 14 +- examples/jacobs_papuans.yml | 12 +- examples/offshoots.yml | 4 +- tests/test_demes.py | 265 +++++++++++++++++++++------------ tests/test_load_dump.py | 24 +-- tests/test_ms.py | 52 +++---- 9 files changed, 314 insertions(+), 199 deletions(-) diff --git a/.mergify.yml b/.mergify.yml index e7fba0e3..33633776 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -51,4 +51,4 @@ pull_request_rules: - check-success=tests (windows-2019, 3.9) actions: queue: - name: default \ No newline at end of file + name: default diff --git a/demes/demes.py b/demes/demes.py index b5368c9e..1e1a85e5 100644 --- a/demes/demes.py +++ b/demes/demes.py @@ -409,26 +409,47 @@ class Pulse: Parameters for a pulse of migration from one deme to another. Source and destination demes follow the forwards-in-time convention, of migrations born in the source deme having children in the dest - deme. + deme. If more than one source is given, migration is concurrent, so that + the sum of the migrant proportions must sum to less than one. - :ivar str source: The source deme. + :ivar list(str) sources: The source deme(s). :ivar str dest: The destination deme. :ivar float ~.time: The time of migration. - :ivar float proportion: At the instant after migration, this is the - proportion of individuals in the destination deme made up of - individuals from the source deme. + :ivar list(float) proportions: Immediately following migration, the proportion(s) + of individuals in the destination deme made up of migrant individuals or + having parents from the source deme(s). """ - source: Name = attr.ib( - validator=[attr.validators.instance_of(str), valid_deme_name] + sources: List[Name] = attr.ib( + validator=attr.validators.and_( + attr.validators.deep_iterable( + member_validator=attr.validators.and_( + attr.validators.instance_of(str), valid_deme_name + ), + iterable_validator=attr.validators.instance_of(list), + ), + nonzero_len, + ) ) dest: Name = attr.ib(validator=[attr.validators.instance_of(str), valid_deme_name]) time: Time = attr.ib(validator=[int_or_float, positive, finite]) - proportion: Proportion = attr.ib(validator=[int_or_float, unit_interval]) + proportions: List[Proportion] = attr.ib( + validator=attr.validators.deep_iterable( + member_validator=attr.validators.and_(int_or_float, unit_interval), + iterable_validator=attr.validators.instance_of(list), + ) + ) def __attrs_post_init__(self): - if self.source == self.dest: - raise ValueError("source and dest cannot be the same deme") + for source in self.sources: + if source == self.dest: + raise ValueError(f"source ({source}) cannot be the same as dest") + if self.sources.count(source) != 1: + raise ValueError(f"source ({source}) cannot be repeated in sources") + if len(self.sources) != len(self.proportions): + raise ValueError("sources and proportions must have the same length") + if sum(self.proportions) > 1: + raise ValueError("proportions must sum to less than one") def assert_close( self, @@ -455,14 +476,28 @@ def assert_close( assert ( self.__class__ is other.__class__ ), f"Failed as other pulse is not instance of {self.__class__} type." - assert self.source == other.source + assert len(self.sources) == len(other.sources) + for s in self.sources: + assert s in other.sources + for s in other.sources: + assert s in self.sources assert self.dest == other.dest assert math.isclose( self.time, other.time, rel_tol=rel_tol, abs_tol=abs_tol ), f"Failed for time {self.time} != {other.time} (other)." + assert len(self.proportions) == len(other.proportions) assert math.isclose( - self.proportion, other.proportion, rel_tol=rel_tol, abs_tol=abs_tol - ), f"Failed for proportion {self.proportion} != {other.proportion} (other)." + sum(self.proportions), + sum(other.proportions), + rel_tol=rel_tol, + abs_tol=abs_tol, + ), ( + f"Failed for unequal proportions sums: " + f"sum({self.proportions}) != sum({other.proportions}) (other)." + ) + assert isclose_deme_proportions( + self.sources, self.proportions, other.sources, other.proportions + ) def isclose( self, @@ -1556,40 +1591,43 @@ def _add_asymmetric_migration( self.migrations.append(migration) return migration - def _add_pulse(self, *, source, dest, proportion, time) -> Pulse: + def _add_pulse(self, *, sources, dest, proportions, time) -> Pulse: """ Add a pulse of migration at a fixed time. Source and destination demes follow the forwards-in-time convention. - :param str source: The name of the source deme. + :param list(str) sources: The name(s) of the source deme(s). :param str dest: The name of the destination deme. - :param float proportion: At the instant after migration, this is the expected - proportion of individuals in the destination deme made up of individuals - from the source deme. + :param list(float) proportion(s): Immediately following migration, the + proportion(s) of individuals in the destination deme made up of + migrant individuals or having parents from the source deme(s). :param float time: The time at which migrations occur. :return: Newly created pulse. :rtype: :class:`.Pulse` """ - for deme_name in (source, dest): + for deme_name in sources + [dest]: if deme_name not in self: raise ValueError(f"{deme_name} not in graph") - self._check_time_intersection(source, dest, time) + for source in sources: + self._check_time_intersection(source, dest, time) if time == self[dest].end_time: raise ValueError( f"invalid pulse at time={time}, which is dest={dest}'s end_time" ) - if time == self[source].start_time: - raise ValueError( - f"invalid pulse at time={time}, which is source={source}'s start_time" - ) + for source in sources: + if time == self[source].start_time: + raise ValueError( + f"invalid pulse at time={time}, " + f"which is source={source}'s start_time" + ) # We create the new pulse object (which checks for common errors) # before checking for edge cases below. new_pulse = Pulse( - source=source, + sources=sources, dest=dest, time=time, - proportion=proportion, + proportions=proportions, ) # Check for models that have multiple pulses defined at the same time. @@ -1598,20 +1636,23 @@ def _add_pulse(self, *, source, dest, proportion, time) -> Pulse: # interpretation of the model. Such models are valid, but the behaviour # may not be what the user expects. # See https://github.com/popsim-consortium/demes-python/issues/46 - sources = set() - dests = set() + all_sources = set() + all_dests = set() for pulse in self.pulses: if pulse.time == time: - sources.add(pulse.source) - dests.add(pulse.dest) - if source in dests or dest in (sources | dests): + all_sources.update(pulse.sources) + all_dests.add(pulse.dest) + if any(source in all_dests for source in sources) or dest in ( + all_sources | all_dests + ): warnings.warn( "Multiple pulses are defined for the same deme(s) at time " f"{time}. The ancestry proportions after this time will thus " "depend on the order in which the pulses have been specified. " "To avoid unexpected behaviour, the graph can instead " "be structured to introduce a new deme at this time with " - "the desired ancestry proportions." + "the desired ancestry proportions or to specify concurrent " + "pulses with multiple sources." ) self.pulses.append(new_pulse) @@ -1916,7 +1957,7 @@ def fromdict(cls, data: MutableMapping[str, Any]) -> "Graph": ) pulse_defaults = pop_object(defaults, "pulse", {}, scope="defaults") - allowed_fields_pulse = ["source", "dest", "time", "proportion"] + allowed_fields_pulse = ["sources", "dest", "time", "proportions"] check_allowed(pulse_defaults, allowed_fields_pulse, "defaults.pulse") # epoch defaults may also be specified within a Deme definition. @@ -2066,15 +2107,15 @@ def fromdict(cls, data: MutableMapping[str, Any]) -> "Graph": ): check_allowed(pulse_data, allowed_fields_pulse, f"pulse[{i}]") insert_defaults(pulse_data, pulse_defaults) - for field in ("source", "dest", "time", "proportion"): + for field in ("sources", "dest", "time", "proportions"): if field not in pulse_data: raise KeyError(f"pulse[{i}]: required field '{field}' not found") try: graph._add_pulse( - source=pulse_data.pop("source"), + sources=pulse_data.pop("sources"), dest=pulse_data.pop("dest"), time=pulse_data.pop("time"), - proportion=pulse_data.pop("proportion"), + proportions=pulse_data.pop("proportions"), ) except (TypeError, ValueError) as e: raise e.__class__(f"pulse[{i}]: invalid pulse") from e @@ -2397,29 +2438,29 @@ def add_migration( def add_pulse( self, *, - source: str = None, + sources: List[str] = None, dest: str = None, - proportion: float = None, + proportions: List[float] = None, time: float = None, ): """ Add a pulse of migration at a fixed time. Source and destination demes follow the forwards-in-time convention. - :param str source: The name of the source deme. + :param list(str) source: The name of the source deme(s). :param str dest: The name of the destination deme. - :param float proportion: At the instant after migration, this is the - expected proportion of individuals in the destination deme made up - of individuals from the source deme. + :param list(float) proportion: At the instant after migration, this is the + expected proportion(s) of individuals in the destination deme made up + of individuals from the source deme(s). :param float time: The time at which migrations occur. """ pulse: MutableMapping[str, Any] = dict() - if source is not None: - pulse["source"] = source + if sources is not None: + pulse["sources"] = sources if dest is not None: pulse["dest"] = dest - if proportion is not None: - pulse["proportion"] = proportion + if proportions is not None: + pulse["proportions"] = proportions if time is not None: pulse["time"] = time @@ -2536,7 +2577,8 @@ def _remove_transient_demes(self) -> None: continue if start_time == end_time: for pulse in self.data.get("pulses", []): - assert pulse["source"] != deme["name"] + for s in pulse["sources"]: + assert s != deme["name"] assert pulse["dest"] != deme["name"] for migration in self.data.get("migrations", []): assert deme["name"] not in migration.get("demes", []) diff --git a/demes/hypothesis_strategies.py b/demes/hypothesis_strategies.py index 718d3e42..d03853a0 100644 --- a/demes/hypothesis_strategies.py +++ b/demes/hypothesis_strategies.py @@ -315,10 +315,10 @@ def pulses_lists(draw, graph, max_pulses=10): ) ingress_proportions[(dest, time)] += proportion pulse = dict( - source=source, + sources=[source], dest=dest, time=time, - proportion=proportion, + proportions=[proportion], ) pulses.append(pulse) n_pulses -= 1 diff --git a/demes/ms.py b/demes/ms.py index 727fa73c..0fa60000 100644 --- a/demes/ms.py +++ b/demes/ms.py @@ -778,10 +778,10 @@ def migration_matrix_at(time): # The order of pulses will later be reversed such that realised # ancestry proportions are maintained forwards in time. b.add_pulse( - source=f"deme{k + 1}", + sources=[f"deme{k + 1}"], dest=f"deme{j + 1}", time=time, - proportion=p, + proportions=[p], ) # Resolve/remove growth_rate in oldest epochs. @@ -830,7 +830,7 @@ def remap_deme_names(graph: demes.Graph, names: Mapping[str, str]) -> demes.Grap migration.source = names[migration.source] migration.dest = names[migration.dest] for pulse in graph.pulses: - pulse.source = names[pulse.source] + pulse.sources = [names[s] for s in pulse.sources] pulse.dest = names[pulse.dest] for k, deme in list(graph._deme_map.items()): del graph._deme_map[k] @@ -989,8 +989,12 @@ def get_growth_rate(epoch): pulse = deme_or_pulse num_demes += 1 new_deme_id = num_demes - e1 = Split(pulse.time, deme_id[pulse.dest], 1 - pulse.proportion) - e2 = Join(pulse.time, new_deme_id, deme_id[pulse.source]) + if len(pulse.sources) > 1: + raise ValueError( + "Currently pulses will only a single source are supported" + ) + e1 = Split(pulse.time, deme_id[pulse.dest], 1 - pulse.proportions[0]) + e2 = Join(pulse.time, new_deme_id, deme_id[pulse.sources[0]]) events.extend([e1, e2]) # Turn migrations off at the start_time. We schedule all start_time diff --git a/examples/jacobs_papuans.yml b/examples/jacobs_papuans.yml index c2d76363..512945f9 100644 --- a/examples/jacobs_papuans.yml +++ b/examples/jacobs_papuans.yml @@ -79,9 +79,9 @@ migrations: - {demes: [CEU, Ghost], rate: 0.000442} pulses: -- {source: Nea1, dest: Ghost, time: 1853.0, proportion: 0.024} -- {source: Den2, dest: Papuan, time: 1575.8620689655172, proportion: 0.018} -- {source: Nea1, dest: CHB, time: 1566.0, proportion: 0.011} -- {source: Nea1, dest: Papuan, time: 1412.0, proportion: 0.002} -- {source: Den1, dest: Papuan, time: 1027.5862068965516, proportion: 0.022} -- {source: Nea1, dest: CHB, time: 883.0, proportion: 0.002} +- {sources: [Nea1], dest: Ghost, time: 1853.0, proportions: [0.024]} +- {sources: [Den2], dest: Papuan, time: 1575.8620689655172, proportions: [0.018]} +- {sources: [Nea1], dest: CHB, time: 1566.0, proportions: [0.011]} +- {sources: [Nea1], dest: Papuan, time: 1412.0, proportions: [0.002]} +- {sources: [Den1], dest: Papuan, time: 1027.5862068965516, proportions: [0.022]} +- {sources: [Nea1], dest: CHB, time: 883.0, proportions: [0.002]} diff --git a/examples/offshoots.yml b/examples/offshoots.yml index d3be5341..35f07887 100644 --- a/examples/offshoots.yml +++ b/examples/offshoots.yml @@ -33,7 +33,7 @@ migrations: - demes: [offshoot1, offshoot2] rate: 2e-5 pulses: - - source: offshoot1 + - sources: [offshoot1] dest: ancestral - proportion: 0.1 + proportions: [0.1] time: 50 diff --git a/tests/test_demes.py b/tests/test_demes.py index 189bdb48..cef65b04 100644 --- a/tests/test_demes.py +++ b/tests/test_demes.py @@ -557,51 +557,90 @@ class TestPulse: def test_bad_time(self): for time in ("inf", "100", {}, [], math.nan): with pytest.raises(TypeError): - Pulse(source="a", dest="b", time=time, proportion=0.1) + Pulse(sources=["a"], dest="b", time=time, proportions=[0.1]) for time in (-10000, -1, -1e-9, 0, math.inf): with pytest.raises(ValueError): - Pulse(source="a", dest="b", time=time, proportion=0.1) + Pulse(sources=["a"], dest="b", time=time, proportions=[0.1]) - def test_bad_proportion(self): - for proportion in ("inf", "100", {}, [], math.nan): + def test_bad_proportions(self): + for proportion in ("inf", "100", {}, math.nan): with pytest.raises(TypeError): - Pulse(source="a", dest="b", time=1, proportion=proportion) + Pulse(sources=["a"], dest="b", time=1, proportions=[proportion]) for proportion in (-10000, -1, -1e-9, 1.2, 100, math.inf): with pytest.raises(ValueError): - Pulse(source="a", dest="b", time=1, proportion=proportion) + Pulse(sources=["a"], dest="b", time=1, proportions=[proportion]) def test_bad_demes(self): for name in (None, 0, math.inf, 1e3, {}, []): with pytest.raises(TypeError): - Pulse(source=name, dest="a", time=1, proportion=0.1) + Pulse(sources=[name], dest="a", time=1, proportions=[0.1]) with pytest.raises(TypeError): - Pulse(source="a", dest=name, time=1, proportion=0.1) + Pulse(sources=["a"], dest=name, time=1, proportions=[0.1]) for name in ("a", "", "pop 1"): with pytest.raises(ValueError): - Pulse(source=name, dest="a", time=1, proportion=0.1) + Pulse(sources=[name], dest="a", time=1, proportions=[0.1]) with pytest.raises(ValueError): - Pulse(source="a", dest=name, time=1, proportion=0.1) + Pulse(sources=["a"], dest=name, time=1, proportions=[0.1]) + + def test_bad_multipopulation_pulse(self): + with pytest.raises(ValueError): + Pulse(sources=["a"], dest="b", time=1, proportions=[0.4, 0.5]) + with pytest.raises(ValueError): + Pulse(sources=["a", "b"], dest="c", time=1, proportions=[0.4]) + with pytest.raises(ValueError): + Pulse(sources=["a", "b"], dest="c", time=1, proportions=[0.6, 0.7]) + with pytest.raises(ValueError): + Pulse( + sources=["a", "b", "c"], dest="b", time=1, proportions=[0.5, 0.1, 0.41] + ) def test_valid_pulse(self): - Pulse(source="a", dest="b", time=1, proportion=1e-9) - Pulse(source="a", dest="b", time=100, proportion=0.9) + Pulse(sources=["a"], dest="b", time=1, proportions=[1e-9]) + Pulse(sources=["a"], dest="b", time=100, proportions=[0.9]) + Pulse(sources=["a", "b", "c"], dest="d", time=1, proportions=[0.1, 0.2, 0.7]) def test_isclose(self): eps = 1e-50 - p1 = Pulse(source="a", dest="b", time=1, proportion=1e-9) + p1 = Pulse(sources=["a"], dest="b", time=1, proportions=[1e-9]) assert p1.isclose(p1) - assert p1.isclose(Pulse(source="a", dest="b", time=1, proportion=1e-9)) - assert p1.isclose(Pulse(source="a", dest="b", time=1 + eps, proportion=1e-9)) - assert p1.isclose(Pulse(source="a", dest="b", time=1, proportion=1e-9 + eps)) + assert p1.isclose(Pulse(sources=["a"], dest="b", time=1, proportions=[1e-9])) + assert p1.isclose( + Pulse(sources=["a"], dest="b", time=1 + eps, proportions=[1e-9]) + ) + assert p1.isclose( + Pulse(sources=["a"], dest="b", time=1, proportions=[1e-9 + eps]) + ) - assert not p1.isclose(Pulse(source="a", dest="c", time=1, proportion=1e-9)) - assert not p1.isclose(Pulse(source="b", dest="a", time=1, proportion=1e-9)) - assert not p1.isclose(Pulse(source="a", dest="b", time=1, proportion=2e-9)) assert not p1.isclose( - Pulse(source="a", dest="b", time=1 + 1e-9, proportion=1e-9) + Pulse(sources=["a"], dest="c", time=1, proportions=[1e-9]) + ) + assert not p1.isclose( + Pulse(sources=["b"], dest="a", time=1, proportions=[1e-9]) + ) + assert not p1.isclose( + Pulse(sources=["a"], dest="b", time=1, proportions=[2e-9]) + ) + assert not p1.isclose( + Pulse(sources=["a"], dest="b", time=1 + 1e-9, proportions=[1e-9]) + ) + + multipulse = Pulse(sources=["a", "b"], dest="c", time=1, proportions=[0.1, 0.2]) + assert multipulse.isclose( + Pulse(sources=["b", "a"], dest="c", time=1, proportions=[0.2, 0.1]) + ) + assert multipulse.isclose( + Pulse( + sources=["a", "b"], dest="c", time=1, proportions=[0.1 + eps, 0.2 + eps] + ) + ) + assert not multipulse.isclose( + Pulse(sources=["a"], dest="c", time=1, proportions=[0.1]) + ) + assert not multipulse.isclose( + Pulse(sources=["a", "b"], dest="c", time=1, proportions=[0.2, 0.2]) ) @@ -1904,10 +1943,10 @@ def test_isclose(self): g3.assert_close(g4) # The order in which pulses are added shouldn't matter. - b3.add_pulse(source="d1", dest="d2", proportion=0.01, time=100) - b3.add_pulse(source="d1", dest="d2", proportion=0.01, time=50) - b4.add_pulse(source="d1", dest="d2", proportion=0.01, time=50) - b4.add_pulse(source="d1", dest="d2", proportion=0.01, time=100) + b3.add_pulse(sources=["d1"], dest="d2", proportions=[0.01], time=100) + b3.add_pulse(sources=["d1"], dest="d2", proportions=[0.01], time=50) + b4.add_pulse(sources=["d1"], dest="d2", proportions=[0.01], time=50) + b4.add_pulse(sources=["d1"], dest="d2", proportions=[0.01], time=100) g3 = b3.resolve() g4 = b4.resolve() g3.assert_close(g4) @@ -2011,7 +2050,7 @@ def test_isclose(self): b4 = copy.deepcopy(b2) b4.add_deme("d1", epochs=[dict(start_size=1000, end_time=0)]) b4.add_deme("d2", epochs=[dict(start_size=1000, end_time=0)]) - b4.add_pulse(source="d1", dest="d2", proportion=0.01, time=100) + b4.add_pulse(sources=["d1"], dest="d2", proportions=[0.01], time=100) g3 = b3.resolve() g4 = b4.resolve() assert not g3.isclose(g4) @@ -2019,11 +2058,11 @@ def test_isclose(self): b3 = copy.deepcopy(b2) b3.add_deme("d1", epochs=[dict(start_size=1000, end_time=0)]) b3.add_deme("d2", epochs=[dict(start_size=1000, end_time=0)]) - b3.add_pulse(source="d2", dest="d1", proportion=0.01, time=100) + b3.add_pulse(sources=["d2"], dest="d1", proportions=[0.01], time=100) b4 = copy.deepcopy(b2) b4.add_deme("d1", epochs=[dict(start_size=1000, end_time=0)]) b4.add_deme("d2", epochs=[dict(start_size=1000, end_time=0)]) - b4.add_pulse(source="d1", dest="d2", proportion=0.01, time=100) + b4.add_pulse(sources=["d1"], dest="d2", proportions=[0.01], time=100) g3 = b3.resolve() g4 = b4.resolve() assert not g3.isclose(g4) @@ -2050,12 +2089,12 @@ def test_isclose_pulse_ordering(self): # Order of pulses matters for simultaneous pulses. b2 = copy.deepcopy(b1) - b2.add_pulse(source="a", dest="b", time=100, proportion=0.1) - b2.add_pulse(source="b", dest="c", time=100, proportion=0.1) + b2.add_pulse(sources=["a"], dest="b", time=100, proportions=[0.1]) + b2.add_pulse(sources=["b"], dest="c", time=100, proportions=[0.1]) g2 = b2.resolve() b3 = copy.deepcopy(b1) - b3.add_pulse(source="b", dest="c", time=100, proportion=0.1) - b3.add_pulse(source="a", dest="b", time=100, proportion=0.1) + b3.add_pulse(sources=["b"], dest="c", time=100, proportions=[0.1]) + b3.add_pulse(sources=["a"], dest="b", time=100, proportions=[0.1]) g3 = b3.resolve() assert not g2.isclose(g3) @@ -2131,7 +2170,7 @@ def test_successors_predecessors(self): b = Builder(defaults=dict(epoch=dict(start_size=1))) b.add_deme("a") b.add_deme("b") - b.add_pulse(source="a", dest="b", proportion=0.1, time=100) + b.add_pulse(sources=["a"], dest="b", proportions=[0.1], time=100) g = b.resolve() assert g.successors() == {"a": [], "b": []} assert g.predecessors() == {"a": [], "b": []} @@ -2179,8 +2218,8 @@ def test_discrete_demographic_events(self): b = Builder() b.add_deme("a", epochs=[dict(start_size=1)]) b.add_deme("b", epochs=[dict(start_size=1)]) - b.add_pulse(source="a", dest="b", time=100, proportion=0.1) - b.add_pulse(source="a", dest="b", time=200, proportion=0.2) + b.add_pulse(sources=["a"], dest="b", time=100, proportions=[0.1]) + b.add_pulse(sources=["a"], dest="b", time=200, proportions=[0.2]) g = b.resolve() de = g.discrete_demographic_events() assert len(de) == 5 @@ -2188,8 +2227,8 @@ def test_discrete_demographic_events(self): assert len(de[event]) == 0 assert self.dfsorted(de["pulses"]) == self.dfsorted( [ - Pulse(source="a", dest="b", proportion=0.1, time=100), - Pulse(source="a", dest="b", proportion=0.2, time=200), + Pulse(sources=["a"], dest="b", proportions=[0.1], time=100), + Pulse(sources=["a"], dest="b", proportions=[0.2], time=200), ] ) @@ -2365,7 +2404,7 @@ def test_basic_resolution(self): b = Builder() b.add_deme("a", epochs=[dict(start_size=1)]) b.add_deme("b", epochs=[dict(start_size=1)]) - b.add_pulse(source="a", dest="b", proportion=0.1, time=100) + b.add_pulse(sources=["a"], dest="b", proportions=[0.1], time=100) b.resolve() def test_bad_data_dict(self): @@ -2964,6 +3003,15 @@ def test_bad_pulses(self): with pytest.raises(TypeError): b.resolve() + # pulses have repeated sources + b = Builder(defaults=dict(epoch=dict(start_size=1))) + b.add_deme("a") + b.add_deme("b") + b.add_deme("c") + b.add_pulse(sources=["a", "a"], dest="c", proportions=[0.1, 0.1], time=1) + with pytest.raises(ValueError): + b.resolve() + # pulse is not a dict for data in (None, [], "string", 0, 1e-5): b = Builder() @@ -2975,22 +3023,22 @@ def test_bad_pulses(self): # dest not in graph b = Builder() b.add_deme("a", epochs=[dict(start_size=100, end_time=0)]) - b.add_pulse(source="a", dest="b", proportion=0.1, time=10) + b.add_pulse(sources=["a"], dest="b", proportions=[0.1], time=10) with pytest.raises(ValueError): b.resolve() # source not in graph b = Builder() b.add_deme("a", epochs=[dict(start_size=100, end_time=0)]) - b.add_pulse(source="b", dest="a", proportion=0.1, time=10) + b.add_pulse(sources=["b"], dest="a", proportions=[0.1], time=10) with pytest.raises(ValueError): b.resolve() - for field in ("source", "dest", "time", "proportion"): + for field in ("sources", "dest", "time", "proportions"): b = Builder() b.add_deme("a", epochs=[dict(start_size=1)]) b.add_deme("b", epochs=[dict(start_size=1)]) - b.add_pulse(source="a", dest="b", proportion=0.1, time=100) + b.add_pulse(sources=["a"], dest="b", proportions=[0.1], time=100) del b.data["pulses"][0][field] with pytest.raises(KeyError): b.resolve() @@ -2999,7 +3047,7 @@ def test_bad_pulse_time(self): b = Builder() b.add_deme("deme1", epochs=[dict(start_size=1000, end_time=0)]) b.add_deme("deme2", epochs=[dict(end_time=100, start_size=1000)]) - b.add_pulse(source="deme1", dest="deme2", proportion=0.1, time=10) + b.add_pulse(sources=["deme1"], dest="deme2", proportions=[0.1], time=10) with pytest.raises(ValueError): b.resolve() @@ -3010,13 +3058,13 @@ def test_bad_pulse_time(self): # Can't have pulse at the dest deme's end_time. b2 = copy.deepcopy(b) - b2.add_pulse(source="A", dest="B", time=g["B"].end_time, proportion=0.1) + b2.add_pulse(sources=["A"], dest="B", time=g["B"].end_time, proportions=[0.1]) with pytest.raises(ValueError): b2.resolve() # Can't have pulse at the source deme's start_time. b2 = copy.deepcopy(b) - b2.add_pulse(source="B", dest="A", time=g["B"].start_time, proportion=0.1) + b2.add_pulse(sources=["B"], dest="A", time=g["B"].start_time, proportions=[0.1]) with pytest.raises(ValueError): b2.resolve() @@ -3029,22 +3077,22 @@ def test_simultaneous_pulses_warning(self): # Warn for duplicate pulses b2 = copy.deepcopy(b1) - b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1) - b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1) + b2.add_pulse(sources=["d0"], dest="d1", time=T, proportions=[0.1]) + b2.add_pulse(sources=["d0"], dest="d1", time=T, proportions=[0.1]) with pytest.warns(UserWarning, match="Multiple pulses.*same.*time"): b2.resolve() # Warn for: d0 -> d1; d1 -> d2. b2 = copy.deepcopy(b1) - b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1) - b2.add_pulse(source="d1", dest="d2", time=T, proportion=0.1) + b2.add_pulse(sources=["d0"], dest="d1", time=T, proportions=[0.1]) + b2.add_pulse(sources=["d1"], dest="d2", time=T, proportions=[0.1]) with pytest.warns(UserWarning, match="Multiple pulses.*same.*time"): b2.resolve() # Warn for: d0 -> d2; d1 -> d2. b2 = copy.deepcopy(b1) - b2.add_pulse(source="d0", dest="d2", time=T, proportion=0.1) - b2.add_pulse(source="d1", dest="d2", time=T, proportion=0.1) + b2.add_pulse(sources=["d0"], dest="d2", time=T, proportions=[0.1]) + b2.add_pulse(sources=["d1"], dest="d2", time=T, proportions=[0.1]) with pytest.warns(UserWarning, match="Multiple pulses.*same.*time"): b2.resolve() @@ -3058,27 +3106,35 @@ def test_unrelated_pulses_no_warning(self): # Shouldn't warn for: d0 -> d1; d0 -> d2. b2 = copy.deepcopy(b1) - b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1) - b2.add_pulse(source="d0", dest="d2", time=T, proportion=0.1) - b2.resolve() + b2.add_pulse(sources=["d0"], dest="d1", time=T, proportions=[0.1]) + b2.add_pulse(sources=["d0"], dest="d2", time=T, proportions=[0.1]) + with pytest.warns(None) as record: + b2.resolve() + assert len(record) == 0 # Shouldn't warn for: d0 -> d1; d2 -> d3. b2 = copy.deepcopy(b1) - b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1) - b2.add_pulse(source="d2", dest="d3", time=T, proportion=0.1) - b2.resolve() + b2.add_pulse(sources=["d0"], dest="d1", time=T, proportions=[0.1]) + b2.add_pulse(sources=["d2"], dest="d3", time=T, proportions=[0.1]) + with pytest.warns(None) as record: + b2.resolve() + assert len(record) == 0 # Different pulse times shouldn't warn for: d0 -> d1; d1 -> d2. b2 = copy.deepcopy(b1) - b2.add_pulse(source="d0", dest="d1", time=T, proportion=0.1) - b2.add_pulse(source="d1", dest="d2", time=2 * T, proportion=0.1) - b2.resolve() + b2.add_pulse(sources=["d0"], dest="d1", time=T, proportions=[0.1]) + b2.add_pulse(sources=["d1"], dest="d2", time=2 * T, proportions=[0.1]) + with pytest.warns(None) as record: + b2.resolve() + assert len(record) == 0 # Different pulse times shouldn't warn for: d0 -> d2; d1 -> d2. b2 = copy.deepcopy(b1) - b2.add_pulse(source="d0", dest="d2", time=T, proportion=0.1) - b2.add_pulse(source="d1", dest="d2", time=2 * T, proportion=0.1) - b2.resolve() + b2.add_pulse(sources=["d0"], dest="d2", time=T, proportions=[0.1]) + b2.add_pulse(sources=["d1"], dest="d2", time=2 * T, proportions=[0.1]) + with pytest.warns(None) as record: + b2.resolve() + assert len(record) == 0 @pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time") def test_pulse_proportions_sum_greater_than_one(self): @@ -3086,10 +3142,18 @@ def test_pulse_proportions_sum_greater_than_one(self): b.add_deme("a") b.add_deme("b") b.add_deme("c") - b.add_pulse(source="b", dest="a", time=100, proportion=0.6) - b.add_pulse(source="c", dest="a", time=100, proportion=0.6) + b.add_pulse(sources=["b"], dest="a", time=100, proportions=[0.6]) + b.add_pulse(sources=["c"], dest="a", time=100, proportions=[0.6]) b.resolve() + b = Builder(defaults=dict(epoch=dict(start_size=1))) + b.add_deme("a") + b.add_deme("b") + b.add_deme("c") + b.add_pulse(sources=["b", "c"], dest="a", time=100, proportions=[0.6, 0.6]) + with pytest.raises(ValueError): + b.resolve() + @pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time") def test_pulse_order(self): b = Builder(defaults=dict(epoch=dict(start_size=1))) @@ -3099,43 +3163,43 @@ def test_pulse_order(self): # Pulses defined in oldest-to-youngest order have order maintained. b.data["pulses"] = [] - b.add_pulse(source="a", dest="b", time=200, proportion=0.1) - b.add_pulse(source="b", dest="c", time=100, proportion=0.1) + b.add_pulse(sources=["a"], dest="b", time=200, proportions=[0.1]) + b.add_pulse(sources=["b"], dest="c", time=100, proportions=[0.1]) g = b.resolve() - assert g.pulses[0].source == "a" + assert g.pulses[0].sources[0] == "a" assert g.pulses[0].dest == "b" - assert g.pulses[1].source == "b" + assert g.pulses[1].sources[0] == "b" assert g.pulses[1].dest == "c" # Pulses defined out of order will be sorted oldest-to-youngest. b.data["pulses"] = [] - b.add_pulse(source="b", dest="c", time=100, proportion=0.1) - b.add_pulse(source="a", dest="b", time=200, proportion=0.1) + b.add_pulse(sources=["b"], dest="c", time=100, proportions=[0.1]) + b.add_pulse(sources=["a"], dest="b", time=200, proportions=[0.1]) g = b.resolve() - assert g.pulses[0].source == "a" + assert g.pulses[0].sources[0] == "a" assert g.pulses[0].dest == "b" - assert g.pulses[1].source == "b" + assert g.pulses[1].sources[0] == "b" assert g.pulses[1].dest == "c" # Simultaneous pulses should be ordered as they were defined. b.data["pulses"] = [] - b.add_pulse(source="a", dest="b", time=100, proportion=0.1) - b.add_pulse(source="b", dest="c", time=100, proportion=0.1) + b.add_pulse(sources=["a"], dest="b", time=100, proportions=[0.1]) + b.add_pulse(sources=["b"], dest="c", time=100, proportions=[0.1]) g = b.resolve() - assert g.pulses[0].source == "a" + assert g.pulses[0].sources[0] == "a" assert g.pulses[0].dest == "b" - assert g.pulses[1].source == "b" + assert g.pulses[1].sources[0] == "b" assert g.pulses[1].dest == "c" # Reverse the order of simultaneous pulses, to check it's no accident. # The order should still match the definitions. b.data["pulses"] = [] - b.add_pulse(source="b", dest="c", time=100, proportion=0.1) - b.add_pulse(source="a", dest="b", time=100, proportion=0.1) + b.add_pulse(sources=["b"], dest="c", time=100, proportions=[0.1]) + b.add_pulse(sources=["a"], dest="b", time=100, proportions=[0.1]) g = b.resolve() - assert g.pulses[0].source == "b" + assert g.pulses[0].sources[0] == "b" assert g.pulses[0].dest == "c" - assert g.pulses[1].source == "a" + assert g.pulses[1].sources[0] == "a" assert g.pulses[1].dest == "b" def test_toplevel_defaults_deme(self): @@ -3314,24 +3378,29 @@ def test_toplevel_defaults_migration(self): @pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time") def test_toplevel_defaults_pulse(self): - # source - b = Builder(defaults=dict(pulse=dict(source="a"))) + # sources + b = Builder(defaults=dict(pulse=dict(sources=["a"]))) for name in "abcd": b.add_deme(name, epochs=[dict(start_size=1)]) for name in "bcd": - b.add_pulse(dest=name, proportion=0.1, time=100) - b.add_pulse(source="d", dest="a", proportion=0.2, time=50) + b.add_pulse(dest=name, proportions=[0.1], time=100) + b.add_pulse(sources=["d"], dest="a", proportions=[0.2], time=50) g = b.resolve() - assert g.pulses[0].source == g.pulses[1].source == g.pulses[2].source == "a" - assert g.pulses[3].source == "d" + assert ( + g.pulses[0].sources[0] + == g.pulses[1].sources[0] + == g.pulses[2].sources[0] + == "a" + ) + assert g.pulses[3].sources[0] == "d" # dest b = Builder(defaults=dict(pulse=dict(dest="a"))) for name in "abcd": b.add_deme(name, epochs=[dict(start_size=1)]) for name in "bcd": - b.add_pulse(source=name, proportion=0.1, time=100) - b.add_pulse(dest="d", source="a", proportion=0.2, time=50) + b.add_pulse(sources=[name], proportions=[0.1], time=100) + b.add_pulse(dest="d", sources=["a"], proportions=[0.2], time=50) g = b.resolve() assert g.pulses[0].dest == g.pulses[1].dest == g.pulses[2].dest == "a" assert g.pulses[3].dest == "d" @@ -3341,27 +3410,27 @@ def test_toplevel_defaults_pulse(self): for name in "abcd": b.add_deme(name, epochs=[dict(start_size=1)]) for name in "bcd": - b.add_pulse(source="a", dest=name, proportion=0.1) - b.add_pulse(source="d", dest="a", proportion=0.2, time=50) + b.add_pulse(sources=["a"], dest=name, proportions=[0.1]) + b.add_pulse(sources=["d"], dest="a", proportions=[0.2], time=50) g = b.resolve() assert g.pulses[0].time == g.pulses[1].time == g.pulses[2].time == 100 assert g.pulses[3].time == 50 - # proportion - b = Builder(defaults=dict(pulse=dict(proportion=0.1))) + # proportions + b = Builder(defaults=dict(pulse=dict(proportions=[0.1]))) for name in "abcd": b.add_deme(name, epochs=[dict(start_size=1)]) for name in "bcd": - b.add_pulse(source="a", dest=name, time=100) - b.add_pulse(source="d", dest="a", time=50, proportion=0.2) + b.add_pulse(sources=["a"], dest=name, time=100) + b.add_pulse(sources=["d"], dest="a", time=50, proportions=[0.2]) g = b.resolve() assert ( - g.pulses[0].proportion - == g.pulses[1].proportion - == g.pulses[2].proportion + g.pulses[0].proportions[0] + == g.pulses[1].proportions[0] + == g.pulses[2].proportions[0] == 0.1 ) - assert g.pulses[3].proportion == 0.2 + assert g.pulses[3].proportions[0] == 0.2 # Test toplevel epoch defaults, including overrides. def test_toplevel_defaults_epoch(self): diff --git a/tests/test_load_dump.py b/tests/test_load_dump.py index 4082d93c..c03cb1d2 100644 --- a/tests/test_load_dump.py +++ b/tests/test_load_dump.py @@ -140,25 +140,25 @@ def jacobs_papuans(): demes=["Ghost3", "Eurasia"], rate=4.42e-4, start_time=T_Eu_bottleneck ) - b.add_pulse(source="NeaI", dest="EastAsia", proportion=0.002, time=883) - b.add_pulse(source="NeaI", dest="Papua", proportion=0.002, time=1412) - b.add_pulse(source="NeaI", dest="Eurasia", proportion=0.011, time=1566) - b.add_pulse(source="NeaI", dest="Ghost1", proportion=0.024, time=1853) + b.add_pulse(sources=["NeaI"], dest="EastAsia", proportions=[0.002], time=883) + b.add_pulse(sources=["NeaI"], dest="Papua", proportions=[0.002], time=1412) + b.add_pulse(sources=["NeaI"], dest="Eurasia", proportions=[0.011], time=1566) + b.add_pulse(sources=["NeaI"], dest="Ghost1", proportions=[0.024], time=1853) m_Den_Papuan = 0.04 p = 0.55 # S10.i p. 31 T_Den1_Papuan_mig = 29.8e3 / generation_time T_Den2_Papuan_mig = 45.7e3 / generation_time b.add_pulse( - source="DenI1", + sources=["DenI1"], dest="Papua", - proportion=p * m_Den_Papuan, + proportions=[p * m_Den_Papuan], time=T_Den1_Papuan_mig, ) b.add_pulse( - source="DenI2", + sources=["DenI2"], dest="Papua", - proportion=(1 - p) * m_Den_Papuan, + proportions=[(1 - p) * m_Den_Papuan], time=T_Den2_Papuan_mig, ) @@ -275,8 +275,8 @@ def check_dumps_complex(self, *, format, simplified): assert f"{deme.name}" in string assert "pulses" in string for pulse in g.pulses: - assert "source" in string - assert pulse.source in string + assert "sources" in string + assert pulse.sources[0] in string assert "dest" in string assert pulse.dest in string assert "migrations" in string @@ -565,10 +565,10 @@ def test_float_subclass(self): self.check_dump_load_roundtrip(b.resolve()) b.add_pulse( - source="A", + sources=["A"], dest="B", time=T[1], - proportion=decimal.Decimal("0.0022"), + proportions=[decimal.Decimal("0.0022")], ) self.check_dump_load_roundtrip(b.resolve()) diff --git a/tests/test_ms.py b/tests/test_ms.py index 8e507436..f62ace0d 100644 --- a/tests/test_ms.py +++ b/tests/test_ms.py @@ -587,9 +587,9 @@ def test_split(self): source = f"deme{num_demes + 1}" # pulse properties pulse = graph.pulses[0] - assert pulse.source == source + assert pulse.sources[0] == source assert pulse.dest == dest - assert math.isclose(pulse.proportion, 1 - p) + assert math.isclose(pulse.proportions[0], 1 - p) assert math.isclose(pulse.time, t * 4 * N0) # source deme properties assert math.isinf(graph[source].start_time) @@ -635,13 +635,13 @@ def test_split_twice_immediately(self): assert math.isclose(graph["deme3"].end_time, T * 4 * N0) assert len(graph.pulses) == 2 # The order of pulses matters here. - assert graph.pulses[0].source == "deme3" + assert graph.pulses[0].sources[0] == "deme3" assert graph.pulses[0].dest == "deme1" - assert math.isclose(graph.pulses[0].proportion, 1 - 0.8) + assert math.isclose(graph.pulses[0].proportions[0], 1 - 0.8) assert math.isclose(graph.pulses[0].time, T * 4 * N0) - assert graph.pulses[1].source == "deme2" + assert graph.pulses[1].sources[0] == "deme2" assert graph.pulses[1].dest == "deme1" - assert math.isclose(graph.pulses[1].proportion, 1 - 0.7) + assert math.isclose(graph.pulses[1].proportions[0], 1 - 0.7) assert math.isclose(graph.pulses[1].time, T * 4 * N0) def test_join(self): @@ -844,9 +844,9 @@ def test_split_then_join(self): assert len(graph["deme2"].epochs) == 1 # check pulse pulse = graph.pulses[0] - assert pulse.source == "deme2" + assert pulse.sources[0] == "deme2" assert pulse.dest == "deme1" - assert pulse.proportion == 1 - p + assert pulse.proportions[0] == 1 - p assert math.isclose(pulse.time, T1 * 4 * N0) def test_split_then_join_immediately(self): @@ -870,9 +870,9 @@ def test_split_then_join_immediately(self): assert len(graph["deme2"].epochs) == 1 # check pulse pulse = graph.pulses[0] - assert pulse.source == "deme2" + assert pulse.sources[0] == "deme2" assert pulse.dest == "deme1" - assert pulse.proportion == 1 - p + assert pulse.proportions[0] == 1 - p assert math.isclose(pulse.time, T1 * 4 * N0) @pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time") @@ -888,14 +888,14 @@ def test_split_then_join_sequence_1(self): graph = demes.from_ms(cmd, N0=N0) assert len(graph.demes) == 3 assert len(graph.pulses) == 2 - assert graph.pulses[0].source == "deme1" + assert graph.pulses[0].sources[0] == "deme1" assert graph.pulses[0].dest == "deme2" assert math.isclose(graph.pulses[0].time, T1 * 4 * N0) - assert math.isclose(graph.pulses[0].proportion, 0.4) - assert graph.pulses[1].source == "deme2" + assert math.isclose(graph.pulses[0].proportions[0], 0.4) + assert graph.pulses[1].sources[0] == "deme2" assert graph.pulses[1].dest == "deme3" assert math.isclose(graph.pulses[1].time, T1 * 4 * N0) - assert math.isclose(graph.pulses[1].proportion, 0.4) + assert math.isclose(graph.pulses[1].proportions[0], 0.4) @pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time") def test_split_then_join_sequence_2(self): @@ -910,14 +910,14 @@ def test_split_then_join_sequence_2(self): graph = demes.from_ms(cmd, N0=N0) assert len(graph.demes) == 3 assert len(graph.pulses) == 2 - assert graph.pulses[0].source == "deme1" + assert graph.pulses[0].sources[0] == "deme1" assert graph.pulses[0].dest == "deme3" assert math.isclose(graph.pulses[0].time, T1 * 4 * N0) - assert math.isclose(graph.pulses[0].proportion, 0.4) - assert graph.pulses[1].source == "deme2" + assert math.isclose(graph.pulses[0].proportions[0], 0.4) + assert graph.pulses[1].sources[0] == "deme2" assert graph.pulses[1].dest == "deme3" assert math.isclose(graph.pulses[1].time, T1 * 4 * N0) - assert math.isclose(graph.pulses[1].proportion, 0.4) + assert math.isclose(graph.pulses[1].proportions[0], 0.4) @pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time") def test_split_then_join_sequence_3(self): @@ -932,14 +932,14 @@ def test_split_then_join_sequence_3(self): graph = demes.from_ms(cmd, N0=N0) assert len(graph.demes) == 3 assert len(graph.pulses) == 2 - assert graph.pulses[0].source == "deme2" + assert graph.pulses[0].sources[0] == "deme2" assert graph.pulses[0].dest == "deme1" assert math.isclose(graph.pulses[0].time, T1 * 4 * N0) - assert math.isclose(graph.pulses[0].proportion, 0.4) - assert graph.pulses[1].source == "deme2" + assert math.isclose(graph.pulses[0].proportions[0], 0.4) + assert graph.pulses[1].sources[0] == "deme2" assert graph.pulses[1].dest == "deme3" assert math.isclose(graph.pulses[1].time, T1 * 4 * N0) - assert math.isclose(graph.pulses[1].proportion, 0.4) + assert math.isclose(graph.pulses[1].proportions[0], 0.4) @pytest.mark.filterwarnings("ignore:Multiple pulses.*same.*time") def test_split_then_join_sequence_then_join(self): @@ -1349,7 +1349,7 @@ def test_deme_names(self): assert "A" in graph assert "B" in graph assert len(graph.pulses) == 1 - assert graph.pulses[0].source == "B" + assert graph.pulses[0].sources[0] == "B" assert graph.pulses[0].dest == "A" # bad deme names @@ -1982,7 +1982,7 @@ def test_pulse(self): b = demes.Builder() b.add_deme("a", epochs=[dict(start_size=N0)]) b.add_deme("b", epochs=[dict(start_size=N0)]) - b.add_pulse(source="a", dest="b", time=T0, proportion=0.1) + b.add_pulse(sources=["a"], dest="b", time=T0, proportions=[0.1]) graph1 = b.resolve() cmd = demes.to_ms(graph1, N0=N0) structure, events = self.parse_command(cmd) @@ -2011,8 +2011,8 @@ def test_pulse_order(self): b.add_deme("a", epochs=[dict(start_size=N0)]) b.add_deme("b", epochs=[dict(start_size=N0)]) b.add_deme("c", epochs=[dict(start_size=N0)]) - b.add_pulse(source="a", dest="b", time=T0, proportion=0.1) - b.add_pulse(source="b", dest="c", time=T0, proportion=0.1) + b.add_pulse(sources=["a"], dest="b", time=T0, proportions=[0.1]) + b.add_pulse(sources=["b"], dest="c", time=T0, proportions=[0.1]) graph1 = b.resolve() cmd = demes.to_ms(graph1, N0=N0)