Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renaming The Infections Seeding Initialization #261

Merged
merged 10 commits into from
Jul 12, 2024
8 changes: 4 additions & 4 deletions docs/source/msei_reference/latent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ Infection Functions
:undoc-members:
:show-inheritance:

Infection Seeding Process
Infection Initialization Process
-------------------------

.. automodule:: pyrenew.latent.infection_seeding_process
.. automodule:: pyrenew.latent.infection_initialization_process
:members:
:undoc-members:
:show-inheritance:

Infection Seeding Method
Infection Initialization Method
------------------------

.. automodule:: pyrenew.latent.infection_seeding_method
.. automodule:: pyrenew.latent.infection_initialization_method
:members:
:undoc-members:
:show-inheritance:
Expand Down
14 changes: 7 additions & 7 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import numpyro.distributions as dist
from pyrenew.process import RtRandomWalkProcess
from pyrenew.latent import (
Infections,
InfectionSeedingProcess,
SeedInfectionsZeroPad,
InfectionInitializationProcess,
InitializeInfectionsZeroPad,
)
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
Expand Down Expand Up @@ -99,7 +99,7 @@ To initialize these five components within the renewal modeling framework, we es

(1) In this example, the generation interval is not estimated but passed as a deterministic instance of `RandomVariable`

(2) an instance of the `InfectionSeedingProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `SeedInfectionsZeroPad`, the latent infections before this time are assumed to be 0.
(2) an instance of the `InfectionInitializationProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `InitializeInfectionsZeroPad`, the latent infections before this time are assumed to be 0.

(3) an instance of the `RtRandomWalkProcess` class with default values

Expand All @@ -114,10 +114,10 @@ pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")

# (2) Initial infections (inferred with a prior)
I0 = InfectionSeedingProcess(
"I0_seeding",
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
SeedInfectionsZeroPad(pmf_array.size),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)

Expand Down Expand Up @@ -155,7 +155,7 @@ The following diagram summarizes how the modules interact via composition; notab
%%| include: true
flowchart TB
genint["(1) gen_int\n(DetermnisticPMF)"]
i0["(2) I0\n(InfectionSeedingProcess)"]
i0["(2) I0\n(InfectionInitializationProcess)"]
rt["(3) rt_proc\n(RtRandomWalkProcess)"]
inf["(4) latent_infections\n(Infections)"]
obs["(5) observation_process\n(PoissonObservation)"]
Expand Down
10 changes: 5 additions & 5 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import RtRandomWalkProcess
from pyrenew.metaclass import DistributionalRV
from pyrenew.latent import (
InfectionSeedingProcess,
SeedInfectionsExponentialGrowth,
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
)
import pyrenew.transformation as t
```
Expand All @@ -45,10 +45,10 @@ gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])
gen_int = DeterministicPMF(gen_int_array, name="gen_int")
feedback_strength = DeterministicVariable(0.05, name="feedback_strength")

I0 = InfectionSeedingProcess(
"I0_seeding",
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
SeedInfectionsExponentialGrowth(
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(0.5, name="rate"),
),
Expand Down
10 changes: 5 additions & 5 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,18 @@ The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to ho
# | label: initializing-rest-of-model
from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import (
InfectionSeedingProcess,
SeedInfectionsExponentialGrowth,
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
)

# Infection process
latent_inf = latent.Infections()
I0 = InfectionSeedingProcess(
"I0_seeding",
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
dist=dist.LogNormal(loc=jnp.log(100), scale=0.5), name="I0"
),
SeedInfectionsExponentialGrowth(
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
deterministic.DeterministicVariable(0.5, name="rate"),
),
Expand Down
11 changes: 7 additions & 4 deletions docs/source/tutorials/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.model import HospitalAdmissionsModel
from pyrenew.process import RtRandomWalkProcess
from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsZeroPad
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsZeroPad,
)
import pyrenew.transformation as t
```

Expand All @@ -93,10 +96,10 @@ pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")

# 2) Initial infections
I0 = InfectionSeedingProcess(
"I0_seeding",
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
SeedInfectionsZeroPad(pmf_array.size),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)

