diff --git a/docs/source/msei_reference/latent.rst b/docs/source/msei_reference/latent.rst index e6aa0859..eb1f9b79 100644 --- a/docs/source/msei_reference/latent.rst +++ b/docs/source/msei_reference/latent.rst @@ -25,18 +25,18 @@ Infection Functions :undoc-members: :show-inheritance: -Infection Seeding Process +Infection Initialization Process ------------------------- -.. automodule:: pyrenew.latent.infection_seeding_process +.. automodule:: pyrenew.latent.infection_initialization_process :members: :undoc-members: :show-inheritance: -Infection Seeding Method +Infection Initialization Method ------------------------ -.. automodule:: pyrenew.latent.infection_seeding_method +.. automodule:: pyrenew.latent.infection_initialization_method :members: :undoc-members: :show-inheritance: diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index c58eec3c..49cdba88 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -19,8 +19,8 @@ import numpyro.distributions as dist from pyrenew.process import RtRandomWalkProcess from pyrenew.latent import ( Infections, - InfectionSeedingProcess, - SeedInfectionsZeroPad, + InfectionInitializationProcess, + InitializeInfectionsZeroPad, ) from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF @@ -99,7 +99,7 @@ To initialize these five components within the renewal modeling framework, we es (1) In this example, the generation interval is not estimated but passed as a deterministic instance of `RandomVariable` -(2) an instance of the `InfectionSeedingProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `SeedInfectionsZeroPad`, the latent infections before this time are assumed to be 0. +(2) an instance of the `InfectionInitializationProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `InitializeInfectionsZeroPad`, the latent infections before this time are assumed to be 0. (3) an instance of the `RtRandomWalkProcess` class with default values @@ -114,10 +114,10 @@ pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) gen_int = DeterministicPMF(pmf_array, name="gen_int") # (2) Initial infections (inferred with a prior) -I0 = InfectionSeedingProcess( - "I0_seeding", +I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(pmf_array.size), + InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) @@ -155,7 +155,7 @@ The following diagram summarizes how the modules interact via composition; notab %%| include: true flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] - i0["(2) I0\n(InfectionSeedingProcess)"] + i0["(2) I0\n(InfectionInitializationProcess)"] rt["(3) rt_proc\n(RtRandomWalkProcess)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 3918545b..cbb989ec 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -31,8 +31,8 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.process import RtRandomWalkProcess from pyrenew.metaclass import DistributionalRV from pyrenew.latent import ( - InfectionSeedingProcess, - SeedInfectionsExponentialGrowth, + InfectionInitializationProcess, + InitializeInfectionsExponentialGrowth, ) import pyrenew.transformation as t ``` @@ -45,10 +45,10 @@ gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1]) gen_int = DeterministicPMF(gen_int_array, name="gen_int") feedback_strength = DeterministicVariable(0.05, name="feedback_strength") -I0 = InfectionSeedingProcess( - "I0_seeding", +I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsExponentialGrowth( + InitializeInfectionsExponentialGrowth( gen_int_array.size, DeterministicVariable(0.5, name="rate"), ), diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index f88a884e..af2d3f84 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -149,18 +149,18 @@ The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to ho # | label: initializing-rest-of-model from pyrenew import model, process, observation, metaclass, transformation from pyrenew.latent import ( - InfectionSeedingProcess, - SeedInfectionsExponentialGrowth, + InfectionInitializationProcess, + InitializeInfectionsExponentialGrowth, ) # Infection process latent_inf = latent.Infections() -I0 = InfectionSeedingProcess( - "I0_seeding", +I0 = InfectionInitializationProcess( + "I0_initialization", metaclass.DistributionalRV( dist=dist.LogNormal(loc=jnp.log(100), scale=0.5), name="I0" ), - SeedInfectionsExponentialGrowth( + InitializeInfectionsExponentialGrowth( gen_int_array.size, deterministic.DeterministicVariable(0.5, name="rate"), ), diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd index 8a12e988..af42e813 100644 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ b/docs/source/tutorials/pyrenew_demo.qmd @@ -67,7 +67,10 @@ from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.model import HospitalAdmissionsModel from pyrenew.process import RtRandomWalkProcess -from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsZeroPad +from pyrenew.latent import ( + InfectionInitializationProcess, + InitializeInfectionsZeroPad, +) import pyrenew.transformation as t ``` @@ -93,10 +96,10 @@ pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) gen_int = DeterministicPMF(pmf_array, name="gen_int") # 2) Initial infections -I0 = InfectionSeedingProcess( - "I0_seeding", +I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(pmf_array.size), + InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) diff --git a/model/src/pyrenew/latent/__init__.py b/model/src/pyrenew/latent/__init__.py index b557d1c5..046160fd 100644 --- a/model/src/pyrenew/latent/__init__.py +++ b/model/src/pyrenew/latent/__init__.py @@ -8,13 +8,15 @@ compute_infections_from_rt_with_feedback, logistic_susceptibility_adjustment, ) -from pyrenew.latent.infection_seeding_method import ( - InfectionSeedMethod, - SeedInfectionsExponentialGrowth, - SeedInfectionsFromVec, - SeedInfectionsZeroPad, +from pyrenew.latent.infection_initialization_method import ( + InfectionInitializationMethod, + InitializeInfectionsExponentialGrowth, + InitializeInfectionsFromVec, + InitializeInfectionsZeroPad, +) +from pyrenew.latent.infection_initialization_process import ( + InfectionInitializationProcess, ) -from pyrenew.latent.infection_seeding_process import InfectionSeedingProcess from pyrenew.latent.infections import Infections from pyrenew.latent.infectionswithfeedback import InfectionsWithFeedback @@ -24,10 +26,10 @@ "logistic_susceptibility_adjustment", "compute_infections_from_rt", "compute_infections_from_rt_with_feedback", - "InfectionSeedMethod", - "SeedInfectionsExponentialGrowth", - "SeedInfectionsFromVec", - "SeedInfectionsZeroPad", - "InfectionSeedingProcess", + "InfectionInitializationMethod", + "InitializeInfectionsExponentialGrowth", + "InitializeInfectionsFromVec", + "InitializeInfectionsZeroPad", + "InfectionInitializationProcess", "InfectionsWithFeedback", ] diff --git a/model/src/pyrenew/latent/infection_seeding_method.py b/model/src/pyrenew/latent/infection_initialization_method.py similarity index 80% rename from model/src/pyrenew/latent/infection_seeding_method.py rename to model/src/pyrenew/latent/infection_initialization_method.py index ce1cfcfc..b09e4c7d 100644 --- a/model/src/pyrenew/latent/infection_seeding_method.py +++ b/model/src/pyrenew/latent/infection_initialization_method.py @@ -7,16 +7,16 @@ from pyrenew.metaclass import RandomVariable -class InfectionSeedMethod(metaclass=ABCMeta): +class InfectionInitializationMethod(metaclass=ABCMeta): """Method for seeding initial infections in a renewal process.""" def __init__(self, n_timepoints: int): - """Default constructor for the ``InfectionSeedMethod`` class. + """Default constructor for the ``InfectionInitializationMethod`` class. Parameters ---------- n_timepoints : int - the number of time points to generate seed infections for + the number of time points to generate initial infections for Returns ------- @@ -27,12 +27,12 @@ def __init__(self, n_timepoints: int): @staticmethod def validate(n_timepoints: int) -> None: - """Validate inputs for the ``InfectionSeedMethod`` class constructor + """Validate inputs for the ``InfectionInitializationMethod`` class constructor Parameters ---------- n_timepoints : int - the number of time points to generate seed infections for + the number of time points to generate initial infections for Returns ------- @@ -54,7 +54,7 @@ def seed_infections(self, I_pre_seed: ArrayLike): Parameters ---------- I_pre_seed : ArrayLike - An array representing some number of latent infections to be used with the specified ``InfectionSeedMethod``. + An array representing some number of latent infections to be used with the specified ``InfectionInitializationMethod``. Returns ------- @@ -66,15 +66,15 @@ def __call__(self, I_pre_seed: ArrayLike): return self.seed_infections(I_pre_seed) -class SeedInfectionsZeroPad(InfectionSeedMethod): +class InitializeInfectionsZeroPad(InfectionInitializationMethod): """ - Create a seed infection vector of specified length by + Create an initial infection vector of specified length by padding a shorter vector with an appropriate number of zeros at the beginning of the time series. """ def seed_infections(self, I_pre_seed: ArrayLike): - """Pad the seed infections with zeros at the beginning of the time series. + """Pad the initial infections with zeros at the beginning of the time series. Parameters ---------- @@ -95,16 +95,16 @@ def seed_infections(self, I_pre_seed: ArrayLike): return jnp.pad(I_pre_seed, (self.n_timepoints - I_pre_seed.size, 0)) -class SeedInfectionsFromVec(InfectionSeedMethod): - """Create seed infections from a vector of infections.""" +class InitializeInfectionsFromVec(InfectionInitializationMethod): + """Create initial infections from a vector of infections.""" def seed_infections(self, I_pre_seed: ArrayLike): - """Create seed infections from a vector of infections. + """Create initial infections from a vector of infections. Parameters ---------- I_pre_seed : ArrayLike - An array with the same length as ``n_timepoints`` to be used as the seed infections. + An array with the same length as ``n_timepoints`` to be used as the initial infections. Returns ------- @@ -120,8 +120,8 @@ def seed_infections(self, I_pre_seed: ArrayLike): return jnp.array(I_pre_seed) -class SeedInfectionsExponentialGrowth(InfectionSeedMethod): - r"""Generate seed infections according to exponential growth. +class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod): + r"""Generate initial infections according to exponential growth. Notes ----- @@ -142,12 +142,12 @@ def __init__( rate: RandomVariable, t_pre_seed: int | None = None, ): - """Default constructor for the ``SeedInfectionsExponentialGrowth`` class. + """Default constructor for the ``InitializeInfectionsExponentialGrowth`` class. Parameters ---------- n_timepoints : int - the number of time points to generate seed infections for + the number of time points to generate initial infections for rate : RandomVariable A random variable representing the rate of exponential growth t_pre_seed : int | None, optional @@ -160,7 +160,7 @@ def __init__( self.t_pre_seed = t_pre_seed def seed_infections(self, I_pre_seed: ArrayLike): - """Generate seed infections according to exponential growth. + """Generate initial infections according to exponential growth. Parameters ---------- diff --git a/model/src/pyrenew/latent/infection_seeding_process.py b/model/src/pyrenew/latent/infection_initialization_process.py similarity index 71% rename from model/src/pyrenew/latent/infection_seeding_process.py rename to model/src/pyrenew/latent/infection_initialization_process.py index 67110a30..a23da038 100644 --- a/model/src/pyrenew/latent/infection_seeding_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -1,22 +1,24 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 import numpyro as npro -from pyrenew.latent.infection_seeding_method import InfectionSeedMethod +from pyrenew.latent.infection_initialization_method import ( + InfectionInitializationMethod, +) from pyrenew.metaclass import RandomVariable -class InfectionSeedingProcess(RandomVariable): +class InfectionInitializationProcess(RandomVariable): """Generate an initial infection history""" def __init__( self, name, I_pre_seed_rv: RandomVariable, - infection_seed_method: InfectionSeedMethod, + infection_seed_method: InfectionInitializationMethod, t_unit: int, t_start: int | None = None, ) -> None: - """Default class constructor for InfectionSeedingProcess + """Default class constructor for InfectionInitializationProcess Parameters ---------- @@ -24,8 +26,8 @@ def __init__( A name to assign to the RandomVariable. I_pre_seed_rv : RandomVariable A RandomVariable representing the number of infections that occur at some time before the renewal process begins. Each `infection_seed_method` uses this random variable in different ways. - infection_seed_method : InfectionSeedMethod - An `InfectionSeedMethod` that generates the seed infections for the renewal process. + infection_seed_method : InfectionInitializationMethod + An `InfectionInitializationMethod` that generates the initial infections for the renewal process. t_unit : int The unit of time for the time series passed to `RandomVariable.set_timeseries`. t_start : int, optional @@ -39,7 +41,9 @@ def __init__( ------- None """ - InfectionSeedingProcess.validate(I_pre_seed_rv, infection_seed_method) + InfectionInitializationProcess.validate( + I_pre_seed_rv, infection_seed_method + ) self.I_pre_seed_rv = I_pre_seed_rv self.infection_seed_method = infection_seed_method @@ -55,16 +59,16 @@ def __init__( @staticmethod def validate( I_pre_seed_rv: RandomVariable, - infection_seed_method: InfectionSeedMethod, + infection_seed_method: InfectionInitializationMethod, ) -> None: - """Validate the input arguments to the InfectionSeedingProcess class constructor + """Validate the input arguments to the InfectionInitializationProcess class constructor Parameters ---------- I_pre_seed_rv : RandomVariable A random variable representing the number of infections that occur at some time before the renewal process begins. - infection_seed_method : InfectionSeedMethod - An method to generate the seed infections. + infection_seed_method : InfectionInitializationMethod + An method to generate the initial infections. Returns ------- @@ -75,14 +79,16 @@ def validate( "I_pre_seed_rv must be an instance of RandomVariable" f"Got {type(I_pre_seed_rv)}" ) - if not isinstance(infection_seed_method, InfectionSeedMethod): + if not isinstance( + infection_seed_method, InfectionInitializationMethod + ): raise TypeError( - "infection_seed_method must be an instance of InfectionSeedMethod" + "infection_seed_method must be an instance of InfectionInitializationMethod" f"Got {type(infection_seed_method)}" ) def sample(self) -> tuple: - """Sample the infection seeding process. + """Sample the Infection Initialization Process. Returns ------- diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index 400e6886..c3110f18 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -17,14 +17,14 @@ class InfectionsSample(NamedTuple): Attributes ---------- - post_seed_infections : ArrayLike | None, optional + post_initialization_infections : ArrayLike | None, optional The estimated latent infections. Defaults to None. """ - post_seed_infections: ArrayLike | None = None + post_initialization_infections: ArrayLike | None = None def __repr__(self): - return f"InfectionsSample(post_seed_infections={self.post_seed_infections})" + return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections})" class Infections(RandomVariable): @@ -91,10 +91,10 @@ def sample( gen_int_rev = jnp.flip(gen_int) recent_I0 = I0[-gen_int_rev.size :] - post_seed_infections = inf.compute_infections_from_rt( + post_initialization_infections = inf.compute_infections_from_rt( I0=recent_I0, Rt=Rt, reversed_generation_interval_pmf=gen_int_rev, ) - return InfectionsSample(post_seed_infections) + return InfectionsSample(post_initialization_infections) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index c6be2ec3..d0a87707 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -17,17 +17,17 @@ class InfectionsRtFeedbackSample(NamedTuple): Attributes ---------- - post_seed_infections : ArrayLike | None, optional + post_initialization_infections : ArrayLike | None, optional The estimated latent infections. Defaults to None. rt : ArrayLike | None, optional The adjusted reproduction number. Defaults to None. """ - post_seed_infections: ArrayLike | None = None + post_initialization_infections: ArrayLike | None = None rt: ArrayLike | None = None def __repr__(self): - return f"InfectionsSample(post_seed_infections={self.post_seed_infections}, rt={self.rt})" + return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections}, rt={self.rt})" class InfectionsWithFeedback(RandomVariable): @@ -180,7 +180,7 @@ def sample( inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) ( - post_seed_infections, + post_initialization_infections, Rt_adj, ) = inf.compute_infections_from_rt_with_feedback( I0=I0, @@ -195,6 +195,6 @@ def sample( npro.deterministic("Rt_adjusted", Rt_adj) return InfectionsRtFeedbackSample( - post_seed_infections=post_seed_infections, + post_initialization_infections=post_initialization_infections, rt=Rt_adj, ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index afbdb8d0..ebb19eac 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -203,7 +203,10 @@ def sample( # Sampling initial infections I0, *_ = self.I0_rv.sample(**kwargs) # Sampling from the latent process - post_seed_latent_infections, *_ = self.latent_infections_rv.sample( + ( + post_initialization_latent_infections, + *_, + ) = self.latent_infections_rv.sample( Rt=Rt, gen_int=gen_int, I0=I0, @@ -211,12 +214,14 @@ def sample( ) observed_infections, *_ = self.infection_obs_process_rv.sample( - mu=post_seed_latent_infections[padding:], + mu=post_initialization_latent_infections[padding:], obs=data_observed_infections, **kwargs, ) - all_latent_infections = jnp.hstack([I0, post_seed_latent_infections]) + all_latent_infections = jnp.hstack( + [I0, post_initialization_latent_infections] + ) npro.deterministic("all_latent_infections", all_latent_infections) observed_infections = au.pad_x_to_match_y( diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 677755c5..90d4b2fd 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -9,9 +9,9 @@ from numpy.testing import assert_array_equal from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import ( + InfectionInitializationProcess, Infections, - InfectionSeedingProcess, - SeedInfectionsZeroPad, + InitializeInfectionsZeroPad, ) from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel @@ -23,10 +23,10 @@ def test_forecast(): """Check that forecasts are the right length and match the posterior up until forecast begins.""" pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) gen_int = DeterministicPMF(pmf_array, name="gen_int") - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) latent_infections = Infections() diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py index 527ac96f..3322bbb4 100644 --- a/model/src/test/test_infection_seeding_method.py +++ b/model/src/test/test_infection_seeding_method.py @@ -4,14 +4,14 @@ import pytest from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import ( - SeedInfectionsExponentialGrowth, - SeedInfectionsFromVec, - SeedInfectionsZeroPad, + InitializeInfectionsExponentialGrowth, + InitializeInfectionsFromVec, + InitializeInfectionsZeroPad, ) def test_seed_infections_exponential(): - """Check that the SeedInfectionsExponentialGrowth class generates the correct number of infections at each time point.""" + """Check that the InitializeInfectionsExponentialGrowth class generates the correct number of infections at each time point.""" n_timepoints = 10 rate_RV = DeterministicVariable(0.5, name="rate_RV") I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV") @@ -20,7 +20,7 @@ def test_seed_infections_exponential(): (I_pre_seed,) = I_pre_seed_RV.sample() (rate,) = rate_RV.sample() - infections_default_t_pre_seed = SeedInfectionsExponentialGrowth( + infections_default_t_pre_seed = InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV ).seed_infections(I_pre_seed) infections_default_t_pre_seed_manual = I_pre_seed * np.exp( @@ -37,7 +37,7 @@ def test_seed_infections_exponential(): # test for failure with non-scalar rate or I_pre_seed rate_RV_2 = DeterministicVariable(np.array([0.5, 0.5]), name="rate_RV") with pytest.raises(ValueError): - SeedInfectionsExponentialGrowth( + InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV_2 ).seed_infections(I_pre_seed) @@ -47,13 +47,13 @@ def test_seed_infections_exponential(): (I_pre_seed_2,) = I_pre_seed_RV_2.sample() with pytest.raises(ValueError): - SeedInfectionsExponentialGrowth( + InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV ).seed_infections(I_pre_seed_2) # test non-default t_pre_seed t_pre_seed = 6 - infections_custom_t_pre_seed = SeedInfectionsExponentialGrowth( + infections_custom_t_pre_seed = InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV, t_pre_seed=t_pre_seed ).seed_infections(I_pre_seed) infections_custom_t_pre_seed_manual = I_pre_seed * np.exp( @@ -69,13 +69,13 @@ def test_seed_infections_exponential(): def test_seed_infections_zero_pad(): - """Check that the SeedInfectionsZeroPad class generates the correct number of infections at each time point.""" + """Check that the InitializeInfectionsZeroPad class generates the correct number of infections at each time point.""" n_timepoints = 10 I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV") (I_pre_seed,) = I_pre_seed_RV.sample() - infections = SeedInfectionsZeroPad(n_timepoints).seed_infections( + infections = InitializeInfectionsZeroPad(n_timepoints).seed_infections( I_pre_seed ) testing.assert_array_equal( @@ -87,7 +87,7 @@ def test_seed_infections_zero_pad(): ) (I_pre_seed_2,) = I_pre_seed_RV_2.sample() - infections_2 = SeedInfectionsZeroPad(n_timepoints).seed_infections( + infections_2 = InitializeInfectionsZeroPad(n_timepoints).seed_infections( I_pre_seed_2 ) testing.assert_array_equal( @@ -95,29 +95,33 @@ def test_seed_infections_zero_pad(): np.pad(I_pre_seed_2, (n_timepoints - I_pre_seed_2.size, 0)), ) - # Check that the SeedInfectionsZeroPad class raises an error when the length of I_pre_seed is greater than n_timepoints. + # Check that the InitializeInfectionsZeroPad class raises an error when the length of I_pre_seed is greater than n_timepoints. with pytest.raises(ValueError): - SeedInfectionsZeroPad(1).seed_infections(I_pre_seed_2) + InitializeInfectionsZeroPad(1).seed_infections(I_pre_seed_2) def test_seed_infections_from_vec(): - """Check that the SeedInfectionsFromVec class generates the correct number of infections at each time point.""" + """Check that the InitializeInfectionsFromVec class generates the correct number of infections at each time point.""" n_timepoints = 10 I_pre_seed = np.arange(n_timepoints) - infections = SeedInfectionsFromVec(n_timepoints).seed_infections( + infections = InitializeInfectionsFromVec(n_timepoints).seed_infections( I_pre_seed ) testing.assert_array_equal(infections, I_pre_seed) I_pre_seed_2 = np.arange(n_timepoints - 1) with pytest.raises(ValueError): - SeedInfectionsFromVec(n_timepoints).seed_infections(I_pre_seed_2) + InitializeInfectionsFromVec(n_timepoints).seed_infections(I_pre_seed_2) n_timepoints_float = 10.0 with pytest.raises(TypeError): - SeedInfectionsFromVec(n_timepoints_float).seed_infections(I_pre_seed) + InitializeInfectionsFromVec(n_timepoints_float).seed_infections( + I_pre_seed + ) n_timepoints_neg = -10 with pytest.raises(ValueError): - SeedInfectionsFromVec(n_timepoints_neg).seed_infections(I_pre_seed) + InitializeInfectionsFromVec(n_timepoints_neg).seed_infections( + I_pre_seed + ) diff --git a/model/src/test/test_infection_seeding_process.py b/model/src/test/test_infection_seeding_process.py index 48d97422..a37bf4ad 100644 --- a/model/src/test/test_infection_seeding_process.py +++ b/model/src/test/test_infection_seeding_process.py @@ -5,38 +5,38 @@ import pytest from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import ( - InfectionSeedingProcess, - SeedInfectionsExponentialGrowth, - SeedInfectionsFromVec, - SeedInfectionsZeroPad, + InfectionInitializationProcess, + InitializeInfectionsExponentialGrowth, + InitializeInfectionsFromVec, + InitializeInfectionsZeroPad, ) from pyrenew.metaclass import DistributionalRV -def test_infection_seeding_process(): - """Check that the InfectionSeedingProcess class generates can be sampled from with all InfectionSeedMethods.""" +def test_infection_initialization_process(): + """Check that the InfectionInitializationProcess class generates can be sampled from with all InfectionInitializationMethods.""" n_timepoints = 10 - zero_pad_model = InfectionSeedingProcess( + zero_pad_model = InfectionInitializationProcess( "zero_pad_model", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints), + InitializeInfectionsZeroPad(n_timepoints), t_unit=1, ) - exp_model = InfectionSeedingProcess( + exp_model = InfectionInitializationProcess( "exp_model", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsExponentialGrowth( + InitializeInfectionsExponentialGrowth( n_timepoints, DeterministicVariable(0.5, name="rate") ), t_unit=1, ) - vec_model = InfectionSeedingProcess( + vec_model = InfectionInitializationProcess( "vec_model", DeterministicVariable(jnp.arange(n_timepoints), name="I0"), - SeedInfectionsFromVec(n_timepoints), + InitializeInfectionsFromVec(n_timepoints), t_unit=1, ) @@ -44,17 +44,17 @@ def test_infection_seeding_process(): with npro.handlers.seed(rng_seed=1): model.sample() - # Check that the InfectionSeedingProcess class raises an error when the wrong type of I0 is passed + # Check that the InfectionInitializationProcess class raises an error when the wrong type of I0 is passed with pytest.raises(TypeError): - InfectionSeedingProcess( + InfectionInitializationProcess( "vec_model", jnp.arange(n_timepoints), - SeedInfectionsFromVec(n_timepoints), + InitializeInfectionsFromVec(n_timepoints), t_unit=1, ) with pytest.raises(TypeError): - InfectionSeedingProcess( + InfectionInitializationProcess( "vec_model", DeterministicVariable(jnp.arange(n_timepoints), name="I0"), 3, diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index 9467c116..f98ded87 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -54,7 +54,7 @@ def _infection_w_feedback_alt( I_vec[t : t + len_gen], np.flip(gen_int) ) - return {"post_seed_infections": I_vec[I0.size :], "rt": Rt_adj} + return {"post_initialization_infections": I_vec[I0.size :], "rt": Rt_adj} def test_infectionsrtfeedback(): @@ -94,7 +94,10 @@ def test_infectionsrtfeedback(): I0=I0, ) - assert_array_equal(samp1.post_seed_infections, samp2.post_seed_infections) + assert_array_equal( + samp1.post_initialization_infections, + samp2.post_initialization_infections, + ) assert_array_equal(samp1.rt, Rt) return None @@ -144,10 +147,12 @@ def test_infectionsrtfeedback_feedback(): ) assert not jnp.array_equal( - samp1.post_seed_infections, samp2.post_seed_infections + samp1.post_initialization_infections, + samp2.post_initialization_infections, ) assert_array_almost_equal( - samp1.post_seed_infections, res["post_seed_infections"] + samp1.post_initialization_infections, + res["post_initialization_infections"], ) assert_array_almost_equal(samp1.rt, res["rt"]) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 1df985d9..30f73091 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -41,7 +41,8 @@ def test_infections_as_deterministic(): inf_sampled2 = inf1.sample(**obs) testing.assert_array_equal( - inf_sampled1.post_seed_infections, inf_sampled2.post_seed_infections + inf_sampled1.post_initialization_infections, + inf_sampled2.post_initialization_infections, ) # Check that Initial infections vector must be at least as long as the generation interval. diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index f79a7e4c..22c5d992 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -12,9 +12,9 @@ import pytest from pyrenew.deterministic import DeterministicPMF, NullObservation from pyrenew.latent import ( + InfectionInitializationProcess, Infections, - InfectionSeedingProcess, - SeedInfectionsZeroPad, + InitializeInfectionsZeroPad, ) from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel @@ -110,10 +110,10 @@ def test_model_basicrenewal_no_obs_model(): with pytest.raises(ValueError): I0 = DistributionalRV(dist=1, name="I0") - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -186,10 +186,10 @@ def test_model_basicrenewal_with_obs_model(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -240,10 +240,10 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index dc2b090a..056b1798 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -17,9 +17,9 @@ ) from pyrenew.latent import ( HospitalAdmissions, + InfectionInitializationProcess, Infections, - InfectionSeedingProcess, - SeedInfectionsZeroPad, + InitializeInfectionsZeroPad, ) from pyrenew.metaclass import DistributionalRV, RandomVariable from pyrenew.model import HospitalAdmissionsModel @@ -192,10 +192,10 @@ def test_model_hosp_no_obs_model(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -302,10 +302,10 @@ def test_model_hosp_with_obs_model(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -392,10 +392,10 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -495,10 +495,10 @@ def test_model_hosp_with_obs_model_weekday_phosp(): n_obs_to_generate = 30 pad_size = 5 - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 5fbbeee2..693cf638 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -10,9 +10,9 @@ import pytest from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import ( + InfectionInitializationProcess, Infections, - InfectionSeedingProcess, - SeedInfectionsZeroPad, + InitializeInfectionsZeroPad, ) from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel @@ -21,10 +21,10 @@ pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) gen_int = DeterministicPMF(pmf_array, name="gen_int") -I0 = InfectionSeedingProcess( - "I0_seeding", +I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) latent_infections = Infections() diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index a2342552..c181dfbd 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -14,9 +14,9 @@ from numpy.testing import assert_array_equal, assert_raises from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import ( + InfectionInitializationProcess, Infections, - InfectionSeedingProcess, - SeedInfectionsZeroPad, + InitializeInfectionsZeroPad, ) from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel @@ -27,10 +27,10 @@ def create_test_model(): # numpydoc ignore=GL08 pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) gen_int = DeterministicPMF(pmf_array, name="gen_int") - I0 = InfectionSeedingProcess( - "I0_seeding", + I0 = InfectionInitializationProcess( + "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) latent_infections = Infections()