Skip to content

Commit

Permalink
Export (and document) Graph.migration_matrices().
Browse files Browse the repository at this point in the history
  • Loading branch information
grahamgower committed Jun 4, 2021
1 parent ad616b3 commit 1c460f1
Showing 1 changed file with 48 additions and 6 deletions.
54 changes: 48 additions & 6 deletions demes/demes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union, Optional, Dict, MutableMapping, Any, Set
from typing import List, Union, Optional, Dict, MutableMapping, Any, Set, Tuple
import itertools
import math
import numbers
Expand Down Expand Up @@ -1631,10 +1631,52 @@ def _add_pulse(self, *, source, dest, proportion, time) -> Pulse:
self.pulses.append(new_pulse)
return new_pulse

def _migration_matrices(self):
def migration_matrices(self) -> Tuple[List[List[List[float]]], List[Number]]:
"""
Return a list of migration matrices, and a list of end times that
partition them. The start time for the first matrix is inf.
Get the migration matrices and the end times that partition them.
Returns a list of matrices, one for each time interval
over which migration rates do not change, in time-descending
order (from most ancient to most recent). For a migration matrix list
:math:`M`, the migration rate is :math:`M[i][j][k]` from deme
:math:`k` into deme :math:`j` during the :math:`i` 'th time interval.
The order of the demes' indices in each matrix matches the
order of demes in the graph's deme list (I.e. deme :math:`j`
corresponds to ``Graph.demes[j]``).
There is always at least one migration matrix in the list, even when
the graph defines no migrations.
A list of end times to which the matrices apply is also
returned. The time intervals to which the migration rates apply are an
open-closed interval ``(start_time, end_time]``, where the start time
of the first matrix is ``inf`` and the start time of subsequent
matrices match the end time of the previous matrix in the list.
.. note::
The last entry of the list of end times is always ``0``,
even when all demes in the graph go extinct before time ``0``.
.. code::
graph = demes.load("gutenkunst_ooa.yml")
mm_list, end_times = graph.migration_matrices()
start_times = [math.inf] + end_times[:-1]
assert len(mm_list) == len(end_times) == len(start_times)
deme_ids = {deme.name: j for j, deme in enumerate(graph.demes)}
j = deme_ids["YRI"]
k = deme_ids["CEU"]
for mm, start_time, end_time in zip(mm_list, start_times, end_times):
print(
f"CEU -> YRI migration rate is {mm[j][k]} during the "
f"time interval ({start_time}, {end_time}]"
)
:return: A 2-tuple of ``(mm_list, end_times)``,
where ``mm_list`` is a list of migration matrices,
and ``end_times`` are a list of end times for each matrix.
:rtype: tuple[list[list[list[float]]], list[float]]
"""
uniq_times = set(migration.start_time for migration in self.migrations)
uniq_times.update(migration.end_time for migration in self.migrations)
Expand All @@ -1644,7 +1686,7 @@ def _migration_matrices(self):
# Extend to t=0 even when there are no migrations.
end_times.append(0)
n = len(self.demes)
mm_list = [[[0] * n for _ in range(n)] for _ in range(len(end_times))]
mm_list = [[[0.0] * n for _ in range(n)] for _ in range(len(end_times))]
deme_id = {deme.name: j for j, deme in enumerate(self.demes)}
for migration in self.migrations:
start_time = math.inf
Expand All @@ -1670,7 +1712,7 @@ def _check_migration_rates(self):
deme in any interval of time.
"""
start_time = math.inf
mm_list, end_times = self._migration_matrices()
mm_list, end_times = self.migration_matrices()
for migration_matrix, end_time in zip(mm_list, end_times):
for j, row in enumerate(migration_matrix):
row_sum = sum(row)
Expand Down

0 comments on commit 1c460f1

Please sign in to comment.