Skip to content

Commit

Permalink
comment concise abstract_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Jan 13, 2025
1 parent 94215e9 commit 57f52cb
Showing 1 changed file with 107 additions and 117 deletions.
224 changes: 107 additions & 117 deletions src/dynode/abstract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,22 @@ def __init__(self) -> None:
def _solve_runner(
self, parameters: dict, tf: int, runner: MechanisticRunner
) -> Solution:
"""runs the runner (ode-solver) for `tf` days using
parameters defined in `parameters` returning a Diffrax Solution object
"""
Run the ODE solver for a specified number of days using given parameters.
Parameters
----------
parameters : dict
parameters object containing parameters required by the runner ODEs
Dictionary containing parameters required by the runner's ODEs.
tf : int
number of days to run the runner for
Number of days to run the ODE solver.
runner : MechanisticRunner
runner class designated with solving ODEs
Instance of MechanisticRunner designated for solving ODEs.
Returns
-------
Solution
diffrax solution object returned from `runner.run()`
Diffrax solution object returned from `runner.run()`.
"""
if "INITIAL_INFECTIONS_SCALE" in parameters.keys():
initial_state = self.scale_initial_infections(
Expand All @@ -106,19 +106,19 @@ def _solve_runner(

def _get_upstream_parameters(self) -> dict:
"""
returns a dictionary containing self.UPSTREAM_PARAMETERS,
erroring if any of the parameters are not in self.config.
Samples any parameters which are of type(numpyro.distribution).
Retrieve upstream parameters from the configuration, sampling any distributions.
Returns
------------
dict[str: Any]
returns a dictionary where keys map to parameters within
self.UPSTREAM_PARAMETERS and the values are the value of that
parameter within self.config. numpyro.distribution objects are sampled
and replaced with a jax ArrayLike sample from that distribution.
-------
dict[str, Any]
Dictionary mapping keys to parameters within `self.UPSTREAM_PARAMETERS`.
Values are taken from `self.config`, with numpyro.distribution objects
sampled and replaced with JAX ArrayLike samples.
Raises
------
RuntimeError
If any parameter in `self.UPSTREAM_PARAMETERS` is not found in `self.config`.
"""
# multiple chains of MCMC calling get_parameters()
# should not share references, deep copy, GH issue for this created
Expand All @@ -139,24 +139,24 @@ def _get_upstream_parameters(self) -> dict:
return parameters

def generate_downstream_parameters(self, parameters: dict) -> dict:
"""takes an existing parameters object and attempts to generate
downstream dependent parameters, based on the values contained within
`parameters`.
Raises RuntimeError if a downstream parameter
does not find the necessary values it needs within `parameters`
"""
Generate downstream dependent parameters based on upstream values.
Parameters
----------
parameters : dict
parameters dictionary generated by `self._get_upstream_parameters()`
containing static or sampled values on which
downstream parameters may depend
Dictionary generated by `self._get_upstream_parameters()` containing
static or sampled values that downstream parameters may depend on.
Returns
-------
dict
an appended onto version of `parameters` with additional downstream parameters added.
Updated version of `parameters` with additional downstream parameters added.
Raises
------
RuntimeError
If a downstream parameter cannot find the necessary upstream values within `parameters`.
"""
try:
# create or re-recreate parameters based on other possibly sampled parameters
Expand Down Expand Up @@ -207,18 +207,15 @@ def generate_downstream_parameters(self, parameters: dict) -> dict:

