Skip to content

Commit

Permalink
Fixing broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed May 23, 2024
1 parent b63acbe commit 219e710
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 8 deletions.
9 changes: 6 additions & 3 deletions model/src/pyrenew/latent/i0.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# numpydoc ignore=GL08


import jax.numpy as jnp
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.metaclass import RandomVariable
Expand Down Expand Up @@ -80,8 +81,10 @@ def sample(
Tuple with the initial infections.
"""
return (
npro.sample(
name=self.name,
fn=self.i0_dist,
jnp.atleast_1d(
npro.sample(
name=self.name,
fn=self.i0_dist,
)
),
)
4 changes: 4 additions & 0 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.datautils as du
from numpy.typing import ArrayLike
from pyrenew.deterministic import NullObservation
from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype
Expand Down Expand Up @@ -282,6 +283,9 @@ def sample(
# Sampling initial infections
i0, *_ = self.sample_i0(**kwargs)

# Padding i0 to match gen_int
i0 = du.pad_x_to_match_y(x=i0, y=gen_int, fill_value=0.0)

# Sampling from the latent process
latent, *_ = self.sample_infections_latent(
Rt=Rt,
Expand Down
Binary file modified model/src/test/baseline/test_model_basicrenewal_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ def test_infections_as_deterministic():

inf1 = Infections()

obs = dict(
Rt=sim_rt,
I0=jnp.repeat(0, repeats=gen_int.size),
gen_int=gen_int,
)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
obs = dict(Rt=sim_rt, I0=10, gen_int=gen_int)
inf_sampled1 = inf1.sample(**obs)
inf_sampled2 = inf1.sample(**obs)

Expand Down
24 changes: 23 additions & 1 deletion model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,26 @@ def test_model_basicrenewal_with_obs_model():


@pytest.mark.mpl_image_compare
def test_model_basicrenewal_plot() -> plt.Figure: # numpydoc ignore=GL08
def test_model_basicrenewal_plot() -> plt.Figure:
"""
Check that the posterior sample looks the same (reproducibility)
Returns
-------
plt.Figure
The figure object
Notes
-----
IMPORTANT: If this test breaks, then it could be that you need
to regenerate the figures. To do so, you can the test using the following
command:
poetry run pytest --mpl-generate-path=src/test/baseline
This will skip validating the figure and save the new figure in the
`src/test/baseline` folder.
"""
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))

I0 = Infections0(I0_dist=dist.LogNormal(0, 1))
Expand Down Expand Up @@ -214,3 +233,6 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08
# For now the assertion is only about the expected number of rows
# It should be about the MCMC inference.
assert inf_mean.to_numpy().shape[0] == 500


test_model_basicrenewal_plot()
3 changes: 0 additions & 3 deletions model/src/test/test_model_hospitalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,3 @@ def test_model_hosp_with_obs_model_weekday_phosp():
# For now the assertion is only about the expected number of rows
# It should be about the MCMC inference.
assert inf_mean.to_numpy().shape[0] == 500


test_model_hosp_no_obs_model()

0 comments on commit 219e710

Please sign in to comment.