Skip to content

Commit

Permalink
Revert "Change Certain numpyros To npro (#316)"
Browse files Browse the repository at this point in the history
This reverts commit ddb45cd.
  • Loading branch information
damonbayer authored Jul 25, 2024
1 parent ddb45cd commit 0a3113f
Show file tree
Hide file tree
Showing 17 changed files with 53 additions and 53 deletions.
8 changes: 4 additions & 4 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Before we start, let's simulate the model with the original `InfectionsWithFeedb
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
Expand Down Expand Up @@ -90,7 +90,7 @@ And simulate from it:
# | label: simulate1
# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model0_samp = model0.sample(n_datapoints=30)
```

Expand Down Expand Up @@ -222,7 +222,7 @@ class InfFeedback(RandomVariable):
)
# Storing adjusted Rt for future use
numpyro.deterministic("Rt_adjusted", Rt_adj)
npro.deterministic("Rt_adjusted", Rt_adj)
# Preparing theoutput
Expand Down Expand Up @@ -259,7 +259,7 @@ model1 = RtInfectionsRenewalModel(
# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_datapoints=30)
```

Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The `RtPeriodicDiff` and `RtWeeklyDiff` classes use `PeriodicBroadcaster` to rep
# | warning: false
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro as npro
from pyrenew import process, deterministic
```

Expand All @@ -40,7 +40,7 @@ rt_proc = process.RtWeeklyDiffProcess(
```

```{python}
with numpyro.handlers.seed(rng_seed=20):
with npro.handlers.seed(rng_seed=20):
sim_data = rt_proc(duration=30)
# Plotting the Rt values
Expand Down Expand Up @@ -84,7 +84,7 @@ dayofweek = process.DayOfWeekEffect(
Like before, we can use the `sample` method to generate samples from the day of week effect:

```{python}
with numpyro.handlers.seed(rng_seed=20):
with npro.handlers.seed(rng_seed=20):
sim_data = dayofweek(duration=30)
# Plotting the effect values
Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import jax.numpy as jnp
import numpyro
import numpyro as npro
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable

Expand Down Expand Up @@ -86,5 +86,5 @@ def sample(
Containing the stored values during construction.
"""
if record:
numpyro.deterministic(self.name, self.vars)
npro.deterministic(self.name, self.vars)
return (self.vars,)
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, NamedTuple

import jax.numpy as jnp
import numpyro
import numpyro as npro
from jax.typing import ArrayLike
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable
Expand Down Expand Up @@ -191,7 +191,7 @@ def sample(
latent_hospital_admissions * self.hosp_report_prob_rv(**kwargs)[0]
)

numpyro.deterministic(
npro.deterministic(
"latent_hospital_admissions", latent_hospital_admissions
)

Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/infection_initialization_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
import numpyro
import numpyro as npro
from pyrenew.latent.infection_initialization_method import (
InfectionInitializationMethod,
)
Expand Down Expand Up @@ -97,6 +97,6 @@ def sample(self) -> tuple:
"""
(I_pre_init,) = self.I_pre_init_rv()
infection_initialization = self.infection_init_method(I_pre_init)
numpyro.deterministic(self.name, infection_initialization)
npro.deterministic(self.name, infection_initialization)

return (infection_initialization,)
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import NamedTuple

import jax.numpy as jnp
import numpyro
import numpyro as npro
import pyrenew.arrayutils as au
import pyrenew.latent.infection_functions as inf
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -192,7 +192,7 @@ def sample(

# Appending initial infections to the infections

numpyro.deterministic("Rt_adjusted", Rt_adj)
npro.deterministic("Rt_adjusted", Rt_adj)

return InfectionsRtFeedbackSample(
post_initialization_infections=post_initialization_infections,
Expand Down
6 changes: 3 additions & 3 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def posterior_predictive(
Random key for the Predictive function call. Defaults to None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to the
:class:`numpyro.infer.Predictive` constructor.
:class:`numpyro.inference.Predictive` constructor.
**kwargs
Additional named arguments passed to the
`__call__()` method of :class:`numpyro.infer.Predictive`
Expand Down Expand Up @@ -559,9 +559,9 @@ def prior_predictive(
rng_key : ArrayLike, optional
Random key for the Predictive function call. Defaults to None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to the numpyro.infer.Predictive constructor.
Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor.
**kwargs
Additional named arguments passed to the `__call__()` method of numpyro.infer.Predictive
Additional named arguments passed to the `__call__()` method of numpyro.inference.Predictive
Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import numpyro as npro
import numpyro.distributions as dist
import pyrenew.transformation as t
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_forecast():

n_datapoints = 30
n_forecast_points = 10
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model_sample = model.sample(n_datapoints=n_datapoints)

model.run(
Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_infection_seeding_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# numpydoc ignore=GL08
import jax.numpy as jnp
import numpyro
import numpyro as npro
import numpyro.distributions as dist
import pytest
from pyrenew.deterministic import DeterministicVariable
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_infection_initialization_process():
)

for model in [zero_pad_model, exp_model, vec_model]:
with numpyro.handlers.seed(rng_seed=1):
with npro.handlers.seed(rng_seed=1):
model()

# Check that the InfectionInitializationProcess class raises an error when the wrong type of I0 is passed
Expand Down
6 changes: 3 additions & 3 deletions model/src/test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro as npro
import pyrenew.latent as latent
from jax.typing import ArrayLike
from numpy.testing import assert_array_almost_equal, assert_array_equal
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_infectionsrtfeedback():

infections = latent.Infections()

with numpyro.handlers.seed(rng_seed=0):
with npro.handlers.seed(rng_seed=0):
samp1 = InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_infectionsrtfeedback_feedback():

infections = latent.Infections()

with numpyro.handlers.seed(rng_seed=0):
with npro.handlers.seed(rng_seed=0):
samp1 = InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
Expand Down
8 changes: 4 additions & 4 deletions model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as testing
import numpyro
import numpyro as npro
import numpyro.distributions as dist
from pyrenew import transformation as t
from pyrenew.deterministic import DeterministicPMF
Expand Down Expand Up @@ -32,15 +32,15 @@ def test_admissions_sample():
transforms=t.ExpTransform(),
)

with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_rt, *_ = rt(n_steps=30)

gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1])
i0 = 10 * jnp.ones_like(gen_int)

inf1 = Infections()

with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
inf_sampled1 = inf1(Rt=sim_rt, gen_int=gen_int, I0=i0)

# Testing the hospital admissions
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_admissions_sample():
),
)

