Skip to content

Commit

Permalink
Reworks InitializeInfectionsExponentialGrowth to work with `numpyro…
Browse files Browse the repository at this point in the history
….plate` (#432)
  • Loading branch information
damonbayer authored Sep 6, 2024
1 parent adf5aaf commit 3c5fbe7
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 55 deletions.
22 changes: 8 additions & 14 deletions pyrenew/latent/infection_initialization_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod):
def __init__(
self,
n_timepoints: int,
rate: RandomVariable,
rate_rv: RandomVariable,
t_pre_init: int | None = None,
):
"""Default constructor for the ``InitializeInfectionsExponentialGrowth`` class.
Expand All @@ -152,13 +152,13 @@ def __init__(
----------
n_timepoints : int
the number of time points to generate initial infections for
rate : RandomVariable
rate_rv : RandomVariable
A random variable representing the rate of exponential growth
t_pre_init : int | None, optional
The time point whose number of infections is described by ``I_pre_init``. Defaults to ``n_timepoints - 1``.
"""
super().__init__(n_timepoints)
self.rate = rate
self.rate_rv = rate_rv
if t_pre_init is None:
t_pre_init = n_timepoints - 1
self.t_pre_init = t_pre_init
Expand All @@ -177,15 +177,9 @@ def initialize_infections(self, I_pre_init: ArrayLike):
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
"""
I_pre_init = jnp.array(I_pre_init)
if I_pre_init.size != 1:
raise ValueError(
f"I_pre_init must be an array of size 1. Got size {I_pre_init.size}."
)
rate = jnp.array(self.rate()[0].value)
if rate.size != 1:
raise ValueError(
f"rate must be an array of size 1. Got size {rate.size}."
)
return I_pre_init * jnp.exp(
rate * (jnp.arange(self.n_timepoints) - self.t_pre_init)
rate = jnp.array(self.rate_rv()[0].value)
initial_infections = I_pre_init * jnp.exp(
rate
* (jnp.arange(self.n_timepoints)[:, jnp.newaxis] - self.t_pre_init)
)
return jnp.squeeze(initial_infections)
134 changes: 93 additions & 41 deletions test/test_infection_initialization_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,64 +14,116 @@
def test_initialize_infections_exponential():
"""Check that the InitializeInfectionsExponentialGrowth class generates the correct number of infections at each time point."""
n_timepoints = 10
rate_RV = DeterministicVariable(name="rate_RV", value=0.5)
I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", value=10.0)
default_t_pre_init = n_timepoints - 1
t_pre_init = 6

(I_pre_init,) = I_pre_init_RV()
(rate,) = rate_RV()
rate_RV = DeterministicVariable(name="rate_RV", value=np.array([0.5, 0.1]))
rate_scalar_RV = DeterministicVariable(name="rate_RV", value=0.5)

I_pre_init = I_pre_init.value
rate = rate.value
infections_default_t_pre_init = InitializeInfectionsExponentialGrowth(
n_timepoints, rate=rate_RV
rate = rate_RV()[0].value
rate_scalar = rate_scalar_RV()[0].value

I_pre_init = np.array([5.0, 10.0])
I_pre_init_scalar = 5.0

# both rate and I_pre_init are arrays with default t_pre_init
result = InitializeInfectionsExponentialGrowth(
n_timepoints, rate_rv=rate_RV
).initialize_infections(I_pre_init)
infections_default_t_pre_init_manual = I_pre_init * np.exp(
rate * (np.arange(n_timepoints) - default_t_pre_init)
)

testing.assert_array_almost_equal(
infections_default_t_pre_init, infections_default_t_pre_init_manual
manual_result = np.column_stack(
[
I_pre_init[0]
* np.exp(rate[0] * (np.arange(n_timepoints) - default_t_pre_init)),
I_pre_init[1]
* np.exp(rate[1] * (np.arange(n_timepoints) - default_t_pre_init)),
]
)

# assert that infections at default t_pre_init is I_pre_init
assert infections_default_t_pre_init[default_t_pre_init] == I_pre_init
## check that the result is as expected
testing.assert_array_almost_equal(result, manual_result)

# test for failure with non-scalar rate or I_pre_init
rate_RV_2 = DeterministicVariable(
name="rate_RV", value=np.array([0.5, 0.5])
## check that infections at default t_pre_init is I_pre_init
testing.assert_array_equal(result[default_t_pre_init], I_pre_init)

# both rate and I_pre_init are arrays with custom t_pre_init
result = InitializeInfectionsExponentialGrowth(
n_timepoints, rate_rv=rate_RV, t_pre_init=t_pre_init
).initialize_infections(I_pre_init)

manual_result = np.column_stack(
[
I_pre_init[0]
* np.exp(rate[0] * (np.arange(n_timepoints) - t_pre_init)),
I_pre_init[1]
* np.exp(rate[1] * (np.arange(n_timepoints) - t_pre_init)),
]
)
with pytest.raises(ValueError):
InitializeInfectionsExponentialGrowth(
n_timepoints, rate=rate_RV_2
).initialize_infections(I_pre_init)

I_pre_init_RV_2 = DeterministicVariable(
name="I_pre_init_RV",
value=np.array([10.0, 10.0]),
## check that the result is as expected
testing.assert_array_almost_equal(result, manual_result)

## check that infections at t_pre_init is I_pre_init
testing.assert_array_equal(result[t_pre_init], I_pre_init)

# rate is array, I_pre_init is scalar with default t_pre_init
result = InitializeInfectionsExponentialGrowth(
n_timepoints, rate_rv=rate_RV
).initialize_infections(I_pre_init_scalar)

manual_result = np.column_stack(
[
I_pre_init_scalar
* np.exp(rate[0] * (np.arange(n_timepoints) - default_t_pre_init)),
I_pre_init_scalar
* np.exp(rate[1] * (np.arange(n_timepoints) - default_t_pre_init)),
]
)
(I_pre_init_2,) = I_pre_init_RV_2()

with pytest.raises(ValueError):
InitializeInfectionsExponentialGrowth(
n_timepoints, rate=rate_RV
).initialize_infections(I_pre_init_2.value)
## check that the result is as expected
testing.assert_array_almost_equal(result, manual_result)

# test non-default t_pre_init
t_pre_init = 6
infections_custom_t_pre_init = InitializeInfectionsExponentialGrowth(
n_timepoints, rate=rate_RV, t_pre_init=t_pre_init
## check that infections at default t_pre_init is I_pre_init
testing.assert_array_equal(result[default_t_pre_init], I_pre_init_scalar)

# rate is scalar, I_pre_init is array with default t_pre_init
result = InitializeInfectionsExponentialGrowth(
n_timepoints, rate_rv=rate_scalar_RV
).initialize_infections(I_pre_init)
infections_custom_t_pre_init_manual = I_pre_init * np.exp(
rate * (np.arange(n_timepoints) - t_pre_init)

manual_result = np.column_stack(
[
I_pre_init[0]
* np.exp(
rate_scalar * (np.arange(n_timepoints) - default_t_pre_init)
),
I_pre_init[1]
* np.exp(
rate_scalar * (np.arange(n_timepoints) - default_t_pre_init)
),
]
)
testing.assert_array_almost_equal(
infections_custom_t_pre_init,
infections_custom_t_pre_init_manual,
decimal=5,

## check that the result is as expected
testing.assert_array_almost_equal(result, manual_result)

## check that infections at default t_pre_init is I_pre_init
testing.assert_array_equal(result[default_t_pre_init], I_pre_init)

# both rate and I_pre_init are scalar with default t_pre_init
result = InitializeInfectionsExponentialGrowth(
n_timepoints, rate_rv=rate_scalar_RV
).initialize_infections(I_pre_init_scalar)

manual_result = I_pre_init_scalar * np.exp(
rate_scalar * (np.arange(n_timepoints) - default_t_pre_init)
)

assert infections_custom_t_pre_init[t_pre_init] == I_pre_init
## check that the result is as expected
testing.assert_array_almost_equal(result, manual_result)

## check that infections at default t_pre_init is I_pre_init
testing.assert_array_equal(result[default_t_pre_init], I_pre_init_scalar)


def test_initialize_infections_zero_pad():
Expand Down

0 comments on commit 3c5fbe7

Please sign in to comment.