Skip to content

Commit

Permalink
Add Builder API.
Browse files Browse the repository at this point in the history
This was fairly invasive, so there are a bunch of unrelated
changes that went in too. In particular, spec compliance for
defaults and symmetric/asymmetric migrations.

Closes popsim-consortium#157, popsim-consortium#174, popsim-consortium#183, popsim-consortium#184, popsim-consortium#201, popsim-consortium#208.
  • Loading branch information
grahamgower committed Feb 22, 2021
1 parent d9b4347 commit e2a7e05
Show file tree
Hide file tree
Showing 15 changed files with 2,702 additions and 1,436 deletions.
15 changes: 14 additions & 1 deletion demes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,18 @@
except ImportError:
pass

from .demes import Epoch, Migration, Pulse, Deme, Graph, Split, Branch, Merge, Admix
from .demes import (
Builder,
Epoch,
Migration,
SymmetricMigration,
AsymmetricMigration,
Pulse,
Deme,
Graph,
Split,
Branch,
Merge,
Admix,
)
from .load_dump import load_asdict, loads_asdict, load, loads, dump, dumps
96 changes: 52 additions & 44 deletions demes/convert/msprime_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List, Mapping, Tuple
import typing
import math
import collections
import itertools
Expand Down Expand Up @@ -33,8 +32,8 @@ def to_msprime(graph: demes.Graph):
pop_id = {deme.id: j for j, deme in enumerate(graph.demes)}

def growth_rate(epoch: demes.Epoch) -> float:
initial_size = typing.cast(float, epoch.end_size)
final_size = typing.cast(float, epoch.start_size)
initial_size = epoch.end_size
final_size = epoch.start_size
if initial_size == final_size:
growth_rate = 0.0
else:
Expand All @@ -61,7 +60,7 @@ def growth_rate(epoch: demes.Epoch) -> float:
initial_size = Ne_invalid
_growth_rate = 0.0
else:
initial_size = typing.cast(float, deme.epochs[-1].end_size)
initial_size = deme.epochs[-1].end_size
_growth_rate = growth_rate(deme.epochs[-1])
population_configurations.append(
msprime.PopulationConfiguration(
Expand Down Expand Up @@ -94,9 +93,7 @@ def growth_rate(epoch: demes.Epoch) -> float:
population_id=pop_id[deme.id],
)
)
if epoch == deme.epochs[0] and not math.isinf(
typing.cast(float, epoch.start_time)
):
if epoch == deme.epochs[0] and not math.isinf(epoch.start_time):
# If this deme doesn't exist at time=inf, invalidate Ne when
# the deme ceases to exist.
demographic_events.append(
Expand All @@ -119,19 +116,15 @@ def growth_rate(epoch: demes.Epoch) -> float:
)

mig_rate_events = []
for migration in reversed(graph.migrations):
dest = pop_id[migration.source]
source = pop_id[migration.dest]
start_time = migration.end_time
end_time = migration.start_time

def append_migration(dest, source, start_time, end_time, rate):
if start_time == 0:
migration_matrix[dest][source] = migration.rate
migration_matrix[dest][source] = rate
else:
mig_rate_events.append(
msprime.MigrationRateChange(
time=start_time,
rate=migration.rate,
rate=rate,
matrix_index=(dest, source),
)
)
Expand All @@ -141,6 +134,22 @@ def growth_rate(epoch: demes.Epoch) -> float:
)
)

for migration in reversed(graph.migrations):
rate = migration.rate
start_time = migration.end_time
end_time = migration.start_time
if isinstance(migration, demes.AsymmetricMigration):
dest = pop_id[migration.source]
source = pop_id[migration.dest]
append_migration(dest, source, start_time, end_time, rate)
else:
assert isinstance(migration, demes.SymmetricMigration)
for x, y in itertools.permutations(migration.demes, 2):
pop_x = pop_id[x]
pop_y = pop_id[y]
append_migration(pop_x, pop_y, start_time, end_time, rate)
append_migration(pop_y, pop_x, start_time, end_time, rate)

