Skip to content

Commit

Permalink
fixing return types of some methods and updating their docstrings (#325)
Browse files Browse the repository at this point in the history
* all functions now return instead of updating self

* small bugfix

* fixing mypy
  • Loading branch information
arik-shurygin authored Jan 13, 2025
1 parent 57f52cb commit b2a571f
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 91 deletions.
19 changes: 11 additions & 8 deletions src/dynode/abstract_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from abc import ABC, abstractmethod
from typing import Any

from numpy import ndarray

from . import SEIC_Compartments, utils


Expand Down Expand Up @@ -35,22 +37,23 @@ def get_initial_state(
assert self.INITIAL_STATE is not None
return self.INITIAL_STATE

def load_initial_population_fractions(self) -> None:
def load_initial_population_fractions(self) -> ndarray:
"""
loads age demographics for the US and
sets the inital population fraction by age bin.
Loads age demographics for the specified region and
returns the inital population fraction by age bin.
Updates
Returns
----------
`self.config.INITIAL_POPULATION_FRACTIONS` : numpy.ndarray
proportion of the total population that falls into each age group,
length of this array is equal the number of age groups sums to 1.0.
numpy.ndarray
Proportion of the total population that falls into each age group,
`len(self.load_initial_population_fractions()) == self.config.NUM_AGE_GROUPS`
`np.sum(self.load_initial_population_fractions()) == 1.0
"""
populations_path = (
self.config.DEMOGRAPHIC_DATA_PATH
+ "population_rescaled_age_distributions/"
)
# TODO support getting more regions than just 1
self.config.INITIAL_POPULATION_FRACTIONS = utils.load_age_demographics(
return utils.load_age_demographics(
populations_path, self.config.REGIONS, self.config.AGE_LIMITS
)[self.config.REGIONS[0]]
48 changes: 22 additions & 26 deletions src/dynode/abstract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,16 @@ def seasonality(
)
)

def retrieve_population_counts(self) -> None:
"""
A wrapper function which takes calculates the age stratified population counts across all the INITIAL_STATE compartments
(minus the book-keeping C compartment.) and stores it in the self.config.POPULATION parameter.
def retrieve_population_counts(self) -> np.ndarray:
"""Calculates the age stratified population counts across all tracked
self.INITIAL_STATE compartments, excluding the book-keeping C compartment.
We do not recieve this data exactly from the initializer, but it is trivial to recalculate.
Returns
-------
np.ndarray
population counts of each age bin within `self.INITIAL_STATE`
"""
self.config.POPULATION = np.sum( # sum together S+E+I compartments
return np.sum( # sum together S+E+I compartments
np.array(
[
np.sum(
Expand All @@ -508,7 +510,7 @@ def retrieve_population_counts(self) -> None:
axis=(0), # sum across compartments, keep age bins
)

def load_cross_immunity_matrix(self) -> None:
def load_cross_immunity_matrix(self) -> jax.Array:
"""
Loads the Crossimmunity matrix given the strain interactions matrix.
Strain interactions matrix is a matrix of shape
Expand All @@ -518,21 +520,18 @@ def load_cross_immunity_matrix(self) -> None:
previously from a strain in dim 1. Neither the strain interactions matrix
nor the crossimmunity matrix take into account waning.
Updates
Returns
----------
self.config.CROSSIMMUNITY_MATRIX:
updates this matrix to shape
(self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST)
jax.Array
matrix of shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST)
containing the relative immune escape values for each challenging
strain compared to each prior immune history in the model.
"""
self.config.CROSSIMMUNITY_MATRIX = (
utils.strain_interaction_to_cross_immunity(
self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS
)
return utils.strain_interaction_to_cross_immunity(
self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS
)

def load_vaccination_model(self) -> None:
def load_vaccination_model(self) -> tuple[jax.Array, jax.Array, jax.Array]:
"""
loads parameters of a polynomial spline vaccination model
stratified on age bin and current vaccination status. Reads spline
Expand All @@ -545,7 +544,7 @@ def load_vaccination_model(self) -> None:
Raises `FileNotFoundError` if directory path does not contain region
specific file matching expected naming convention.
UPDATES
Returns
-----------
the following are 3 parallel lists, each with leading dimensions
`(NUM_AGE_GROUPS, MAX_VAX_COUNT+1)` identifying the vaccination spline
Expand All @@ -562,7 +561,6 @@ def load_vaccination_model(self) -> None:
array defining the coefficients (a,b,c,d) of each
base equation `(a + b(t) + c(t)^2 + d(t)^3)` for the spline defined
by `VACCINATION_MODEL_KNOT_LOCATIONS[i][j]`.
"""
# if the user passes a directory instead of a file path
# check to see if the state exists in the directory and use that
Expand Down Expand Up @@ -642,12 +640,10 @@ def load_vaccination_model(self) -> None:
vax_knot_locations[age_group_idx, vax_idx, :] = np.array(
knot_locations
)
self.config.VACCINATION_MODEL_KNOTS = jnp.array(vax_knots)
self.config.VACCINATION_MODEL_KNOT_LOCATIONS = jnp.array(
vax_knot_locations
)
self.config.VACCINATION_MODEL_BASE_EQUATIONS = jnp.array(
vax_base_equations
return (
jnp.array(vax_knots),
jnp.array(vax_knot_locations),
jnp.array(vax_base_equations),
)

def seasonal_vaccination_reset(self, t: ArrayLike) -> ArrayLike:
Expand Down Expand Up @@ -699,7 +695,7 @@ def seasonal_vaccination_reset(self, t: ArrayLike) -> ArrayLike:
# if no seasonal vaccination, this function always returns zero
return 0

def load_contact_matrix(self) -> None:
def load_contact_matrix(self) -> np.ndarray:
"""
loads region specific contact matrix, usually sourced from
https://github.com/mobs-lab/mixing-patterns
Expand All @@ -711,7 +707,7 @@ def load_contact_matrix(self) -> None:
where `CONTACT_MATRIX[i][j]` refers to the per capita
interaction rate between age bin `i` and `j`
"""
self.config.CONTACT_MATRIX = utils.load_demographic_data(
return utils.load_demographic_data(
self.config.DEMOGRAPHIC_DATA_PATH,
self.config.REGIONS,
self.config.NUM_AGE_GROUPS,
Expand Down
124 changes: 80 additions & 44 deletions src/dynode/covid_sero_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import Array

from . import SEIC_Compartments, utils
from .abstract_initializer import AbstractInitializer
Expand All @@ -29,30 +30,44 @@ def __init__(self, config_initializer_path, global_variables_path):
self.config = Config(global_json).add_file(initializer_json)

if not hasattr(self.config, "INITIAL_POPULATION_FRACTIONS"):
self.load_initial_population_fractions()
self.config.INITIAL_POPULATION_FRACTIONS = (
self.load_initial_population_fractions()
)

self.config.POPULATION = (
self.config.POP_SIZE * self.config.INITIAL_POPULATION_FRACTIONS
)
# self.POPULATION.shape = (NUM_AGE_GROUPS,)

if not hasattr(self.config, "INIT_IMMUNE_HISTORY"):
self.load_immune_history_via_serological_data()
self.config.INIT_IMMUNE_HISTORY = (
self.load_immune_history_via_serological_data()
)
# self.INIT_IMMUNE_HISTORY.shape = (age, hist, num_vax, waning)
if not hasattr(self.config, "CONTACT_MATRIX"):
self.load_contact_matrix()
self.config.CONTACT_MATRIX = self.load_contact_matrix()

if not hasattr(self.config, "CROSSIMMUNITY_MATRIX"):
self.load_cross_immunity_matrix()
self.config.CROSSIMMUNITY_MATRIX = (
self.load_cross_immunity_matrix()
)
# stratify initial infections appropriately across age, hist, vax counts
if not (
hasattr(self.config, "INIT_INFECTED_DIST")
hasattr(self.config, "INIT_INFECTIOUS_DIST")
and hasattr(self.config, "INIT_EXPOSED_DIST")
) and hasattr(self.config, "CONTACT_MATRIX_PATH"):
self.load_init_infection_infected_and_exposed_dist_via_contact_matrix()
self.config.INIT_INFECTION_DIST = (
self.load_initial_infection_dist_via_contact_matrix()
)
self.config.INIT_INFECTIOUS_DIST = (
self.get_initial_infectious_distribution()
)
self.config.INIT_EXPOSED_DIST = (
self.get_initial_exposed_distribution()
)

# load initial state using
# INIT_IMMUNE_HISTORY, INIT_INFECTED_DIST, and INIT_EXPOSED_DIST
# INIT_IMMUNE_HISTORY, INIT_INFECTIOUS_DIST, and INIT_EXPOSED_DIST
self.INITIAL_STATE = self.load_initial_state(
self.config.INITIAL_INFECTIONS
)
Expand All @@ -63,7 +78,7 @@ def load_initial_state(
"""
a function which takes a number of initial infections,
disperses them across infectious and exposed compartments
according to `self.config.INIT_INFECTED_DIST`
according to `self.config.INIT_INFECTIOUS_DIST`
and `self.config.INIT_EXPOSED_DIST` matricies,
then subtracts both those populations from the total population and
places the remaining individuals in the susceptible compartment,
Expand All @@ -78,10 +93,10 @@ def load_initial_state(
----------
the following variables be loaded into self:
CONTACT_MATRIX: loading in config or via self.load_contact_matrix()
INIT_INFECTED_DIST: loaded in config or via
`load_init_infection_infected_and_exposed_dist_via_contact_matrix()`
INIT_INFECTIOUS_DIST: loaded in config or via
`get_initial_infectious_distribution()`
INIT_EXPOSED_DIST: loaded in config or via
`load_init_infection_infected_and_exposed_dist_via_contact_matrix()`
`get_initial_exposed_distribution()`
INIT_IMMUNE_HISTORY: loaded in config or via
`load_immune_history_via_serological_data()`.
Expand All @@ -91,9 +106,9 @@ def load_initial_state(
a tuple of len 4 representing the S, E, I, and C compartment
population counts after model initialization.
"""
# create population distribution with INIT_INFECTED_DIST then sum by age
# create population distribution with INIT_INFECTIOUS_DIST then sum by age
initial_infectious_count = (
initial_infections * self.config.INIT_INFECTED_DIST
initial_infections * self.config.INIT_INFECTIOUS_DIST
)
initial_infectious_count_ages = jnp.sum(
initial_infectious_count,
Expand Down Expand Up @@ -132,16 +147,15 @@ def load_initial_state(
jnp.zeros(initial_exposed_count.shape), # c
)

def load_immune_history_via_serological_data(self) -> None:
def load_immune_history_via_serological_data(self) -> np.ndarray:
"""
loads the sero init file for `self.config.REGIONS[0]` and
converts it to a numpy matrix representing the initial immune history
of individuals within each age bin in the system. Saving matrix
to self.config.INIT_IMMUNE_HISTORY
returns a numpy matrix representing the initial immune history
of individuals within each age bin in the system.
Updates
Returns
---------
INIT_IMMUNE_HISTORY: np.ndarray
np.ndarray
a matrix of shape (self.config.NUM_AGE_GROUPS,
2**self.config.NUM_STRAINS,
self.config.MAX_VACCINATION_COUNT + 1,
Expand Down Expand Up @@ -227,9 +241,9 @@ def load_immune_history_via_serological_data(self) -> None:
"each age group does not sum to 1 in the sero initialization file of %s"
% str(self.config.REGIONS[0])
)
self.config.INIT_IMMUNE_HISTORY = sero_matrix
return sero_matrix

def load_init_infection_infected_and_exposed_dist_via_contact_matrix(
def load_initial_infection_dist_via_contact_matrix(
self,
) -> None:
"""
Expand All @@ -251,11 +265,12 @@ def load_init_infection_infected_and_exposed_dist_via_contact_matrix(
individuals in age bin `i`, who fall under
immune history `j`, vaccination count `k`, and waning bin `l`
"""
# use relative wait times in each compartment to get distribution of infections across
# infected vs exposed compartments
exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / (
self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD
)
if not hasattr(self.config, "CONTACT_MATRIX_PATH"):
raise RuntimeError(
"Attempting to build initial infection distribution "
"without a path to a contact matrix in "
"self.config.CONTACT_MATRIX_PATH"
)
# use contact matrix to get the infection age distributions
eig_data = np.linalg.eig(self.config.CONTACT_MATRIX)
max_index = np.argmax(eig_data[0])
Expand Down Expand Up @@ -319,35 +334,59 @@ def load_init_infection_infected_and_exposed_dist_via_contact_matrix(
),
] = 0
# disperse infections across E and I compartments by exposed_to_infectous_ratio
self.config.INIT_INFECTION_DIST = infection_dist
self.config.INIT_EXPOSED_DIST = (
exposed_to_infectous_ratio * self.config.INIT_INFECTION_DIST
return infection_dist

def get_initial_infectious_distribution(self):
if not hasattr(self.config, "INIT_INFECTION_DIST"):
raise RuntimeError(
"this function requires `self.config.INIT_INFECTION_DIST`"
"set if via load_initial_infection_dist_via_contact_matrix()"
"before calling this function"
)
# use relative wait times in each compartment to get distribution
# of infections across infected vs exposed compartments
exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / (
self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD
)
self.config.INIT_INFECTED_DIST = (
return (
1 - exposed_to_infectous_ratio
) * self.config.INIT_INFECTION_DIST

def load_contact_matrix(self) -> None:
def get_initial_exposed_distribution(self):
if not hasattr(self.config, "INIT_INFECTION_DIST"):
raise RuntimeError(
"this function requires `self.config.INIT_INFECTION_DIST`"
"set if via load_initial_infection_dist_via_contact_matrix()"
"before calling this function"
)
# use relative wait times in each compartment to get distribution
# of infections across infected vs exposed compartments
exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / (
self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD
)
return exposed_to_infectous_ratio * self.config.INIT_INFECTION_DIST

def load_contact_matrix(self) -> np.ndarray:
"""
loads region specific contact matrix, usually sourced from
returns region specific contact matrix, usually sourced from
https://github.com/mobs-lab/mixing-patterns
Updates
Returns
----------
`self.config.CONTACT_MATRIX` : numpy.ndarray
numpy.ndarray
a matrix of shape (self.config.NUM_AGE_GROUPS, self.config.NUM_AGE_GROUPS)
where `CONTACT_MATRIX[i][j]` refers to the per capita
interaction rate between age bin `i` and `j`
"""
self.config.CONTACT_MATRIX = utils.load_demographic_data(
return utils.load_demographic_data(
self.config.DEMOGRAPHIC_DATA_PATH,
self.config.REGIONS,
self.config.NUM_AGE_GROUPS,
self.config.AGE_LIMITS[0],
self.config.AGE_LIMITS,
)[self.config.REGIONS[0]]["avg_CM"]

def load_cross_immunity_matrix(self) -> None:
def load_cross_immunity_matrix(self) -> Array:
"""
Loads the Crossimmunity matrix given the strain interactions matrix.
Strain interactions matrix is a matrix of shape
Expand All @@ -357,16 +396,13 @@ def load_cross_immunity_matrix(self) -> None:
previously from a strain in dim 1. Neither the strain interactions matrix
nor the crossimmunity matrix take into account waning.
Updates
Returns
----------
self.config.CROSSIMMUNITY_MATRIX:
updates this matrix to shape
(self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST)
jax.Array
matrix of shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST)
containing the relative immune escape values for each challenging
strain compared to each prior immune history in the model.
"""
self.config.CROSSIMMUNITY_MATRIX = (
utils.strain_interaction_to_cross_immunity(
self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS
)
return utils.strain_interaction_to_cross_immunity(
self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS
)
Loading

0 comments on commit b2a571f

Please sign in to comment.