def get_parameters(self) -> dict:
"""
Goes through parameters listed in self.UPSTREAM_PARAMETERS,
sampling them if they are distributions.
Then generates any downstream parameters that rely on those parameters
in self.generate_downstream_parameters().
Returning the resulting dictionary for use in ODEs
Retrieve parameters by sampling upstream distributions and
generating downstream parameters.
Returns
-----------
dict{str, Any}
dict containing a combination of `self.UPSTREAM_PARAMETERS` found
in `self.config` and downstream parameters from
`self.generate_downstream_parameters()`
-------
dict[str, Any]
Dictionary containing a combination of `self.UPSTREAM_PARAMETERS` found
in `self.config` and downstream parameters generated from
`self.generate_downstream_parameters()`.
"""
parameters = self._get_upstream_parameters()
parameters = self.generate_downstream_parameters(parameters)
Expand All @@ -234,54 +231,37 @@ def external_i(
introduction_pcts: jax.Array,
) -> jax.Array:
"""
Given some time t, returns jnp.array of shape
`self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I]` representing
external infected persons interacting with the population.
CONTINUOUS AND DIFFERENTIABLE FOR ALL TIMES t.
The stratafication of the external population is decided by the
introduced strains, which are defined by 3 parallel lists.
peak intro rate date (`introduction_times`),
how quickly or slowly those individuals contact the tracked population
(`introduction_scales`), and the magnitude of external
infected individuals introduced as a % of the tracked population
(`introduction_pcts`)
Calculate the number of external infected individuals interacting
with the population at time t.
Parameters
----------
`t`: ArrayLike
current time in the model.
t : ArrayLike
Current time in the model.
`introduction_times`: jax.Array
a list representing the times at which external strains should peak
in their rate of external introduction.
if `len(introduction_times) < len(self.config.STRAIN_R0s)` earlier
strains are not introduced.
introduction_times : jax.Array
List representing times at which external strains should peak
in their rate of introduction.
`introduction_scales`:jax.Array
a list representing the standard deviation of the
curve that external strains are introduced with, in days
introduction_scales : jax.Array
List representing the standard deviation of the curve for introducing
external strains, in days.
`introduction_pcts`: jax.Array
a list representing the proportion of each age bin in
self.POPULATION[self.config.INTRODUCTION_AGE_MASK] that will be
exposed to the introduced strain over the whole curve
introduction_pcts : jax.Array
List representing the proportion of each age bin in
`self.POPULATION[self.config.INTRODUCTION_AGE_MASK]`
that will be exposed to the introduced strain over the entire curve.
Returns
-----------
-------
jax.Array
jnp.array(shape=(self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I].shape))
of external individuals to the system interacting with tracked
susceptibles within the system, used to impact force of infection.
An array of shape matching `self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I]`
representing external individuals interacting with tracked susceptibles.
Note
-----------
use the boolean list `self.config.INTRODUCTION_AGE_MASK`
to select which age bins the external populations will have.
External populations are not tracked, but still interact with
the contact matrix, meaning the age of external "travelers"
is still a factor in the spread of a new strain.
Notes
-----
Use `self.config.INTRODUCTION_AGE_MASK` to select which age bins are affected by external populations.
External populations are not tracked but still interact with the contact matrix, influencing spread dynamics.
"""

# define a function that returns 0 for non-introduced strains
Expand Down Expand Up @@ -335,31 +315,26 @@ def zero_function(_):
@partial(jax.jit, static_argnums=(0))
def vaccination_rate(self, t: ArrayLike) -> jax.Array:
"""
Given some time t, returns a jnp.array of shape
(self.config.NUM_AGE_GROUPS, self.config.MAX_VACCINATION_COUNT + 1)
representing the instantaneous age / vax history stratified vaccination
rates for an additional vaccine.
CONTINUOUS AND DIFFERENTIABLE FOR ALL TIMES `t`.
Calculate the instantaneous vaccination rates stratified by age and vaccination history.
Parameters
----------
t: ArrayLike
current time in the model.
t : ArrayLike
Current time in the model.
Returns
-----------
vaccination_rates: jnp.Array
`shape=(self.config.NUM_AGE_GROUPS, self.config.MAX_VACCINATION_COUNT + 1) `
of vaccination rates for each age bin and vax history strata.
-------
jax.Array
An array of shape (self.config.NUM_AGE_GROUPS, self.config.MAX_VACCINATION_COUNT + 1)
representing vaccination rates for each age bin and vaccination history strata.
Note
-----------
use `self.config.VACCINATION_MODEL_DAYS_SHIFT` param to shift t=0 for
specifically the vaccination_rates() function and not the whole model.
Notes
-----
Use `self.config.VACCINATION_MODEL_DAYS_SHIFT` to adjust t=0 specifically for this function.
Refer to `load_vaccination_model` for details on spline definitions and loading.
see `load_vaccination_model` for description on spline definitions and
loading.
The function is continuous and differentiable for all times `t`.
"""
# shifting splines if needed for multi-epochs, 0 by default
t_added = getattr(self.config, "VACCINATION_MODEL_DAYS_SHIFT", 0)
Expand All @@ -384,36 +359,44 @@ def vaccination_rate(self, t: ArrayLike) -> jax.Array:

