Skip to content

Commit

Permalink
Rt with infection feedback (#123)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Dylan H. Morris <dzl1@cdc.gov>
Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>
Co-authored-by: Damon Bayer <xum8@cdc.gov>
  • Loading branch information
4 people authored Jun 3, 2024
1 parent f9c057a commit f11ce38
Show file tree
Hide file tree
Showing 27 changed files with 1,399 additions and 128 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,4 @@ replay_pid*
.DS_Store

/.quarto/
*_files
Empty file modified .pre-commit-rst-placeholder.sh
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ This section contains tutorials that demonstrate how to use the `pyrenew` packag
getting-started
example-with-datasets
pyrenew_demo
extending_pyrenew
2 changes: 2 additions & 0 deletions model/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/.quarto/
_compiled
20 changes: 12 additions & 8 deletions model/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ test:
poetry run pytest --mpl --mpl-default-tolerance=10

docs: docs/pyrenew_demo.md docs/getting-started.md \
docs/example-with-datasets.md
docs/example-with-datasets.md docs/extending_pyrenew.md

docs/pyrenew_demo.md: docs/pyrenew_demo.qmd
poetry run quarto render docs/pyrenew_demo.qmd
Expand All @@ -37,16 +37,20 @@ docs/getting-started.md: docs/getting-started.qmd
docs/example-with-datasets.md: docs/example-with-datasets.qmd
poetry run quarto render docs/example-with-datasets.qmd

docs/extending_pyrenew.md: docs/extending_pyrenew.qmd
poetry run quarto render docs/extending_pyrenew.qmd

docs/py: docs/notebooks
jupyter nbconvert --to python docs/pyrenew_demo.ipynb
jupyter nbconvert --to python docs/getting-started.ipynb
jupyter nbconvert --to python docs/example-with-datasets.ipynb
for i in docs/*.ipynb; do \
jupyter nbconvert --to python $$i ; \
done

docs/notebooks:
quarto convert docs/pyrenew_demo.qmd --output docs/pyrenew_demo.ipynb
quarto convert docs/getting-started.qmd --output docs/getting-started.ipynb
quarto convert docs/example-with-datasets.qmd --output \
docs/example-with-datasets.ipynb
for i in docs/*.qmd; do \
if [ $$i -nt $$(basename $$i .qmd).ipynb ]; then \
quarto convert $$i --output docs/$$(basename $$i .qmd).ipynb ; \
fi \
done

test_images:
echo "Generating reference images for tests"
Expand Down
3 changes: 3 additions & 0 deletions model/docs/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@

/.quarto/
_compiled

*.ipynb
*.py
256 changes: 256 additions & 0 deletions model/docs/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
---
title: Extending pyrenew
format: gfm
---

This tutorial illustrates how to extend `pyrenew` with custom `RandomVariable` classes. We will use the `InfectionsWithFeedback` class as an example. The `InfectionsWithFeedback` class is a `RandomVariable` that models the number of infections at time $t$ as a function of the number of infections at time $t - \tau$ and the reproduction number at time $t$. The reproduction number at time $t$ is a function of the *unadjusted* reproduction number at time $t - \tau$ and the number of infections at time $t - \tau$:

$$
\begin{align*}
I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\
\mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(-\gamma(t)\sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right)
\end{align*}
$$

Where $\mathcal{R}^u(t)$ is the unadjusted reproduction number, $g(t)$ is the generation interval, $\gamma(t)$ is the infection feedback strength, and $f(t)$ is the infection feedback pmf.

## The expected outcome

Before we start, let's simulate the model with the original `InfectionsWithFeedback` class. To keep it simple, we will simulate the model with no observation process, in other words, only with latent infections. The following code-chunk loads the required libraries and defines the model components:

```{python}
#| label: setup
import jax
import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import RtRandomWalkProcess
from pyrenew.metaclass import DistributionalRV
```

The following code-chunk defines the model components. Notice that for both the generation interval and the infection feedback, we use a deterministic PMF with equal probabilities:

```{python}
#| label: model-components
gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1]))
feedback_strength = DeterministicVariable(0.05)
I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")
latent_infections = InfectionsWithFeedback(
infection_feedback_strength = feedback_strength,
infection_feedback_pmf = gen_int,
)
rt = RtRandomWalkProcess()
```

With all the components defined, we can build the model:

```{python}
#| label: build1
model0 = RtInfectionsRenewalModel(
gen_int=gen_int,
I0=I0,
latent_infections=latent_infections,
Rt_process=rt,
observation_process=None,
)
```

And simulate it from:

```{python}
#| label: simulate1
# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model0_samp = model0.sample(n_timepoints=30)
```

```{python}
#| label: fig-simulate1
#| fig-cap: Simulated infections with no observation process
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()
```

## Pyrenew's random variable class

### Fundamentals

All instances of PyRenew's `RandomVariable` should have at least three functions: `__init__()`, `validate()`, and `sample()`. The `__init__()` function is the constructor and initializes the class. The `validate()` function checks if the class is correctly initialized. Finally, the `sample()` method contains the core of the class; it should return a tuple or named tuple. The following is a minimal example of a `RandomVariable` class based on `numpyro.distributions.Normal`:

```python
from pyrenew.metaclass import RandomVariable

class MyNormal(RandomVariable):
def __init__(self, loc, scale):
self.validate(scale)
self.loc = loc
self.scale = scale
return None

@staticmethod
def validate(self):
if self.scale <= 0:
raise ValueError("Scale must be positive")
return None

def sample(self, **kwargs):
return (dist.Normal(loc=self.loc, scale=self.scale),)
```

The `@staticmethod` decorator exposes the `validate` function to be used outside the class. Next, we show how to build a more complex `RandomVariable` class; the `InfectionsWithFeedback` class.

### The `InfectionsWithFeedback` class

Although returning namedtuples is not strictly required, they are the recommended return type, as they make the code more readable. The following code-chunk shows how to create a named tuple for the `InfectionsWithFeedback` class:

```{python}
#| label: data-class
from collections import namedtuple
# Creating a tuple to store the output
InfFeedbackSample = namedtuple(
typename='InfFeedbackSample',
field_names=['infections', 'rt'],
defaults=(None, None)
)
```

The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.datautils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:

```{python}
#| label: new-model-def
#| code-line-numbers: true
# Creating the class
from pyrenew.metaclass import RandomVariable
from pyrenew.latent import compute_infections_from_rt_with_feedback
from pyrenew import datautils as du
from jax.typing import ArrayLike
import jax.numpy as jnp
class InfFeedback(RandomVariable):
"""Latent infections"""
def __init__(
self,
infection_feedback_strength: RandomVariable,
infection_feedback_pmf: RandomVariable,
infections_mean_varname: str = "latent_infections",
) -> None:
"""Constructor"""
self.infection_feedback_strength = infection_feedback_strength
self.infection_feedback_pmf = infection_feedback_pmf
self.infections_mean_varname = infections_mean_varname
return None
def validate(self):
"""
Generally, this method should be more meaningful, but we will skip it for now
"""
return None
def sample(
self,
Rt: ArrayLike,
I0: ArrayLike,
gen_int: ArrayLike,
**kwargs,
) -> tuple:
"""Sample infections with feedback"""
# Generation interval
gen_int_rev = jnp.flip(gen_int)
# Baseline infections
I0_vec = I0[-gen_int_rev.size :]
# Sampling inf feedback strength and adjusting the shape
inf_feedback_strength, *_ = self.infection_feedback_strength.sample(
**kwargs,
)
inf_feedback_strength = du.pad_x_to_match_y(
x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0]
)
# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)
# Generating the infections with feedback
all_infections, Rt_adj = compute_infections_from_rt_with_feedback(
I0=I0_vec,
Rt_raw=Rt,
infection_feedback_strength=inf_feedback_strength,
reversed_generation_interval_pmf=gen_int_rev,
reversed_infection_feedback_pmf=inf_fb_pmf_rev,
)
# Storing adjusted Rt for future use
npro.deterministic("Rt_adjusted", Rt_adj)
# Preparing theoutput
return InfFeedbackSample(
infections=all_infections,
rt=Rt_adj,
)
```

The core of the class is implemented in the `sample()` method. Things to highlight from the above code:

1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback.sample()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable.sample()` calls are expected to include the `**kwargs` argument, even if unused.

2. **Calls to `RandomVariable.sample()`**: All calls to `RandomVariable.sample()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength.sample()` and `infection_feedback_pmf.sample()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback.sample()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

```{python}
#| label: simulation2
latent_infections2 = InfFeedback(
infection_feedback_strength = feedback_strength,
infection_feedback_pmf = gen_int,
)
model1 = RtInfectionsRenewalModel(
gen_int=gen_int,
I0=I0,
latent_infections=latent_infections2,
Rt_process=rt,
observation_process=None,
)
# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_timepoints=30)
```

Comparing `model0` with `model1`, these two should match:

```{python}
#| label: fig-model0-vs-model1
#| fig-cap: Comparing latent infections from model 0 and model 1
import matplotlib.pyplot as plt
fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].set_xlabel("Time (model 0)")
ax[1].set_xlabel("Time (model 1)")
ax[0].set_ylabel("Infections")
plt.show()
```
Loading

0 comments on commit f11ce38

Please sign in to comment.