Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process returns tuple with array equal to duration (patch to PR 123) #143

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions model/docs/example-with-datasets.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(range(0, timeframe + 1), sim_data.Rt)
axs[0].plot(range(0, timeframe), sim_data.Rt)
axs[0].set_ylabel('Rt')

# Infections plot
axs[1].plot(range(0, timeframe + 1), sim_data.sampled_admissions)
axs[1].plot(range(0, timeframe), sim_data.sampled_admissions)
axs[1].set_ylabel('Infections')
axs[1].set_yscale('log')

Expand All @@ -214,7 +214,7 @@ plt.show()

## Fitting the model

We can fit the model to the data. We will use the `run` method of the model object. The two inputs this model requires are `n_timepoints` and `observed_admissions`
We can fit the model to the data. We will use the `run` method of the model object:


```{python}
Expand All @@ -224,7 +224,6 @@ import jax
hosp_model.run(
num_samples=2000,
num_warmup=2000,
n_timepoints=dat.shape[0] - 1,
observed_admissions=dat["daily_hosp_admits"].to_numpy(),
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
Expand Down Expand Up @@ -263,7 +262,6 @@ dat_w_padding = np.hstack((np.repeat(np.nan, days_to_impute), dat_w_padding))
hosp_model.run(
num_samples=2000,
num_warmup=2000,
n_timepoints=dat_w_padding.shape[0] - 1,
observed_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
Expand Down Expand Up @@ -361,7 +359,6 @@ Running the model (with the same padding as before):
hosp_model_weekday.run(
num_samples=2000,
num_warmup=2000,
n_timepoints=dat_w_padding.shape[0] - 1,
observed_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
Expand Down
5 changes: 2 additions & 3 deletions model/docs/getting-started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(range(0, 31), sim_data.Rt)
axs[0].plot(range(0, len(sim_data.Rt)), sim_data.Rt)
axs[0].set_ylabel('Rt')

# Infections plot
axs[1].plot(range(0, 31), sim_data.sampled_infections)
axs[1].plot(range(0, len(sim_data.Rt)), sim_data.sampled_infections)
axs[1].set_ylabel('Infections')

fig.suptitle('Basic renewal model')
Expand All @@ -177,7 +177,6 @@ model1.run(
num_warmup=2000,
num_samples=1000,
observed_infections=sim_data.sampled_infections,
n_timepoints = len(sim_data[1])-1,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
1 change: 0 additions & 1 deletion model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ hospmodel.run(
num_warmup=1000,
num_samples=1000,
observed_admissions=x.sampled_admissions,
n_timepoints = len(x.sampled_admissions)-1,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
6 changes: 3 additions & 3 deletions model/src/pyrenew/deterministic/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DeterministicProcess(DeterministicVariable):

def sample(
self,
n_timepoints: int,
duration: int,
**kwargs,
) -> tuple:
"""
Expand All @@ -34,9 +34,9 @@ def sample(

res, *_ = super().sample(**kwargs)

dif = n_timepoints - res.shape[0]
dif = duration - res.shape[0]

if dif > 0:
return (jnp.hstack([res, jnp.repeat(res[-1], dif)]),)

return (res[:n_timepoints],)
return (res[:duration],)
23 changes: 21 additions & 2 deletions model/src/pyrenew/model/admissionsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@

def sample(
self,
n_timepoints: int,
n_timepoints: int | None = None,
observed_admissions: ArrayLike | None = None,
padding: int = 0,
**kwargs,
Expand All @@ -194,7 +194,7 @@

Parameters
----------
n_timepoints : int
n_timepoints : int, optional
Number of timepoints to sample (passed to the basic renewal model).
observed_admissions : ArrayLike, optional
The observed hospitalization data (passed to the basic renewal
Expand All @@ -206,6 +206,12 @@
Additional keyword arguments passed through to internal sample()
calls, should there be any.

Notes
-----
When `observed_admissions` is None, `n_timepoints` must be specified.
If both are specified, they must have the same length, otherwise an
exception is raised.

Returns
-------
HospModelSample
Expand All @@ -217,6 +223,19 @@
sample_observed_admissions : For sampling observed hospital admissions
"""

if n_timepoints is None:
if observed_admissions is not None:
n_timepoints = len(observed_admissions)
else:
raise ValueError(

Check warning on line 230 in model/src/pyrenew/model/admissionsmodel.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/model/admissionsmodel.py#L230

Added line #L230 was not covered by tests
"n_timepoints must be specified if observed_admissions is None"
)
elif observed_admissions is not None:
if n_timepoints != len(observed_admissions):
raise ValueError(

Check warning on line 235 in model/src/pyrenew/model/admissionsmodel.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/model/admissionsmodel.py#L234-L235

Added lines #L234 - L235 were not covered by tests
"n_timepoints must match length of observed_admissions"
)

# Getting the initial quantities from the basic model
basic_model = self.basic_renewal.sample(
n_timepoints=n_timepoints,
Expand Down
29 changes: 26 additions & 3 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
sampled_infections: ArrayLike | None = None

def __repr__(self):
return f"RtInfectionsRenewalSample(Rt={self.Rt}, latent_infections={self.latent_infections}, sampled_infections={self.sampled_infections})"
return (

Check warning on line 35 in model/src/pyrenew/model/rtinfectionsrenewalmodel.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/model/rtinfectionsrenewalmodel.py#L35

Added line #L35 was not covered by tests
f"RtInfectionsRenewalSample(Rt={self.Rt}, "
f"latent_infections={self.latent_infections}, "
f"sampled_infections={self.sampled_infections})"
)


class RtInfectionsRenewalModel(Model):
Expand Down Expand Up @@ -246,7 +250,7 @@

def sample(
self,
n_timepoints: int,
n_timepoints: int | None = None,
observed_infections: ArrayLike | None = None,
padding: int = 0,
**kwargs,
Expand All @@ -256,7 +260,7 @@

Parameters
----------
n_timepoints : int
n_timepoints : int, optional
Number of timepoints to sample.
observed_infections : ArrayLike | None, optional
Observed infections. Defaults to None.
Expand All @@ -267,11 +271,30 @@
Additional keyword arguments passed through to internal sample()
calls, if any

Notes
-----
When `observed_admissions` is None, `n_timepoints` must be specified.
If both are specified, they must have the same length, otherwise an
exception is raised.

Returns
-------
RtInfectionsRenewalSample
"""

if n_timepoints is None:
if observed_infections is not None:
n_timepoints = len(observed_infections)

Check warning on line 287 in model/src/pyrenew/model/rtinfectionsrenewalmodel.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/model/rtinfectionsrenewalmodel.py#L286-L287

Added lines #L286 - L287 were not covered by tests
else:
raise ValueError(

Check warning on line 289 in model/src/pyrenew/model/rtinfectionsrenewalmodel.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/model/rtinfectionsrenewalmodel.py#L289

Added line #L289 was not covered by tests
"n_timepoints or observed_infections must be provided."
)
elif observed_infections is not None:
if n_timepoints != len(observed_infections):
raise ValueError(

Check warning on line 294 in model/src/pyrenew/model/rtinfectionsrenewalmodel.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/model/rtinfectionsrenewalmodel.py#L294

Added line #L294 was not covered by tests
"n_timepoints and observed_infections must have the same length."
)

# Sampling from Rt (possibly with a given Rt, depending on
# the Rt_process (RandomVariable) object.)
Rt, *_ = self.sample_rt(
Expand Down
4 changes: 3 additions & 1 deletion model/src/pyrenew/process/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def sample(
Returns
-------
tuple
With a single array of shape (duration,).
"""
order = self.autoreg.shape[0]
if inits is None:
Expand All @@ -85,7 +86,8 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08
return new_carry, new_term

noise = numpyro.sample(
name + "_noise", dist.Normal(0, self.noise_sd).expand((duration,))
name + "_noise",
dist.Normal(0, self.noise_sd).expand((duration - inits.size,)),
)

last, ts = lax.scan(_ar_scanner, inits - self.mean, noise)
Expand Down
1 change: 1 addition & 0 deletions model/src/pyrenew/process/firstdifferencear.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def sample(
Returns
-------
tuple
With a single array of shape (duration,).
"""
rocs, *_ = self.rate_of_change_proc.sample(
duration, inits=init_rate_of_change, name=name + "_rate_of_change"
Expand Down
1 change: 1 addition & 0 deletions model/src/pyrenew/process/rtrandomwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def sample(
Returns
-------
tuple
With a single array of shape (duration,).
"""

Rt0 = npro.sample("Rt0", self.Rt0_dist)
Expand Down
3 changes: 2 additions & 1 deletion model/src/pyrenew/process/simplerandomwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ def sample(
Returns
-------
tuple
With a single array of shape (duration,).
"""

if init is None:
init = npro.sample(name + "_init", self.error_distribution)
diffs = npro.sample(
name + "_diffs", self.error_distribution.expand((duration,))
name + "_diffs", self.error_distribution.expand((duration - 1,))
)

return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),)
Expand Down
Binary file modified model/src/test/baseline/test_model_basicrenewal_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def test_deterministic():
jnp.array([0.25, 0.25, 0.2, 0.3]),
)
testing.assert_array_equal(
var3.sample(n_timepoints=5)[0],
var3.sample(duration=5)[0],
jnp.array([1, 2, 3, 4, 4]),
)

testing.assert_array_equal(
var3.sample(n_timepoints=3)[0],
var3.sample(duration=3)[0],
jnp.array(
[
1,
Expand Down
8 changes: 6 additions & 2 deletions model/src/test/test_first_difference_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ def test_fd_ar_can_be_sampled():
with numpyro.handlers.seed(rng_seed=62):
# can sample with and without inits
# for the rate of change
ar_fd.sample(3532, init_val=jnp.array([50.0]))
ar_fd.sample(
ans0 = ar_fd.sample(3532, init_val=jnp.array([50.0]))
ans1 = ar_fd.sample(
3532,
init_val=jnp.array([50.0]),
init_rate_of_change=jnp.array([0.25]),
)

# Checking proper shape
assert ans0[0].shape == (3532,)
assert ans1[0].shape == (3532,)
10 changes: 4 additions & 6 deletions model/src/test/test_model_hospitalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def test_model_hosp_with_obs_model():
num_samples=500,
rng_key=jax.random.PRNGKey(272),
observed_admissions=model1_samp.sampled_admissions,
n_timepoints=30,
)

inf = model1.spread_draws(["predicted_admissions"])
Expand Down Expand Up @@ -286,7 +285,6 @@ def test_model_hosp_with_obs_model_weekday_phosp_2():
num_samples=500,
rng_key=jax.random.PRNGKey(272),
observed_admissions=model1_samp.sampled_admissions,
n_timepoints=30,
)

inf = model1.spread_draws(["predicted_admissions"])
Expand All @@ -307,6 +305,7 @@ def test_model_hosp_with_obs_model_weekday_phosp():
"""

gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
n_obs_to_generate = 30

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down Expand Up @@ -343,15 +342,15 @@ def test_model_hosp_with_obs_model_weekday_phosp():
weekday = jnp.array([1, 1, 1, 1, 2, 2])
weekday = jnp.tile(weekday, 10)
weekday = weekday / weekday.sum()
weekday = weekday[:31]
weekday = weekday[:n_obs_to_generate]

weekday = DeterministicVariable(weekday)

hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4])
hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10)
hosp_report_prob_dist = hosp_report_prob_dist / hosp_report_prob_dist.sum()

hosp_report_prob_dist = hosp_report_prob_dist[:31]
hosp_report_prob_dist = hosp_report_prob_dist[:n_obs_to_generate]

hosp_report_prob_dist = DeterministicVariable(vars=hosp_report_prob_dist)

Expand All @@ -376,7 +375,7 @@ def test_model_hosp_with_obs_model_weekday_phosp():
# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_timepoints=30)
model1_samp = model1.sample(n_timepoints=n_obs_to_generate)

obs = jnp.hstack(
[jnp.repeat(jnp.nan, 5), model1_samp.sampled_admissions[5:]]
Expand All @@ -388,7 +387,6 @@ def test_model_hosp_with_obs_model_weekday_phosp():
num_samples=500,
rng_key=jax.random.PRNGKey(272),
observed_admissions=obs,
n_timepoints=30,
padding=5,
)

Expand Down
11 changes: 9 additions & 2 deletions model/src/test/test_random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ def test_rw_can_be_sampled():

with numpyro.handlers.seed(rng_seed=62):
# can sample with and without inits
rw_normal.sample(3532, init=jnp.array([50.0]))
rw_normal.sample(5023)
ans0 = rw_normal.sample(3532, init=jnp.array([50.0]))
ans1 = rw_normal.sample(5023)

# check that the samples are of the right shape
assert ans0[0].shape == (3532,)
assert ans1[0].shape == (5023,)


def test_rw_samples_correctly_distributed():
Expand All @@ -35,6 +39,9 @@ def test_rw_samples_correctly_distributed():
with numpyro.handlers.seed(rng_seed=62):
samples, *_ = rw_normal.sample(n_samples, init=init_arr)

# Checking the shape
assert samples.shape == (n_samples,)

# diffs should not be greater than
# 4 sigma
diffs = jnp.diff(samples)
Expand Down