From 3c5fbe72b611cfbfa0c3e9da6f0199ab51213717 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Fri, 6 Sep 2024 15:48:12 -0500 Subject: [PATCH] Reworks `InitializeInfectionsExponentialGrowth` to work with `numpyro.plate` (#432) --- .../latent/infection_initialization_method.py | 22 ++- test/test_infection_initialization_method.py | 134 ++++++++++++------ 2 files changed, 101 insertions(+), 55 deletions(-) diff --git a/pyrenew/latent/infection_initialization_method.py b/pyrenew/latent/infection_initialization_method.py index f4a0c0e7..67d25fb3 100644 --- a/pyrenew/latent/infection_initialization_method.py +++ b/pyrenew/latent/infection_initialization_method.py @@ -143,7 +143,7 @@ class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod): def __init__( self, n_timepoints: int, - rate: RandomVariable, + rate_rv: RandomVariable, t_pre_init: int | None = None, ): """Default constructor for the ``InitializeInfectionsExponentialGrowth`` class. @@ -152,13 +152,13 @@ def __init__( ---------- n_timepoints : int the number of time points to generate initial infections for - rate : RandomVariable + rate_rv : RandomVariable A random variable representing the rate of exponential growth t_pre_init : int | None, optional The time point whose number of infections is described by ``I_pre_init``. Defaults to ``n_timepoints - 1``. """ super().__init__(n_timepoints) - self.rate = rate + self.rate_rv = rate_rv if t_pre_init is None: t_pre_init = n_timepoints - 1 self.t_pre_init = t_pre_init @@ -177,15 +177,9 @@ def initialize_infections(self, I_pre_init: ArrayLike): An array of length ``n_timepoints`` with the number of initialized infections at each time point. """ I_pre_init = jnp.array(I_pre_init) - if I_pre_init.size != 1: - raise ValueError( - f"I_pre_init must be an array of size 1. Got size {I_pre_init.size}." - ) - rate = jnp.array(self.rate()[0].value) - if rate.size != 1: - raise ValueError( - f"rate must be an array of size 1. Got size {rate.size}." - ) - return I_pre_init * jnp.exp( - rate * (jnp.arange(self.n_timepoints) - self.t_pre_init) + rate = jnp.array(self.rate_rv()[0].value) + initial_infections = I_pre_init * jnp.exp( + rate + * (jnp.arange(self.n_timepoints)[:, jnp.newaxis] - self.t_pre_init) ) + return jnp.squeeze(initial_infections) diff --git a/test/test_infection_initialization_method.py b/test/test_infection_initialization_method.py index 4ca77510..ce9796f5 100644 --- a/test/test_infection_initialization_method.py +++ b/test/test_infection_initialization_method.py @@ -14,64 +14,116 @@ def test_initialize_infections_exponential(): """Check that the InitializeInfectionsExponentialGrowth class generates the correct number of infections at each time point.""" n_timepoints = 10 - rate_RV = DeterministicVariable(name="rate_RV", value=0.5) - I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", value=10.0) default_t_pre_init = n_timepoints - 1 + t_pre_init = 6 - (I_pre_init,) = I_pre_init_RV() - (rate,) = rate_RV() + rate_RV = DeterministicVariable(name="rate_RV", value=np.array([0.5, 0.1])) + rate_scalar_RV = DeterministicVariable(name="rate_RV", value=0.5) - I_pre_init = I_pre_init.value - rate = rate.value - infections_default_t_pre_init = InitializeInfectionsExponentialGrowth( - n_timepoints, rate=rate_RV + rate = rate_RV()[0].value + rate_scalar = rate_scalar_RV()[0].value + + I_pre_init = np.array([5.0, 10.0]) + I_pre_init_scalar = 5.0 + + # both rate and I_pre_init are arrays with default t_pre_init + result = InitializeInfectionsExponentialGrowth( + n_timepoints, rate_rv=rate_RV ).initialize_infections(I_pre_init) - infections_default_t_pre_init_manual = I_pre_init * np.exp( - rate * (np.arange(n_timepoints) - default_t_pre_init) - ) - testing.assert_array_almost_equal( - infections_default_t_pre_init, infections_default_t_pre_init_manual + manual_result = np.column_stack( + [ + I_pre_init[0] + * np.exp(rate[0] * (np.arange(n_timepoints) - default_t_pre_init)), + I_pre_init[1] + * np.exp(rate[1] * (np.arange(n_timepoints) - default_t_pre_init)), + ] ) - # assert that infections at default t_pre_init is I_pre_init - assert infections_default_t_pre_init[default_t_pre_init] == I_pre_init + ## check that the result is as expected + testing.assert_array_almost_equal(result, manual_result) - # test for failure with non-scalar rate or I_pre_init - rate_RV_2 = DeterministicVariable( - name="rate_RV", value=np.array([0.5, 0.5]) + ## check that infections at default t_pre_init is I_pre_init + testing.assert_array_equal(result[default_t_pre_init], I_pre_init) + + # both rate and I_pre_init are arrays with custom t_pre_init + result = InitializeInfectionsExponentialGrowth( + n_timepoints, rate_rv=rate_RV, t_pre_init=t_pre_init + ).initialize_infections(I_pre_init) + + manual_result = np.column_stack( + [ + I_pre_init[0] + * np.exp(rate[0] * (np.arange(n_timepoints) - t_pre_init)), + I_pre_init[1] + * np.exp(rate[1] * (np.arange(n_timepoints) - t_pre_init)), + ] ) - with pytest.raises(ValueError): - InitializeInfectionsExponentialGrowth( - n_timepoints, rate=rate_RV_2 - ).initialize_infections(I_pre_init) - I_pre_init_RV_2 = DeterministicVariable( - name="I_pre_init_RV", - value=np.array([10.0, 10.0]), + ## check that the result is as expected + testing.assert_array_almost_equal(result, manual_result) + + ## check that infections at t_pre_init is I_pre_init + testing.assert_array_equal(result[t_pre_init], I_pre_init) + + # rate is array, I_pre_init is scalar with default t_pre_init + result = InitializeInfectionsExponentialGrowth( + n_timepoints, rate_rv=rate_RV + ).initialize_infections(I_pre_init_scalar) + + manual_result = np.column_stack( + [ + I_pre_init_scalar + * np.exp(rate[0] * (np.arange(n_timepoints) - default_t_pre_init)), + I_pre_init_scalar + * np.exp(rate[1] * (np.arange(n_timepoints) - default_t_pre_init)), + ] ) - (I_pre_init_2,) = I_pre_init_RV_2() - with pytest.raises(ValueError): - InitializeInfectionsExponentialGrowth( - n_timepoints, rate=rate_RV - ).initialize_infections(I_pre_init_2.value) + ## check that the result is as expected + testing.assert_array_almost_equal(result, manual_result) - # test non-default t_pre_init - t_pre_init = 6 - infections_custom_t_pre_init = InitializeInfectionsExponentialGrowth( - n_timepoints, rate=rate_RV, t_pre_init=t_pre_init + ## check that infections at default t_pre_init is I_pre_init + testing.assert_array_equal(result[default_t_pre_init], I_pre_init_scalar) + + # rate is scalar, I_pre_init is array with default t_pre_init + result = InitializeInfectionsExponentialGrowth( + n_timepoints, rate_rv=rate_scalar_RV ).initialize_infections(I_pre_init) - infections_custom_t_pre_init_manual = I_pre_init * np.exp( - rate * (np.arange(n_timepoints) - t_pre_init) + + manual_result = np.column_stack( + [ + I_pre_init[0] + * np.exp( + rate_scalar * (np.arange(n_timepoints) - default_t_pre_init) + ), + I_pre_init[1] + * np.exp( + rate_scalar * (np.arange(n_timepoints) - default_t_pre_init) + ), + ] ) - testing.assert_array_almost_equal( - infections_custom_t_pre_init, - infections_custom_t_pre_init_manual, - decimal=5, + + ## check that the result is as expected + testing.assert_array_almost_equal(result, manual_result) + + ## check that infections at default t_pre_init is I_pre_init + testing.assert_array_equal(result[default_t_pre_init], I_pre_init) + + # both rate and I_pre_init are scalar with default t_pre_init + result = InitializeInfectionsExponentialGrowth( + n_timepoints, rate_rv=rate_scalar_RV + ).initialize_infections(I_pre_init_scalar) + + manual_result = I_pre_init_scalar * np.exp( + rate_scalar * (np.arange(n_timepoints) - default_t_pre_init) ) - assert infections_custom_t_pre_init[t_pre_init] == I_pre_init + ## check that the result is as expected + testing.assert_array_almost_equal(result, manual_result) + + ## check that infections at default t_pre_init is I_pre_init + testing.assert_array_equal(result[default_t_pre_init], I_pre_init_scalar) def test_initialize_infections_zero_pad():