@partial(jax.jit, static_argnums=(0))
def beta_coef(self, t: ArrayLike) -> ArrayLike:
"""Mechanism to directly influence transmission rate to account for
external factors. Defaults to 1.0. Modified via the `BETA_TIMES` and
`BETA_COEFICIENTS` config parameters. See Example for
behavior.
"""
Calculate the coefficient to modify the transmission rate based on external factors.
Parameters
----------
t: ArrayLike
current time in the model.
t : ArrayLike
Current time in the model.
Returns
----------
ArrayLike
Coefficient by which BETA can be multiplied to externally
increase or decrease the value.
-------
ArrayLike
Coefficient by which BETA can be multiplied to
externally increase or decrease its value.
Example
----------
multiple values of `t` passed in at once to save space.
-------
Multiple values of `t` being passed at once for this example only.
>>> self.config.BETA_COEFICIENTS
jnp.array([-1.0, 0.0, 1.0])
>>> self.config.BETA_TIMES
jnp.array([25, 50])
>>> self.beta_coef(t=[0, 24])
[-1. -1.]
>>> self.beta_coef(t=[25, 26, 49])
[0. 0. 0.]
>>> self.beta_coef(t=[50, 51, 99])
[1. 1. 1.]
Notes
-----
The function defaults to a coefficient of 1.0 if no
modifications are specified. It uses `BETA_TIMES` and
`BETA_COEFICIENTS` from the configuration for adjustments.
"""
# a smart lookup function that works with JAX just in time compilation
# if t > self.config.BETA_TIMES_i, return self.config.BETA_COEFICIENTS_i
Expand Down Expand Up @@ -740,24 +723,31 @@ def scale_initial_infections(
self, scale_factor: ArrayLike
) -> SEIC_Compartments:
"""
a function which modifies returns a modified version of
self.INITIAL_STATE scaling the number of initial infections by `scale_factor`.
Scale the number of initial infections by a specified factor.
Preserves the ratio of the Exposed/Infectious compartment population sizes.
Does not modified self.INITIAL_STATE, returns a copy.
This function returns a modified version of `self.INITIAL_STATE`,
scaling the number of initial infections while preserving the ratio
between the Exposed and Infectious compartments. The original
`self.INITIAL_STATE` remains unchanged.
Parameters
----------
scale_factor: float
a multiplier value >=0.0.
`scale_factor < 1` reduces number of initial infections,
`scale_factor == 1.0` leaves initial infections unchanged,
`scale_factor > 1` increases number of initial infections.
scale_factor : float
A multiplier value >= 0.0.
- `scale_factor < 1`: Reduces the number of initial infections.
- `scale_factor == 1.0`: Leaves initial infections unchanged.
- `scale_factor > 1`: Increases the number of initial infections.
Returns
---------
A copy of `self.INITIAL_INFECTIONS` with each compartment
scaled according to `scale_factor`
-------
SEIC_Compartments
A copy of `self.INITIAL_STATE` with each compartment
scaled up or down depending on `scale_factor`.
Notes
-----
The function ensures that the relative sizes of
Exposed and Infectious compartments are preserved during scaling.
"""
pop_counts_by_compartment = jnp.array(
[
Expand Down

0 comments on commit 57f52cb

Please sign in to comment.