Expand Down
24 changes: 13 additions & 11 deletions model/src/pyrenew/latent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
compute_infections_from_rt_with_feedback,
logistic_susceptibility_adjustment,
)
from pyrenew.latent.infection_seeding_method import (
InfectionSeedMethod,
SeedInfectionsExponentialGrowth,
SeedInfectionsFromVec,
SeedInfectionsZeroPad,
from pyrenew.latent.infection_initialization_method import (
InfectionInitializationMethod,
InitializeInfectionsExponentialGrowth,
InitializeInfectionsFromVec,
InitializeInfectionsZeroPad,
)
from pyrenew.latent.infection_initialization_process import (
InfectionInitializationProcess,
)
from pyrenew.latent.infection_seeding_process import InfectionSeedingProcess
from pyrenew.latent.infections import Infections
from pyrenew.latent.infectionswithfeedback import InfectionsWithFeedback

Expand All @@ -24,10 +26,10 @@
"logistic_susceptibility_adjustment",
"compute_infections_from_rt",
"compute_infections_from_rt_with_feedback",
"InfectionSeedMethod",
"SeedInfectionsExponentialGrowth",
"SeedInfectionsFromVec",
"SeedInfectionsZeroPad",
"InfectionSeedingProcess",
"InfectionInitializationMethod",
"InitializeInfectionsExponentialGrowth",
"InitializeInfectionsFromVec",
"InitializeInfectionsZeroPad",
"InfectionInitializationProcess",
"InfectionsWithFeedback",
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from pyrenew.metaclass import RandomVariable


class InfectionSeedMethod(metaclass=ABCMeta):
class InfectionInitializationMethod(metaclass=ABCMeta):
"""Method for seeding initial infections in a renewal process."""

def __init__(self, n_timepoints: int):
"""Default constructor for the ``InfectionSeedMethod`` class.
"""Default constructor for the ``InfectionInitializationMethod`` class.

Parameters
----------
n_timepoints : int
the number of time points to generate seed infections for
the number of time points to generate initial infections for

Returns
-------
Expand All @@ -27,12 +27,12 @@ def __init__(self, n_timepoints: int):

@staticmethod
def validate(n_timepoints: int) -> None:
"""Validate inputs for the ``InfectionSeedMethod`` class constructor
"""Validate inputs for the ``InfectionInitializationMethod`` class constructor

Parameters
----------
n_timepoints : int
the number of time points to generate seed infections for
the number of time points to generate initial infections for

Returns
-------
Expand All @@ -54,7 +54,7 @@ def seed_infections(self, I_pre_seed: ArrayLike):
Parameters
----------
I_pre_seed : ArrayLike
An array representing some number of latent infections to be used with the specified ``InfectionSeedMethod``.
An array representing some number of latent infections to be used with the specified ``InfectionInitializationMethod``.

Returns
-------
Expand All @@ -66,15 +66,15 @@ def __call__(self, I_pre_seed: ArrayLike):
return self.seed_infections(I_pre_seed)


class SeedInfectionsZeroPad(InfectionSeedMethod):
class InitializeInfectionsZeroPad(InfectionInitializationMethod):
"""
Create a seed infection vector of specified length by
Create an initial infection vector of specified length by
padding a shorter vector with an appropriate number of
zeros at the beginning of the time series.
"""

def seed_infections(self, I_pre_seed: ArrayLike):
"""Pad the seed infections with zeros at the beginning of the time series.
"""Pad the initial infections with zeros at the beginning of the time series.

Parameters
----------
Expand All @@ -95,16 +95,16 @@ def seed_infections(self, I_pre_seed: ArrayLike):
return jnp.pad(I_pre_seed, (self.n_timepoints - I_pre_seed.size, 0))


class SeedInfectionsFromVec(InfectionSeedMethod):
"""Create seed infections from a vector of infections."""
class InitializeInfectionsFromVec(InfectionInitializationMethod):
"""Create initial infections from a vector of infections."""

def seed_infections(self, I_pre_seed: ArrayLike):
"""Create seed infections from a vector of infections.
"""Create initial infections from a vector of infections.

Parameters
----------
I_pre_seed : ArrayLike
An array with the same length as ``n_timepoints`` to be used as the seed infections.
An array with the same length as ``n_timepoints`` to be used as the initial infections.

Returns
-------
Expand All @@ -120,8 +120,8 @@ def seed_infections(self, I_pre_seed: ArrayLike):
return jnp.array(I_pre_seed)


class SeedInfectionsExponentialGrowth(InfectionSeedMethod):
r"""Generate seed infections according to exponential growth.
class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod):
r"""Generate initial infections according to exponential growth.

Notes
-----
Expand All @@ -142,12 +142,12 @@ def __init__(
rate: RandomVariable,
t_pre_seed: int | None = None,
):
"""Default constructor for the ``SeedInfectionsExponentialGrowth`` class.
"""Default constructor for the ``InitializeInfectionsExponentialGrowth`` class.

Parameters
----------
n_timepoints : int
the number of time points to generate seed infections for
the number of time points to generate initial infections for
rate : RandomVariable
A random variable representing the rate of exponential growth
t_pre_seed : int | None, optional
Expand All @@ -160,7 +160,7 @@ def __init__(
self.t_pre_seed = t_pre_seed

def seed_infections(self, I_pre_seed: ArrayLike):
"""Generate seed infections according to exponential growth.
"""Generate initial infections according to exponential growth.

Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
import numpyro as npro
from pyrenew.latent.infection_seeding_method import InfectionSeedMethod
from pyrenew.latent.infection_initialization_method import (
InfectionInitializationMethod,
)
from pyrenew.metaclass import RandomVariable


class InfectionSeedingProcess(RandomVariable):
class InfectionInitializationProcess(RandomVariable):
"""Generate an initial infection history"""

def __init__(
self,
name,
I_pre_seed_rv: RandomVariable,
infection_seed_method: InfectionSeedMethod,
infection_seed_method: InfectionInitializationMethod,
t_unit: int,
t_start: int | None = None,
) -> None:
"""Default class constructor for InfectionSeedingProcess
"""Default class constructor for InfectionInitializationProcess

Parameters
----------
name : str
A name to assign to the RandomVariable.
I_pre_seed_rv : RandomVariable
A RandomVariable representing the number of infections that occur at some time before the renewal process begins. Each `infection_seed_method` uses this random variable in different ways.
infection_seed_method : InfectionSeedMethod
An `InfectionSeedMethod` that generates the seed infections for the renewal process.
infection_seed_method : InfectionInitializationMethod
An `InfectionInitializationMethod` that generates the initial infections for the renewal process.
t_unit : int
The unit of time for the time series passed to `RandomVariable.set_timeseries`.
t_start : int, optional
Expand All @@ -39,7 +41,9 @@ def __init__(
-------
None
"""
InfectionSeedingProcess.validate(I_pre_seed_rv, infection_seed_method)
InfectionInitializationProcess.validate(
I_pre_seed_rv, infection_seed_method
)

self.I_pre_seed_rv = I_pre_seed_rv
self.infection_seed_method = infection_seed_method
Expand All @@ -55,16 +59,16 @@ def __init__(
@staticmethod
def validate(
I_pre_seed_rv: RandomVariable,
infection_seed_method: InfectionSeedMethod,
infection_seed_method: InfectionInitializationMethod,
) -> None:
"""Validate the input arguments to the InfectionSeedingProcess class constructor
"""Validate the input arguments to the InfectionInitializationProcess class constructor

Parameters
----------
I_pre_seed_rv : RandomVariable
A random variable representing the number of infections that occur at some time before the renewal process begins.
infection_seed_method : InfectionSeedMethod
An method to generate the seed infections.
infection_seed_method : InfectionInitializationMethod
An method to generate the initial infections.

Returns
-------
Expand All @@ -75,14 +79,16 @@ def validate(
"I_pre_seed_rv must be an instance of RandomVariable"
f"Got {type(I_pre_seed_rv)}"
)
if not isinstance(infection_seed_method, InfectionSeedMethod):
if not isinstance(
infection_seed_method, InfectionInitializationMethod
):
raise TypeError(
"infection_seed_method must be an instance of InfectionSeedMethod"
"infection_seed_method must be an instance of InfectionInitializationMethod"
f"Got {type(infection_seed_method)}"
)

def sample(self) -> tuple:
"""Sample the infection seeding process.
"""Sample the Infection Initialization Process.

Returns
-------
Expand Down
Loading