diff --git a/demes/demes.py b/demes/demes.py index f27f0932..9de4ce56 100644 --- a/demes/demes.py +++ b/demes/demes.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional, Dict, MutableMapping, Any, Set +from typing import List, Union, Optional, Dict, MutableMapping, Any, Set, Tuple import itertools import math import numbers @@ -1631,10 +1631,52 @@ def _add_pulse(self, *, source, dest, proportion, time) -> Pulse: self.pulses.append(new_pulse) return new_pulse - def _migration_matrices(self): + def migration_matrices(self) -> Tuple[List[List[List[float]]], List[Number]]: """ - Return a list of migration matrices, and a list of end times that - partition them. The start time for the first matrix is inf. + Get the migration matrices and the end times that partition them. + + Returns a list of matrices, one for each time interval + over which migration rates do not change, in time-descending + order (from most ancient to most recent). For a migration matrix list + :math:`M`, the migration rate is :math:`M[i][j][k]` from deme + :math:`k` into deme :math:`j` during the :math:`i` 'th time interval. + The order of the demes' indices in each matrix matches the + order of demes in the graph's deme list (I.e. deme :math:`j` + corresponds to ``Graph.demes[j]``). + + There is always at least one migration matrix in the list, even when + the graph defines no migrations. + + A list of end times to which the matrices apply is also + returned. The time intervals to which the migration rates apply are an + open-closed interval ``(start_time, end_time]``, where the start time + of the first matrix is ``inf`` and the start time of subsequent + matrices match the end time of the previous matrix in the list. + + .. note:: + The last entry of the list of end times is always ``0``, + even when all demes in the graph go extinct before time ``0``. + + + .. code:: + + graph = demes.load("gutenkunst_ooa.yml") + mm_list, end_times = graph.migration_matrices() + start_times = [math.inf] + end_times[:-1] + assert len(mm_list) == len(end_times) == len(start_times) + deme_ids = {deme.name: j for j, deme in enumerate(graph.demes)} + j = deme_ids["YRI"] + k = deme_ids["CEU"] + for mm, start_time, end_time in zip(mm_list, start_times, end_times): + print( + f"CEU -> YRI migration rate is {mm[j][k]} during the " + f"time interval ({start_time}, {end_time}]" + ) + + :return: A 2-tuple of ``(mm_list, end_times)``, + where ``mm_list`` is a list of migration matrices, + and ``end_times`` are a list of end times for each matrix. + :rtype: tuple[list[list[list[float]]], list[float]] """ uniq_times = set(migration.start_time for migration in self.migrations) uniq_times.update(migration.end_time for migration in self.migrations) @@ -1644,7 +1686,7 @@ def _migration_matrices(self): # Extend to t=0 even when there are no migrations. end_times.append(0) n = len(self.demes) - mm_list = [[[0] * n for _ in range(n)] for _ in range(len(end_times))] + mm_list = [[[0.0] * n for _ in range(n)] for _ in range(len(end_times))] deme_id = {deme.name: j for j, deme in enumerate(self.demes)} for migration in self.migrations: start_time = math.inf @@ -1670,7 +1712,7 @@ def _check_migration_rates(self): deme in any interval of time. """ start_time = math.inf - mm_list, end_times = self._migration_matrices() + mm_list, end_times = self.migration_matrices() for migration_matrix, end_time in zip(mm_list, end_times): for j, row in enumerate(migration_matrix): row_sum = sum(row)