with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0])

testing.assert_array_less(
Expand Down
8 changes: 4 additions & 4 deletions model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as testing
import numpyro
import numpyro as npro
import numpyro.distributions as dist
import pyrenew.transformation as t
import pytest
Expand All @@ -30,7 +30,7 @@ def test_infections_as_deterministic():
transforms=t.ExpTransform(),
)

with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_rt, *_ = rt(n_steps=30)

gen_int = jnp.array([0.25, 0.25, 0.25, 0.25])
Expand All @@ -42,7 +42,7 @@ def test_infections_as_deterministic():
I0=jnp.zeros(gen_int.size),
gen_int=gen_int,
)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
inf_sampled1 = inf1(**obs)
inf_sampled2 = inf1(**obs)

Expand All @@ -52,7 +52,7 @@ def test_infections_as_deterministic():
)

# Check that Initial infections vector must be at least as long as the generation interval.
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with pytest.raises(ValueError):
obs["I0"] = jnp.array([1])
inf1(**obs)
14 changes: 7 additions & 7 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import numpyro as npro
import numpyro.distributions as dist
import polars as pl
import pyrenew.transformation as t
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_model_basicrenewal_no_timepoints_or_observations():
)

np.random.seed(2203)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with pytest.raises(ValueError, match="Either"):
model1.sample(n_datapoints=None, data_observed_infections=None)

Expand Down Expand Up @@ -103,7 +103,7 @@ def test_model_basicrenewal_both_timepoints_and_observations():
)

np.random.seed(2203)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with pytest.raises(ValueError, match="Cannot pass both"):
model1.sample(
n_datapoints=30,
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_model_basicrenewal_no_obs_model():

# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model0_samp = model0.sample(n_datapoints=30)
model0_samp.Rt
model0_samp.latent_infections
Expand All @@ -155,7 +155,7 @@ def test_model_basicrenewal_no_obs_model():
# Generating
model0.infection_obs_process_rv = NullObservation()
np.random.seed(223)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model0.sample(n_datapoints=30)

np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt)
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_model_basicrenewal_with_obs_model():

# Sampling and fitting model 1 (with obs infections)
np.random.seed(2203)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_datapoints=30)

model1.run(
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08
# Sampling and fitting model 1 (with obs infections)
np.random.seed(2203)
pad_size = 5
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_datapoints=30, padding=pad_size)

model1.run(
Expand Down
6 changes: 3 additions & 3 deletions model/src/test/test_observation_negativebinom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import numpy.testing as testing
import numpyro
import numpyro as npro
from jax.typing import ArrayLike
from pyrenew.deterministic import DeterministicVariable
from pyrenew.observation import NegativeBinomialObservation
Expand All @@ -21,7 +21,7 @@ def test_negativebinom_deterministic_obs():

np.random.seed(223)
rates = np.random.randint(1, 5, size=10)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_nb1 = negb(mu=rates, obs=rates)
sim_nb2 = negb(mu=rates, obs=rates)

Expand All @@ -48,7 +48,7 @@ def test_negativebinom_random_obs():

np.random.seed(223)
rates = np.repeat(5, 20000)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_nb1 = negb(mu=rates)
sim_nb2 = negb(mu=rates)
assert isinstance(sim_nb1, tuple)
Expand Down
4 changes: 2 additions & 2 deletions model/src/test/test_observation_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as testing
import numpyro
import numpyro as npro
from pyrenew.observation import PoissonObservation


Expand All @@ -17,7 +17,7 @@ def test_poisson_obs():

np.random.seed(223)
rates = np.random.randint(1, 5, size=10)
with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)):
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
sim_pois, *_ = pois(mu=rates)

testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois))
Loading

0 comments on commit 0a3113f

Please sign in to comment.