Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Usage Of npro In Favor Of Just numpyro #316

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ We start by loading the needed components to build a basic renewal model:
# | warning: false
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.latent import (
Expand All @@ -33,7 +33,7 @@ from pyrenew.metaclass import (
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam

numpyro.set_host_device_count(2)
npro.set_host_device_count(2)
```

## Architecture of `RtInfectionsRenewalModel`
Expand Down Expand Up @@ -137,7 +137,7 @@ class MyRt(RandomVariable):
pass

def sample(self, n_steps: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
sd_rt = npro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

rt_rv = TransformedRandomVariable(
"Rt_rv",
Expand Down Expand Up @@ -204,7 +204,7 @@ Using `numpyro`, we can simulate data using the `sample()` member function of `R

```{python}
# | label: simulate
with numpyro.handlers.seed(rng_seed=53):
with npro.handlers.seed(rng_seed=53):
sim_data = model1.sample(n_datapoints=40)

sim_data
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ plt.show()

### Fundamentals

All instances of PyRenew's `RandomVariable` should have at least three functions: `__init__()`, `validate()`, and `sample()`. The `__init__()` function is the constructor and initializes the class. The `validate()` function checks if the class is correctly initialized. Finally, the `sample()` method contains the core of the class; it should return a tuple or named tuple. The following is a minimal example of a `RandomVariable` class based on `numpyro.distributions.Normal`:
All instances of PyRenew's `RandomVariable` should have at least three functions: `__init__()`, `validate()`, and `sample()`. The `__init__()` function is the constructor and initializes the class. The `validate()` function checks if the class is correctly initialized. Finally, the `sample()` method contains the core of the class; it should return a tuple or named tuple. The following is a minimal example of a `RandomVariable` class based on `npro.distributions.Normal`:

```{python}
from pyrenew.metaclass import RandomVariable
Expand Down Expand Up @@ -238,7 +238,7 @@ The core of the class is implemented in the `sample()` method. Things to highlig

2. **Calls to `RandomVariable()`**: All calls to `RandomVariable()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength()` and `infection_feedback_pmf()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.
3. **Saving computed quantities**: Since `Rt_adj` is not generated via `npro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/getting_started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ pip install git+https://github.com/CDCgov/multisignal-epi-inference@main#subdire

## The fundamentals

`pyrenew`'s core components are the metaclasses `RandomVariable` and `Model` (in Python, a _metaclass_ is a class whose instances are also classes, where a _class_ is a template for making objects). Within the `pyrenew` package, a `RandomVariable` is a quantity that models can estimate and sample from, **including deterministic quantities**. The benefit of this design is that the definition of the `sample()` function can be arbitrary, allowing the user to either sample from a distribution using `numpyro.sample()`, compute fixed quantities (like a mechanistic equation), or return a fixed value (like a pre-computed PMF.) For instance, when estimating a PMF, the `RandomVariable` sampling function may roughly be defined as:
`pyrenew`'s core components are the metaclasses `RandomVariable` and `Model` (in Python, a _metaclass_ is a class whose instances are also classes, where a _class_ is a template for making objects). Within the `pyrenew` package, a `RandomVariable` is a quantity that models can estimate and sample from, **including deterministic quantities**. The benefit of this design is that the definition of the `sample()` function can be arbitrary, allowing the user to either sample from a distribution using `npro.sample()`, compute fixed quantities (like a mechanistic equation), or return a fixed value (like a pre-computed PMF.) For instance, when estimating a PMF, the `RandomVariable` sampling function may roughly be defined as:

```python
# define a new class called MyRandVar that inherits from the RandomVariable class
class MyRandVar(RandomVariable):
#define a method called sample that returns an object of type ArrayLike
def sample(...) -> ArrayLike:
# calls sample function from NumPyro package
return numpyro.sample(...)
return npro.sample(...)
```

Whereas, in some other cases, we may instead use a fixed quantity for that variable (like a pre-computed PMF), where the `RandomVariable`'s sample function could instead be defined as:
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ engine: jupyter
```{python}
# | label: numpyro setup
# | echo: false
import numpyro
import numpyro as npro

numpyro.set_host_device_count(2)
npro.set_host_device_count(2)
```

This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data.
Expand Down Expand Up @@ -258,7 +258,7 @@ import numpy as np
timeframe = 120


with numpyro.handlers.seed(rng_seed=223):
with npro.handlers.seed(rng_seed=223):
simulated_data = hosp_model.sample(n_datapoints=timeframe)
```

Expand Down
36 changes: 18 additions & 18 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro as npro
import polars as pl
from jax.typing import ArrayLike
from numpyro.infer import MCMC, NUTS, Predictive
Expand Down Expand Up @@ -214,12 +214,12 @@ def __call__(self, **kwargs):
class DistributionalRV(RandomVariable):
"""
Wrapper class for random variables that sample
from a single :class:`numpyro.distributions.Distribution`.
from a single :class:`npro.distributions.Distribution`.
"""

def __init__(
self,
dist: numpyro.distributions.Distribution,
dist: npro.distributions.Distribution,
name: str,
reparam: Reparam = None,
) -> None:
Expand All @@ -228,12 +228,12 @@ def __init__(

Parameters
----------
dist : numpyro.distributions.Distribution
dist : npro.distributions.Distribution
Distribution of the random variable.
name : str
Name of the random variable.

reparam : numpyro.infer.reparam.Reparam
reparam : npro.infer.reparam.Reparam
If not None, reparameterize sampling
from the distribution according to the
given numpyro reparameterizer
Expand All @@ -259,10 +259,10 @@ def validate(dist: any) -> None:
"""
Validation of the distribution to be implemented in subclasses.
"""
if not isinstance(dist, numpyro.distributions.Distribution):
if not isinstance(dist, npro.distributions.Distribution):
raise ValueError(
"dist should be an instance of "
f"numpyro.distributions.Distribution, got {dist}"
f"npro.distributions.Distribution, got {dist}"
)

return None
Expand All @@ -279,7 +279,7 @@ def sample(
----------
obs : ArrayLike, optional
Observations passed as the `obs` argument to
:fun:`numpyro.sample()`. Default `None`.
:fun:`npro.sample()`. Default `None`.
**kwargs : dict, optional
Additional keyword arguments passed through
to internal sample calls, should there be any.
Expand All @@ -289,8 +289,8 @@ def sample(
tuple
Containing the sampled from the distribution.
"""
with numpyro.handlers.reparam(config=self.reparam_dict):
sample = numpyro.sample(
with npro.handlers.reparam(config=self.reparam_dict):
sample = npro.sample(
name=self.name,
fn=self.dist,
obs=obs,
Expand Down Expand Up @@ -411,11 +411,11 @@ def run(
----------
nuts_args : dict, optional
Dictionary of arguments passed to the
:class:`numpyro.infer.NUTS` kernel.
:class:`npro.infer.NUTS` kernel.
Defaults to None.
mcmc_args : dict, optional
Dictionary of arguments passed to the
:class:`numpyro.infer.MCMC` constructor.
:class:`npro.infer.MCMC` constructor.
Defaults to None.

Returns
Expand Down Expand Up @@ -445,7 +445,7 @@ def print_summary(
exclude_deterministic: bool = True,
) -> None:
"""
A wrapper of :meth:`numpyro.infer.MCMC.print_summary`
A wrapper of :meth:`npro.infer.MCMC.print_summary`

Parameters
----------
Expand Down Expand Up @@ -508,7 +508,7 @@ def posterior_predictive(
**kwargs,
) -> dict:
"""
A wrapper for :class:`numpyro.infer.Predictive` to generate
A wrapper for :class:`npro.infer.Predictive` to generate
posterior predictive samples.

Parameters
Expand All @@ -517,10 +517,10 @@ def posterior_predictive(
Random key for the Predictive function call. Defaults to None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to the
:class:`numpyro.inference.Predictive` constructor.
:class:`npro.infer.Predictive` constructor.
**kwargs
Additional named arguments passed to the
`__call__()` method of :class:`numpyro.infer.Predictive`
`__call__()` method of :class:`npro.infer.Predictive`

Returns
-------
Expand Down Expand Up @@ -559,9 +559,9 @@ def prior_predictive(
rng_key : ArrayLike, optional
Random key for the Predictive function call. Defaults to None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor.
Dictionary of arguments to be passed to the numpyro.infer.Predictive constructor.
**kwargs
Additional named arguments passed to the `__call__()` method of numpyro.inference.Predictive
Additional named arguments passed to the `__call__()` method of numpyro.infer.Predictive

Returns
-------
Expand Down
6 changes: 3 additions & 3 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import NamedTuple

import jax.numpy as jnp
import numpyro
import numpyro as npro
import pyrenew.arrayutils as au
from numpy.typing import ArrayLike
from pyrenew.deterministic import NullObservation
Expand Down Expand Up @@ -217,7 +217,7 @@ def sample(
all_latent_infections = jnp.hstack(
[I0, post_initialization_latent_infections]
)
numpyro.deterministic("all_latent_infections", all_latent_infections)
npro.deterministic("all_latent_infections", all_latent_infections)

if observed_infections is not None:
observed_infections = au.pad_x_to_match_y(
Expand All @@ -233,7 +233,7 @@ def sample(
jnp.nan,
pad_direction="start",
)
numpyro.deterministic("Rt", Rt)
npro.deterministic("Rt", Rt)

return RtInfectionsRenewalSample(
Rt=Rt,
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/observation/negativebinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

import numpyro
import numpyro as npro
import numpyro.distributions as dist
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
Expand Down Expand Up @@ -89,7 +89,7 @@ def sample(
"""
concentration, *_ = self.concentration_rv.sample()

negative_binomial_sample = numpyro.sample(
negative_binomial_sample = npro.sample(
name=self.name,
fn=dist.NegativeBinomial2(
mean=mu + self.eps,
Expand Down
6 changes: 3 additions & 3 deletions model/src/pyrenew/observation/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

import numpyro
import numpyro as npro
import numpyro.distributions as dist
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
Expand All @@ -25,7 +25,7 @@ def __init__(
Parameters
----------
name : str
Passed to numpyro.sample.
Passed to npro.sample.
eps : float, optional
Small value added to the rate parameter to avoid zero values.
Defaults to 1e-8.
Expand Down Expand Up @@ -67,7 +67,7 @@ def sample(
tuple
"""

poisson_sample = numpyro.sample(
poisson_sample = npro.sample(
name=self.name,
fn=dist.Poisson(rate=mu + self.eps),
obs=obs,
Expand Down
8 changes: 4 additions & 4 deletions model/src/pyrenew/process/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import jax.numpy as jnp
import numpyro
import numpyro as npro
import numpyro.distributions as dist
from jax import lax
from jax.typing import ArrayLike
Expand All @@ -31,7 +31,7 @@ def __init__(
Parameters
----------
name : str
Name of the parameter passed to numpyro.sample.
Name of the parameter passed to npro.sample.
mean: float
Mean parameter.
autoreg : ArrayLike
Expand Down Expand Up @@ -75,7 +75,7 @@ def sample(
"""
order = self.autoreg.shape[0]
if inits is None:
inits = numpyro.sample(
inits = npro.sample(
self.name + "_sampled_inits",
dist.Normal(0, self.noise_sd).expand((order,)),
)
Expand All @@ -85,7 +85,7 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08
new_carry = jnp.hstack([new_term, carry[: (order - 1)]])
return new_carry, new_term

noise = numpyro.sample(
noise = npro.sample(
self.name + "_noise",
dist.Normal(0, self.noise_sd).expand((duration - inits.size,)),
)
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/simplerandomwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
----------
name : str
A name for the random variable, used to
name sites within it in :fun :`numpyro.sample()`
name sites within it in :fun :`npro.sample()`
calls.
step_rv : RandomVariable
RandomVariable representing the step distribution.
Expand Down
14 changes: 7 additions & 7 deletions model/src/pyrenew/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from abc import ABCMeta, abstractmethod

import numpyro
import numpyro as npro
import numpyro.distributions as dist
import pyrenew.transformation as t
from jax.typing import ArrayLike
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
name : str
The name of the observation process,
which will be used to name the constituent
sampled parameters in calls to `numpyro.sample`
sampled parameters in calls to `npro.sample`

fixed_predictor_values : ArrayLike (n_predictors, n_observations)
Matrix of fixed values of the predictor variables
Expand All @@ -70,7 +70,7 @@ def __init__(
Prior distribution for the regression intercept
value

coefficient_priors : numpyro.distributions.Distribution
coefficient_priors : npro.distributions.Distribution
Vectorized prior distribution for the regression
coefficient values

Expand All @@ -82,11 +82,11 @@ def __init__(

intercept_suffix : str, optional
Suffix for naming the intercept random variable in
class to numpyro.sample(). Default `"_intercept"`.
class to npro.sample(). Default `"_intercept"`.

coefficient_suffix : str, optional
Suffix for naming the regression coefficient
random variables in calls to numpyro.sample().
random variables in calls to npro.sample().
Default `"_coefficients"`.
"""
if transform is None:
Expand Down Expand Up @@ -135,10 +135,10 @@ def sample(self) -> dict:
A dictionary containing transformed predictions, and
the intercept and coefficients sample distributions.
"""
intercept = numpyro.sample(
intercept = npro.sample(
self.name + self.intercept_suffix, self.intercept_prior
)
coefficients = numpyro.sample(
coefficients = npro.sample(
self.name + self.coefficient_suffix, self.coefficient_priors
)
prediction = self.predict(intercept, coefficients)
Expand Down
Loading