diff --git a/poetry.lock b/poetry.lock index 3360078b..6e4f1c45 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -2058,6 +2058,23 @@ numpy = ">=1.16.6" [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pydocstyle" +version = "6.3.0" +description = "Python docstring style checker" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pydocstyle-6.3.0-py3-none-any.whl", hash = "sha256:118762d452a49d6b05e194ef344a55822987a462831ade91ec5c06fd2169d019"}, + {file = "pydocstyle-6.3.0.tar.gz", hash = "sha256:7ce43f0c0ac87b07494eb9c0b462c0b73e6ff276807f204d6b53edc72b7e44e1"}, +] + +[package.dependencies] +snowballstemmer = ">=2.2.0" + +[package.extras] +toml = ["tomli (>=1.2.3)"] + [[package]] name = "pygments" version = "2.18.0" @@ -2422,6 +2439,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "snowballstemmer" +version = "2.2.0" +description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." +optional = false +python-versions = "*" +files = [ + {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, + {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, +] + [[package]] name = "tensorflow-probability" version = "0.24.0" @@ -2677,4 +2705,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e55445f43a1bcbc76c57ace67cb9d3281e15cfb2d393ec546f3182f2bbe12d61" +content-hash = "78c2ddfc29549185e42c48b9d0ff9a68bb5124b04fb16c32d7fb6512c84d75a3" diff --git a/pyproject.toml b/pyproject.toml index 761dfeb7..f409ed54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ mypy = "^1.10.0" requests = "^2.32.3" docker = "^7.1.0" bayeux-ml = "^0.1.14" +pydocstyle = "^6.3.0" [tool.poetry.group.dev.dependencies] diff --git a/src/dynode/__init__.py b/src/dynode/__init__.py index 90666ea9..8c34632f 100644 --- a/src/dynode/__init__.py +++ b/src/dynode/__init__.py @@ -1,4 +1,13 @@ -# needs to exist to define a module +"""DynODE, a dynamic ordinary differential model framework. + +DynODE is a a compartmental mechanistic ODE model that accounts for +age structure, immunity history, vaccination, immunity waning and +multiple variants. + +DynODE is currently under active development and will be substantially +refactored in the near future! +""" + # ruff: noqa: E402 import jax diff --git a/src/dynode/abstract_initializer.py b/src/dynode/abstract_initializer.py index 47143e46..5f143881 100644 --- a/src/dynode/abstract_initializer.py +++ b/src/dynode/abstract_initializer.py @@ -1,5 +1,5 @@ -""" -A module that creates an abstract class for an initializer object. +"""A module that creates an abstract class for an initializer object. + An initializer objects primary purpose is initialize the state on which ODEs will be run. AbstractInitializers will often be tasked with reading, parsing, and combining data sources to produce an initial state representing some analyzed population @@ -8,18 +8,27 @@ from abc import ABC, abstractmethod from typing import Any +from numpy import ndarray + from . import SEIC_Compartments, utils class AbstractInitializer(ABC): - """ - An Abstract class meant for use by disease-specific initializers. - an initializers sole responsibility is to return an INITIAL_STATE + """An abstract class meant for use by disease-specific initializers. + + An initializer's sole responsibility is to return an INITIAL_STATE parameter via self.get_initial_state(). """ @abstractmethod def __init__(self, initializer_config) -> None: + """Load parameters from `initializer_config` and generate self.INITIAL_STATE. + + Parameters + ---------- + initializer_config : str + str path to config json holding necessary initializer parameters. + """ # add these for mypy self.INITIAL_STATE: SEIC_Compartments | None = None self.config: Any = {} @@ -28,27 +37,32 @@ def __init__(self, initializer_config) -> None: def get_initial_state( self, ) -> SEIC_Compartments: - """ - Returns the initial state of the model as defined by the child class in __init__ + """Get the initial state of the model as defined by the child class in __init__. + + Returns + ------- + SEIC_Compartments + tuple of matricies representing initial state of each compartment + in the model. """ assert self.INITIAL_STATE is not None return self.INITIAL_STATE - def load_initial_population_fractions(self) -> None: - """ - a wrapper function which loads age demographics for the US and sets the inital population fraction by age bin. + def load_initial_population_fractions(self) -> ndarray: + """Load age demographics for the specified region. - Updates - ---------- - `self.config.INITIAL_POPULATION_FRACTIONS` : numpy.ndarray - proportion of the total population that falls into each age group, - length of this array is equal the number of age groups and will sum to 1.0. + Returns + ------- + numpy.ndarray + Proportion of the total population that falls into each age group. + `len(self.load_initial_population_fractions()) == self.config.NUM_AGE_GROUPS` + `np.sum(self.load_initial_population_fractions()) == 1.0 """ populations_path = ( self.config.DEMOGRAPHIC_DATA_PATH + "population_rescaled_age_distributions/" ) # TODO support getting more regions than just 1 - self.config.INITIAL_POPULATION_FRACTIONS = utils.load_age_demographics( + return utils.load_age_demographics( populations_path, self.config.REGIONS, self.config.AGE_LIMITS )[self.config.REGIONS[0]] diff --git a/src/dynode/abstract_parameters.py b/src/dynode/abstract_parameters.py index d89b5d52..2badf434 100644 --- a/src/dynode/abstract_parameters.py +++ b/src/dynode/abstract_parameters.py @@ -1,9 +1,8 @@ -""" -A module containing an abstract class used to set up parameters for -running in Ordinary Differential Equations (ODEs). +"""A module to set up parameters for running in Ordinary Differential Equations (ODEs). -Responsible for loading and assembling functions to describe vaccination uptake, seasonality, -external transmission of new or existing viruses and other generic respiratory virus aspects. +Responsible for loading and assembling functions to describe vaccination uptake, +seasonality, external transmission of new or existing viruses and other generic +respiratory virus aspects. """ import copy @@ -26,8 +25,9 @@ class AbstractParameters: """A class to define a disease-agnostic parameters object for running disease models. - Manages parameter passing and creation, as well as definition of - seasonality, vaccination, external introductions, and external beta shifting functions + + Manages parameter passing, sampling and creation, as well as definition of + seasonality, vaccination, external introductions, and external beta shifting functions. """ UPSTREAM_PARAMETERS = [ @@ -62,6 +62,7 @@ class AbstractParameters: @abstractmethod def __init__(self) -> None: + """Initialize a parameters object for passing data to ODEs.""" # add these for mypy type checker self.config = Config("{}") initial_state = tuple( @@ -73,22 +74,22 @@ def __init__(self) -> None: def _solve_runner( self, parameters: dict, tf: int, runner: MechanisticRunner ) -> Solution: - """runs the runner for `tf` days using parameters defined in `parameters` - returning a Diffrax Solution object + """ + Run the ODE solver for a specified number of days using given parameters. Parameters ---------- parameters : dict - parameters object containing parameters required by the runner ODEs + Dictionary containing parameters required by the runner's ODEs. tf : int - number of days to run the runner for + Number of days to run the ODE solver. runner : MechanisticRunner - runner class designated with solving ODEs + Instance of MechanisticRunner designated for solving ODEs. Returns ------- Solution - diffrax solution object returned from runner.run() + Diffrax solution object returned from `runner.run()`. """ if "INITIAL_INFECTIONS_SCALE" in parameters.keys(): initial_state = self.scale_initial_infections( @@ -105,18 +106,19 @@ def _solve_runner( def _get_upstream_parameters(self) -> dict: """ - returns a dictionary containing self.UPSTREAM_PARAMETERS, erroring if any of the parameters - within are not found within self.config. - - Samples any parameters which are of type(numpyro.distribution). + Retrieve upstream parameters from the configuration, sampling any distributions. Returns - ------------ - dict[str: Any] - - returns a dictionary where keys map to parameters within self.UPSTREAM_PARAMETERS and the values - are the value of that parameter within self.config, distributions are sampled and replaced - with a jax ArrayLike representing that value in the JIT compilation scheme used by jax/numpyro. + ------- + dict[str, Any] + Dictionary mapping keys to parameters within `self.UPSTREAM_PARAMETERS`. + Values are taken from `self.config`, with numpyro.distribution objects + sampled and replaced with JAX ArrayLike samples. + + Raises + ------ + RuntimeError + If any parameter in `self.UPSTREAM_PARAMETERS` is not found in `self.config`. """ # multiple chains of MCMC calling get_parameters() # should not share references, deep copy, GH issue for this created @@ -137,27 +139,24 @@ def _get_upstream_parameters(self) -> dict: return parameters def generate_downstream_parameters(self, parameters: dict) -> dict: - """takes an existing parameters object and attempts to generate a number of - downstream dependent parameters, based on the values contained within `parameters`. - - Raises RuntimeError if a downstream parameter - does not find the necessary values it needs within `parameters` - - Example - --------- - if the parameter `Y = 1/X` then X must be defined within `parameters` and - we call `parameters["Y"] = 1 / parameters["X"]` + """ + Generate downstream dependent parameters based on upstream values. Parameters ---------- parameters : dict - parameters dictionary generated by `self._get_upstream_parameters()` - containing static or sampled values on which downstream parameters may depend + Dictionary generated by `self._get_upstream_parameters()` containing + static or sampled values that downstream parameters may depend on. Returns ------- dict - an appended onto version of `parameters` with additional downstream parameters added. + Updated version of `parameters` with additional downstream parameters added. + + Raises + ------ + RuntimeError + If a downstream parameter cannot find the necessary upstream values within `parameters`. """ try: # create or re-recreate parameters based on other possibly sampled parameters @@ -207,17 +206,14 @@ def generate_downstream_parameters(self, parameters: dict) -> dict: return parameters def get_parameters(self) -> dict: - """ - Goes through parameters listed in self.UPSTREAM_PARAMETERS, sampling them - if they are distributions, collecting them untouched otherwise. - Then attempts to generate any downstream parameters that rely on those parameters - in self.generate_downstream_parameters(). Returning the resulting dictionary - for use in the ODEs (ordinary differential equations) + """Sample upstream distributions and generating downstream parameters. Returns - ----------- - dict{str:obj} where obj may either be a float value, - or a jax tracer, in the case of a sampled value or list containing sampled values. + ------- + dict[str, Any] + Dictionary containing a combination of `self.UPSTREAM_PARAMETERS` found + in `self.config` and downstream parameters generated from + `self.generate_downstream_parameters()`. """ parameters = self._get_upstream_parameters() parameters = self.generate_downstream_parameters(parameters) @@ -232,53 +228,41 @@ def external_i( introduction_scales: jax.Array, introduction_pcts: jax.Array, ) -> jax.Array: - """ - Given some time t, returns jnp.array of shape self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I] representing external infected persons - interacting with the population. it does so by calling some function f_s(t) for each strain s. - - MUST BE CONTINUOUS AND DIFFERENTIABLE FOR ALL TIMES t. - - The stratafication of the external population is decided by the introduced strains, which are defined by - 3 parallel lists of the time they peak (`introduction_times`), - the number of external infected individuals introduced as a % of the tracked population (`introduction_pcts`) - and how quickly or slowly those individuals contact the tracked population (`introduction_scales`) + """Calculate the number of external infected individuals interacting with the population at time t. Parameters ---------- - `t`: float as Traced - current time in the model, due to the just-in-time nature of Jax this float value may be contained within a - traced array of shape () and size 1. Thus no explicit comparison should be done on "t". - - `introduction_times`: list[int] as Traced - a list representing the times at which external strains should be introduced, in days, after t=0 of the model - This list is ordered inversely to self.config.STRAIN_R0s. If 2 external strains are defined, the two - values in `introduction_times` will refer to the last 2 STRAIN_R0s, not the first two. - - `introduction_scales`: list[float] as Traced - a list representing the standard deviation of the curve that external strains are introduced with, in days - This list is ordered inversely to self.config.STRAIN_R0s. If 2 external strains are defined, the two - values in `introduction_times` will refer to the last 2 STRAIN_R0s, not the first two. - - `introduction_pcts`: list[float] as Traced - a list representing the proportion of each age bin in self.POPULATION[self.config.INTRODUCTION_AGE_MASK] - that will be exposed to the introduced strain over the entire course of the introduction. - This list is ordered inversely to self.config.STRAIN_R0s. If 2 external strains are defined, the two - values in `introduction_times` will refer to the last 2 STRAIN_R0s, not the first two. + t : ArrayLike + Current time in the model. - Returns - ----------- - external_i_compartment: jax.Array - jnp.array(shape=(self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I].shape)) of external individuals to the system - interacting with susceptibles within the system, used to impact force of infection. - """ + introduction_times : jax.Array + List representing times at which external strains should peak + in their rate of introduction. - # define a function that returns 0 for non-introduced strains - def zero_function(_): - return 0 + introduction_scales : jax.Array + List representing the standard deviation of the curve for introducing + external strains, in days. + + introduction_pcts : jax.Array + List representing the proportion of each age bin in + `self.POPULATION[self.config.INTRODUCTION_AGE_MASK]` + that will be exposed to the introduced strain over the entire curve. + Returns + ------- + jax.Array + An array of shape matching `self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I]` + representing external individuals interacting with tracked susceptibles. + + Notes + ----- + Use `self.config.INTRODUCTION_AGE_MASK` to select which age bins are + affected by external populations. External populations are not tracked + but still interact with the contact matrix, influencing spread dynamics. + """ external_i_distributions = [ - zero_function for _ in range(self.config.NUM_STRAINS) - ] + lambda _: 0 for _ in range(self.config.NUM_STRAINS) + ] # start with zeros functions introduction_percentage_by_strain = [0] * self.config.NUM_STRAINS for introduced_strain_idx, ( introduced_time, @@ -287,7 +271,7 @@ def zero_function(_): ) in enumerate( zip(introduction_times, introduction_scales, introduction_pcts) ): - # earlier introduced strains earlier will be placed closer to historical strains (0 and 1) + # INTRODUCTED_STRAINS are parallel to the END of the STRAIN_R0s dist_idx = ( self.config.NUM_STRAINS - self.config.NUM_INTRODUCED_STRAINS @@ -296,7 +280,7 @@ def zero_function(_): # use a normal PDF with std dv external_i_distributions[dist_idx] = partial( pdf, loc=introduced_time, scale=introduction_scale - ) + ) # type: ignore introduction_percentage_by_strain[dist_idx] = introduction_perc # with our external_i_distributions set up, now we can execute them on `t` # set up our return value @@ -323,29 +307,29 @@ def zero_function(_): @partial(jax.jit, static_argnums=(0)) def vaccination_rate(self, t: ArrayLike) -> jax.Array: """ - Given some time t, returns a jnp.array of shape (self.config.NUM_AGE_GROUPS, self.config.MAX_VACCINATION_COUNT + 1) - representing the age / vax history stratified vaccination rates for an additional vaccine. Used by transmission models - to determine vaccination rates at a particular time step. - In the cases that your model's definition of t=0 is later the vaccination spline's definition of t=0 - use the `VACCINATION_MODEL_DAYS_SHIFT` config parameter to shift the vaccination spline's t=0 right. - - MUST BE CONTINUOUS AND DIFFERENTIABLE FOR ALL TIMES t. If you want a piecewise implementation of vax rates must declare jump points - in the MCMC object. + Calculate the instantaneous vaccination rates stratified by age and vaccination history. Parameters ---------- - t: float as Traced - current time in the model, due to the just-in-time nature of Jax this float value may be contained within a - traced array of shape () and size 1. Thus no explicit comparison should be done on "t". + t : ArrayLike + Current time in the model. Returns - ----------- - vaccination_rates: jnp.Array - jnp.array(shape=(self.config.NUM_AGE_GROUPS, self.config.MAX_VACCINATION_COUNT + 1)) of vaccination rates for each age bin and vax history strata. + ------- + jax.Array + An array of shape (self.config.NUM_AGE_GROUPS, self.config.MAX_VACCINATION_COUNT + 1) + representing vaccination rates for each age bin and vaccination history strata. + + Notes + ----- + Use `self.config.VACCINATION_MODEL_DAYS_SHIFT` to adjust t=0 + specifically for this function. + Refer to `load_vaccination_model` for details on spline definitions + and loading. + The function is continuous and differentiable for all times `t`. """ # shifting splines if needed for multi-epochs, 0 by default t_added = getattr(self.config, "VACCINATION_MODEL_DAYS_SHIFT", 0) - # default to 1.0 (unchanged) if parameter does not exist vaccination_rates_log = utils.evaluate_cubic_spline( t + t_added, self.config.VACCINATION_MODEL_KNOT_LOCATIONS, @@ -367,18 +351,44 @@ def vaccination_rate(self, t: ArrayLike) -> jax.Array: @partial(jax.jit, static_argnums=(0)) def beta_coef(self, t: ArrayLike) -> ArrayLike: - """Returns a coefficient for the beta value for cases of external impacts - on transmission not directly accounted for in the model. - Currently implemented via an array search with timings BETA_TIMES and coefficients BETA_COEFICIENTS + """ + Calculate the coefficient to modify the transmission rate based on external factors. Parameters ---------- - t: float as Traced - current time in the model. Due to the just-in-time nature of Jax this float value may be contained within a - traced array of shape () and size 1. Thus no explicit comparison should be done on "t". + t : ArrayLike + Current time in the model. + + Returns + ------- + ArrayLike + Coefficient by which BETA can be multiplied to + externally increase or decrease its value. + + Examples + -------- + Multiple values of `t` being passed at once for this example only. + + >>> self.config.BETA_COEFICIENTS + jnp.array([-1.0, 0.0, 1.0]) + + >>> self.config.BETA_TIMES + jnp.array([25, 50]) + + >>> self.beta_coef(t=[0, 24]) + [-1. -1.] - Returns: - Coefficient by which BETA can be multiplied to externally increase or decrease the value to account for measures or seasonal forcing. + >>> self.beta_coef(t=[25, 26, 49]) + [0. 0. 0.] + + >>> self.beta_coef(t=[50, 51, 99]) + [1. 1. 1.] + + Notes + ----- + The function defaults to a coefficient of 1.0 if no modifications are + specified. It uses `BETA_TIMES` and `BETA_COEFICIENTS` from the + configuration for adjustments. """ # a smart lookup function that works with JAX just in time compilation # if t > self.config.BETA_TIMES_i, return self.config.BETA_COEFICIENTS_i @@ -399,14 +409,15 @@ def seasonality( seasonality_second_wave: ArrayLike, seasonality_shift: ArrayLike, ) -> ArrayLike: - """ - Returns the seasonlity coefficient as determined by two cosine waves - multiplied by `seasonality_peak` and `seasonality_second_wave` and shifted by `seasonality_shift` days. + """Calculate seasonlity coefficient for time `t`. - Parameters - ----------- - t: int/Traced as jax.Tracer during runtime + As determined by two cosine waves multiplied by `seasonality_peak` and + `seasonality_second_wave` and shifted by `seasonality_shift` days. + Parameters + ---------- + t: ArrayLike + Current model day. seasonality_amplitude: float/Traced maximum and minimum of the combined curves, taking values of `1 +/-seasonality_amplitude` respectively @@ -417,12 +428,12 @@ def seasonality( seasonality_shift: float/Traced horizontal shift across time in days, cant not exceed +/-(365/2) if seasonality_shift=0, peak occurs at t=0. + Returns - ----------- + ------- Seasonality coefficient signaling an increase (>1) or decrease (<1) - in transmission due to the impact of seasonality. - + in transmission due to the impact of seasonality.\ """ # cosine curves are defined by a cycle of 365 days begining at jan 1st # start by shifting the curve some number of days such that we line up with our INIT_DATE @@ -464,14 +475,17 @@ def seasonality( ) ) - def retrieve_population_counts(self) -> None: - """ - A wrapper function which takes calculates the age stratified population counts across all the INITIAL_STATE compartments - (minus the book-keeping C compartment.) and stores it in the self.config.POPULATION parameter. + def retrieve_population_counts(self) -> np.ndarray: + """Calculate the age stratified population counts across all tracked compartments. + + Excludes the book-keeping C compartment. - We do not recieve this data exactly from the initializer, but it is trivial to recalculate. + Returns + ------- + np.ndarray + population counts of each age bin within `self.INITIAL_STATE` """ - self.config.POPULATION = np.sum( # sum together S+E+I compartments + return np.sum( # sum together S+E+I compartments np.array( [ np.sum( @@ -490,34 +504,62 @@ def retrieve_population_counts(self) -> None: axis=(0), # sum across compartments, keep age bins ) - def load_cross_immunity_matrix(self) -> None: - """ - Loads the Crossimmunity matrix given the strain interactions matrix. - Strain interactions matrix is a matrix of shape (num_strains, num_strains) representing the relative immune escape risk - of those who are being challenged by a strain in dim 0 but have recovered from a strain in dim 1. - Neither the strain interactions matrix nor the crossimmunity matrix take into account waning. + def load_cross_immunity_matrix(self) -> jax.Array: + """Load the Crossimmunity matrix given the strain interactions matrix. - Updates - ---------- - self.config.CROSSIMMUNITY_MATRIX: - updates this matrix to shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) containing the relative immune escape - values for each challenging strain compared to each prior immune history in the model. + Returns + ------- + jax.Array + matrix of shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) + containing the relative immune escape values for each challenging + strain compared to each prior immune history in the model. + + Notes + ----- + Strain interactions matrix is a matrix of shape + (self.config.NUM_STRAINS, self.config.NUM_STRAINS), + representing the relative immune escape risk of those who are being + challenged by a strain in dim 0 but have recovered + previously from a strain in dim 1. Neither the strain interactions + matrix nor the crossimmunity matrix take into account waning. """ - self.config.CROSSIMMUNITY_MATRIX = ( - utils.strain_interaction_to_cross_immunity( - self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS - ) + return utils.strain_interaction_to_cross_immunity( + self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS ) - def load_vaccination_model(self) -> None: - """ - loads parameters of a polynomial spline vaccination model - stratified on age bin and current vaccination status. - - Raises FileNotFoundError if directory given does not contain the state-specific - filename. Formatted as spline_fits_state_name.csv. + def load_vaccination_model(self) -> tuple[jax.Array, jax.Array, jax.Array]: + """Load parameters of a polynomial spline vaccination model. - Also raises FileNotFoundError if passed non-csv or non-file paths. + Returns + ------- + the following are 3 parallel lists, each with leading dimensions + `(NUM_AGE_GROUPS, MAX_VAX_COUNT+1)` identifying the vaccination spline + from age group I and vaccination count J to vaccination count J+1. + (indivduals vaccinated while at `MAX_VAX_COUNT` generally stay in + the same tier, but this is ODE specific). + VACCINATION_MODEL_KNOTS: jax.Array + array of knot coefficients for each knot located on + `VACCINATION_MODEL_KNOT_LOCATIONS[i][j]` + VACCINATION_MODEL_KNOT_LOCATIONS: jax.Array + array of knot locations by model day, with 0 indicating the knot is + placed on self.config.INIT_DATE. + VACCINATION_MODEL_BASE_EQUATIONS: jax.Array + array defining the coefficients (a,b,c,d) of each + base equation `(a + b(t) + c(t)^2 + d(t)^3)` for the spline defined + by `VACCINATION_MODEL_KNOT_LOCATIONS[i][j]`. + + Raises + ------ + FileNotFoundError + if path is not a csv file or directory. Or if directory path does not contain region + specific file matching expected naming convention. + + Notes + ----- + Reads spline information from `self.config.VACCINATION_MODEL_DATA`, + if path given is a directory, attempts a region-specific lookup with + `self.config.REGIONS[0]`, using format + `self.config.VACCINATION_MODEL_DATA/spline_fits_{region_name}` """ # if the user passes a directory instead of a file path # check to see if the state exists in the directory and use that @@ -597,29 +639,34 @@ def load_vaccination_model(self) -> None: vax_knot_locations[age_group_idx, vax_idx, :] = np.array( knot_locations ) - self.config.VACCINATION_MODEL_KNOTS = jnp.array(vax_knots) - self.config.VACCINATION_MODEL_KNOT_LOCATIONS = jnp.array( - vax_knot_locations - ) - self.config.VACCINATION_MODEL_BASE_EQUATIONS = jnp.array( - vax_base_equations + return ( + jnp.array(vax_knots), + jnp.array(vax_knot_locations), + jnp.array(vax_base_equations), ) def seasonal_vaccination_reset(self, t: ArrayLike) -> ArrayLike: - """ - if model implements seasonal vaccination, returns evaluation of a continuously differentiable function - at time `t` to outflow individuals from the top most vaccination bin (functionally the seasonal tier) + """Calculate seasonal vaccination outflow coefficient. + + If model implements seasonal vaccination, returns evaluation of a + continuously differentiable function at time `t` to outflow individuals + from the top most vaccination bin (functionally the seasonal tier) into the second highest bin. - Example + Parameters ---------- - if self.config.SEASONAL_VACCINATION == True - - at `t=utils.date_to_sim_day(self.config.VACCINATION_SEASON_CHANGE)` returns 1 - else returns near 0 for t far from self.config.VACCINATION_SEASON_CHANGE. - - This value of 1 is used by model ODES to outflow individuals from the top vaccination bin - into the one below it, indicating a new vaccination season. + t: ArrayLike + current time in the model. + + Examples + -------- + >>> assert self.config.SEASONAL_VACCINATION + >>> self.config.VACCINATION_SEASON_CHANGE + 50 + >>> np.isclose(self.seasonal_vaccination_reset(50), 1.0) + True + >>> np.isclose(self.seasonal_vaccination_reset(49, 51), [1.0, 1.0]) + [False, False] """ if ( hasattr(self.config, "SEASONAL_VACCINATION") @@ -647,17 +694,19 @@ def seasonal_vaccination_reset(self, t: ArrayLike) -> ArrayLike: # if no seasonal vaccination, this function always returns zero return 0 - def load_contact_matrix(self) -> None: - """ - a wrapper function that loads a contact matrix for the USA based on mixing paterns data found here: - https://github.com/mobs-lab/mixing-patterns + def load_contact_matrix(self) -> np.ndarray: + """Load region specific contact matrix. - Updates - ---------- - `self.config.CONTACT_MATRIX` : numpy.ndarray - a matrix of shape (self.config.NUM_AGE_GROUPS, self.config.NUM_AGE_GROUPS) with each value representing TODO + Usually sourced from https://github.com/mobs-lab/mixing-patterns. + + Returns + ------- + numpy.ndarray + a matrix of shape (self.config.NUM_AGE_GROUPS, self.config.NUM_AGE_GROUPS) + where `CONTACT_MATRIX[i][j]` refers to the per capita + interaction rate between age bin `i` and `j` """ - self.config.CONTACT_MATRIX = utils.load_demographic_data( + return utils.load_demographic_data( self.config.DEMOGRAPHIC_DATA_PATH, self.config.REGIONS, self.config.NUM_AGE_GROUPS, @@ -669,23 +718,31 @@ def scale_initial_infections( self, scale_factor: ArrayLike ) -> SEIC_Compartments: """ - a function which modifies returns a modified version of - self.INITIAL_STATE scaling the number of initial infections by `scale_factor`. + Scale the number of initial infections by a specified factor. - Preserves the ratio of the Exposed/Infectious compartment population sizes. - Does not modified self.INITIAL_STATE, returns a copy. + This function returns a modified version of `self.INITIAL_STATE`, + scaling the number of initial infections while preserving the ratio + between the Exposed and Infectious compartments. The original + `self.INITIAL_STATE` remains unchanged. Parameters ---------- - scale_factor: float - a multiplier value >=0.0. - `scale_factor` < 1 reduces number of initial infections, - `scale_factor` == 1.0 leaves initial infections unchanged, - `scale_factor` > 1 increases number of initial infections. + scale_factor : float + A multiplier value >= 0.0. + - `scale_factor < 1`: Reduces the number of initial infections. + - `scale_factor == 1.0`: Leaves initial infections unchanged. + - `scale_factor > 1`: Increases the number of initial infections. Returns - --------- - A copy of INITIAL_INFECTIONS with each compartment being scaled according to `scale_factor` + ------- + SEIC_Compartments + A copy of `self.INITIAL_STATE` with each compartment + scaled up or down depending on `scale_factor`. + + Notes + ----- + The function ensures that the relative sizes of + Exposed and Infectious compartments are preserved during scaling. """ pop_counts_by_compartment = jnp.array( [ diff --git a/src/dynode/config.py b/src/dynode/config.py index 0310ae2f..36b9ac8b 100644 --- a/src/dynode/config.py +++ b/src/dynode/config.py @@ -13,6 +13,7 @@ import warnings from enum import IntEnum from functools import partial +from typing import Any import jax.numpy as jnp import numpy as np @@ -22,31 +23,35 @@ class Config: - """ - A Configuration class designed to take JSON config files, - validate them, and generate downstream parameters where applicable - """ + """A factory class to validate and build on top of JSON config files.""" + + def __init__(self, config_json_str: str) -> None: + """Initialize configuration instance. - def __init__(self, config_json_str) -> None: + Parameters + ---------- + config_json_str : str + JSON string representing a dictionary you wish to merge in + """ self.add_file(config_json_str) - def add_file(self, config_json_str): - """loads a JSON string into self, - overriding any shared names, asserting valid configuration of parameters, - and setting any downstream parameters. + def add_file(self, config_json_str: str): + """Merge in another configuration JSON and assert new valid state. + + Overriding any shared names and setting downstream parameters. Parameters ---------- config_json_str : str - JSON string representing a dictionary you wish to add into self + JSON string representing a dictionary you wish to merge in Returns ------- Config - self with the parameters from `config_json_str` added on, as well as - any downstream parameters generated. + self with the parameters from `config_json_str` added on, + as well as any downstream parameters generated. """ - # adds another config to self.__dict__ and resets downstream parameters again + # adds another config to self.__dict__ and reruns downstream parameters config = json.loads( config_json_str, object_hook=distribution_converter ) @@ -56,15 +61,34 @@ def add_file(self, config_json_str): self.set_downstream_parameters() return self - def asdict(self): + def _asdict(self): return self.__dict__ - def convert_types(self, config): - """ - takes a dictionary of config parameters, consults the PARAMETERS global list and attempts to convert the type - of each parameter whos name matches. + def convert_types(self, config: dict[str, str | Any]) -> dict[str, Any]: + """Convert parameters to correct types. + + Takes a dictionary of config parameters, consults the PARAMETERS + global list and attempts to convert the type + of each key within `config` which matches a `name` from PARAMETERS. + + Parameters + ---------- + config : dict[str, Any] + parameters whos types you wish to adjust + + Returns + ------- + dict[str, Any] + `config` with types of matched parameters modified + + Raises + ------ + ConfigParserError + if type casting of any parameter within `Config` fails. """ - for parameter in PARAMETERS: + for p in PARAMETERS: + assert isinstance(p, dict), "mypy assert on %s" % p + parameter = p key = parameter["name"] # if this validator needs to be cast if "type" in parameter.keys(): @@ -80,10 +104,20 @@ def convert_types(self, config): return config def set_downstream_parameters(self): - """ - A function that checks if a specific parameter exists, then sets any parameters that depend on it. - - E.g., `NUM_AGE_GROUPS` = len(`AGE_LIMITS`) if `AGE_LIMITS` exists, set `NUM_AGE_GROUPS` + """Generate depedent downstream parameters. + + Checks if a specific parameter exists, sets any parameters that depend on it. + + Examples + -------- + >>> hasattr(self, "AGE_LIMITS") + True + >>> hasattr(self, "NUM_AGE_GROUPS") + False + >>> self.set_downstream_parameters() + >>> hasattr(self, "NUM_AGE_GROUPS") + True + >>> assert len(self.AGE_LIMITS) == self.NUM_AGE_GROUPS """ for parameter in PARAMETERS: key = parameter["name"] @@ -95,17 +129,21 @@ def set_downstream_parameters(self): downstream_function(self, key) def assert_valid_configuration(self): - """ - checks the soundness of parameters passed into Config by referencing the name of parameters passed to the config - with the PARAMETERS global variable. If a distribution is passed instead of a value, blindly accepts the distribution. + """Validate parameters passed into Config. - Raises assert errors if parameter(s) are incongruent in some way. + References PARAMETER's `validate` functions, if listed. + + Raises + ------ + ConfigValidationError + if parameter(s) are incongruent in some way, either individually + or in combination with one another. """ for param in PARAMETERS: key = param["name"] key = make_list_if_not(key) validator_funcs = param.get("validate", False) - # if there are validators to test, and the key(s) are found in our config, lets test them + # if validator_funcs, and the key(s) are found in self, lets test if validator_funcs and all([hasattr(self, k) for k in key]): validator_funcs = make_list_if_not(validator_funcs) vals = [getattr(self, k) for k in key] @@ -122,48 +160,72 @@ def assert_valid_configuration(self): ] except Exception as e: if len(key) > 1: - err_text = """There was an issue validating your Config object. - The error was caused by the intersection of the following parameters: %s. - %s""" % ( + err_text = """There was an issue validating your Config + object. The error was caused by the intersection of + the following parameters: %s.%s""" % ( key, e, ) else: - err_text = """The following error occured while validating the %s - parameter in your configuration file: %s""" % ( + err_text = """The following error occured while + validating the %s parameter in your configuration + file: %s""" % ( key[0], e, ) raise ConfigValidationError(err_text) -def make_list_if_not(obj): - return obj if isinstance(obj, (list, np.ndarray)) else [obj] +def make_list_if_not(obj: Any) -> list[Any] | np.ndarray: + """Turn an object to a list if it is not already. + Parameters + ---------- + obj : Any + object, may or may not be iterable -def distribution_converter(dct): + Returns + ------- + list[Any] + [obj], single element iterable containing obj. """ - Converts a distribution or transform as specified in JSON config file into - a numpyro distribution/transform object. - This function is called as a part of json.loads(object_hook=distribution_converter) - meaning it executes on EVERY JSON object within a JSON string, - recursively from innermost nested outwards. + return obj if isinstance(obj, (list, np.ndarray)) else [obj] - a distribution is identified by the `distribution` and `params` keys inside of a json object - while a transform is identified by the `transform` and `params` keys inside of a json object - and a constraint is identified by the `constraint` and `params` keys inside of a json object +def distribution_converter( + dct: dict, +) -> ( + dict + | distributions.Distribution + | transforms.Transform + | distributions.constraints.Constraint +): + """Convert a distribution or transform JSON object to its numpyro object equal. - PARAMETERS + This function is called as a part of `json.loads(object_hook=distribution_converter)` + meaning it executes on EVERY JSON object, recursively from innermost nested outwards. + + Parameters ---------- - `dct`: dict + dct : dict A dictionary representing any JSON object that is passed into `Config`. - Including nested JSON objects which are executed from deepest nested outwards. Returns - ----------- - dict or numpyro.distributions object. If `distribution_converter` identifies that dct is a valid JSON representation of a - numpyro distribution or transform, it will return it. Otherwise it returns dct unmodified. + ------- + dict | distributions.Distribution | transforms.Transform | distributions.constraints.Constraint + distributions.Distribution if json dict has "distribution" and + "params" key. transforms.Transform if dict has a "transform" key and + "params" key. distributions.constraints.Constraint if dict has + "constraint" and "params" key. Otherwise dict returned untouched. + + Notes + ----- + A distribution is identified by the `distribution` and `params` + keys inside of a json object. + A transform is identified by the `transform` and `params` + keys inside of a json object. + A constraint is identified by the `constraint` and `params` + keys inside of a json object """ try: if "distribution" in dct.keys() and "params" in dct.keys(): @@ -173,14 +235,16 @@ def distribution_converter(dct): distribution = distribution_types[numpyro_dst]( **numpyro_dst_params ) - # numpyro does lazy eval of distributions, if the user passes in invalid parameter values - # they wont be caught until runtime, so we sample here to raise an error + # numpyro does lazy eval of distributions, + # if the user passes in invalid parameter values they wont be + # caught until runtime, sample here to raise any errors early _ = distribution.sample(PRNGKey(1)) return distribution else: raise KeyError( - "The distribution name was not found in the available distributions, " - "see distribution names here: https://num.pyro.ai/en/stable/distributions.html#distributions" + "The distribution name was not found in the " + "available distributions, see distribution names here: " + "https://num.pyro.ai/en/stable/distributions.html#distributions" ) elif "transform" in dct.keys() and "params" in dct.keys(): numpyro_transform = dct["transform"] @@ -192,8 +256,9 @@ def distribution_converter(dct): return transform else: raise KeyError( - "The transform name was not found in the available transformations, " - "see transform names here: https://num.pyro.ai/en/stable/distributions.html#transforms" + "The transform name was not found in the available " + "transformations, see transform names here: " + "https://num.pyro.ai/en/stable/distributions.html#transforms" ) elif "constraint" in dct.keys(): numpyro_constraint = dct["constraint"] @@ -212,15 +277,16 @@ def distribution_converter(dct): return constraint else: raise KeyError( - "The constraint name was not found in the available constraints, " - "see constraint names here: https://num.pyro.ai/en/stable/_modules/numpyro/distributions/constraints.html" + "The constraint name was not found in the available " + "constraints, see constraint names here: " + "https://num.pyro.ai/en/stable/_modules/numpyro/distributions/constraints.html" ) except Exception as e: # reraise the error raise ConfigParserError( - "There was an error parsing the following distribution/transformation: %s \n " - "see docs to make sure you didnt misspell something: https://num.pyro.ai/en/stable/distributions.html#distributions \n" - "or you may have passed incorrect parameters types/names into the distribution" + "There was an error parsing the following object: %s \n " + "see docs to make sure you didnt misspell a parameter: " + "https://num.pyro.ai/en/stable/distributions.html#distributions" % str(dct) ) from e # do nothing if this isnt a distribution or transform @@ -230,10 +296,7 @@ def distribution_converter(dct): ############################################################################# #######################DOWNSTREAM/VALIDATION FUNCTIONS####################### ############################################################################# -def set_downstream_age_variables(conf, _): - """ - given AGE_LIMITS, set downstream variables from there - """ +def _set_downstream_age_variables(conf, _): conf.NUM_AGE_GROUPS = len(conf.AGE_LIMITS) conf.AGE_GROUP_STRS = [ @@ -244,7 +307,7 @@ def set_downstream_age_variables(conf, _): conf.AGE_GROUP_IDX = IntEnum("age", conf.AGE_GROUP_STRS, start=0) -def set_num_waning_compartments_and_rates(conf, _): +def _set_num_waning_compartments_and_rates(conf, _): conf.NUM_WANING_COMPARTMENTS = len(conf.WANING_TIMES) # odes often need waning rates not times # since last waning compartment set to 0, avoid a div by zero error here @@ -256,17 +319,11 @@ def set_num_waning_compartments_and_rates(conf, _): ) -def set_num_introduced_strains(conf, _): - """ - given INTRODUCTION_TIMES, set downstream variables from there - """ +def _set_num_introduced_strains(conf, _): conf.NUM_INTRODUCED_STRAINS = len(conf.INTRODUCTION_TIMES) -def set_wane_enum(conf, _): - """ - given NUM_WANING_COMPARTMENTS set the WANE_IDX - """ +def _set_wane_enum(conf, _): conf.WANE_IDX = IntEnum( "w_idx", ["W" + str(idx) for idx in range(conf.NUM_WANING_COMPARTMENTS)], @@ -274,14 +331,14 @@ def set_wane_enum(conf, _): ) -def path_checker(key, value): +def _path_checker(key, value): assert os.path.exists(value), "%s : %s is not a valid path" % (key, value) -def test_positive(key, value): - """ - checks if a value is positive. - If `value` is a distribution, checks that the lower bound of its support is positive +def _test_positive(key, value): + """Check if a value is positive. + + If distribution, check that the lower bound of its support is positive. """ if issubclass(type(value), distributions.Distribution): if hasattr(value.support, "lower_bound"): @@ -311,7 +368,7 @@ def test_positive(key, value): ) -def test_enum_len(key, enum, expected_len): +def _test_enum_len(key, enum, expected_len): assert ( len(enum) == expected_len ), "Expected %s to have %s entries, got %s" % ( @@ -321,9 +378,9 @@ def test_enum_len(key, enum, expected_len): ) -def test_not_negative(key, value): - """ - checks if a value is not negative. +def _test_not_negative(key, value): + """Check if a value is not negative. + If `value` is a distribution, checks that the lower bound of its support not negative """ if issubclass(type(value), distributions.Distribution): @@ -354,10 +411,8 @@ def test_not_negative(key, value): ) -def test_all_in_list(key, lst, func): - """ - a function which tests a different constraint function defined in this file across all values of a list - """ +def _test_all_in_list(key, lst, func): + """Test a constraint function across all values of a list.""" try: for i, value in enumerate(lst): func(key, value) @@ -369,9 +424,9 @@ def test_all_in_list(key, lst, func): ) from e -def age_limit_checks(key, age_limits): - test_not_negative(key, age_limits[0]) - test_ascending(key, age_limits) +def _age_limit_checks(key, age_limits): + _test_not_negative(key, age_limits[0]) + _test_ascending(key, age_limits) assert all( [isinstance(a, int) for a in age_limits] ), "ages must be int, not float because census age data is specified as int" @@ -382,9 +437,9 @@ def age_limit_checks(key, age_limits): ) -def compare_geq(keys, vals): - """ - compares that vals[0] >= vals[1], +def _compare_geq(keys, vals): + """Assert that vals[0] >= vals[1]. + attempting to compare the upper and lower bounds of vals[0] and vals[1] if either or both are distributions. some distribution `a` is considered >= distribution `b` if @@ -484,7 +539,7 @@ def compare_geq(keys, vals): ) -def test_type(key, val, tested_type): +def _test_type(key, val, tested_type): assert isinstance(val, tested_type) or issubclass( type(val), tested_type ), "%s must be an %s, found %s" % ( @@ -494,22 +549,22 @@ def test_type(key, val, tested_type): ) -def test_non_empty(key, val): +def _test_non_empty(key, val): assert len(val) > 0, "%s is expected to be a non-empty list" % key -def test_len(keys, vals): +def _test_len(keys, vals): assert vals[0] == len(vals[1]), "len(%s) must equal to %s" % ( keys[1], keys[0], ) -def test_equal_len(keys, vals): - test_len(keys, [len(vals[0]), vals[1]]) +def _test_equal_len(keys, vals): + _test_len(keys, [len(vals[0]), vals[1]]) -def test_shape(keys, vals): +def _test_shape(keys, vals): key1, key2 = keys[0], keys[1] shape_of_matrix, array = vals[0], vals[1] assert shape_of_matrix == array.shape, "%s.shape must equal to %s" % ( @@ -518,13 +573,13 @@ def test_shape(keys, vals): ) -def test_ascending(key, lst): +def _test_ascending(key, lst): assert all([lst[idx - 1] < lst[idx] for idx in range(1, len(lst))]), ( "%s must be placed in increasing order" % key ) -def test_zero(key, val): +def _test_zero(key, val): assert val == 0, "value in %s must be zero" % key @@ -548,16 +603,23 @@ def test_zero(key, val): ############################################################################# """ PARAMETERS: -A list of possible parameters contained within any Config file that a model may be expected to read in. -name: the parameter name as written in the JSON config or a list of parameter names. - if isinstance(name, list) all parameter names must be present before any other sections are executed. -validate: a single function, or list of functions, each with a signature of f(str, obj) -> None +A list of possible parameters contained within any Config file that a model +may be expected to read in. + +name: the parameter name as written in the JSON config or a + list of parameter names. If `isinstance(name, list)` all parameter names + must be present before any other sections are executed. +validate: a single function, or list of functions, + each with a signature of f(str, obj) -> None that raise assertion errors if their conditions are not met. Note: ALL validators must pass for Config to accept the parameter - For the case of test_type, the type of the parameter may be ANY of the tested_type dtypes. -type: If the parameter type is a non-json primative type, specify a function that takes in the nearest JSON primative type and does - the type conversion. E.G: np.array recieves a JSON primative (list) and returns a numpy array. -downstream: if receiving this parameter kicks off downstream parameters to be modified or created, a function which takes the Config() + For the case of test_type, the type of the parameter may be + ANY of the tested_type dtypes. +type: If the parameter type is a non-json primative type, specify a function + that takes in the nearest JSON primative type and does the type conversion. + E.G: np.array recieves a JSON primative (list) and returns a numpy array. +downstream: if receiving this parameter kicks off downstream parameters to be + modified or created, a function which takes the Config() class is accepted to modify/create the downstream parameters. Note about partial(): the partial function creates an anonymous function, taking a named function as input as well as some @@ -567,140 +629,143 @@ class is accepted to modify/create the downstream parameters. PARAMETERS = [ { "name": "SAVE_PATH", - "validate": [partial(test_type, tested_type=str), path_checker], + "validate": [partial(_test_type, tested_type=str), _path_checker], }, { "name": "DEMOGRAPHIC_DATA_PATH", - "validate": [partial(test_type, tested_type=str), path_checker], + "validate": [partial(_test_type, tested_type=str), _path_checker], }, { "name": "SEROLOGICAL_DATA_PATH", - "validate": [partial(test_type, tested_type=str), path_checker], + "validate": [partial(_test_type, tested_type=str), _path_checker], }, { "name": "SIM_DATA_PATH", - "validate": [partial(test_type, tested_type=str), path_checker], + "validate": [partial(_test_type, tested_type=str), _path_checker], }, { "name": "VACCINATION_MODEL_DATA", - "validate": [partial(test_type, tested_type=str), path_checker], + "validate": [partial(_test_type, tested_type=str), _path_checker], }, { "name": "AGE_LIMITS", - "validate": [partial(test_type, tested_type=list), age_limit_checks], - "downstream": set_downstream_age_variables, + "validate": [partial(_test_type, tested_type=list), _age_limit_checks], + "downstream": _set_downstream_age_variables, }, { "name": "POP_SIZE", - "validate": [partial(test_type, tested_type=int), test_positive], + "validate": [partial(_test_type, tested_type=int), _test_positive], }, { "name": "INITIAL_INFECTIONS", "validate": [ - partial(test_type, tested_type=(int, float)), - test_not_negative, + partial(_test_type, tested_type=(int, float)), + _test_not_negative, ], }, { "name": "INITIAL_INFECTIONS_SCALE", "validate": [ partial( - test_type, tested_type=(int, float, distributions.Distribution) + _test_type, + tested_type=(int, float, distributions.Distribution), ), - test_not_negative, + _test_not_negative, ], }, { "name": ["POP_SIZE", "INITIAL_INFECTIONS"], - "validate": compare_geq, + "validate": _compare_geq, }, { "name": "INFECTIOUS_PERIOD", "validate": [ partial( - test_type, tested_type=(int, float, distributions.Distribution) + _test_type, + tested_type=(int, float, distributions.Distribution), ), - test_not_negative, + _test_not_negative, ], }, { "name": "EXPOSED_TO_INFECTIOUS", "validate": [ partial( - test_type, tested_type=(int, float, distributions.Distribution) + _test_type, + tested_type=(int, float, distributions.Distribution), ), - test_not_negative, + _test_not_negative, ], }, { "name": "WANING_TIMES", "validate": [ - partial(test_type, tested_type=list), - lambda key, vals: [test_positive(key, val) for val in vals[:-1]], - lambda key, vals: test_zero(key, vals[-1]), - lambda key, vals: [test_type(key, val, int) for val in vals], + partial(_test_type, tested_type=list), + lambda key, vals: [_test_positive(key, val) for val in vals[:-1]], + lambda key, vals: _test_zero(key, vals[-1]), + lambda key, vals: [_test_type(key, val, int) for val in vals], ], - "downstream": set_num_waning_compartments_and_rates, + "downstream": _set_num_waning_compartments_and_rates, }, { "name": "NUM_WANING_COMPARTMENTS", "validate": [ - partial(test_type, tested_type=int), - test_positive, + partial(_test_type, tested_type=int), + _test_positive, ], - "downstream": set_wane_enum, + "downstream": _set_wane_enum, }, { "name": "WANING_PROTECTIONS", "validate": lambda key, vals: [ - test_not_negative(key, val) for val in vals + _test_not_negative(key, val) for val in vals ], "type": np.array, }, { "name": ["NUM_WANING_COMPARTMENTS", "WANING_TIMES"], - "validate": test_len, + "validate": _test_len, }, { "name": ["NUM_WANING_COMPARTMENTS", "WANING_PROTECTIONS"], - "validate": test_len, + "validate": _test_len, }, { "name": "STRAIN_INTERACTIONS", - "validate": test_non_empty, + "validate": _test_non_empty, "type": np.array, }, { "name": ["NUM_STRAINS", "STRAIN_INTERACTIONS"], # check that STRAIN_INTERACTIONS shape is (NUM_STRAINS, NUM_STRAINS) - "validate": lambda key, vals: test_shape( + "validate": lambda key, vals: _test_shape( key, [(vals[0], vals[0]), vals[1]] ), }, { "name": ["NUM_STRAINS", "CROSSIMMUNITY_MATRIX"], # check that CROSSIMMUNITY_MATRIX shape is (NUM_STRAINS, 2**NUM_STRAINS) - "validate": lambda key, vals: test_shape( + "validate": lambda key, vals: _test_shape( key, [(vals[0], 2 ** vals[0]), vals[1]] ), }, { "name": ["NUM_STRAINS", "STRAIN_IDX"], # check that len(STRAIN_IDX)==NUM_STRAINS - "validate": lambda keys, vals: test_enum_len( + "validate": lambda keys, vals: _test_enum_len( keys[1], vals[1], vals[0] ), }, { "name": "MAX_VACCINATION_COUNT", - "validate": test_not_negative, + "validate": _test_not_negative, }, { "name": "AGE_DOSE_SPECIFIC_VAX_COEF", "type": np.array, "validate": [ - lambda key, val: test_all_in_list( - key, val.flatten(), test_not_negative + lambda key, val: _test_all_in_list( + key, val.flatten(), _test_not_negative ), ], }, @@ -710,107 +775,107 @@ class is accepted to modify/create the downstream parameters. "NUM_AGE_GROUPS", "MAX_VACCINATION_COUNT", ], - "validate": lambda keys, vals: test_shape( + "validate": lambda keys, vals: _test_shape( keys, ((vals[1], vals[2] + 1), vals[0]) ), }, { "name": "VACCINE_EFF_MATRIX", - "validate": test_non_empty, + "validate": _test_non_empty, "type": np.array, }, { "name": "BETA_TIMES", "validate": lambda key, lst: [ - test_not_negative(key, beta_time) for beta_time in lst + _test_not_negative(key, beta_time) for beta_time in lst ], "type": np.array, }, { "name": "BETA_COEFICIENTS", "validate": lambda key, lst: [ - test_not_negative(key, beta_time) for beta_time in lst + _test_not_negative(key, beta_time) for beta_time in lst ], "type": jnp.array, }, { "name": "CONSTANT_STEP_SIZE", "validate": [ - test_not_negative, - partial(test_type, tested_type=(int, float)), + _test_not_negative, + partial(_test_type, tested_type=(int, float)), ], "type": float, }, { "name": "SOLVER_RELATIVE_TOLERANCE", "validate": [ - test_not_negative, - partial(test_type, tested_type=float), + _test_not_negative, + partial(_test_type, tested_type=float), # RTOL <= 1 - lambda key, val: compare_geq(["1.0", key], [1.0, val]), + lambda key, val: _compare_geq(["1.0", key], [1.0, val]), ], "type": float, }, { "name": "SOLVER_ABSOLUTE_TOLERANCE", "validate": [ - test_not_negative, - partial(test_type, tested_type=float), + _test_not_negative, + partial(_test_type, tested_type=float), # ATOL <= 1 - lambda key, val: compare_geq(["1.0", key], [1.0, val]), + lambda key, val: _compare_geq(["1.0", key], [1.0, val]), ], "type": float, }, { "name": "SOLVER_MAX_STEPS", "validate": [ - partial(test_type, tested_type=(int)), + partial(_test_type, tested_type=(int)), # STEPS >= 1 - lambda key, val: compare_geq([key, "1"], [val, 1]), + lambda key, val: _compare_geq([key, "1"], [val, 1]), ], "type": int, }, { "name": "STRAIN_R0s", "validate": [ - partial(test_type, tested_type=np.ndarray), - test_non_empty, - partial(test_all_in_list, func=test_not_negative), + partial(_test_type, tested_type=np.ndarray), + _test_non_empty, + partial(_test_all_in_list, func=_test_not_negative), ], "type": np.array, }, { "name": ["NUM_STRAINS", "MAX_VACCINATION_COUNT", "VACCINE_EFF_MATRIX"], # check that VACCINE_EFF_MATRIX shape is (NUM_STRAINS, MAX_VACCINATION_COUNT + 1) - "validate": lambda key, vals: test_shape( + "validate": lambda key, vals: _test_shape( key, [(vals[0], vals[1] + 1), vals[2]] ), }, { "name": "INTRODUCTION_TIMES", "validate": [ - partial(test_type, tested_type=list), + partial(_test_type, tested_type=list), lambda key, val: [ - [test_not_negative(key, intro_time) for intro_time in val] + [_test_not_negative(key, intro_time) for intro_time in val] ], ], - "downstream": set_num_introduced_strains, + "downstream": _set_num_introduced_strains, }, { "name": "INTRODUCTION_SCALES", "validate": [ - partial(test_type, tested_type=list), + partial(_test_type, tested_type=list), lambda key, val: [ - [test_positive(key, intro_scale) for intro_scale in val] + [_test_positive(key, intro_scale) for intro_scale in val] ], ], }, { "name": "INTRODUCTION_PCTS", "validate": [ - partial(test_type, tested_type=list), + partial(_test_type, tested_type=list), lambda key, val: [ - [test_not_negative(key, intro_perc) for intro_perc in val] + [_test_not_negative(key, intro_perc) for intro_perc in val] ], ], }, @@ -821,10 +886,10 @@ class is accepted to modify/create the downstream parameters. "INTRODUCTION_PCTS", ], "validate": [ - lambda key, val: test_equal_len( + lambda key, val: _test_equal_len( [key[0], key[1]], [val[0], val[1]] ), - lambda key, val: test_equal_len( + lambda key, val: _test_equal_len( [key[1], key[2]], [val[1], val[2]] ), # by transitive property, len(INTRODUCTION_TIMES) == len(INTRODUCTION_PCTS) @@ -834,33 +899,36 @@ class is accepted to modify/create the downstream parameters. "name": "SEASONALITY_AMPLITUDE", "validate": [ partial( - test_type, tested_type=(float, int, distributions.Distribution) + _test_type, + tested_type=(float, int, distributions.Distribution), ), # -1.0 <= SEASONALITY_PEAK <= 1.0 - lambda key, val: compare_geq([key, "-1.0"], [val, -1.0]), - lambda key, val: compare_geq(["1.0", key], [1.0, val]), + lambda key, val: _compare_geq([key, "-1.0"], [val, -1.0]), + lambda key, val: _compare_geq(["1.0", key], [1.0, val]), ], }, { "name": "SEASONALITY_SECOND_WAVE", "validate": [ partial( - test_type, tested_type=(float, int, distributions.Distribution) + _test_type, + tested_type=(float, int, distributions.Distribution), ), # 0 <= SEASONALITY_SECOND_WAVE <= 1.0 - lambda key, val: compare_geq([key, "0"], [val, 0]), - lambda key, val: compare_geq(["1.0", key], [1.0, val]), + lambda key, val: _compare_geq([key, "0"], [val, 0]), + lambda key, val: _compare_geq(["1.0", key], [1.0, val]), ], }, { "name": "SEASONALITY_SHIFT", "validate": [ partial( - test_type, tested_type=(float, int, distributions.Distribution) + _test_type, + tested_type=(float, int, distributions.Distribution), ), # -365/2 <= SEASONALITY_SHIFT <= 365/2 - lambda key, val: compare_geq([key, "-365/2"], [val, -182.5]), - lambda key, val: compare_geq(["365/2", key], [182.5, val]), + lambda key, val: _compare_geq([key, "-365/2"], [val, -182.5]), + lambda key, val: _compare_geq(["365/2", key], [182.5, val]), ], }, { @@ -885,7 +953,7 @@ class is accepted to modify/create the downstream parameters. }, { "name": "SEASONAL_VACCINATION", - "validate": partial(test_type, tested_type=(bool)), + "validate": partial(_test_type, tested_type=(bool)), }, { "name": "COMPARTMENT_IDX", @@ -920,22 +988,20 @@ class is accepted to modify/create the downstream parameters. { "name": "MAX_TREE_DEPTH", "validate": [ - partial(test_type, tested_type=(int)), - test_positive, + partial(_test_type, tested_type=(int)), + _test_positive, ], }, ] class ConfigParserError(Exception): - """A basic class meant to denote when the Config - class is having an issue parsing a configuration file""" + """Exception when the Config class is having an issue parsing a configuration file.""" pass class ConfigValidationError(Exception): - """A basic class meant to denote when the Config - class is having an issue validating a configuration file""" + """Exception when the Config class is having an issue validating a configuration file.""" pass diff --git a/src/dynode/covid_sero_initializer.py b/src/dynode/covid_sero_initializer.py index 2fbafe8f..97288899 100644 --- a/src/dynode/covid_sero_initializer.py +++ b/src/dynode/covid_sero_initializer.py @@ -1,12 +1,11 @@ -"""This module defines a covid initializer that uses serology data combined with -an interaction and immunity matrix to create an -initial state of immunity, exposed, and infectious individuals.""" +"""Define a covid initializer for parsing and transforming input serology data.""" import os import jax.numpy as jnp import numpy as np import pandas as pd +from jax import Array from . import SEIC_Compartments, utils from .abstract_initializer import AbstractInitializer @@ -14,18 +13,30 @@ class CovidSeroInitializer(AbstractInitializer): - """A Covid Specific initializer class using serology input data to stratify immunity and initial infections""" + """A Covid specific initializer class using serology input data to stratify immunity.""" def __init__(self, config_initializer_path, global_variables_path): - """ - initialize a mechanistic model for covid19 case prediction using serological data. + """Create an initializer for covid19 case prediction using serological data. + + Updates the `self.INITIAL_STATE` jax array to contain all relevant + age and immune distributions of the specified population. + + Parameters + ---------- + config_initializer_path : str + Path to initializer specific JSON parameters. + global_variables_path : str + Path to global JSON for parameters shared across all components + of the model. """ initializer_json = open(config_initializer_path, "r").read() global_json = open(global_variables_path, "r").read() self.config = Config(global_json).add_file(initializer_json) if not hasattr(self.config, "INITIAL_POPULATION_FRACTIONS"): - self.load_initial_population_fractions() + self.config.INITIAL_POPULATION_FRACTIONS = ( + self.load_initial_population_fractions() + ) self.config.POPULATION = ( self.config.POP_SIZE * self.config.INITIAL_POPULATION_FRACTIONS @@ -33,21 +44,34 @@ def __init__(self, config_initializer_path, global_variables_path): # self.POPULATION.shape = (NUM_AGE_GROUPS,) if not hasattr(self.config, "INIT_IMMUNE_HISTORY"): - self.load_immune_history_via_serological_data() + self.config.INIT_IMMUNE_HISTORY = ( + self.load_immune_history_via_serological_data() + ) # self.INIT_IMMUNE_HISTORY.shape = (age, hist, num_vax, waning) if not hasattr(self.config, "CONTACT_MATRIX"): - self.load_contact_matrix() + self.config.CONTACT_MATRIX = self.load_contact_matrix() if not hasattr(self.config, "CROSSIMMUNITY_MATRIX"): - self.load_cross_immunity_matrix() + self.config.CROSSIMMUNITY_MATRIX = ( + self.load_cross_immunity_matrix() + ) # stratify initial infections appropriately across age, hist, vax counts if not ( - hasattr(self.config, "INIT_INFECTED_DIST") + hasattr(self.config, "INIT_INFECTIOUS_DIST") and hasattr(self.config, "INIT_EXPOSED_DIST") ) and hasattr(self.config, "CONTACT_MATRIX_PATH"): - self.load_init_infection_infected_and_exposed_dist_via_contact_matrix() + self.config.INIT_INFECTION_DIST = ( + self.load_initial_infection_dist_via_contact_matrix() + ) + self.config.INIT_INFECTIOUS_DIST = ( + self.get_initial_infectious_distribution() + ) + self.config.INIT_EXPOSED_DIST = ( + self.get_initial_exposed_distribution() + ) - # load initial state using INIT_IMMUNE_HISTORY, INIT_INFECTED_DIST, and INIT_EXPOSED_DIST + # load initial state using + # INIT_IMMUNE_HISTORY, INIT_INFECTIOUS_DIST, and INIT_EXPOSED_DIST self.INITIAL_STATE = self.load_initial_state( self.config.INITIAL_INFECTIONS ) @@ -55,33 +79,41 @@ def __init__(self, config_initializer_path, global_variables_path): def load_initial_state( self, initial_infections: float ) -> SEIC_Compartments: - """ - a function which takes a number of initial infections, - disperses them across infectious and exposed compartments according to the INIT_INFECTED_DIST - and INIT_EXPOSED_DIST distributions, then subtracts both those populations from the total population and - places the remaining individuals in the susceptible compartment, - distributed according to the INIT_IMMUNE_HISTORY distribution. + """Disperse initial infections across infectious and exposed compartments. Parameters ---------- - initial_infections: the number of infections to disperse between infectious and exposed compartments. - - Requires - ---------- - the following variables be loaded into self: - CONTACT_MATRIX: loading in config or via self.load_contact_matrix() - INIT_INFECTED_DIST: loaded in config or via load_init_infection_infected_and_exposed_dist_via_contact_matrix() - INIT_EXPOSED_DIST: loaded in config or via load_init_infection_infected_and_exposed_dist_via_contact_matrix() - INIT_IMMUNE_HISTORY: loaded in config or via load_immune_history_via_serological_data(). + initial_infections: the number of infections to + disperse between infectious and exposed compartments. Returns - ---------- - INITIAL_STATE: tuple(jnp.ndarray) - a tuple of len 4 representing the S, E, I, and C compartment population counts after model initialization. + ------- + INITIAL_STATE: SEIC_Compartments + a tuple of len 4 representing the S, E, I, and C compartment + population counts after model initialization. + + Notes + ----- + Requires the following variables be loaded into self: + - CONTACT_MATRIX: loading in config or via + `self.load_contact_matrix()` + - INIT_INFECTIOUS_DIST: loaded in config or via + `get_initial_infectious_distribution()` + - INIT_EXPOSED_DIST: loaded in config or via + `get_initial_exposed_distribution()` + - INIT_IMMUNE_HISTORY: loaded in config or via + `load_immune_history_via_serological_data()`. + + Age and immune history distributions of infectious and exposed + populations dictated by `self.config.INIT_INFECTIOUS_DIST` and + `self.config.INIT_EXPOSED_DIST` matricies. Subtracts both those + populations from the total population and places the remaining + individuals in the susceptible compartment, distributed according to + the `self.config.INIT_IMMUNE_HISTORY` matrix. """ - # create population distribution using INIT_INFECTED_DIST, then sum them for later use + # create population distribution with INIT_INFECTIOUS_DIST then sum by age initial_infectious_count = ( - initial_infections * self.config.INIT_INFECTED_DIST + initial_infections * self.config.INIT_INFECTIOUS_DIST ) initial_infectious_count_ages = jnp.sum( initial_infectious_count, @@ -91,7 +123,7 @@ def load_initial_state( self.config.I_AXIS_IDX.strain, ), ) - # create population distribution using INIT_EXPOSED_DIST, then sum them for later use + # create population distribution with INIT_EXPOSED_DIST then sum by age initial_exposed_count = ( initial_infections * self.config.INIT_EXPOSED_DIST ) @@ -103,7 +135,8 @@ def load_initial_state( self.config.I_AXIS_IDX.strain, ), ) - # susceptible / partial susceptible = Total population - infected_count - exposed_count + # susceptible / partial susceptible = + # Total population - infected_count - exposed_count initial_susceptible_count = ( self.config.POPULATION - initial_infectious_count_ages @@ -119,13 +152,31 @@ def load_initial_state( jnp.zeros(initial_exposed_count.shape), # c ) - def load_immune_history_via_serological_data(self) -> None: - """ - loads the sero init file for self.config.REGIONS[0] and converts it to a numpy matrix - representing the initial immune history of the individuals in the system. Saving matrix - to self.config.INIT_IMMUNE_HISTORY + def load_immune_history_via_serological_data(self) -> np.ndarray: + """Load the serology init file and calculates initial immune history of susceptibles. - assumes each age bin in INIT_IMMUNE_HISTORY sums to 1, will fail if not. + Returns + ------- + np.ndarray + The initial immune history of individuals within each age bin in the system. + `INIT_IMMUNE_HISTORY[i][j][k][l]` describes the proportion of + individuals in age bin `i`, who fall under + immune history `j`, vaccination count `k`, and waning bin `l`. + + Examples + -------- + Assume united_states_initialization.csv exists and is valid. + >>> init = CovidSeroInitializer("c.json", "global.json") + >>> immune_histories = init.load_immune_history_via_serological_data() + >>> immune_histories.shape == (init.config.NUM_AGE_GROUPS, + ... 2**init.config.NUM_STRAINS, + ... init.config.MAX_VACCINATION_COUNT + 1, + ... init.config.NUM_WANING_COMPARTMENTS) + True + # sum across all bins except age group and ensure they sum to 1 + >>> all(np.isclose(np.sum(immune_histories, axis=(1, 2, 3)), + ... np.ones(init.config.NUM_AGE_GROUPS))) + True """ file_name = ( str(self.config.REGIONS[0]).replace(" ", "_") @@ -200,22 +251,33 @@ def load_immune_history_via_serological_data(self) -> None: "each age group does not sum to 1 in the sero initialization file of %s" % str(self.config.REGIONS[0]) ) - self.config.INIT_IMMUNE_HISTORY = sero_matrix + return sero_matrix - def load_init_infection_infected_and_exposed_dist_via_contact_matrix( + def load_initial_infection_dist_via_contact_matrix( self, - ) -> None: - """ - a function which estimates the demographics of initial infections by looking at the currently susceptible population's - level of protection as well as the contact matrix for mixing patterns. + ) -> np.ndarray: + """Estimates the demographics and immune histories of initial infections. - Disperses these infections across the E and I compartments by the ratio of the waiting times in each compartment. + Looks at the currently susceptible population's proposed level of + protection as well as the contact matrix for mixing patterns. Tailored + specifically for initialization with the `omicron` strain for + feburary 2022. + + Returns + ------- + np.ndarray + matrix describing the proportion of new infections falling under + each stratification of the compartment. E.g + `INIT_INFECTION_DIST[i][j][k][l]` describes the proportion of + individuals in age bin `i`, who fall under + immune history `j`, vaccination count `k`, and strain`l` """ - # use relative wait times in each compartment to get distribution of infections across - # infected vs exposed compartments - exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / ( - self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD - ) + if not hasattr(self.config, "CONTACT_MATRIX_PATH"): + raise RuntimeError( + "Attempting to build initial infection distribution " + "without a path to a contact matrix in " + "self.config.CONTACT_MATRIX_PATH" + ) # use contact matrix to get the infection age distributions eig_data = np.linalg.eig(self.config.CONTACT_MATRIX) max_index = np.argmax(eig_data[0]) @@ -279,25 +341,86 @@ def load_init_infection_infected_and_exposed_dist_via_contact_matrix( ), ] = 0 # disperse infections across E and I compartments by exposed_to_infectous_ratio - self.config.INIT_INFECTION_DIST = infection_dist - self.config.INIT_EXPOSED_DIST = ( - exposed_to_infectous_ratio * self.config.INIT_INFECTION_DIST + return infection_dist + + def get_initial_infectious_distribution(self) -> np.ndarray: + """Get actively infectious proportion of initial infections. + + Returns + ------- + np.ndarray + actively infectious compartment as a proportion of `INIT_INFECTION_DIST`. + + Raises + ------ + RuntimeError + if self.config.INIT_INFECTION_DIST does not exist. Usually created + via self.load_initial_infection_dist_via_contact_matrix() + """ + if not hasattr(self.config, "INIT_INFECTION_DIST"): + raise RuntimeError( + "this function requires `self.config.INIT_INFECTION_DIST`" + "set if via load_initial_infection_dist_via_contact_matrix()" + "before calling this function" + ) + # use relative wait times in each compartment to get distribution + # of infections across infected vs exposed compartments + exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / ( + self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD ) - self.config.INIT_INFECTED_DIST = ( + return ( 1 - exposed_to_infectous_ratio ) * self.config.INIT_INFECTION_DIST - def load_contact_matrix(self) -> None: + def get_initial_exposed_distribution(self) -> np.ndarray: + """Get exposed proportion of initial infections. + + Returns + ------- + np.ndarray + actively exposed compartment as a proportion of `INIT_INFECTION_DIST`. + + Raises + ------ + RuntimeError + if self.config.INIT_INFECTION_DIST does not exist. Usually created + via self.load_initial_infection_dist_via_contact_matrix() + + Notes + ----- + Ratio of initial infections across the E and I compartments + dictated by the ratio of their waiting times. + ``` + self.config.EXPOSED_TO_INFECTIOUS + / (self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD) + ``` """ - a wrapper function that loads a contact matrix for the USA based on mixing paterns data found here: - https://github.com/mobs-lab/mixing-patterns + if not hasattr(self.config, "INIT_INFECTION_DIST"): + raise RuntimeError( + "this function requires `self.config.INIT_INFECTION_DIST`" + "set if via load_initial_infection_dist_via_contact_matrix()" + "before calling this function" + ) + # use relative wait times in each compartment to get distribution + # of infections across infected vs exposed compartments + exposed_to_infectous_ratio = self.config.EXPOSED_TO_INFECTIOUS / ( + self.config.EXPOSED_TO_INFECTIOUS + self.config.INFECTIOUS_PERIOD + ) + return exposed_to_infectous_ratio * self.config.INIT_INFECTION_DIST - Updates - ---------- - `self.config.CONTACT_MATRIX` : numpy.ndarray - a matrix of shape (self.config.NUM_AGE_GROUPS, self.config.NUM_AGE_GROUPS) with each value representing TODO + def load_contact_matrix(self) -> np.ndarray: + """Load the region specific contact matrix. + + Usually sourced from https://github.com/mobs-lab/mixing-patterns + + Returns + ------- + numpy.ndarray + a matrix of shape (self.config.NUM_AGE_GROUPS, self.config.NUM_AGE_GROUPS) + where `CONTACT_MATRIX[i][j]` refers to the per capita + interaction rate between age bin `i` and `j` """ - self.config.CONTACT_MATRIX = utils.load_demographic_data( + return utils.load_demographic_data( self.config.DEMOGRAPHIC_DATA_PATH, self.config.REGIONS, self.config.NUM_AGE_GROUPS, @@ -305,21 +428,27 @@ def load_contact_matrix(self) -> None: self.config.AGE_LIMITS, )[self.config.REGIONS[0]]["avg_CM"] - def load_cross_immunity_matrix(self) -> None: - """ - Loads the Crossimmunity matrix given the strain interactions matrix. - Strain interactions matrix is a matrix of shape (num_strains, num_strains) representing the relative immune escape risk - of those who are being challenged by a strain in dim 0 but have recovered from a strain in dim 1. - Neither the strain interactions matrix nor the crossimmunity matrix take into account waning. + def load_cross_immunity_matrix(self) -> Array: + """Load the crossimmunity matrix given the strain interactions matrix. - Updates - ---------- - self.config.CROSSIMMUNITY_MATRIX: - updates this matrix to shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) containing the relative immune escape - values for each challenging strain compared to each prior immune history in the model. + Returns + ------- + jax.Array + matrix of shape (self.config.NUM_STRAINS, self.config.NUM_PREV_INF_HIST) + containing the relative immune escape values for each challenging + strain compared to each prior immune history in the model. + + Notes + ----- + Strain interactions matrix is a matrix of shape + (self.config.NUM_STRAINS, self.config.NUM_STRAINS) + representing the relative immune escape risk of those who are being + challenged by a strain in dim 0 but have recovered + previously from a strain in dim 1. + + Neither the strain interactions matrix + nor the crossimmunity matrix take into account waning. """ - self.config.CROSSIMMUNITY_MATRIX = ( - utils.strain_interaction_to_cross_immunity( - self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS - ) + return utils.strain_interaction_to_cross_immunity( + self.config.NUM_STRAINS, self.config.STRAIN_INTERACTIONS ) diff --git a/src/dynode/dynode_runner.py b/src/dynode/dynode_runner.py index f800bdcf..aa3df968 100644 --- a/src/dynode/dynode_runner.py +++ b/src/dynode/dynode_runner.py @@ -1,6 +1,6 @@ -""" -The following abstract class defines a an abstract_azure_runner, -commonly used to accelerate runs of the model onto azure this file +"""Defines a an abstract_azure_runner, to standardize DynODE experiments. + +Commonly used to accelerate runs of the model onto azure this file aids the user in the production of timeseries to describe a model run It also handles the saving of stderr and stdout copies as the job executes. @@ -25,8 +25,9 @@ class AbstractDynodeRunner(ABC): - """An Abstract class made to standardize the process of running an experiment on Azure. - Children of this class may use the functions within to standardize their processies across experiments + """An abstract class made to standardize the process of running simulations and fitting. + + Children of this class may use functions within to standardize their processies across experiments. """ def __init__(self, azure_output_dir): @@ -44,30 +45,36 @@ def __init__(self, azure_output_dir): @abstractmethod def process_state(self, state, **kwargs): - """Abstract function meant to be implemented by the instance of the runner. - This handles all of the logic of actually getting a solution object. Feel free to override - or use a different function + """Abstract function meant to be implemented by instance of the runner. - Calls upon save_config, save_inference_posteriors/save_static_run_timelines to + Entry point that handles all of the logic of getting a solution object. + + Should call helper functions like save_config, + save_inference_posteriors/save_static_run_timeseries to easily save its outputs for later visualization. Parameters ---------- state : str USPS state code for an individual state or territory. + kwargs : any + any other parameters needed to identify an individual simulation. """ pass def save_config(self, config_path: str, suffix: str = "_used"): - """saves a config json located at `config_path` appending `suffix` to the filename - to help distinguish it from other configs. + """Save a copy of config json located at `config_path`. + + Appends `suffix` to the filename to help distinguish it from input configs. Parameters ---------- config_path : str - the path, relative or absolute, to the config file wishing to be saved. + the path, relative or absolute, + to the config file wishing to be saved. suffix : str, optional - suffix to append onto filename, if "" config path remains untouched, by default "_used" + suffix to append onto filename, + if "" config filename remains untouched, by default "_used" """ config_path = config_path.replace( "\\", "/" diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index 188d3e8d..81f2e06c 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -1,7 +1,6 @@ -""" -The following code is used to fit a series of prior parameter distributions via running them -through Ordinary Differential Equations (ODEs) and comparing the likelihood of the output to some -observed metrics. +"""Fit a series of prior parameter distributions through Ordinary Differential Equations (ODEs). + +Compare the likelihood of the output to some observed metrics. """ import datetime @@ -26,9 +25,10 @@ class MechanisticInferer(AbstractParameters): - """ - A class responsible for managing the fitting process of a mechanistic runner. - Taking in priors, sampling from their distributions, managing MCMC or the sampling/fitting proceedure of choice, + """Manage the fitting process of epidemiological parameters on ODEs. + + Taking in priors, sampling from their distributions, + managing MCMC or the sampling/fitting proceedure of choice, and coordinating the parsing and use of the posterier distributions. """ @@ -39,25 +39,52 @@ def __init__( runner: MechanisticRunner, initial_state: SEIC_Compartments, ): + """Initialize an inferer object with config JSONS, a set of ODEs, and an initial state. + + Parameters + ---------- + global_variables_path : str + Path to global JSON for parameters shared across all components + of the model. + distributions_path : str + Path to inferer specific JSON of parameters containing prior distributions. + runner : MechanisticRunner + Runner class to solve ODEs and return infection timeseries. + initial_state : SEIC_Compartments + Initial compartment state at t=0. + """ distributions_json = open(distributions_path, "r").read() global_json = open(global_variables_path, "r").read() self.config = Config(global_json).add_file(distributions_json) self.runner = runner self.INITIAL_STATE = initial_state self.infer_complete = False - self.set_infer_algo() - self.retrieve_population_counts() - self.load_vaccination_model() - self.load_contact_matrix() - - def set_infer_algo(self, inferer_type: str = "mcmc") -> None: - """Sets the inferer's inference algorithm and sampler. + # set inference algo to mcmc + self.inference_algo = self.set_infer_algo() + # retrieve population age distribution via passed initial state + self.config.POPULATION = self.retrieve_population_counts() + # load all vaccination splines + ( + self.config.VACCINATION_MODEL_KNOTS, + self.config.VACCINATION_MODEL_KNOT_LOCATIONS, + self.config.VACCINATION_MODEL_BASE_EQUATIONS, + ) = self.load_vaccination_model() + self.config.CONTACT_MATRIX = self.load_contact_matrix() + + def set_infer_algo(self, inferer_type: str = "mcmc") -> MCMC: + """Set inference algorithm with attached sampler. Parameters ---------- inferer_type : str, optional infer algo you wish to use, by default "mcmc" + Returns + ------- + MCMC + returns MCMC inference algorithm as it is the only supported + algorithm currently + Raises ------ NotImplementedError @@ -74,7 +101,7 @@ def set_infer_algo(self, inferer_type: str = "mcmc") -> None: if inferer_type == "mcmc": # default to max tree depth of 5 if not specified tree_depth = getattr(self.config, "MAX_TREE_DEPTH", 5) - self.inference_algo = MCMC( + return MCMC( NUTS( self.likelihood, dense_mass=True, @@ -90,24 +117,25 @@ def set_infer_algo(self, inferer_type: str = "mcmc") -> None: def _get_predictions( self, parameters: dict, solution: Solution ) -> jax.Array: - """generates post-hoc predictions from solved timeseries in `Solution` and - parameters used to generate them within `parameters`. This will often be hospitalizations - but could be more than just that. + """Generate post-hoc predictions from solved timeseries in `Solution`. + + Optionally use parameters if sampling variables used for generating predictions. Parameters ---------- parameters : dict - parameters object returned by `get_parameters()` possibly containing information about the - infection hospitalization ratio + Parameters object returned by `self.get_parameters()` if needed + to produce predictions for likelihood. solution : Solution - Solution object returned by `_solve_runner` or any call to `self.runner.run()` - containing compartment timeseries + Solution object returned by `self._solve_runner()` or any + call to `self.runner.run()` containing compartment timeseries Returns ------- jax.Array or tuple[jax.Array] - one or more jax arrays representing the different post-hoc predictions generated from - `solution`. If fitting upon hospitalizations only, then a single jax.Array representing hospitalizations will be present. + one or more jax arrays representing the different + post-hoc predictions generated from `solution`. In this case + only hospitalizations are returned. """ # add 1 to idxs because we are stratified by time in the solution object # sum down to just time x age bins @@ -129,7 +157,25 @@ def _get_predictions( def run_simulation( self, tf: int - ) -> dict[str, Union[Solution, jax.Array],]: + ) -> dict[str, Solution | jax.Array | dict]: + """Solves ODEs and package result together with post-hoc predictions. + + Parameters + ---------- + tf : int + number of days to run simulation for + + Returns + ------- + dict[str, Solution | jax.Array | dict] + dictionary containing following key value pairs: + + solution: `diffrax.Solution` object of simulation timeseries + hospitalizations : `jax.Array` return value from + `self._get_predictions()` + parameters : `dict` result of `self.get_parameters()` passed to the + runner to generate the `diffrax.Solution` object. + """ parameters = self.get_parameters() solution = self._solve_runner(parameters, tf, self.runner) hospitalizations = self._get_predictions(parameters, solution) @@ -144,46 +190,27 @@ def likelihood( tf: int, obs_metrics: jax.Array, ): - """ + """Sample likelihood of observed metrics given suite of parameter values. + Given some observed metrics, samples the likelihood of them occuring under a set of parameter distributions sampled by self.inference_algo. - If `obs_metrics` is not defined and `infer_mode=False`, returns a dictionary - containing the Solution object returned by `self.runner`, the hospitalizations - predicted by the model, and the parameters returned by `self.get_parameters()` - - if obs_metrics is None likelihood will not actually fit to values and instead return Solutions - based on randomly sampled values. - - if obs_metrics is None, will run model for runs for `tf` days - otherwise runs for `len(obs_metrics)` days. If both `tf` and `obs_metrics` are None, raises RuntimeError. - Currently expects hospitalization data and samples IHR. Parameters ---------- - obs_metrics : jax.Array, optional - observed data, currently expecting hospitalization data, by default None - tf : int, optional - days to run model for, if obs_metrics is not provided, this parameter is used, by default None - infer_mode : bool, optional - whether or not to sample log likelihood of hospitalizations - using `obs_metrics` as observed variables, by default True + tf : int + days to run simulation for before comparing to obs_metrics + obs_metrics : jax.Array + observed data, currently expecting hospitalization data Returns ------- - dict[str, Union[Solution, jax.Array, dict]] - dictionary containing three keys, `solution`, `hospitalizations`, and `parameters` - containing the `Solution` object returned by self.runner, the predicted hospitalizations, and - the parameters run respectively - - Raises - ------ - RuntimeError - if obs_metrics is None AND tf is none, raises runtime error. Need one or the other + None """ dct = self.run_simulation(tf) solution = dct["solution"] predicted_metrics = dct["hospitalizations"] + assert isinstance(predicted_metrics, jax.Array) assert isinstance(solution, Solution) self._checkpoint_compartment_sizes(solution) predicted_metrics = jnp.maximum(predicted_metrics, 1e-6) @@ -194,19 +221,19 @@ def likelihood( ) def infer(self, obs_metrics: jax.Array) -> MCMC: - """ - Infer parameters given priors inside of self.config, - returns an inference_algo object with posterior distributions for each sampled parameter. + """Infer parameters given priors inside of self.config. + Parameters - ----------- - obs_metrics: jnp.array - observed metrics on which likelihood will be calculated on to tune parameters. - See `likelihood()` method for implemented definition of `obs_metrics` + ---------- + obs_metrics: jax.Array + observed metrics on which likelihood will be calculated on + to tune parameters. Returns - ----------- - an inference object, often numpyro.infer.MCMC object used to infer parameters. - This can be used to print summaries, pass along covariance matrices, or query posterier distributions + ------- + MCMC + The inference object, currently `numpyro.infer.MCMC`, + used to infer parameters and produce posterior samples. """ self.inference_algo.run( rng_key=PRNGKey(self.config.INFERENCE_PRNGKEY), @@ -219,15 +246,16 @@ def infer(self, obs_metrics: jax.Array) -> MCMC: return self.inference_algo def _debug_likelihood(self, **kwargs) -> bx.Model: - """uses Bayeux to recreate the self.likelihood function for purposes of basic sanity checking + """EXPERIMENTAL function recreates `self.likelihood` for basic sanity checking. - passes all parameters given to it to `self.likelihood`, initializes with `self.INITIAL_STATE` + Passes all parameters given to it to `self.likelihood`, + initializes with `self.INITIAL_STATE` and passes `self.config.INFERENCE_PRNGKEY` as seed for randomness. Returns ------- Bayeux.Model - model object used to debug + Model object used to debug. """ bx_model = bx.Model.from_numpyro( jax.tree_util.Partial(self.likelihood, **kwargs), @@ -240,15 +268,15 @@ def _debug_likelihood(self, **kwargs) -> bx.Model: return bx_model def _checkpoint_compartment_sizes(self, solution: Solution): - """marks the final_timesteps parameters as well as any - requested dates from self.config.COMPARTMENT_SAVE_DATES if the + """Take note of compartment sizes at end of `solution` and on key dates. + + Saves requested dates from `self.config.COMPARTMENT_SAVE_DATES` if the parameter exists. Skipping over any invalid dates. This method does not actually save the compartment sizes to a file, instead it stores the values within `self.inference_algo.get_samples()` so that they may be later saved by self.checkpoint() or by the user. - Parameters ---------- solution : diffrax.Solution @@ -277,23 +305,31 @@ def _checkpoint_compartment_sizes(self, solution: Solution): def checkpoint( self, checkpoint_path: str, group_by_chain: bool = True ) -> None: - """ - a function which saves the posterior samples from `self.inference_algo` into `checkpoint_path` as a json file. - will save anything sampled or numpyro.deterministic as long as it is tracked by `self.inference_algo`. + """Save the posterior samples from `self.inference_algo`. + + Saves samples into `checkpoint_path` as a json file. + will save anything sampled or numpyro.deterministic as + long as it is within the numpyro trace. Parameters - ----------- + ---------- checkpoint_path: str - a path to which the json file is saved to. Throws error if folders do not exist, overwrites existing JSON files within. + a path to which the json file is saved to. Throws error if folder + does not exist, overwrites existing JSON files within. + + group_by_chain: bool, Optional + whether or not saved JSON should retain chain/sample structure + or flatten all chains together into a single list of samples. + Default, True which retains chain structure creating 2d lists. Raises - ---------- + ------ ValueError if inference has not been called (not self.infer_complete), and thus there are no posteriors to be saved to `checkpoint_path` Returns - ----------- + ------- None """ if not self.infer_complete: @@ -327,54 +363,56 @@ def load_posterior_particle( tuple[int, int], dict[str, Union[Solution, jax.Array, dict[str, jax.Array]]], ]: - """ - loads a list (or singular) of particles defined by a chain/particle tuple. - Using sampled values from self.inference_algo.get_samples() to run - `self.likelihood` with static values from that particle. + """Simulate posterior particles without full inference flow. + + Particles are identified by a (chain, particle) indexing tuple. - if `external_posteriors` are specified, uses them instead of self.inference_algo.get_samples() - to load static particle values. + if `external_particle` is specified uses that dict instead of + self.inference_algo.get_samples() to load numpyro sites. Parameters - ------------ + ---------- particles: Union[tuple[int, int], list[tuple[int, int]]] - a single tuple or list of tuples, each of which specifies the (chain_num, particle_num) to load + a single tuple or list of tuples, each of which specifies + the (chain_num, particle_num) to load will error if values are out of range of what was sampled. tf: Union[int, None]: - number of days to run posterior model for, defaults to same number of days used in fitting - if possible. - external_posteriors: dict - for use of particles defined somewhere outside of this instance of the MechanisticInferer. - For example, loading a checkpoint.json containing saved posteriors from a different run. - expects keys that match those given to `numpyro.sample` often from - inference_algo.get_samples(group_by_chain=True). + number of days to run posterior model for, + defaults to same number of days used in fitting, if possible. + external_particle: dict + for use of particles defined somewhere outside of this class + instance. For example, loading a checkpoint.json containing saved + posteriors from a different run. Expects keys that match sampled + sites within `get_parameters()` verbose: bool, optional - whether or not to pring out the current chain_particle value being executed + whether or not to pring out the current + (chain, particle) value being loaded. Returns - --------------- + ------- `dict[tuple(int, int)]` a dictionary containing - the returned value of `self.likelihood` evaluated with values from (chain_num, particle_num). - Posterior values used append to the dictionary under the "posteriors" key. + the returned value of `self.run_simulation()` evaluated with values + from (chain_num, particle_num). + Posterior values used appended to the dictionary under the "posteriors" key. - Example - -------------- + Examples + -------- `load_posterior_particle([(0, 100), [1, 120],...]) = {(0, 100): {solution: diffrax.Solution, "posteriors": {...}}, (1, 120): {solution: diffrax.Solution, "posteriors": {...}} ...}` - Note - ------------ - Very important note if you choose to use `external_posteriors`. In the scenario - this instance of `MechanisticInferer.likelihood` samples parameters not named in `external_posteriors` - they will be RESAMPLED according to the distribution passed in the config. - This method will also salt the RNG key used on the prior according to the - chain & particule numbers currently being run. + Notes + ----- + In the scenario this instance of `MechanisticInferer.run_simulation()` + samples parameters not named in `external_particle` + they will be resampled according to the prior specified in self.config. + Some RNG salt is applied to each (chain, particle) pair so samples + missing from `external_particle` are sampled different for each particle. - This may be useful to you if you wish to fit upon some data, then introduce - a new varying parameter over the posteriors (often during projection). + This may be useful to you if you wish to fit upon some data, then + vary a new parameter over the posteriors (often during projection). """ # if its a single particle, convert to len(1) list for simplicity if isinstance(particles, tuple): @@ -435,9 +473,9 @@ def _load_posterior_single_particle( tf: int, chain_paricle_seed: int, ) -> dict: - """ - PRIVATE FUNCTION - used by `load_posterior_particle` to actually execute a single posterior particle on `self.likelihood` + """Execute `self.run_simulation()` on a single posterior particle. + + Used by `load_posterior_particle`. Dont touch unless you know what you are doing. Parameters @@ -446,17 +484,20 @@ def _load_posterior_single_particle( a dictionary linking a parameter name to its posterior value, a single value or list depending on the sampled parameter tf : int - the number of days to run the posteriors for + the number of days to run the simulation for chain_paricle_seed : int - some salting unique to the particle being run, used to randomize any NEW parameters sampled that are + some salting unique to the particle being run, + used to randomize any NEW parameters sampled that are not within `single_particle` Returns ------- - dict[str: [jax.Array, Solution]] - a solution_dict containing the return value of `self.likelihood` as well as - a field `posteriors` containing the values within `single_particle` as well as - any new sampled values created by `self.likelihood` that were not found in `single_particle` + dict[str, jax.Array| Solution | dict] + a solution_dict containing the return value of + `self.run_simulation()` as well as a key "posteriors". + "posteriors" contains the values within `single_particle` as well + as any newly sampled values created by `self.run_simulation()` + that were not found in `single_particle` """ # run the model with the same seed, but this time # all calls to numpyro.sample() will lookup the value from single_particle_chain diff --git a/src/dynode/mechanistic_runner.py b/src/dynode/mechanistic_runner.py index 47fd3815..85a7f652 100644 --- a/src/dynode/mechanistic_runner.py +++ b/src/dynode/mechanistic_runner.py @@ -1,6 +1,4 @@ -""" -The following is a class which runs a series of ODE equations, and returns Solution objects for analysis or fitting. -""" +"""Solve a system of ODEs and return a Solution object.""" import datetime from collections.abc import Callable @@ -14,6 +12,7 @@ ODETerm, PIDController, SaveAt, + Solution, Tsit5, diffeqsolve, ) @@ -27,8 +26,7 @@ class MechanisticRunner: - """A class responsible for solving Ordinary Differential Equations (ODEs) - given some initial state, parameters, and the equations themselves""" + """Solves ODEs using Diffrax and produces Solution objects.""" def __init__( self, @@ -37,6 +35,14 @@ def __init__( SEIC_Compartments, ], ): + """Initialize MechanisticRunner for solving Ordinary Differential Equations. + + Parameters + ---------- + model : Callable[[jax.typing.ArrayLike, PyTree, dict], SEIC_Compartments] + Set of ODEs, taking time, initial state, and dictionary of + parameters. + """ self.model = model def run( @@ -44,19 +50,40 @@ def run( initial_state: SEIC_Compartments, args: dict, tf: Union[int, datetime.date] = 100, - ): - """ - run `self.model` using `initial_state` as y@t=0 and parameters provided by the `args` dictionary. - `self.model` will run for `tf` days if isinstance(tf, int) - or until specified datetime if isintance(tf, datetime). + ) -> Solution: + """Solve ODEs for `tf` days using `initial_state` and `args` parameters. + + Uses diffrax.Tsit5() solver. + + + Parameters + ---------- + initial_state : SEIC_Compartments + tuple of jax arrays representing the compartments modeled by + ODEs in their initial states at t=0. + args : dict[str,Any] + arguments to pass to ODEs containing necessary parameters to + solve. + tf : int | datetime.date, Optional + number of days to solve ODEs for, if date is passed, runs + up to that date, by default 100 days + + Returns + ------- + diffrax.Solution + Solution object, sol.ys containing compartment states for each day + including t=0 and t=tf. For more information on whats included + within diffrax.Solution see: + https://docs.kidger.site/diffrax/api/solution/ - NOTE + Notes -------------- - - No partial date (or time) calculations partial days are truncated down. - - Uses date object within `args['INIT_DATE']` to calculate time between `t=0` and `t=tf` - - if `args["CONSTANT_STEP_SIZE"] > 0` uses constant stepsizer of that size, else uses adaptive step sizing - - discontinuous timepoints can not be specified with constant step sizer - - implemented with `diffrax.Tsit5()` solver + - No partial date (or time) calculations partial days are truncated + - if `args["CONSTANT_STEP_SIZE"] > 0` uses constant stepsizer of + that size, else uses adaptive step sizing with + `args["SOLVER_RELATIVE_TOLERANCE"]` and + `args["SOLVER_ABSOLUTE_TOLERANCE"]` + - discontinuous timepoints can not be specified with constant step sizer """ term = ODETerm( lambda t, state, parameters: self.model(state, t, parameters) diff --git a/src/dynode/static_value_parameters.py b/src/dynode/static_value_parameters.py index 615ca650..e73910db 100644 --- a/src/dynode/static_value_parameters.py +++ b/src/dynode/static_value_parameters.py @@ -1,7 +1,4 @@ -""" -This class is responsible for providing parameters to the model in the case that -no parameters are being sampled, and thus no complex inference or fitting is needed. -""" +"""Provides static parameters to ODEs to solve.""" from . import SEIC_Compartments from .abstract_parameters import AbstractParameters @@ -9,7 +6,7 @@ class StaticValueParameters(AbstractParameters): - """A Parameters class made for use on all static parameters, with no in-built sampling mechanism""" + """A Parameters class made for use on static parameters, with no sampling mechanism.""" def __init__( self, @@ -17,15 +14,32 @@ def __init__( runner_config_path: str, global_variables_path: str, ) -> None: + """Initialize an parameters object with config JSONS and an initial state. + + Parameters + ---------- + global_variables_path : str + Path to global JSON for parameters shared across all components + of the model. + distributions_path : str + Path to runner specific JSON of parameters containing static parameters. + runner : MechanisticRunner + Runner class to solve ODEs and return infection timeseries. + initial_state : SEIC_Compartments + Initial compartment state at t=0. + """ runner_json = open(runner_config_path, "r").read() global_json = open(global_variables_path, "r").read() self.config = Config(global_json).add_file(runner_json) self.INITIAL_STATE = INITIAL_STATE # load self.config.POPULATION - self.retrieve_population_counts() + self.config.POPULATION = self.retrieve_population_counts() # load self.config.VACCINATION_MODEL_KNOTS/ # VACCINATION_MODEL_KNOT_LOCATIONS/VACCINATION_MODEL_BASE_EQUATIONS - self.load_vaccination_model() - # load self.config.CONTACT_MATRIX - self.load_contact_matrix() - # rest of the work is handled by the AbstractParameters + # load all vaccination splines + ( + self.config.VACCINATION_MODEL_KNOTS, + self.config.VACCINATION_MODEL_KNOT_LOCATIONS, + self.config.VACCINATION_MODEL_BASE_EQUATIONS, + ) = self.load_vaccination_model() + self.config.CONTACT_MATRIX = self.load_contact_matrix() diff --git a/src/dynode/utils.py b/src/dynode/utils.py index f132d45c..5bb1d5dd 100644 --- a/src/dynode/utils.py +++ b/src/dynode/utils.py @@ -1,4 +1,4 @@ -"""A utils file full of different utility functions used within various components of Initialization, inference, running, and interpretation""" +"""utility functions used within various components of initialization, inference, and interpretation.""" import datetime import glob @@ -27,28 +27,35 @@ # SAMPLING FUNCTIONS # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ def sample_if_distribution(parameters): - """ - given a dictionary of keys and parameters, searches through all keys - and samples the distribution associated with that key, if it exists. - Otherwise keeps the value associated with that key. - Converts lists with distributions inside to `jnp.ndarray` + """Search through a dictionary and sample any `numpyro.distribution` objects found. + + Replaces the distribution object within `parameters` with a sample from + that distribution and converts all lists to `jnp.ndarray`. - Lists containing distributions will have the parameter's index - marked according to its position in the matrix. - For some 2x2 matrix `x`, `x_1_1` refers to the sampled - version of the last element of `x` + Numpyro sample site names will match the key of the `parameters` dict unless + the distribution is part of a list. Lists containing distributions will have + site name suffixes according to their index in the matrix. Parameters ---------- - `parameters: dict{str: obj}` - a dictionary mapping a parameter name to an object, either a value or a distribution. - `numpyro.distribution` objects are sampled, and their sampled value replaces the distribution object - within parameters. Capable of sampling lists with static values and distributions together. + parameters : dict[str: Any] + A dictionary mapping parameter names to any object. + `numpyro.distribution` objects are sampled, and their sampled values replace + the distribution objects within `parameters`. Returns - ---------- - parameters dictionary with any `numpyro.distribution` objects replaced with jax.tracer samples - of those distributions from `numpyro.sample` + ------- + dict + The parameters dictionary with any `numpyro.distribution` objects replaced by + samples of those distributions from `numpyro.sample`. All lists and + `np.ndarray` are replaced by `jnp.array`. + + Examples + -------- + >>> import numpyro.distributions as dist + >>> params = {'a': dist.Normal(0, 1), 'b': [dist.Normal(0, 1), dist.Normal(0, 1)]} + >>> new_params = sample_if_distribution(params) + # This would replace 'a' with a sample from Normal(0, 1) and each element in 'b' with samples from Normal(0, 1). """ for key, param in parameters.items(): # if distribution, sample and replace @@ -97,31 +104,38 @@ def sample_if_distribution(parameters): def identify_distribution_indexes( parameters: dict[str, Any], ) -> dict[str, dict[str, str | tuple | None]]: - """ - A inverse of the `sample_if_distribution()` which allows users to identify the locations - of numpyro samples. Given a dictionary of parameters, identifies which parameters - are numpyro distributions or are distributions within a list and returns a mapping - between the sample names and its actual parameter name and index. + """Identify the locations and site names of numpyro samples. - Example - -------------- - parameters = {"test": [0, numpyro.distributions.Normal(), 2], "example": numpyro.distributions.Normal()} - identify_distribution_indexes(parameters) = {"test_1": {"sample_name": "test", "sample_idx": tuple(1)}, - "example": {"sample_name": "example", "sample_idx": None}} + The inverse of `sample_if_distribution()`, identifies which parameters + are numpyro distributions and returns a mapping between the sample site + names and its actual parameter name and index. Parameters - ------------- - a dictionary containing keys of different parameters names and values of any type + ---------- + parameters : dict[str, Any] + A dictionary containing keys of different parameter + names and values of any type. Returns - ------------ - `dict[str, dict[str, str | tuple[int] | None]]` - - a dictionary mapping the sample name to the parameter name within `parameters`. - (if the sampled parameter is within a larger list, returns a tuple of indexes as well, otherwise None) - key: str -> sampled parameter name as produced by `sample_if_distribution()` - value: `dict[str, str | tuple | None]` -> "sample_name" = sample name within input `parameters` - -> "sample_idx" = sample index if within list, else None + ------- + dict[str, dict[str, str | tuple[int] | None]] + A dictionary mapping the sample name to the dict key within `parameters`. + If the sampled parameter is within a larger list, returns a tuple of indexes as well, + otherwise None. + + - key: `str` + Sampled parameter name as produced by `sample_if_distribution()`. + - value: `dict[str, str | tuple | None]` + "sample_name" maps to key within `parameters` and "sample_idx" provides + the indexes of the distribution if it is found in a list, otherwise None. + + Examples + -------- + >>> import numpyro.distributions as dist + >>> parameters = {"test": [0, dist.Normal(), 2], "example": dist.Normal()} + >>> identify_distribution_indexes(parameters) + {'test_1': {'sample_name': 'test', 'sample_idx': (1,)}, + 'example': {'sample_name': 'example', 'sample_idx': None}} """ def get_index(indexes): @@ -169,22 +183,30 @@ def get_index(indexes): # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# Vaccination modeling, using cubic splines to model vax uptake in the population stratified by age and current vax shot. +# Vaccination modeling, using cubic splines to model vax uptake +# in the population stratified by age and current vax shot. def base_equation(t, coefficients): - """ - the base of a spline equation, without knots, follows a simple cubic formula - a + bt + ct^2 + dt^3. This is a vectorized version of this equation which takes in - a matrix of `a` values, as well as a marix of `b`, `c`, and `d` coefficients. - PARAMETERS + """Compute the base of a spline equation without knots. + + Follows a simple cubic formula: a + bt + ct^2 + dt^3. + This is a vectorized version that takes in a matrix of + coefficients for each age x vaccination combination. + + Parameters ---------- - t: jax.tracer array - a jax tracer containing within it the time in days since model simulation start - intercepts: jnp.array() - intercepts of each cubic spline base equation for all combinations of age bin and vax history - intercepts.shape=(NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1) - coefficients: jnp.array() - coefficients of each cubic spline base equation for all combinations of age bin and vax history - coefficients.shape=(NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, 3) + t : jax.ArrayLike + Simulation day. + coefficients : jnp.ndarray + Coefficients of each cubic spline base equation for all + combinations of age bin and vaccination history. + Shape: (NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, 4) + + Returns + ------- + jnp.ndarray + The result of executing the base equation `a + bt + ct^2 + dt^3` + for each age group and vaccination count combination. + Shape: (NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1) """ return jnp.sum( coefficients @@ -194,6 +216,30 @@ def base_equation(t, coefficients): def conditional_knots(t, knots, coefficients): + """Evaluate knots of a spline. + + Evaluates combination of an indicator variable and the + coefficient associated with that knot. + + Executes the following equation: + sum_{i}^{len(knots)}(coefficients[i] * (t - knots[i])^3 * I(t > knots[i])) + where I() is an indicator variable. + + Parameters + ---------- + t : jax.ArrayLike + Simulation day. + knots : jax.Array + Knot locations to compare with `t`. + coefficients : jax.Array + Knot coefficients to multiply each knot with, + assuming it is active at some timestep `t`. + + Returns + ------- + jax.Array + Resulting values summed over the last dimension of the matrices. + """ indicators = jnp.where(t > knots, t - knots, 0) # multiply coefficients by 3 since we taking derivative of cubic spline. return jnp.sum(indicators**3 * coefficients, axis=-1) @@ -208,32 +254,34 @@ def evaluate_cubic_spline( base_equations: jnp.ndarray, knot_coefficients: jnp.ndarray, ) -> float: - """ - Returns the value of a cubic spline with knots and coefficients evaluated on day `t` for each age_bin x vax history combination. - Cubic spline equation: + """Evaluate a cubic spline with knots and coefficients on day `t`. - f(t) = a + bt + ct^2 + dt^3 + sum_{i}^{len(knots)}(knot_coefficients_{i} * (t-knot_locations_{i})^3 * I(t > knot_locations_{i})) - - Where coef/knots[i] is the i'th index of each array. and the I() function is an indicator variable 1 or 0. + Cubic spline equation age_bin x vaccination history combination: + ``` + f(t) = a + bt + ct^2 + dt^3 + + sum_{i}^{len(knot_locations)}(knot_coefficients[i] + * (t - knot_locations[i])^3 + * I(t > knot_locations[i])) + ``` Parameters ---------- - t: jax.tracer array - a jax tracer containing within it the time in days since model simulation start - knot_locations: jnp.ndarray - knot locations of each cubic spline for all combinations of age bin and vax history - knots.shape=(NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, # knots in each spline) - base_equations" jnp.ndarray - the base equation coefficients (a + bt + ct^2 + dt^3) of each cubic spline for all combinations of age bin and vax history - knots.shape=(NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, 4) - knot_coefficients: jnp.ndarray - knot coefficients of each cubic spline for all combinations of age bin and vax history. - including first 4 coefficients for the base equation. - coefficients.shape=(NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, # knots in each spline + 4) + t : jax.ArrayLike + Simulation day. + knot_locations : jnp.ndarray + Knot locations for all combinations of age bin and vaccination history. + Shape: (NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, #knots) + base_equations : jnp.ndarray + Base equation coefficients (a + bt + ct^2 + dt^3) for all combinations of age bin and vaccination history. + Shape: (NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, 4) + knot_coefficients : jnp.ndarray + Knot coefficients for all combinations of age bin and vaccination history. + Shape: (NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1, #knots) Returns - ---------- - jnp.array() containing the proportion of individuals in each age x vax combination that will be vaccinated during this time step. + ------- + jnp.ndarray + Proportion of individuals in each age x vaccination combination vaccinated during this time step. """ base = base_equation(t, base_equations) knots = conditional_knots(t, knot_locations, knot_coefficients) @@ -241,8 +289,9 @@ def evaluate_cubic_spline( def season_1peak(t, seasonality_second_wave, seasonality_shift): - """ - a utils function used to calculate seasonality, + """Deprecate. + + A utils function used to calculate seasonality, this one is for the winter wave occuring at t=0 if `seasonality_shift=0` and `seasonality_second_wave=0` """ @@ -252,8 +301,9 @@ def season_1peak(t, seasonality_second_wave, seasonality_shift): def season_2peak(t, seasonality_second_wave, seasonality_shift): - """ - a utils function used to calculate seasonality, + """Deprecate. + + A utils function used to calculate seasonality, this one is for the summer wave occuring at t=182.5 if `seasonality_shift=0` and `seasonality_second_wave=1` """ @@ -268,39 +318,57 @@ def season_2peak(t, seasonality_second_wave, seasonality_shift): def sim_day_to_date(sim_day: int, init_date: datetime.date): - """ - given the current model's simulation day as an integer, and the date of the model initialization. - returns a date object representing the current simulation day. + """Compute date object for given `sim_day` and `init_date`. + + Given current model's simulation day as integer and + initialization date, returns date object representing current simulation day. Parameters ---------- - sim_day: int - current model simulation day where sim_day==0==init_date - init_date: datetime.date - the initialization date of the simulation, usually found in the config.INIT_DATE parameter + sim_day : int + Current model simulation day where sim_day==0==init_date. + + init_date : datetime.date + Initialization date usually found in config.INIT_DATE parameter. Returns - ----------- - datetime.date object representing the current sim_day of the simulation. + ------- + datetime.date object representing current `sim_day` + + Examples + -------- + >>> import datetime + >>> init_date = datetime.date(2022, 10, 15) + >>> sim_day_to_date(10, init_date ) + datetime.date(2022, 10, 25 ) """ return init_date + datetime.timedelta(days=sim_day) -def sim_day_to_epiweek(sim_day: int, init_date: datetime.date): - """ - given the current model's simulation day as an integer, and the date of the model initialization. - returns an integer cdc epiweek that sim_day falls in. +def sim_day_to_epiweek( + sim_day: int, init_date: datetime.date +) -> epiweeks.Week: + """Calculate CDC epiweek that sim_day falls in. Parameters ---------- - sim_day: int - current model simulation day where sim_day==0==init_date - init_date: datetime.date - the initialization date of the simulation, usually found in the config.INIT_DATE parameter + sim_day : int + Current model simulation day where sim_day==o==init_date. + + init_date : datetime.date + Initialization date usually found in config.INIT_DATE parameter. Returns - ----------- - epiweek.Week object representing the cdc epiweek of the simulation on day sim_day. + ------- + epiweeks.Week + CDC epiweek on day sim_day + + Examples + -------- + >>> import datetime + >>> init_date=datetime.date(2022, 10, 15) + >>> sim_day_to_epiweek(10, init_date ) + epiweeks.Week(year=2022, week=42) """ date = sim_day_to_date(sim_day, init_date) epi_week = epiweeks.Week.fromdate(date) @@ -308,67 +376,92 @@ def sim_day_to_epiweek(sim_day: int, init_date: datetime.date): def date_to_sim_day(date: datetime.date, init_date: datetime.date): - """ - given a date object, converts back to simulation days using init_date as reference for t=0 + """Convert date object to simulation days using init_date as reference point. Parameters ---------- - sim_day: datetime.date - date to be converted to a simulation day - init_date: datetime.date - the initialization date of the simulation, usually found in the config.INIT_DATE parameter + date : datetime.date + Date being converted into integer simulation days. + + init_date : datetime.date + Initialization date usually found in config.INIT_DATE parameter. Returns - ----------- - int simulation day representing how many days from `init_date` have passed. + ------- + int + how many days have passed since `init _date` + + Examples + -------- + >>> import datetime + >>> init_date=datetime.date(2022, 10, 15) + >>> date=datetime.date(2022, 11, 05) + >>> date_to_sim_day(date, init_date) + 21 """ return (date - init_date).days def date_to_epi_week(date: datetime.date): - """ - given a date object, converts to cdc epi week using init_date as reference for t=0 + """Convert a date object to CDC epi week. Parameters ---------- - sim_day: datetime.date - date to be converted to a simulation day + sim_day : datetime.date + Date to be converted to a simulation day. + Returns - ----------- - epiweeks.Week obj representing the epi_week that `date` falls in + ------- + epiweeks.Week + The epi_week that `date` falls in. """ epi_week = epiweeks.Week.fromdate(date) return epi_week def new_immune_state(current_state: int, exposed_strain: int) -> int: - """a method using BITWISE OR to determine a new immune state position given - current state and the exposing strain + """Determine a new immune state after applying an exposing strain to an immune state. + + Uses bitwise OR given the current state and the exposing strain. Parameters ---------- - current_state: int - int representing the current state of the individual or group being exposed to a strain - exposed_strain: int - int representing the strain exposed to the individuals in state `current_state`. - expects that `0 <= exposed_strain <= num_strains - 1` - num_strains: int - number of strains in the model - - Example - ---------- + current_state : int + Int representing the current state of the + individual or group being exposed to a strain. + exposed_strain : int + Int representing the strain exposed to the + individuals in state `current_state`. + expects that `0 <= exposed_strain <= num_strains - 1`. + + Returns + ------- + int + Individual or population's new immune state after exposure and recovery + from `exposed_strain`. + + Examples + -------- num_strains = 2, possible states are: - 00(no exposure), 1(exposed to strain 0 only), 2(exposed to strain 1 only), 3(exposed to both) - - new_immune_state(current_state, exposed_strain): new_state (explanation) - new_immune_state(0, 0): 1 (no previous exposure, now exposed to strain 0) - new_immune_state(0, 1): 2 (no previous exposure, now exposed to strain 1) - new_immune_state(1, 0): 1 (exposed to strain 0 already, no change in state) - new_immune_state(2, 1): 2 (exposed to strain 1 already, no change in state) - new_immune_state(1, 1): 3 (exposed to strain 0 prev, now exposed to both) - new_immune_state(2, 0): 3 (exposed to strain 1 prev, now exposed to both) - new_immune_state(3, 0): 3 (exposed to both already, no change in state) - new_immune_state(3, 1): 3 (exposed to both already, no change in state) + 00(no exposure), 1(exposed to strain 0 only), 2(exposed to strain 1 only), + 3(exposed to both) + + >>> new_immune_state(current_state = 0, exposed_strain = 0) + 1 #no previous exposure, now exposed to strain 0 + >>> new_immune_state(0, 1) + 2 #no previous exposure, now exposed to strain 1 + >>> new_immune_state(1, 0) + 1 #exposed to strain 0 already, no change in state + >>> new_immune_state(2, 1) + 2 #exposed to strain 1 already, no change in state + >>> new_immune_state(1, 1) + 3 #exposed to strain 0 previously, now exposed to both + >>> new_immune_state(2, 0) + 3 #exposed to strain 1 previously, now exposed to both + >>> new_immune_state(3, 0) + 3 #exposed to both already, no change in state + >>> new_immune_state(3, 1) + 3 #exposed to both already, no change in state """ if isinstance(exposed_strain, (int, float)) and isinstance( current_state, (int, float) @@ -380,7 +473,7 @@ def new_immune_state(current_state: int, exposed_strain: int) -> int: int(current_state_binary, 2) | int(exposed_strain_binary, 2), "b" ) return int(new_state, 2) - else: # being used with jax tracers + else: # being passed jax.ArrayLike # if we are passing jax tracers, convert to bit arrays first current_state_binary = jnp.unpackbits( jnp.array([current_state]).astype("uint8") @@ -394,29 +487,33 @@ def new_immune_state(current_state: int, exposed_strain: int) -> int: def all_immune_states_with(strain: int, num_strains: int): - """ - a function returning all of the immune states which contain an exposure to `strain` + """Determine all immune states which contain an exposure to `strain`. Parameters ---------- - strain: int - int representing the exposed to strain, expects that `0 <= strain <= num_strains - 1` - num_strains: int - number of strains in the model + strain : int + Int representing the exposed-to strain, + expects that `0 <= strain <= num_strains - 1`. + num_strains : int + Number of strains in the model. Returns - ---------- - list[int] representing all states that include previous exposure to `strain` - - Example - ---------- - in a simple model where num_strains = 2 the following is returned. - Reminder: state = 0 (no exposure), - state = 1/2 (exposure to strain 0/1 respectively), state=3 (exposed to both) - - all_immune_states_with(0, 2) -> [1, 3] - - all_immune_states_with(1, 2) -> [2, 3] + ------- + list[int] + all immune states that include previous exposure to `strain` + + Examples + -------- + in a simple model where num_strains = 2 + Reminder: + state = 0 (no exposure), + state = 1/2 (exposure to strain 0/1 respectively), + state = 3 (exposed to both) + + >>> all_immune_states_with(strain = 0, num_strains = 2) + [1, 3] + >>> all_immune_states_with(strain = 1, num_strains = 2) + [2, 3] """ # represent all possible states as binary binary_array = [bin(val) for val in range(2**num_strains)] @@ -434,29 +531,33 @@ def all_immune_states_with(strain: int, num_strains: int): def all_immune_states_without(strain: int, num_strains: int): - """ - function returning all of the immune states which DO NOT contain an exposure to `strain` + """Determine all immune states which do not contain an exposure to `strain`. Parameters ---------- - strain: int - int representing the NOT exposed to strain, expects that `0 <= strain <= num_strains - 1` - num_strains: int - number of strains in the model + strain : int + Int representing the NOT exposed to strain, + expects that `0 <= strain <= num_strains - 1`. + num_strains : int + Number of strains in the model. Returns - ---------- - list[int] representing all states that DO NOT include previous exposure to `strain` - - Example - ---------- - in a simple model where num_strains = 2 the following is returned. - Reminder: state = 0 (no exposure), - state = 1/2 (exposure to strain 0/1 respectively), state=3 (exposed to both) - - all_immune_states_without(strain = 0, num_strains = 2) -> [0, 2] - - all_immune_states_without(strain = 1, num_strains = 2) -> [0, 1] + ------- + list[int] representing all immune states that + do not include previous exposure to `strain` + + Examples + -------- + in a simple model where num_strains = 2. + Reminder: + state = 0 (no exposure), + state = 1/2 (exposure to strain 0/1 respectively), + state = 3 (exposed to both) + + >>> all_immune_states_with(strain = 0, num_strains = 2) + [0, 2] + >>> all_immune_states_with(strain = 1, num_strains = 2) + [0, 1] """ all_states = list(range(2**num_strains)) states_with_strain = all_immune_states_with(strain, num_strains) @@ -465,21 +566,23 @@ def all_immune_states_without(strain: int, num_strains: int): def get_strains_exposed_to(state: int, num_strains: int): - """ - Returns a list of integers representing the strains a given individual was exposed to end up in state `state`. - Says nothing of the order at which an individual was exposed to those strains, list returned sorted increasing. + """Unpack all strain exposures an immune state was exposed to. + + Says nothing of the order at which an individual was exposed to strains. Parameters - ----------- - state: int - the state a given individual is in, as dicated by a single or series of exposures to strains. - state dynamics determined by `new_immune_state()` - num_strains: int - the total number of strains in the model, used to determin total size of state space. + ---------- + state : int + The state a given individual is in, as dicated by a single or series of + exposures to strains. State dynamics determined by `new_immune_state()`. + num_strains : int + The total number of strains in the model, + used to determin total size of state space. Returns - ----------- - list[int] representing the strains the individual in `state` was exposed to, sorted increasing. + ------- + list[int] + strains the individual in `state` was exposed to. """ state_binary = format(state, "b") # prepend 0s if needed. @@ -498,29 +601,38 @@ def get_strains_exposed_to(state: int, num_strains: int): def combined_strains_mapping( from_strain: int, to_strain: int, num_strains: int ): - """ - given a strain `from_strain` and `to_strain` returns a mapping of all immune states before and after strains are combined. - - Example - ----------- - in a basic 2 strain model you have the following immune states: - 0-> no exposure, 1 -> strain 0 exposure, 2-> strain 1 exposure, 3-> exposure to both - - calling `combine_strains(1, 0, 2)` will combine strains 0 and 1 returning - `{0:0, 1:1, 2:1, 3:1}`, because there is no functional difference strain 0 and 1 the immune state space becomes binary. + """Merge two strain definitions together. Parameters ---------- - from_strain: int - the strain index representing the strain being collapsed, whos references will be rerouted. - to_strain: int - the strain index representing the strain being joined with to_strain, typically the ancestral or 0 index. + from_strain : int + The strain index representing the strain being collapsed, + whos references will be rerouted. + to_strain : int + The strain index representing the strain being joined with + to_strain, typically the ancestral or 0 index. + num_strains : int + Number of strains in the model, constrains immune state space. Returns - ----------- - dict[int:int] mapping from immune state -> immune state before and after `from_strain` is combined with `to_strain` for all states. + ------- + tuple(dict[int,int], dict[int,int]) + First dict[int,int] maps from immune state -> immune state before and + after `from_strain` is combined with `to_strain` for all states. + + Second dict[int,int] maps from strain idx -> strain idx + before and after`from_strain` is combined with `to_strain` for all strains. - dict[int:int] mapping from strain idx -> strain idx before and after`from_strain` is combined with `to_strain` for all strains. + Examples + -------- + In a basic 2 strain model you have the following immune states: + 0-> no exposure, 1 -> strain 0 exposure, + 2-> strain 1 exposure, 3-> exposure to both + + >>> combine_strains(from_strain = 1, to_strain = 0, num_strains = 2) + ({0:0, 1:1, 2:1, 3:1}, {0:0, 1:0}), + # immune state space becomes binary. + # both strain 0 and 1 now route to strain 0 """ # we do nothing if from_strain is equal to to_strain, we arent collapsing anything there. if from_strain == to_strain: @@ -569,33 +681,41 @@ def combine_strains( strain_dim=3, strain_axis=False, ): - """ - takes an individual compartment and combines the states and strains within it according to `state_mapping` and `strain_mapping`. - If compartment has a strain axis, strain_axis=True. + """Merge two or more strain definitions together within a compartment. + + Combines the state dimensions and optionally the strain dimension if + `strain_axis=True`. Parameters ---------- - compartment: np.ndarray - the compartment being changed, must be four dimensional with immune state in the 2nd dimension and strain (if applicable) in the last dimension. - state_mapping: dict[int:int] - a mapping of pre-combine state to post-combine state, as generated by combined_strains_mapping(), must cover all states found in `compartment`. - can be many to one relationship of keys to state values. - strain_mapping: dict[int:int] - a mapping of pre-combine strain to post-combine strain, as generated by combined_strains_mapping(), must cover all strains found in `compartment`. - can be many to one relationship of keys to state values. - num_strains: int - number of strains in the model - state_dim: int - if the dimension of the immune_state column is non-standard, specify which dimension immune state is found in - strain_dim: int - if the dimension of the strain column is non-standard, specify which dimension strain is found in - strain_axis: bool - whether or not `compartment` includes a strain axis in the last dimension that must also be combined. + compartment : np.ndarray + The compartment being changed, must be four dimensional + with immune state in the `state_dim` dimension and + strain (if applicable) in the `strain_dim` dimension. + state_mapping : dict[int:int] + A mapping of pre-combine state to post-combine state, + as generated by `combined_strains_mapping()`, + must cover all states found in `compartment[state_dim]`. + strain_mapping : dict[int:int] + A mapping of pre-combine strain to post-combine strain, + as generated by `combined_strains_mapping()`, + must cover all strains found in `compartment[strain_dim]`. + num_strains : int + Number of strains in the model. + state_dim : int + Which dimension in `compartment` immune state is found in, default 1. + strain_dim : int + Which dimension in `compartment` strain num is found in, if applicable, + default 3. + strain_axis : bool + Whether or not `compartment` includes a strain axis + in `strain_dim`. Not all compartments track `strain`. Returns - ---------- + ------- np.ndarray: - a modified copy of `compartment` with all immune states and strains combined according to state_mapping and strain_mapping + A modified copy of `compartment` with all immune states and + strains combined according to `state_mapping` and `strain_mapping` """ # begin with a copy of the compartment in all zeros strain_combined_compartment = np.zeros(compartment.shape) @@ -637,40 +757,53 @@ def combine_strains( def combine_epochs( epoch_solutions, from_strains, to_strains, strain_idxs, num_tracked_strains ): - """ - given N epochs, combines their solutions by translating all immune states and infections of past epochs into the most recent epochs defintions. - Solutions are expected to be 5 dimensions, with the first dimension being timesteps, and the remaining 4 following the standard compartment structure. - immune state in the 3rd dimension (of 5) and strain in the last dimension (if applicable). + """Deprecate due to bad design and complexity. - `epoch_solutions`, `from_strains`, and `to_strains` must be given in order from earliest to most recent epoch. - Does not assume anything about the dates or times these events occured on other than them being sequential. + Given N epochs, combines their solutions by translating all immune states + and infections of past epochs into the most recent epochs defintions. + Solutions are expected to be 5 dimensions, with the first dimension being + timesteps, and the remaining 4 following the standard compartment structure. + immune state in the 3rd dimension (of 5) and + strain if applicable for E and I compartments in the last dimension. + + `epoch_solutions`, `from_strains`, and `to_strains` must be given in + order from earliest to most recent epoch. Does not assume anything about + the dates or times these events occured on other than them being sequential. Parameters ---------- epoch_solutions: list[tuple(np.ndarray)] - a list of each epoch's solution.ys object as given by Diffeqsolve().ys or BasicMechanisticModel.run().ys + a list of each epoch's solution.ys object as given by + Diffeqsolve().ys or BasicMechanisticModel.run().ys in order from earliest to most recent epoch. from_strains: list[int/None] - a parallel list of strain indexes indiciating the strain combinations that occured at the end of each epoch. - len(from_strain) = N-1 for N epochs, since last epoch does not combine with anything + a parallel list of strain indexes indiciating the strain + combinations that occured at the end of each epoch. + len(from_strain) = N-1 for N epochs, since last epoch + does not combine with anything to_strains: list[int/None] - a parallel list of strain indexes indiciating the strain combinations that occured at the end of each epoch. - len(to_strain) = N-1 for N epochs, since last epoch does not combine with anything + a parallel list of strain indexes indiciating the strain combinations + that occured at the end of each epoch. + len(to_strain) = N-1 for N epochs, since last epoch does + not combine with anything strain_idxs: list[IntEnum] a list of IntEnums to identify the strain name to index for each epoch. len(strain_idxs) = N for N epochs num_strains_consistent: int - the number of strains consistent across all epochs, their definitions may change but there are always `num_tracked_strains` tracked in each epoch. + the number of strains consistent across all epochs, their definitions + may change but there are always `num_strains_consistent` + tracked in each epoch. Returns - ----------- - tuple(np.ndarray): a single state object that combines the timelines of all N epochs with states and strain definitions matching that of the most recent epoch. + ------- + tuple(np.ndarray): a single state object that combines the timelines of + all N epochs with states and strain definitions matching that + of the most recent epoch. """ - transition_tables = [] # create transition tables for each epoch to the next for idx, (from_strain, to_strain) in enumerate( @@ -747,19 +880,28 @@ def combine_epochs( def find_age_bin(age: int, age_limits: list[int]) -> int: """ - Given an age, return the age bin it belongs to in the age limits array + Given an age, return the age bin it belongs to in the age limits array. Parameters ---------- - age: int - age of the individual to be binned - age_limits: list(int) - age limit for each age bin in the model, begining with minimum age - values are exclusive in upper bound. so [0,18) means 0-17, 18+ + age : int + Age of the individual or population to be binned. + age_limits : list[int] + Age limit for each age bin in the model, beginning with minimum age, + values are exclusive in upper bound. so [0,18] means 0-17, 18+. Returns - ---------- - The index of the bin, assuming 0 is the youngest age bin and len(age_limits)-1 is the oldest age bin + ------- + int + The index of the bin, assuming 0 is the youngest age bin + and len(age_limits)-1 is the oldest age bin. + + Examples + -------- + >>> [find_age_bin(age = age, age_limits = [0,18,50,65]) + ... for age in [0, 17, 18, 49, 50, 64, 65, 100]] + [0, 0, 1, 1, 2, 2, 3, 3] + """ current_bin = -1 for age_limit in age_limits: @@ -771,30 +913,32 @@ def find_age_bin(age: int, age_limits: list[int]) -> int: def find_vax_bin(vax_shots: int, max_doses: int) -> int: - """ - Given a number of vaccinations, returns the bin it belongs to given the maximum doses ceiling + """Calculate vaccination bin. Parameters ---------- - vax_shots: int - the number of vaccinations given to the individual - max_doses: int - the number of doses maximum before all subsequent doses are no longer counted + vax_shots : int + The number of vaccinations given. + max_doses : int + The number of doses maximum before all subsequent + doses are no longer counted. Returns - ---------- - The index of the vax bin, min(vax_shots, max_doses) + ------- + int + Index representing which vaccination bin the population + or individual belong to. """ return min(vax_shots, max_doses) def convert_hist(strains: str, STRAIN_IDX: IntEnum) -> int: - """ - a function that transforms a comma separated list of strains and transform them into an immune history state. + """Parse a comma separated list of strains into an immune history state. + Any unrecognized strain strings inside of `strains` do not contiribute to the returned state. - Example - ---------- + Examples + -------- strains: "alpha, delta, omicron" STRAIN_IDX: delta=0, omicron=1 num_strains: 2 @@ -809,7 +953,6 @@ def convert_hist(strains: str, STRAIN_IDX: IntEnum) -> int: an enum containing the name of each strain and its associated strain index, as initialized by ConfigBase. num_strains: the number of _tracked_ strains in the model. - """ state = 0 for strain in filter(None, strains.split(",")): @@ -819,20 +962,19 @@ def convert_hist(strains: str, STRAIN_IDX: IntEnum) -> int: def convert_strain(strain: str, STRAIN_IDX: IntEnum) -> int: - """ - given a text description of a string, return the correct strain index as specified by the STRAIN_IDX enum. - If strain is not found in STRAIN_IDX, return 0 (the oldest strain included in the model) + """Lookup strain name in STRAIN_IDX, return 0 if not found. Parameters - ----------- + ---------- strain: str a string representing the infecting strain, capitalization does not matter. STRAIN_IDX: intEnum an enum containing the name of each strain and its associated strain index, as initialized by ConfigBase. Returns - ---------- - STRAIN_IDX[strain] if exists, else 0 + ------- + int + STRAIN_IDX[strain] if exists, else 0 """ if strain.lower() in STRAIN_IDX._member_map_: return STRAIN_IDX[strain.lower()] @@ -842,18 +984,19 @@ def convert_strain(strain: str, STRAIN_IDX: IntEnum) -> int: def find_waning_compartment(TSLIE: int, waning_times: list[int]) -> int: """ - Given a TSLIE (time since last immunogenetic event) in days, returns the waning compartment index of the event. + Determine the waning compartment index based on time since last immunogenetic event (TSLIE). Parameters ---------- - TSLIE: int - the number of days since the initialization of the model that the immunogenetic event occured (this could be vaccination or infection). - waning_times: list(int) - the number of days an individual stays in each waning compartment, ending in zero as the last compartment does not wane. + TSLIE : int + Days since the immunogenetic event (e.g., vaccination or infection). + waning_times : list[int] + Days an individual stays in each waning compartment, ending in zero. Returns - ---------- - index of the waning compartment that an event belongs, to if that event happened `TSLIE` days in the past. + ------- + int + Index of the waning bin for an event that occurred `TSLIE` days ago. """ # possible with cumulative sum, but this solution still O(N) and more interpretable current_bin = 0 @@ -871,23 +1014,28 @@ def strain_interaction_to_cross_immunity( num_strains: int, strain_interactions: np.ndarray ) -> Array: """ - a function which takes a strain_interactions matrix, which is of shape (num_strains, num_strains) - and returns a cross immunity matrix of shape (num_strains, 2**num_strains) representing the immunity - of all 2**num_strains immune histories against some challenging strain. + Convert a strain interaction matrix to a cross-immunity matrix. Parameters ---------- - num_strains: int - the number of strains for which the crossimmunity matrix is being generated. - strain_interactions: np.array - a matrix of shape (num_strains, num_strains) representing the relative immunity of someone recovered from - one strain to a different challenging strain. 1's in the diagnal representing 0 reinfection (before waning). + num_strains : int + Number of strains in the model. + strain_interactions : np.ndarray + Matrix (num_strains, num_strains) representing + relative immunity from one strain to another. + `strain_interactions[i][j] = 1.0` states + full immunity from challenging strain `i` after + recovery from strain `j`. Returns - ---------- - crossimmunity_matrix: jnp.array - a matrix of shape (num_strains, 2**num_strains) representing the relative immunity of someone with a specific - immune history to a challenging strain. + ------- + jax.Array + Matrix (num_strains, 2**num_strains) representing immunity + for all immune history permutations against a challenging strain. + + Notes + ----- + Relative immunity does not account for waning. """ infection_history = range(2**num_strains) crossimmunity_matrix = jnp.zeros((num_strains, len(infection_history))) @@ -936,25 +1084,21 @@ def strain_interaction_to_cross_immunity( def drop_sample_chains(samples: dict, dropped_chain_vals: list): """ - a function, given a dictionary which is the result of a call to `mcmc.get_samples()` - drops specified chains from the posterior samples. This is usually done when a single or multiple - chains do not converge with the other chains. This ensures that this divergent chain does not - impact posterior distributions meant to summarize the posterior samples. + Drop specified chains from posterior samples. Parameters - ----------- - `samples`: dict{str: list} - a dictionary where parameter names are keys and samples are a list. - In the case of M chains and N samples per chain, the list will be of shape MxN - with one row per chain, each containing N samples. - - `dropped_chain_vals`: list - a list of indexes (rows in the MxN grouped samples list) to be dropped, - if the list is empty no chains are dropped. + ---------- + samples : dict[str, list] + Dictionary with parameter names as keys and sample + lists as values. Shapes of these values are (M,N) for + a model with M chains and N samples per chain. + dropped_chain_vals : list[int] + List of chain indices to be dropped. If empty, no chains are dropped. Returns - ---------- - dict{str: list} a copy of the samples dictionary with chains in `dropped_chain_vals` dropped + ------- + dict[str, list] + Copy of samples dictionary with specified chains removed. """ # Create a new dictionary to store the filtered samples filtered_dict = {} @@ -980,31 +1124,23 @@ def flatten_list_parameters( samples: dict[str, np.ndarray], ) -> dict[str, np.ndarray]: """ - given a dictionary of parameter names and samples, identifies any parameters that are - placed under a single name, but actually multiple independent draws from the same distribution. - These parameters are often the result of a call to `numpyro.plate(P)` for some number of draws `P` - After identifying plated samples, this function will separate the `P` draws into their own - keys in the samples dictionary. + Flatten plated parameters into separate keys in the samples dictionary. Parameters ---------- - `samples`: dict{str: np.ndarray} - a dictionary where parameter names are keys and samples are a list. - In the case of M chains and N samples per chain, the list will be of shape MxN normally - with one row per chain, each containing N samples. - In the case that the parameter is drawn P independent times, the list will be of shape - MxNxP. + samples : dict[str, np.ndarray] + Dictionary with parameter names as keys and sample + arrays as values. Arrays may have shape MxNxP for P independent draws. Returns - ---------- - dict{str: np.ndarray} a dictionary in which parameters with lists of shape MxNxP are split into - P separate parameters, each with lists of shape MxN for M chains and N samples. - This function scaled with any number of dimensions > 2. So for PxQ independent draws - will be separated into PxQ parameters each with lists of shape MxN. - - NOTE - ----------- - If you only have parameters of dimension 2, nothing will be changed and a copy of your dict will be returned + ------- + dict[str, np.ndarray] + Dictionary with plated parameters split into + separate keys. Each new key has arrays of shape MxN. + + Notes + ----- + If no plated parameters are present, returns a copy of the dictionary. """ return_dict = {} for key, value in samples.items(): @@ -1029,19 +1165,20 @@ def flatten_list_parameters( def drop_keys_with_substring(dct: dict[str, Any], drop_s: str): - """A simple helper function designed to drop keys from a dictionary if they contain some substring + """ + Drop keys from a dictionary if they contain a specified substring. Parameters ---------- dct : dict[str, Any] - a dictionary with string keys + Dictionary with string keys. drop_s : str - keys containing `drop_s` as a substring will be dropped + Substring to check for in keys. Returns ------- - dict[str, any] - dct with keys containing drop_s removed, otherwise untouched. + dict[str, Any] + Dictionary with keys containing `drop_s` removed. """ keys_to_drop = [key for key in dct.keys() if drop_s in key] for key in keys_to_drop: @@ -1055,30 +1192,33 @@ def drop_keys_with_substring(dct: dict[str, Any], drop_s: str): def convolve_hosp_to_death(hosp, hfr, shape, scale, padding="nan"): - """ - Model deaths based on hospitalizations. The function calculates expected deaths based - on input weekly age-specific `hospitalization` and hospitalization fatality risk + """Model deaths based on hospitalizations. + + The function calculates expected deaths based on input weekly age-specific + `hospitalization` and hospitalization fatality risk (`hfr`), then delay the deaths (relative to hospitalization) based on a gamma distribution of parameters `shape` and `scale`. The gamma specification is _daily_, which then gets discretized into 5 weeks for convolution. + Parameters ---------- `hosp` : numpy.array - age-specific weekly hospitalization with shape of (num_weeks, NUM_AGE_GROUPS) + Age-specific weekly hospitalization with shape of (num_weeks, NUM_AGE_GROUPS) `hfr`: numpy.array - age-specific hospitalization fatality risk with shape of (NUM_AGE_GROUPS) - shape: float - shape parameter of the gamma delay distribution, is > 0 - scale: float - scale parameter of the gamma delay distribution, is > 0 and 1/rate - padding: str "nan", "nearest" or "no" - boolean flag determining if the output array is of same length as `hosp` with + Age-specific hospitalization fatality risk with shape of (NUM_AGE_GROUPS) + shape : float + Shape parameter of the gamma delay distribution, is > 0 + scale : float + Scale parameter of the gamma delay distribution, is > 0 and 1/rate + padding : str {"nan", "nearest", "no"} + Boolean flag determining if the output array is of same length as `hosp` with first 4 weeks padded with nan or not. Note: the "valid" modelled deaths would always be 4 weeks less than input hospitalization. + Returns - ---------- - numpy.array: - list of `num_day` vaccination rates arrays, each by the shape of (NUM_AGE_GROUPS, + ------- + numpy.array + List of `num_day` vaccination rates arrays, each by the shape of (NUM_AGE_GROUPS, MAX_VAX_COUNT + 1) """ expected_deaths = hosp * hfr[None, :] @@ -1109,19 +1249,22 @@ def convolve_hosp_to_death(hosp, hfr, shape, scale, padding="nan"): def generate_yearly_age_bins_from_limits(age_limits: list) -> list[list[int]]: - """ - given age limits, generates age bins with each year contained in that bin up to 85 years old exclusive - - Example - ---------- - age_limits = [0, 5, 10, 15 ... 80] - returns [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14]... [80, 81, 82, 83, 84]] + """Generate age bins up to 85 years old exclusive based on age limits. Parameters ---------- - age_limits: list(int): - beginning with minimum age inclusive, boundary of each age bin exclusive. Not including last age bin. - do not include implicit 85 in age_limits, this function appends that bin automatically. + age_limits : list[int] + Boundaries of each age bin. The last bin is implicitly up to 85. + + Returns + ------- + list[list[int]] + List of lists containing integer years within each age bin. + + Examples + -------- + >>> generate_yearly_age_bins_from_limits([0, 5, 10, ... 80]) + [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9],... [80, 81, 82, 83, 84]] """ age_groups = [] for age_idx in range(1, len(age_limits)): @@ -1137,24 +1280,24 @@ def load_age_demographics( regions: list[str], age_limits: list[int], ) -> dict[str, np.ndarray]: - """Returns normalized proportions of each agebin as defined by age_limits for the regions given. - Does this by searching for age demographics data in path. + """Load normalized proportions of each age bin for given regions. Parameters ---------- - path: str - path to the demographic-data folder, either relative or absolute. - regions: list(str) - list of FIPS regions to create normalized proportions for - age_limits: list(int) - age limits for each age bin in the model, begining with minimum age - values are exclusive in upper bound. so [0, 18, 50] means 0-17, 18-49, 50+ - max age is enforced at 84 inclusive. All persons older than 84 in population numbers are counted as 84 years old + path : str + Path to the demographic-data folder. + regions : list(str) + List of FIPS regions + age_limits : list(int) + Age limits for each bin; values are exclusive upper bounds. + Max tracked age is enforced at 84 inclusive. All + populations older than 84 are counted as 84 years old. Returns - ---------- - demographic_data : dict - a dictionary maping FIPS code region supplied in `regions` to an array of length `len(age_limits)` representing + ------- + demographic_data : dict[str, np.ndarray] + A dictionary maping FIPS region supplied in `regions` + to an array of length `len(age_limits)` representing the __relative__ population proportion of each bin, summing to 1. """ assert os.path.exists( @@ -1217,21 +1360,18 @@ def load_age_demographics( def plot_sample_chains(samples): - """ - a function that given a dictionary of parameter names and MxN samples for each parameter - plots the trace plot of each of the M chains through the N samples in that chain. + """Plot trace plots of M chains through N samples for each parameter. Parameters ---------- - `samples`: dict{str: list} - a dictionary where parameter names are keys and samples are a list. - In the case of M chains and N samples per chain, the list will be of shape MxN - with one row per chain, each containing N samples. + samples : dict[str, list] + Dictionary with parameter names as keys and sample lists as values (MxN shape). Returns - ---------- - plots each parameter along with each chain of that parameter, - also returns `plt.fig` and `plt.axs` objects for modification. + ------- + tuple[matplotlib.Figure, matplotlib.Axes] + Plots each parameter along with each chain of that parameter, + also returns `plt.fig` and `plt.axs` objects for modification. """ # ensure samples are all NxM before plotting if any([samples[key].ndim == 3 for key in samples.keys()]): @@ -1259,42 +1399,42 @@ def get_timeline_from_solution_with_command( strain_idx: IntEnum, command: str, ): - """ - A function designed to execute `command` over a `sol` object, returning a timeline after `command` is used to select a certain view of `sol` + """Execute `command` over a Solution object to obtain a view on the timeseries. Possible values of `command` include: - a compartment title, as specified in the `compartment_idx` IntEnum. Eg:"S", "E", "I" - a strain title, as specified in `strain_idx` IntEnum. Eg "omicron", "delta" - a wane index, as specified by `w_idx`. Eg: "W0" "W1" - - a numpy slice of a compartment title, as specified in the `compartment_idx` IntEnum. Eg: "S[:, 0, 0, :]" or "E[:, 1:3, [0,1], 1]" + - a numpy slice of a compartment title, as specified in the `compartment_idx` + IntEnum. Eg: "S[:, 0, 0, :]" or "E[:, 1:3, [0,1], 1]" Format must include compartment title, followed by square brackets and comma separated slices. Do NOT include extra time dimension found in the sol object. Assume dimensionality of the compartment as in initialization. Parameters ---------- `sol` : tuple(jnp.array) - generally .ys object containing ODE run as described by https://docs.kidger.site/diffrax/api/solution/ + Generally .ys object containing ODE run as described by + https://docs.kidger.site/diffrax/api/solution/ a tuple containing the ys of the ODE run. `compartment_idx`: IntEnum: - an enum containing the name of each compartment and its associated compartment index, - as initialized by the config file of the model that generated `sol` + An enum containing the name of each compartment and its associated compartment index, + as initialized by the config file of the model that generated `sol`. `w_idx`: IntEnum: - an enum containing the name of each waning compartment and its associated compartment index, - as initialized by the config file of the model that generated `sol` + An enum containing the name of each waning compartment and its associated compartment index, + as initialized by the config file of the model that generated `sol`. `strain_idx`: intEnum - an enum containing the name of each strain and its associated strain index, - as initialized by the config file of the model that generated `sol` + An enum containing the name of each strain and its associated strain index, + as initialized by the config file of the model that generated `sol`. `command`: str - a string command of the format specified in the function description. + A string command of the format specified in the function description. Returns - ---------- - tuple(jnp.array, str): - a slice of the `sol` object collapsed into the first dimension of the command selected. - eg: return.shape = sol[0].shape[0] since all first dimensions in sol are equal normally. - label: a string with the label of the new line, - this helps with interpretability as commands sometimes lack necessary context + ------- + tuple(jnp.array, str) + a slice of the `sol` object collapsed into the first dimension + a string with the label of the new line, helps with + interpretability as commands sometimes lack necessary context. """ def is_close(x): @@ -1381,21 +1521,20 @@ def is_close(x): def get_var_proportions(inferer, solution): """ - Calculate _daily_ variant proportions based on a simulation run. + Calculate daily variant proportions based on a simulation run. Parameters ---------- - `inferer` : AbstractParameters - an AbstractParameters (e.g., MechanisticInferer or StaticValueParameters) that - is used to produce `solution`. - `solution`: tuple(jnp.array) - solution object that comes out from an ODE run (specifically through - `diffrax.diffeqsolve`) + inferer : AbstractParameters + An AbstractParameters (e.g., MechanisticInferer or + StaticValueParameters) used to produce `solution`. + solution : diffrax.Solution + Solution object from an ODE run (specifically through `diffrax.diffeqsolve`). Returns - ---------- - jnp.array: - an array of strain prevalence by the shape of (num_days, NUM_STRAINS) + ------- + jnp.ndarray + Array of strain prevalence with shape (num_days, inferer.config.NUM_STRAINS). """ strain_incidence = jnp.sum( solution.ys[inferer.config.COMPARTMENT_IDX.C], @@ -1412,22 +1551,19 @@ def get_var_proportions(inferer, solution): def get_seroprevalence(inferer, solution): """ - Calculate the seroprevalence (more precisely the cumulative attack rate) based on - a simulation run. + Calculate the seroprevalence (cumulative attack rate) based on a simulation run. Parameters ---------- - `inferer` : AbstractParameters - an AbstractParameters (e.g., MechanisticInferer or StaticValueParameters) that - is used to produce `solution`. - `solution`: tuple(jnp.array) - solution object that comes out from an ODE run (specifically through - `diffrax.diffeqsolve`) + inferer : AbstractParameters + An AbstractParameters (e.g., MechanisticInferer or StaticValueParameters) used to produce `solution`. + solution : tuple[jnp.ndarray] + Solution object from an ODE run (specifically through `diffrax.diffeqsolve`). Returns - ---------- - jnp.array: - an array of seroprevalence by the shape of (num_days, NUM_AGE_GROUPS) + ------- + jnp.ndarray + Array of seroprevalence with shape (num_days, NUM_AGE_GROUPS). """ never_infected = jnp.sum( solution.ys[inferer.config.COMPARTMENT_IDX.S][:, :, 0, :, :], @@ -1442,24 +1578,21 @@ def get_seroprevalence(inferer, solution): def get_foi_suscept(p, force_of_infection): - """ - Calculate the force of infections experienced by the susceptibles, _after_ - factoring their immunity. + """Calculate the force of infections experienced by susceptibles after factoring their immunity. Parameters ---------- - `p` : Parameters - a Parameters object which is a spoofed dictionary for easy referencing, - which is an output of `.get_parameters()` from AbstractParameter. - `force_of_infection`: jnp.array - an array of (NUM_AGE_GROUPS, NUM_STRAINS) that quantifies the force of - infection experienced by age group by strain. + p : Parameters + A Parameters object that is a spoofed dictionary for easy referencing, + output of `.get_parameters()` from AbstractParameter. + force_of_infection : jnp.ndarray + Array of shape (NUM_AGE_GROUPS, NUM_STRAINS) quantifying + the force of infection by age group and strain. Returns - ---------- - jnp.array: - an array of immunity protection by the shape of (NUM_STRAINS, num_days, - NUM_AGE_GROUPS) + ------- + jnp.ndarray + Array of immunity protection with shape (NUM_STRAINS, num_days, NUM_AGE_GROUPS). """ foi_suscept = [] for strain in range(p.NUM_STRAINS): @@ -1499,9 +1632,10 @@ def get_foi_suscept(p, force_of_infection): def get_immunity(inferer, solution): - """ - Calculate the age-strain-specific population immunity. Specifically, the expected - immunity of a randomly selected person of certain age towards certain strain. + """Calculate the age-strain-specific population immunity. + + Specifically, the expected immunity of a randomly selected person of + certain age towards certain strain. Parameters ---------- @@ -1513,7 +1647,7 @@ def get_immunity(inferer, solution): `diffrax.diffeqsolve`) Returns - ---------- + ------- jnp.array: an array of immunity protection by the shape of (NUM_STRAINS, num_days, NUM_AGE_GROUPS) @@ -1537,8 +1671,7 @@ def get_immunity(inferer, solution): def get_vaccination_rates(inferer, num_day): - """ - Calculate _daily_ vaccination rates over the course of `num_day`. + """Calculate _daily_ vaccination rates over the course of `num_day`. Parameters ---------- @@ -1549,7 +1682,7 @@ def get_vaccination_rates(inferer, num_day): number of simulation days Returns - ---------- + ------- list: list of `num_day` vaccination rates arrays, each by the shape of (NUM_AGE_GROUPS, MAX_VACCINATION_COUNT + 1) @@ -1562,16 +1695,13 @@ def get_vaccination_rates(inferer, num_day): # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -def rho(M: np.ndarray) -> np.ndarray: - return np.max(np.real(np.linalg.eigvals(M))) - - def make_two_settings_matrices( path_to_population_data: str, path_to_settings_data: str, region: str = "United States", ) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]: - """ + """Load and parse settings contact matricies for a given region. + For a single region, read the two column (age, population counts) population csv (up to age 85) then read the 85 column interaction settings csvs by setting (four files) and combine them into an aggregate 85 x 85 matrix @@ -1664,7 +1794,8 @@ def create_age_grouped_CM( minimum_age: int, age_limits, ) -> tuple[np.ndarray, list[float]]: - """ + """Load a contact matrix and group it into age bins. + Parameters ---------- region_data : pd.DataFrame @@ -1672,6 +1803,13 @@ def create_age_grouped_CM( population sizes setting_CM : np.ndarray An 85x85 contact matrix for a given setting (either school or other) + num_age_groups : int + number of age bins. + minimum_age : int + lowest possible tracked age in years. + age_limits : list[int] + Age limit for each age bin in the model, beginning with minimum age, + values are exclusive in upper bound. so [0,18] means 0-17, 18+. Returns ------- @@ -1727,10 +1865,26 @@ def load_demographic_data( minimum_age, age_limits, ) -> dict[str, dict[str, np.ndarray]]: - """ - Loads demography data for the specified FIPS regions, contact mixing data sourced from: + """Load demography data for the specified FIPS regions. + + Contact mixing data sourced often from: https://github.com/mobs-lab/mixing-patterns + Parameters + ---------- + demographics_path : str + path to demographic data directory, contains "contact_matrices" and + "population_rescaled_age_distributions" directories. + regions : list[str] + list of FIPS regions to load. + num_age_groups : int + number of age bins. + minimum_age : int + lowest possible tracked age in years. + age_limits : list[int] + Age limit for each age bin in the model, beginning with minimum age, + values are exclusive in upper bound. so [0,18] means 0-17, 18+. + Returns ------- demographic_data : dict @@ -1790,8 +1944,8 @@ def load_demographic_data( # Save one of the two N_ages (they are the same) in a new N_age var N_age = N_age_sch # Rescale contact matrices by leading eigenvalue - avg_CM = avg_CM / rho(avg_CM) - sch_CM = sch_CM / rho(sch_CM) + avg_CM = avg_CM / np.max(np.real(np.linalg.eigvals(avg_CM))) + sch_CM = sch_CM / np.max(np.real(np.linalg.eigvals(sch_CM))) # Transform Other cm with the new age limits [NB: to transpose?] region_demographic_data_dict = { "sch_CM": sch_CM.T, @@ -1817,52 +1971,95 @@ class Parameters(object): """A dummy container that converts a dictionary into attributes.""" def __init__(self, dict: dict): + """Initialize an empty spoof parameters object. + + Parameters + ---------- + dict : dict + parameters and data for spoof class to hold. + """ self.__dict__ = dict class dual_logger_out(object): - """ - a class that splits stdout, flushing its contents to a file as well as to stdout - this is useful for Azure Batch to save logs but also see the output live on the node + """Split stdout, flushing its contents to a file as well as to stdout. + + Useful for experiments to save logs but also see the output live. """ def __init__(self, name, mode): + """Spoofs stdout __init__ but redirects flow to a file as well. + + Parameters + ---------- + name : str + File name to pipe output to. + mode : str + file open mode, usually "w" or "x". + """ self.file = open(name, mode) self.stdout = sys.stdout sys.stdout = self def close(self): + """Finish writing to file and direct stdout back to sys.stdout.""" sys.stdout = self.stdout self.file.close() def write(self, data): + """Write `data` to file and to sys.stdout. + + Parameters + ---------- + data : str + data to write to file and to sys.stdout + """ self.file.write(data) self.stdout.write(data) def flush(self): + """Flush file contents.""" self.file.flush() class dual_logger_err(object): - """ - a class that splits stderror, flushing its contents to a file as well as to terminal if an error occurs - this is useful for Azure Batch to save logs but also see the output live on the node + """Splits stderror, flushing its contents to a file as well as to terminal. + + Useful for experiments to save logs but also see the output live. """ def __init__(self, name, mode): + """Spoofs stderr __init__ but redirects flow to a file as well. + + Parameters + ---------- + name : str + File name to pipe output to. + mode : str + file open mode, usually "w" or "x". + """ self.file = open(name, mode) self.stderr = sys.stderr sys.stderr = self def close(self): + """Finish writing to file and direct stderr back to sys.stderr.""" sys.stderr = self.stderr self.file.close() def write(self, data): + """Write `data` to file and to sys.stderr. + + Parameters + ---------- + data : str + data to write to file and to sys.stderr + """ self.file.write(data) self.stderr.write(data) def flush(self): + """Flush file contents.""" self.file.flush() @@ -1874,8 +2071,9 @@ def flush(self): def find_files( directory: str, filename_contains: str, recursive=False ) -> list[str]: - """searched `directory` for any files with `filename_contains`, - optionally searched recrusively down from `directory` + """Search `directory` for any files with `filename_contains`. + + Optionally search recrusively down from `directory`. Parameters ---------- diff --git a/src/dynode/vis_utils.py b/src/dynode/vis_utils.py index d96faf67..0b852b8a 100644 --- a/src/dynode/vis_utils.py +++ b/src/dynode/vis_utils.py @@ -1,4 +1,4 @@ -"""A series of utility functions for generating visualizations for the model""" +"""A set of utility functions for generating visualizations for the model.""" from typing import Any @@ -19,6 +19,8 @@ class VisualizationError(Exception): + """An exception class for Visualization Errors.""" + pass @@ -87,51 +89,36 @@ def plot_model_overview_subplot_matplotlib( "seaborn-v0_8-colorblind", ], ) -> plt.Figure: - """Given a dataframe resembling the azure_visualizer_timeline csv, - if it exists, returns an overview figure. The figure will contain 1 column - per state in `timeseries_df["state"]` if the column exists. The - figure will contain one row per plot_type + """Generate an overview figure containing subplots for various model metrics. Parameters ---------- - timeseries_df : pandas.DataFrame - a dataframe containing at least the following columns: - ["date", "chain_particle", "state"] followed by columns identifying - different timeseries of interest to be plotted. - E.g. vaccination_0_17, vaccination_18_49, total_infection_incidence. - columns that share the same plot_type will be plotted on the same plot, - with their differences in the legend. - All chain_particle replicates are plotted as low - opacity lines for each plot_type + timeseries_df : pd.DataFrame + DataFrame containing at least ["date", "chain_particle", "state"] + followed by columns for different time series to be plotted. + pop_sizes : dict[str, int] - population sizes of each state as a dictionary. - Keys must match the "state" column within timeseries_df + Population sizes for each state as a dictionary. Keys must match + the values in the "state" column of `timeseries_df`. + plot_types : np.ndarray[str], optional - each of the plot types to be plotted. - plot_types not found in `timeseries_df` are skipped. - columns are identified using the "in" operation, - so plot_type must be found in each of its identified columns - by default ["seasonality_coef", "vaccination_", - "_external_introductions", "_strain_proportion", "_average_immunity", - "total_infection_incidence", "pred_hosp_"] + Types of plots to be generated. + Elements not found in `timeseries_df` are skipped. + plot_titles : np.ndarray[str], optional - titles for each plot_type as displayed on each subplot, - by default [ "Seasonality Coefficient", "Vaccination Rate By Age", - "External Introductions by Strain (per 100k)", - "Strain Proportion of New Infections", - "Average Population Immunity Against Strains", - "Total Infection Incidence (per 100k)", - "Predicted Hospitalizations (per 100k)"] + Titles for each subplot corresponding to `plot_types`. + plot_normalizations : np.ndarray[int] - normalization factor for each plot type + Normalization factors for each plot type. + matplotlib_style: list[str] | str - matplotlib style to plot in, by default ["seaborn-v0_8-colorblind"] + Matplotlib style to use for plotting. Returns ------- - matplotlib.pyplot.Figure - matplotlib Figure containing subplots with a column for each state - and a row for each plot_type + plt.Figure + Matplotlib Figure containing subplots with one column per state + and one row per plot type. """ necessary_cols = ["date", "chain_particle", "state"] assert all( @@ -269,36 +256,27 @@ def plot_checkpoint_inference_correlation_pairs( "seaborn-v0_8-colorblind", ], ): - """Given a dictionary mapping a sampled parameter's name to its - posteriors samples, returns a figure plotting - the correlation of each sampled parameter with all other sampled parameters - on the upper half of the plot the correlation values, on the diagonal a - historgram of the posterior values, and on the bottom half a scatter - plot of the parameters against eachother along with a matching trend line. - + """Plot correlation pairs of sampled parameters with histograms and trend lines. Parameters ---------- - posteriors_in: dict[str , np.ndarray | list] - a dictionary (usually loaded from the checkpoint.json file) containing - the sampled posteriors for each chain in the shape - (num_chains, num_samples). All parameters generated with numpyro.plate - and thus have a third dimension (num_chains, num_samples, num_plates) - are flattened to the desired shape and displayed as - separate parameters with _i suffix for each i in num_plates. - max_samples_calculated: int - a max cap of posterior samples per chain on which - calculations such as correlations and plotting will be performed - set for efficiency of plot generation, - set to -1 to disable cap, by default 100 - matplotlib_style: list[str] | str - matplotlib style to plot in, by default ["seaborn-v0_8-colorblind"] + posteriors_in : dict[str, np.ndarray | list] + Dictionary mapping parameter names to their posterior samples + (shape: num_chains, num_samples). Parameters generated with + numpyro.plate are flattened and displayed as separate parameters + with _i suffix for each i in num_plates. + + max_samples_calculated : int + Maximum number of posterior samples per chain for calculations + such as correlations and plotting. Set to -1 to disable cap; default is 100. + + matplotlib_style : list[str] | str + Matplotlib style to use for plotting; default is ["seaborn-v0_8-colorblind"]. Returns ------- - matplotlib.pyplot.Figure - Figure with `n` rows and `n` columns where - `n` is the number of sampled parameters + plt.Figure + Figure with n rows and n columns where n is the number of sampled parameters. """ # convert lists to np.arrays posteriors: dict[str, np.ndarray] = flatten_list_parameters( @@ -410,27 +388,23 @@ def plot_mcmc_chains( "seaborn-v0_8-colorblind", ], ) -> plt.Figure: - """given a `samples` dictionary containing posterior samples - often returned from numpyro.get_samples(group_by_chain=True) - or from the checkpoint.json saved file, plots each MCMC chain - for each sampled parameter in a roughly square subplot. + """Plot MCMC chains for each sampled parameter in a grid of subplots. Parameters ---------- - posteriors: dict[str , np.ndarray | list] - a dictionary (usually loaded from the checkpoint.json file) containing - the sampled posteriors for each chain in the shape - (num_chains, num_samples). All parameters generated with numpyro.plate - and thus have a third dimension (num_chains, num_samples, num_plates) - are flattened to the desired and displayed as + samples_in : dict[str, np.ndarray | list] + Dictionary containing posterior samples (shape: num_chains, num_samples). + Parameters generated with numpyro.plate are flattened and displayed as separate parameters with _i suffix for each i in num_plates. + matplotlib_style : list[str] | str, optional - matplotlib style to plot in by default ["seaborn-v0_8-colorblind"] + Matplotlib style to use for plotting; + default is ["seaborn-v0_8-colorblind"]. Returns ------- - matplotlib.pyplot.Figure - matplotlib figure containing the plots + plt.Figure + Matplotlib figure containing the plots. """ # Determine the number of parameters and chains samples: dict[str, np.ndarray] = flatten_list_parameters( @@ -485,30 +459,30 @@ def plot_prior_distributions( num_samples=5000, hist_kwargs={"bins": 50, "density": True}, ) -> plt.Figure: - """Given a dictionary of parameter keys and possibly values of - numpyro.distribution objects, samples them a number of times - and returns a plot of those samples to help - visualize the range of values taken by that prior distribution. + """Visualize prior distributions by sampling from them and plotting the results. Parameters ---------- priors : dict[str, Any] - a dictionary with str keys possibly containing distribution - objects as values. Each key with a distribution object type - key will be included in the plot + Dictionary with string keys and distribution + objects as values. Each key with a distribution object will be + included in the plot. + matplotlib_style : list[str] | str, optional - matplotlib style to plot in by default ["seaborn-v0_8-colorblind"] - num_samples: int, optional - the number of times to sample each distribution, mild impact on - figure performance. By default 50000 - hist_kwargs: dict[str: Any] - additional kwargs passed to plt.hist(), by default {"bins": 50} + Matplotlib style to use for plotting; + default is ["seaborn-v0_8-colorblind"]. + + num_samples : int, optional + Number of times to sample each distribution; + default is 5000. + + hist_kwargs : dict[str: Any] + Additional kwargs passed to `plt.hist()`; default is {"bins": 50}. Returns ------- plt.Figure - matplotlib figure that is roughly square containing all distribution - keys found within priors. + Matplotlib figure containing all distribution keys found within `priors`. """ dist_only = {} d = identify_distribution_indexes(priors)