Skip to content

Commit

Permalink
Allow simultaneous pulses using multiple sources
Browse files Browse the repository at this point in the history
  • Loading branch information
apragsdale committed Jul 28, 2021
1 parent 586f073 commit 6b49ef9
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 199 deletions.
2 changes: 1 addition & 1 deletion .mergify.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ pull_request_rules:
- check-success=tests (windows-2019, 3.9)
actions:
queue:
name: default
name: default
136 changes: 89 additions & 47 deletions demes/demes.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,26 +409,47 @@ class Pulse:
Parameters for a pulse of migration from one deme to another.
Source and destination demes follow the forwards-in-time convention,
of migrations born in the source deme having children in the dest
deme.
deme. If more than one source is given, migration is concurrent, so that
the sum of the migrant proportions must sum to less than one.
:ivar str source: The source deme.
:ivar list(str) sources: The source deme(s).
:ivar str dest: The destination deme.
:ivar float ~.time: The time of migration.
:ivar float proportion: At the instant after migration, this is the
proportion of individuals in the destination deme made up of
individuals from the source deme.
:ivar list(float) proportions: Immediately following migration, the proportion(s)
of individuals in the destination deme made up of migrant individuals or
having parents from the source deme(s).
"""

source: Name = attr.ib(
validator=[attr.validators.instance_of(str), valid_deme_name]
sources: List[Name] = attr.ib(
validator=attr.validators.and_(
attr.validators.deep_iterable(
member_validator=attr.validators.and_(
attr.validators.instance_of(str), valid_deme_name
),
iterable_validator=attr.validators.instance_of(list),
),
nonzero_len,
)
)
dest: Name = attr.ib(validator=[attr.validators.instance_of(str), valid_deme_name])
time: Time = attr.ib(validator=[int_or_float, positive, finite])
proportion: Proportion = attr.ib(validator=[int_or_float, unit_interval])
proportions: List[Proportion] = attr.ib(
validator=attr.validators.deep_iterable(
member_validator=attr.validators.and_(int_or_float, unit_interval),
iterable_validator=attr.validators.instance_of(list),
)
)

def __attrs_post_init__(self):
if self.source == self.dest:
raise ValueError("source and dest cannot be the same deme")
for source in self.sources:
if source == self.dest:
raise ValueError(f"source ({source}) cannot be the same as dest")
if self.sources.count(source) != 1:
raise ValueError(f"source ({source}) cannot be repeated in sources")
if len(self.sources) != len(self.proportions):
raise ValueError("sources and proportions must have the same length")
if sum(self.proportions) > 1:
raise ValueError("proportions must sum to less than one")

def assert_close(
self,
Expand All @@ -455,14 +476,28 @@ def assert_close(
assert (
self.__class__ is other.__class__
), f"Failed as other pulse is not instance of {self.__class__} type."
assert self.source == other.source
assert len(self.sources) == len(other.sources)
for s in self.sources:
assert s in other.sources
for s in other.sources:
assert s in self.sources
assert self.dest == other.dest
assert math.isclose(
self.time, other.time, rel_tol=rel_tol, abs_tol=abs_tol
), f"Failed for time {self.time} != {other.time} (other)."
assert len(self.proportions) == len(other.proportions)
assert math.isclose(
self.proportion, other.proportion, rel_tol=rel_tol, abs_tol=abs_tol
), f"Failed for proportion {self.proportion} != {other.proportion} (other)."
sum(self.proportions),
sum(other.proportions),
rel_tol=rel_tol,
abs_tol=abs_tol,
), (
f"Failed for unequal proportions sums: "
f"sum({self.proportions}) != sum({other.proportions}) (other)."
)
assert isclose_deme_proportions(
self.sources, self.proportions, other.sources, other.proportions
)

def isclose(
self,
Expand Down Expand Up @@ -1556,40 +1591,43 @@ def _add_asymmetric_migration(
self.migrations.append(migration)
return migration

def _add_pulse(self, *, source, dest, proportion, time) -> Pulse:
def _add_pulse(self, *, sources, dest, proportions, time) -> Pulse:
"""
Add a pulse of migration at a fixed time.
Source and destination demes follow the forwards-in-time convention.
:param str source: The name of the source deme.
:param list(str) sources: The name(s) of the source deme(s).
:param str dest: The name of the destination deme.
:param float proportion: At the instant after migration, this is the expected
proportion of individuals in the destination deme made up of individuals
from the source deme.
:param list(float) proportion(s): Immediately following migration, the
proportion(s) of individuals in the destination deme made up of
migrant individuals or having parents from the source deme(s).
:param float time: The time at which migrations occur.
:return: Newly created pulse.
:rtype: :class:`.Pulse`
"""
for deme_name in (source, dest):
for deme_name in sources + [dest]:
if deme_name not in self:
raise ValueError(f"{deme_name} not in graph")
self._check_time_intersection(source, dest, time)
for source in sources:
self._check_time_intersection(source, dest, time)
if time == self[dest].end_time:
raise ValueError(
f"invalid pulse at time={time}, which is dest={dest}'s end_time"
)
if time == self[source].start_time:
raise ValueError(
f"invalid pulse at time={time}, which is source={source}'s start_time"
)
for source in sources:
if time == self[source].start_time:
raise ValueError(
f"invalid pulse at time={time}, "
f"which is source={source}'s start_time"
)

# We create the new pulse object (which checks for common errors)
# before checking for edge cases below.
new_pulse = Pulse(
source=source,
sources=sources,
dest=dest,
time=time,
proportion=proportion,
proportions=proportions,
)

# Check for models that have multiple pulses defined at the same time.
Expand All @@ -1598,20 +1636,23 @@ def _add_pulse(self, *, source, dest, proportion, time) -> Pulse:
# interpretation of the model. Such models are valid, but the behaviour
# may not be what the user expects.
# See https://github.com/popsim-consortium/demes-python/issues/46
sources = set()
dests = set()
all_sources = set()
all_dests = set()
for pulse in self.pulses:
if pulse.time == time:
sources.add(pulse.source)
dests.add(pulse.dest)
if source in dests or dest in (sources | dests):
all_sources.update(pulse.sources)
all_dests.add(pulse.dest)
if any(source in all_dests for source in sources) or dest in (
all_sources | all_dests
):
warnings.warn(
"Multiple pulses are defined for the same deme(s) at time "
f"{time}. The ancestry proportions after this time will thus "
"depend on the order in which the pulses have been specified. "
"To avoid unexpected behaviour, the graph can instead "
"be structured to introduce a new deme at this time with "
"the desired ancestry proportions."
"the desired ancestry proportions or to specify concurrent "
"pulses with multiple sources."
)

self.pulses.append(new_pulse)
Expand Down Expand Up @@ -1916,7 +1957,7 @@ def fromdict(cls, data: MutableMapping[str, Any]) -> "Graph":
)

pulse_defaults = pop_object(defaults, "pulse", {}, scope="defaults")
allowed_fields_pulse = ["source", "dest", "time", "proportion"]
allowed_fields_pulse = ["sources", "dest", "time", "proportions"]
check_allowed(pulse_defaults, allowed_fields_pulse, "defaults.pulse")

# epoch defaults may also be specified within a Deme definition.
Expand Down Expand Up @@ -2066,15 +2107,15 @@ def fromdict(cls, data: MutableMapping[str, Any]) -> "Graph":
):
check_allowed(pulse_data, allowed_fields_pulse, f"pulse[{i}]")
insert_defaults(pulse_data, pulse_defaults)
for field in ("source", "dest", "time", "proportion"):
for field in ("sources", "dest", "time", "proportions"):
if field not in pulse_data:
raise KeyError(f"pulse[{i}]: required field '{field}' not found")
try:
graph._add_pulse(
source=pulse_data.pop("source"),
sources=pulse_data.pop("sources"),
dest=pulse_data.pop("dest"),
time=pulse_data.pop("time"),
proportion=pulse_data.pop("proportion"),
proportions=pulse_data.pop("proportions"),
)
except (TypeError, ValueError) as e:
raise e.__class__(f"pulse[{i}]: invalid pulse") from e
Expand Down Expand Up @@ -2397,29 +2438,29 @@ def add_migration(
def add_pulse(
self,
*,
source: str = None,
sources: List[str] = None,
dest: str = None,
proportion: float = None,
proportions: List[float] = None,
time: float = None,
):
"""
Add a pulse of migration at a fixed time.
Source and destination demes follow the forwards-in-time convention.
:param str source: The name of the source deme.
:param list(str) source: The name of the source deme(s).
:param str dest: The name of the destination deme.
:param float proportion: At the instant after migration, this is the
expected proportion of individuals in the destination deme made up
of individuals from the source deme.
:param list(float) proportion: At the instant after migration, this is the
expected proportion(s) of individuals in the destination deme made up
of individuals from the source deme(s).
:param float time: The time at which migrations occur.
"""
pulse: MutableMapping[str, Any] = dict()
if source is not None:
pulse["source"] = source
if sources is not None:
pulse["sources"] = sources
if dest is not None:
pulse["dest"] = dest
if proportion is not None:
pulse["proportion"] = proportion
if proportions is not None:
pulse["proportions"] = proportions
if time is not None:
pulse["time"] = time

Expand Down Expand Up @@ -2536,7 +2577,8 @@ def _remove_transient_demes(self) -> None:
continue
if start_time == end_time:
for pulse in self.data.get("pulses", []):
assert pulse["source"] != deme["name"]
for s in pulse["sources"]:
assert s != deme["name"]
assert pulse["dest"] != deme["name"]
for migration in self.data.get("migrations", []):
assert deme["name"] not in migration.get("demes", [])
Expand Down
4 changes: 2 additions & 2 deletions demes/hypothesis_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,10 @@ def pulses_lists(draw, graph, max_pulses=10):
)
ingress_proportions[(dest, time)] += proportion
pulse = dict(
source=source,
sources=[source],
dest=dest,
time=time,
proportion=proportion,
proportions=[proportion],
)
pulses.append(pulse)
n_pulses -= 1
Expand Down
14 changes: 9 additions & 5 deletions demes/ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,10 @@ def migration_matrix_at(time):
# The order of pulses will later be reversed such that realised
# ancestry proportions are maintained forwards in time.
b.add_pulse(
source=f"deme{k + 1}",
sources=[f"deme{k + 1}"],
dest=f"deme{j + 1}",
time=time,
proportion=p,
proportions=[p],
)

# Resolve/remove growth_rate in oldest epochs.
Expand Down Expand Up @@ -830,7 +830,7 @@ def remap_deme_names(graph: demes.Graph, names: Mapping[str, str]) -> demes.Grap
migration.source = names[migration.source]
migration.dest = names[migration.dest]
for pulse in graph.pulses:
pulse.source = names[pulse.source]
pulse.sources = [names[s] for s in pulse.sources]
pulse.dest = names[pulse.dest]
for k, deme in list(graph._deme_map.items()):
del graph._deme_map[k]
Expand Down Expand Up @@ -989,8 +989,12 @@ def get_growth_rate(epoch):
pulse = deme_or_pulse
num_demes += 1
new_deme_id = num_demes
e1 = Split(pulse.time, deme_id[pulse.dest], 1 - pulse.proportion)
e2 = Join(pulse.time, new_deme_id, deme_id[pulse.source])
if len(pulse.sources) > 1:
raise ValueError(
"Currently pulses will only a single source are supported"
)
e1 = Split(pulse.time, deme_id[pulse.dest], 1 - pulse.proportions[0])
e2 = Join(pulse.time, new_deme_id, deme_id[pulse.sources[0]])
events.extend([e1, e2])

# Turn migrations off at the start_time. We schedule all start_time
Expand Down
12 changes: 6 additions & 6 deletions examples/jacobs_papuans.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ migrations:
- {demes: [CEU, Ghost], rate: 0.000442}

pulses:
- {source: Nea1, dest: Ghost, time: 1853.0, proportion: 0.024}
- {source: Den2, dest: Papuan, time: 1575.8620689655172, proportion: 0.018}
- {source: Nea1, dest: CHB, time: 1566.0, proportion: 0.011}
- {source: Nea1, dest: Papuan, time: 1412.0, proportion: 0.002}
- {source: Den1, dest: Papuan, time: 1027.5862068965516, proportion: 0.022}
- {source: Nea1, dest: CHB, time: 883.0, proportion: 0.002}
- {sources: [Nea1], dest: Ghost, time: 1853.0, proportions: [0.024]}
- {sources: [Den2], dest: Papuan, time: 1575.8620689655172, proportions: [0.018]}
- {sources: [Nea1], dest: CHB, time: 1566.0, proportions: [0.011]}
- {sources: [Nea1], dest: Papuan, time: 1412.0, proportions: [0.002]}
- {sources: [Den1], dest: Papuan, time: 1027.5862068965516, proportions: [0.022]}
- {sources: [Nea1], dest: CHB, time: 883.0, proportions: [0.002]}
4 changes: 2 additions & 2 deletions examples/offshoots.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ migrations:
- demes: [offshoot1, offshoot2]
rate: 2e-5
pulses:
- source: offshoot1
- sources: [offshoot1]
dest: ancestral
proportion: 0.1
proportions: [0.1]
time: 50
Loading

0 comments on commit 6b49ef9

Please sign in to comment.