# Collapse migration rate events in the same generation.
# This is not strictly needed, but usually results in fewer events.
mig_rate_events.sort(key=lambda de: de.time)
Expand Down Expand Up @@ -212,11 +221,11 @@ def from_msprime(
# up the correct `Epoch`s, `Migration`s, and `Pulse`s outside of the graph.
gtmp: Mapping[str, dict] = {"demes": {}}

# List of demes.Epoch, keyed by deme id.
epochs: Mapping[str, List[demes.Epoch]] = collections.defaultdict(list)
# List of epoch dicts, keyed by deme id.
epochs: Mapping[str, List[dict]] = collections.defaultdict(list)
# List of deme.Migration, keyed by (source, dest) indexes
migrations: Mapping[
Tuple[int, int], List[demes.Migration]
Tuple[int, int], List[demes.AsymmetricMigration]
] = collections.defaultdict(list)
# migration_matrix in the previous ddb epoch
prev_mm = np.zeros((num_pops, num_pops))
Expand All @@ -240,7 +249,7 @@ def from_msprime(
gtmp["demes"][pop_name] = {
"ancestors": [],
"proportions": [],
"epochs": [demes.Epoch(start_size=1)],
"epochs": [dict(start_size=1)],
}
pop_param_changes.add(pop_name)

Expand All @@ -257,7 +266,7 @@ def from_msprime(
gtmp["demes"][parent] = {
"ancestors": [],
"proportions": [],
"epochs": [demes.Epoch(start_size=1)],
"epochs": [dict(start_size=1)],
}

if math.isclose(sum(proportions), 1):
Expand All @@ -266,7 +275,7 @@ def from_msprime(
# Set attributes after deme creation, to avoid internal
# checks about the ancestors' existence time intervals.
gtmp["demes"][child]["epochs"] = [
demes.Epoch(start_time=ddb_epoch.start_time, start_size=1)
dict(start_time=ddb_epoch.start_time, start_size=1)
]
gtmp["demes"][child]["ancestors"] = ancestors
gtmp["demes"][child]["proportions"] = proportions
Expand Down Expand Up @@ -294,17 +303,15 @@ def from_msprime(
if deme_id not in epochs:
epochs[deme_id].append(gtmp["demes"][deme_id]["epochs"][0])
last_epoch = epochs[deme_id][-1]
last_epoch.end_time = ddb_epoch.start_time
if last_epoch.start_time == ddb_epoch.end_time:
last_epoch.start_size = pop.end_size
last_epoch.end_size = pop.start_size
last_epoch["end_time"] = ddb_epoch.start_time
if last_epoch.get("start_time") == ddb_epoch.end_time:
last_epoch["start_size"] = pop.end_size
last_epoch["end_size"] = pop.start_size

if name[j] in pop_param_changes:
# Add new epoch, to be fixed in the next ddb_epoch iteration.
epochs[deme_id].append(
demes.Epoch(
start_time=ddb_epoch.start_time, end_time=0, start_size=1
)
dict(start_time=ddb_epoch.start_time, end_time=0, start_size=1)
)

# Construct per-pair lists of migrations from the migration matrix.
Expand All @@ -315,7 +322,7 @@ def from_msprime(
continue
if prev_mm[j, k] != msp_mm[j, k] and msp_mm[j, k] != 0:
# new Migration
m = demes.Migration(
m = demes.AsymmetricMigration(
source=name[j],
dest=name[k],
start_time=ddb_epoch.end_time,
Expand All @@ -333,39 +340,40 @@ def from_msprime(

for deme_id in epochs.keys():
epoch = epochs[deme_id][0]
if epoch.start_time is None or math.isinf(epoch.start_time):
epochs[deme_id][0] = demes.Epoch(
start_time=epoch.start_time,
end_time=epoch.end_time,
start_size=epoch.end_size,
end_size=epoch.end_size,
start_time = epoch.get("start_time")
if start_time is None or math.isinf(start_time):
epochs[deme_id][0] = dict(
start_time=start_time,
end_time=epoch["end_time"],
start_size=epoch["end_size"],
end_size=epoch["end_size"],
)

# Create a fresh demes graph, now that we have complete epoch information
# for each deme. This also validates consistency between parameters.
g = demes.Graph(
b = demes.Builder(
description="Converted from msprime demography.",
time_units="generations",
)

for deme_id, deme_dict in gtmp["demes"].items():
g.deme(
b.add_deme(
deme_id,
ancestors=deme_dict["ancestors"],
proportions=deme_dict["proportions"],
start_time=epochs[deme_id][0].start_time,
start_time=epochs[deme_id][0]["start_time"],
epochs=[
demes.Epoch(
end_time=epoch.end_time,
start_size=epoch.start_size,
end_size=epoch.end_size,
dict(
end_time=epoch["end_time"],
start_size=epoch["start_size"],
end_size=epoch["end_size"],
)
for epoch in epochs[deme_id]
],
)

for pulse in pulses:
g.pulse(
b.add_pulse(
source=pulse.source,
dest=pulse.dest,
proportion=pulse.proportion,
Expand All @@ -374,15 +382,15 @@ def from_msprime(

for migration_list in migrations.values():
for migration in migration_list:
g.migration(
b.add_migration(
source=migration.source,
dest=migration.dest,
start_time=migration.start_time,
end_time=migration.end_time,
rate=migration.rate,
)

return g
return b.resolve()


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit e2a7e05

Please sign in to comment.