diff --git a/src/dynode/abstract_initializer.py b/src/dynode/abstract_initializer.py index 540d610..dde761e 100644 --- a/src/dynode/abstract_initializer.py +++ b/src/dynode/abstract_initializer.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod from typing import Any +from numpy import ndarray + from . import SEIC_Compartments, utils @@ -35,22 +37,23 @@ def get_initial_state( assert self.INITIAL_STATE is not None return self.INITIAL_STATE - def load_initial_population_fractions(self) -> None: + def load_initial_population_fractions(self) -> ndarray: """ - loads age demographics for the US and - sets the inital population fraction by age bin. + Loads age demographics for the specified region and + returns the inital population fraction by age bin. - Updates + Returns ---------- - `self.config.INITIAL_POPULATION_FRACTIONS` : numpy.ndarray - proportion of the total population that falls into each age group, - length of this array is equal the number of age groups sums to 1.0. + numpy.ndarray + Proportion of the total population that falls into each age group, + `len(self.load_initial_population_fractions()) == self.config.NUM_AGE_GROUPS` + `np.sum(self.load_initial_population_fractions()) == 1.0 """ populations_path = ( self.config.DEMOGRAPHIC_DATA_PATH + "population_rescaled_age_distributions/" ) # TODO support getting more regions than just 1 - self.config.INITIAL_POPULATION_FRACTIONS = utils.load_age_demographics( + return utils.load_age_demographics( populations_path, self.config.REGIONS, self.config.AGE_LIMITS )[self.config.REGIONS[0]] diff --git a/src/dynode/abstract_parameters.py b/src/dynode/abstract_parameters.py index 3216240..3a42dd3 100644 --- a/src/dynode/abstract_parameters.py +++ b/src/dynode/abstract_parameters.py @@ -482,14 +482,16 @@ def seasonality( ) ) - def retrieve_population_counts(self) -> None: - """ - A wrapper function which takes calculates the age stratified population counts across all the INITIAL_STATE compartments - (minus the book-keeping C compartment.) and stores it in the self.config.POPULATION parameter. + def retrieve_population_counts(self) -> np.ndarray: + """Calculates the age stratified population counts across all tracked + self.INITIAL_STATE compartments, excluding the book-keeping C compartment. - We do not recieve this data exactly from the initializer, but it is trivial to recalculate. + Returns + ------- + np.ndarray + population counts of each age bin within `self.INITIAL_STATE` """ - self.config.POPULATION = np.sum( # sum together S+E+I compartments + return np.sum( # sum together S+E+I compartments np.array( [ np.sum( @@ -508,7 +510,7 @@ def retrieve_population_counts(self) -> None: axis=(0), # sum across compartments, keep age bins ) - def load_cross_immunity_matrix(self) -> None: + def load_cross_immunity_matrix(self) -> jax.Array: """ Loads the Crossimmunity matrix given the strain interactions matrix. Strain interactions matrix is a matrix of shape @@ -518,21 +520,18 @@ def load_cross_immunity_matrix(self) -> None: previously from a strain in dim 1. Neither the strain interactions matrix nor the crossimmunity matrix take into account waning. - Updates + Returns ---------- - self.config.CROSSIMMUNITY_MATRIX: - updates this matrix to shape - (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) + jax.Array + matrix of shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) containing the relative immune escape values for each challenging strain compared to each prior immune history in the model. """ - self.config.CROSSIMMUNITY_MATRIX = ( - utils.strain_interaction_to_cross_immunity( - self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS - ) + return utils.strain_interaction_to_cross_immunity( + self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS ) - def load_vaccination_model(self) -> None: + def load_vaccination_model(self) -> tuple[jax.Array, jax.Array, jax.Array]: """ loads parameters of a polynomial spline vaccination model stratified on age bin and current vaccination status. Reads spline @@ -545,7 +544,7 @@ def load_vaccination_model(self) -> None: Raises `FileNotFoundError` if directory path does not contain region specific file matching expected naming convention. - UPDATES + Returns ----------- the following are 3 parallel lists, each with leading dimensions `(NUM_AGE_GROUPS, MAX_VAX_COUNT+1)` identifying the vaccination spline @@ -562,7 +561,6 @@ def load_vaccination_model(self) -> None: array defining the coefficients (a,b,c,d) of each base equation `(a + b(t) + c(t)^2 + d(t)^3)` for the spline defined by `VACCINATION_MODEL_KNOT_LOCATIONS[i][j]`. - """ # if the user passes a directory instead of a file path # check to see if the state exists in the directory and use that @@ -642,12 +640,10 @@ def load_vaccination_model(self) -> None: vax_knot_locations[age_group_idx, vax_idx, :] = np.array( knot_locations ) - self.config.VACCINATION_MODEL_KNOTS = jnp.array(vax_knots) - self.config.VACCINATION_MODEL_KNOT_LOCATIONS = jnp.array( - vax_knot_locations - ) - self.config.VACCINATION_MODEL_BASE_EQUATIONS = jnp.array( - vax_base_equations + return ( + jnp.array(vax_knots), + jnp.array(vax_knot_locations), + jnp.array(vax_base_equations), ) def seasonal_vaccination_reset(self, t: ArrayLike) -> ArrayLike: @@ -699,7 +695,7 @@ def seasonal_vaccination_reset(self, t: ArrayLike) -> ArrayLike: # if no seasonal vaccination, this function always returns zero return 0 - def load_contact_matrix(self) -> None: + def load_contact_matrix(self) -> np.ndarray: """ loads region specific contact matrix, usually sourced from https://github.com/mobs-lab/mixing-patterns @@ -711,7 +707,7 @@ def load_contact_matrix(self) -> None: where `CONTACT_MATRIX[i][j]` refers to the per capita interaction rate between age bin `i` and `j` """ - self.config.CONTACT_MATRIX = utils.load_demographic_data( + return utils.load_demographic_data( self.config.DEMOGRAPHIC_DATA_PATH, self.config.REGIONS, self.config.NUM_AGE_GROUPS, diff --git a/src/dynode/covid_sero_initializer.py b/src/dynode/covid_sero_initializer.py index e17f194..e707523 100644 --- a/src/dynode/covid_sero_initializer.py +++ b/src/dynode/covid_sero_initializer.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import numpy as np import pandas as pd +from jax import Array from . import SEIC_Compartments, utils from .abstract_initializer import AbstractInitializer @@ -29,7 +30,9 @@ def __init__(self, config_initializer_path, global_variables_path): self.config = Config(global_json).add_file(initializer_json) if not hasattr(self.config, "INITIAL_POPULATION_FRACTIONS"): - self.load_initial_population_fractions() + self.config.INITIAL_POPULATION_FRACTIONS = ( + self.load_initial_population_fractions() + ) self.config.POPULATION = ( self.config.POP_SIZE * self.config.INITIAL_POPULATION_FRACTIONS @@ -37,22 +40,34 @@ def __init__(self, config_initializer_path, global_variables_path): # self.POPULATION.shape = (NUM_AGE_GROUPS,) if not hasattr(self.config, "INIT_IMMUNE_HISTORY"): - self.load_immune_history_via_serological_data() + self.config.INIT_IMMUNE_HISTORY = ( + self.load_immune_history_via_serological_data() + ) # self.INIT_IMMUNE_HISTORY.shape = (age, hist, num_vax, waning) if not hasattr(self.config, "CONTACT_MATRIX"): - self.load_contact_matrix() + self.config.CONTACT_MATRIX = self.load_contact_matrix() if not hasattr(self.config, "CROSSIMMUNITY_MATRIX"): - self.load_cross_immunity_matrix() + self.config.CROSSIMMUNITY_MATRIX = ( + self.load_cross_immunity_matrix() + ) # stratify initial infections appropriately across age, hist, vax counts if not ( - hasattr(self.config, "INIT_INFECTED_DIST") + hasattr(self.config, "INIT_INFECTIOUS_DIST") and hasattr(self.config, "INIT_EXPOSED_DIST") ) and hasattr(self.config, "CONTACT_MATRIX_PATH"): - self.load_init_infection_infected_and_exposed_dist_via_contact_matrix() + self.config.INIT_INFECTION_DIST = ( + self.load_initial_infection_dist_via_contact_matrix() + ) + self.config.INIT_INFECTIOUS_DIST = ( + self.get_initial_infectious_distribution() + ) + self.config.INIT_EXPOSED_DIST = ( + self.get_initial_exposed_distribution() + ) # load initial state using - # INIT_IMMUNE_HISTORY, INIT_INFECTED_DIST, and INIT_EXPOSED_DIST + # INIT_IMMUNE_HISTORY, INIT_INFECTIOUS_DIST, and INIT_EXPOSED_DIST self.INITIAL_STATE = self.load_initial_state( self.config.INITIAL_INFECTIONS ) @@ -63,7 +78,7 @@ def load_initial_state( """ a function which takes a number of initial infections, disperses them across infectious and exposed compartments - according to `self.config.INIT_INFECTED_DIST` + according to `self.config.INIT_INFECTIOUS_DIST` and `self.config.INIT_EXPOSED_DIST` matricies, then subtracts both those populations from the total population and places the remaining individuals in the susceptible compartment, @@ -78,10 +93,10 @@ def load_initial_state( ---------- the following variables be loaded into self: CONTACT_MATRIX: loading in config or via self.load_contact_matrix() - INIT_INFECTED_DIST: loaded in config or via - `load_init_infection_infected_and_exposed_dist_via_contact_matrix()` + INIT_INFECTIOUS_DIST: loaded in config or via + `get_initial_infectious_distribution()` INIT_EXPOSED_DIST: loaded in config or via - `load_init_infection_infected_and_exposed_dist_via_contact_matrix()` + `get_initial_exposed_distribution()` INIT_IMMUNE_HISTORY: loaded in config or via `load_immune_history_via_serological_data()`. @@ -91,9 +106,9 @@ def load_initial_state( a tuple of len 4 representing the S, E, I, and C compartment population counts after model initialization. """ - # create population distribution with INIT_INFECTED_DIST then sum by age + # create population distribution with INIT_INFECTIOUS_DIST then sum by age initial_infectious_count = ( - initial_infections * self.config.INIT_INFECTED_DIST + initial_infections * self.config.INIT_INFECTIOUS_DIST ) initial_infectious_count_ages = jnp.sum( initial_infectious_count, @@ -132,16 +147,15 @@ def load_initial_state( jnp.zeros(initial_exposed_count.shape), # c ) - def load_immune_history_via_serological_data(self) -> None: + def load_immune_history_via_serological_data(self) -> np.ndarray: """ loads the sero init file for `self.config.REGIONS[0]` and - converts it to a numpy matrix representing the initial immune history - of individuals within each age bin in the system. Saving matrix - to self.config.INIT_IMMUNE_HISTORY + returns a numpy matrix representing the initial immune history + of individuals within each age bin in the system. - Updates + Returns --------- - INIT_IMMUNE_HISTORY: np.ndarray + np.ndarray a matrix of shape (self.config.NUM_AGE_GROUPS, 2**self.config.NUM_STRAINS, self.config.MAX_VACCINATION_COUNT + 1, @@ -227,9 +241,9 @@ def load_immune_history_via_serological_data(self) -> None: "each age group does not sum to 1 in the sero initialization file of %s" % str(self.config.REGIONS[0]) ) - self.config.INIT_IMMUNE_HISTORY = sero_matrix + return sero_matrix - def load_init_infection_infected_and_exposed_dist_via_contact_matrix( + def load_initial_infection_dist_via_contact_matrix( self, ) -> None: """ @@ -251,11 +265,12 @@ def load_init_infection_infected_and_exposed_dist_via_contact_matrix( individuals in age bin `i`, who fall under immune history `j`, vaccination count `k`, and waning bin `l` """ - # use relative wait times in each compartment to get distribution of infections across - # infected vs exposed compartments - exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / ( - self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD - ) + if not hasattr(self.config, "CONTACT_MATRIX_PATH"): + raise RuntimeError( + "Attempting to build initial infection distribution " + "without a path to a contact matrix in " + "self.config.CONTACT_MATRIX_PATH" + ) # use contact matrix to get the infection age distributions eig_data = np.linalg.eig(self.config.CONTACT_MATRIX) max_index = np.argmax(eig_data[0]) @@ -319,27 +334,51 @@ def load_init_infection_infected_and_exposed_dist_via_contact_matrix( ), ] = 0 # disperse infections across E and I compartments by exposed_to_infectous_ratio - self.config.INIT_INFECTION_DIST = infection_dist - self.config.INIT_EXPOSED_DIST = ( - exposed_to_infectous_ratio * self.config.INIT_INFECTION_DIST + return infection_dist + + def get_initial_infectious_distribution(self): + if not hasattr(self.config, "INIT_INFECTION_DIST"): + raise RuntimeError( + "this function requires `self.config.INIT_INFECTION_DIST`" + "set if via load_initial_infection_dist_via_contact_matrix()" + "before calling this function" + ) + # use relative wait times in each compartment to get distribution + # of infections across infected vs exposed compartments + exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / ( + self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD ) - self.config.INIT_INFECTED_DIST = ( + return ( 1 - exposed_to_infectous_ratio ) * self.config.INIT_INFECTION_DIST - def load_contact_matrix(self) -> None: + def get_initial_exposed_distribution(self): + if not hasattr(self.config, "INIT_INFECTION_DIST"): + raise RuntimeError( + "this function requires `self.config.INIT_INFECTION_DIST`" + "set if via load_initial_infection_dist_via_contact_matrix()" + "before calling this function" + ) + # use relative wait times in each compartment to get distribution + # of infections across infected vs exposed compartments + exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / ( + self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD + ) + return exposed_to_infectous_ratio * self.config.INIT_INFECTION_DIST + + def load_contact_matrix(self) -> np.ndarray: """ - loads region specific contact matrix, usually sourced from + returns region specific contact matrix, usually sourced from https://github.com/mobs-lab/mixing-patterns - Updates + Returns ---------- - `self.config.CONTACT_MATRIX` : numpy.ndarray + numpy.ndarray a matrix of shape (self.config.NUM_AGE_GROUPS, self.config.NUM_AGE_GROUPS) where `CONTACT_MATRIX[i][j]` refers to the per capita interaction rate between age bin `i` and `j` """ - self.config.CONTACT_MATRIX = utils.load_demographic_data( + return utils.load_demographic_data( self.config.DEMOGRAPHIC_DATA_PATH, self.config.REGIONS, self.config.NUM_AGE_GROUPS, @@ -347,7 +386,7 @@ def load_contact_matrix(self) -> None: self.config.AGE_LIMITS, )[self.config.REGIONS[0]]["avg_CM"] - def load_cross_immunity_matrix(self) -> None: + def load_cross_immunity_matrix(self) -> Array: """ Loads the Crossimmunity matrix given the strain interactions matrix. Strain interactions matrix is a matrix of shape @@ -357,16 +396,13 @@ def load_cross_immunity_matrix(self) -> None: previously from a strain in dim 1. Neither the strain interactions matrix nor the crossimmunity matrix take into account waning. - Updates + Returns ---------- - self.config.CROSSIMMUNITY_MATRIX: - updates this matrix to shape - (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) + jax.Array + matrix of shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) containing the relative immune escape values for each challenging strain compared to each prior immune history in the model. """ - self.config.CROSSIMMUNITY_MATRIX = ( - utils.strain_interaction_to_cross_immunity( - self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS - ) + return utils.strain_interaction_to_cross_immunity( + self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS ) diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index fca87c7..bd906b7 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -46,19 +46,32 @@ def __init__( self.runner = runner self.INITIAL_STATE = initial_state self.infer_complete = False - self.set_infer_algo() - self.retrieve_population_counts() - self.load_vaccination_model() - self.load_contact_matrix() - - def set_infer_algo(self, inferer_type: str = "mcmc") -> None: - """Sets the inferer's inference algorithm and sampler. + # set inference algo to mcmc + self.inference_algo = self.set_infer_algo() + # retrieve population age distribution via passed initial state + self.config.POPULATION = self.retrieve_population_counts() + # load all vaccination splines + ( + self.config.VACCINATION_MODEL_KNOTS, + self.config.VACCINATION_MODEL_KNOT_LOCATIONS, + self.config.VACCINATION_MODEL_BASE_EQUATIONS, + ) = self.load_vaccination_model() + self.config.CONTACT_MATRIX = self.load_contact_matrix() + + def set_infer_algo(self, inferer_type: str = "mcmc") -> MCMC: + """returns inference algorithm with attached sampler. Parameters ---------- inferer_type : str, optional infer algo you wish to use, by default "mcmc" + Returns + ---------- + MCMC + returns MCMC inference algorithm as it is the only supported + algorithm currently + Raises ------ NotImplementedError @@ -75,7 +88,7 @@ def set_infer_algo(self, inferer_type: str = "mcmc") -> None: if inferer_type == "mcmc": # default to max tree depth of 5 if not specified tree_depth = getattr(self.config, "MAX_TREE_DEPTH", 5) - self.inference_algo = MCMC( + return MCMC( NUTS( self.likelihood, dense_mass=True, diff --git a/src/dynode/static_value_parameters.py b/src/dynode/static_value_parameters.py index 6a31601..d09d6f0 100644 --- a/src/dynode/static_value_parameters.py +++ b/src/dynode/static_value_parameters.py @@ -23,10 +23,13 @@ def __init__( self.config = Config(global_json).add_file(runner_json) self.INITIAL_STATE = INITIAL_STATE # load self.config.POPULATION - self.retrieve_population_counts() + self.config.POPULATION = self.retrieve_population_counts() # load self.config.VACCINATION_MODEL_KNOTS/ # VACCINATION_MODEL_KNOT_LOCATIONS/VACCINATION_MODEL_BASE_EQUATIONS - self.load_vaccination_model() - # load self.config.CONTACT_MATRIX - self.load_contact_matrix() - # rest of the work is handled by the AbstractParameters + # load all vaccination splines + ( + self.config.VACCINATION_MODEL_KNOTS, + self.config.VACCINATION_MODEL_KNOT_LOCATIONS, + self.config.VACCINATION_MODEL_BASE_EQUATIONS, + ) = self.load_vaccination_model() + self.config.CONTACT_MATRIX = self.load_contact_matrix()