Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update infectionswithfeedback process #440

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a10b95c
add function to pad using edge values and tests
sbidari Sep 10, 2024
f89aa9d
modify convolve and infectionwithfeedback to work with 2d array
sbidari Sep 10, 2024
b7fb437
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 10, 2024
b0e8a3d
test convolve scanner for 2d arrays
sbidari Sep 11, 2024
1862aee
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 435-upd…
sbidari Sep 11, 2024
114af6a
add tests for 2d array, infectionsrtfeedback
sbidari Sep 11, 2024
a625fde
more tests
sbidari Sep 11, 2024
91be92f
remove pad to match functions
sbidari Sep 12, 2024
396c2d7
remove tests
sbidari Sep 12, 2024
8197059
use list comprehension in test_infectionsrtfeedback
sbidari Sep 12, 2024
97d19d0
revert to using for loop
sbidari Sep 12, 2024
d933df4
add more test for convolve scanner functions
sbidari Sep 12, 2024
39667c1
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 12, 2024
c210606
add check for initial infections and Rt ndims
sbidari Sep 12, 2024
5e91418
Merge branch '435-update-infectionswithfeedback-process-to-handle-bat…
sbidari Sep 12, 2024
f5557a3
remove typos and superfluous print statements
sbidari Sep 12, 2024
46e32e8
add more tests
sbidari Sep 12, 2024
728aad3
change value of inf_feedback in tests to cover diff scenarios
sbidari Sep 12, 2024
54c0527
add input array tests for infections.py
sbidari Sep 12, 2024
a6ce6c6
add test with plate
sbidari Sep 13, 2024
f37a229
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 13, 2024
cc61d0b
code review changes
sbidari Sep 13, 2024
a4a74f0
Merge branch '435-update-infectionswithfeedback-process-to-handle-bat…
sbidari Sep 13, 2024
0cd56ad
code review suggestion
sbidari Sep 13, 2024
0f3e2fd
replace jnp.dot with einsum
sbidari Sep 13, 2024
542ad84
code review suggestions
sbidari Sep 16, 2024
bafe3b0
code review changes
sbidari Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def sample(
**kwargs,
)
)

if inf_feedback_strength.ndim < Rt.ndim:
inf_feedback_strength = jnp.expand_dims(inf_feedback_strength, 0)
sbidari marked this conversation as resolved.
Show resolved Hide resolved
damonbayer marked this conversation as resolved.
Show resolved Hide resolved

# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.shape[0] == 1:
inf_feedback_strength = au.pad_edges_to_match(
Expand Down
79 changes: 79 additions & 0 deletions test/test_infection_and_infectionwithfeedback.py
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Test to verify Infection and InfectionsWithFeedback class
return error when input array shape for I0 and Rt are invalid
"""

import jax.numpy as jnp
import numpy as np
import numpyro
import pytest

import pyrenew.latent as latent
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable


def test_infections_with_feedback_invalid_inputs():
"""
Test the InfectionsWithFeedback class cannot
be sampled when Rt and I0 have invalid input shapes
"""
I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8])
I0_2d = jnp.array(
np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3)
).reshape((7, -1))
Rt = jnp.ones(10)
gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0])

inf_feed_strength = DeterministicVariable(
name="inf_feed_strength", value=0.5
)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int)

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
infection_feedback_strength=inf_feed_strength,
infection_feedback_pmf=inf_feedback_pmf,
)

infections = latent.Infections()

with numpyro.handlers.seed(rng_seed=0):
with pytest.raises(
ValueError,
match="Initial infections must be at least as long as the generation interval.",
):
InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
I0=I0_1d,
)

with pytest.raises(
ValueError,
match="Initial infections vector must be at least as long as the generation interval.",
):
infections(
gen_int=gen_int,
Rt=Rt,
I0=I0_1d,
)

with pytest.raises(
ValueError,
match="Initial infections and Rt must have the same dimensions.",
):
InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
I0=I0_2d,
)

with pytest.raises(
ValueError,
match="Initial infections and Rt must have the same dimensions.",
):
infections(
gen_int=gen_int,
Rt=Rt,
I0=I0_2d,
)
69 changes: 1 addition & 68 deletions test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_infectionsrtfeedback(Rt, I0):
# By doing the infection feedback strength 0, Rt = Rt_adjusted
# So infection should be equal in both
inf_feed_strength = DeterministicVariable(
name="inf_feed_strength", value=jnp.zeros_like(Rt)
name="inf_feed_strength", value=jnp.array(0)
)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int)

Expand Down Expand Up @@ -193,70 +193,3 @@ def test_infectionsrtfeedback_feedback(Rt, I0):
assert_array_almost_equal(samp1.rt, res["rt"])

return None


def test_infections_with_feedback_invalid_inputs():
"""
Test the InfectionsWithFeedback class cannot
be sampled when Rt and I0 have invalid input shapes
"""
I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8])
I0_2d = jnp.array(
np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3)
).reshape((7, -1))
Rt = jnp.ones(10)
gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0])

inf_feed_strength = DeterministicVariable(
name="inf_feed_strength", value=0.5
)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int)

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
infection_feedback_strength=inf_feed_strength,
infection_feedback_pmf=inf_feedback_pmf,
)

infections = latent.Infections()

with numpyro.handlers.seed(rng_seed=0):
with pytest.raises(
ValueError,
match="Initial infections must be at least as long as the generation interval.",
):
InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
I0=I0_1d,
)

with pytest.raises(
ValueError,
match="Initial infections vector must be at least as long as the generation interval.",
):
infections(
gen_int=gen_int,
Rt=Rt,
I0=I0_1d,
)

with pytest.raises(
ValueError,
match="Initial infections and Rt must have the same dimensions.",
):
InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
I0=I0_2d,
)

with pytest.raises(
ValueError,
match="Initial infections and Rt must have the same dimensions.",
):
infections(
gen_int=gen_int,
Rt=Rt,
I0=I0_2d,
)
44 changes: 44 additions & 0 deletions test/test_infectionwithfeedback_plate_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Test the InfectionsWithFeedback class works well within numpyro plate
"""

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

import pyrenew.latent as latent
from pyrenew.deterministic import DeterministicPMF
from pyrenew.randomvariable import DistributionalVariable


def test_infections_with_feedback_plate_compatibility():
"""
Test the InfectionsWithFeedback matching the Infections class.
"""
I0 = jnp.array(
np.array([0.0, 0.0, 0.0, 0.5, 0.6, 0.7, 0.8] * 5).reshape(-1, 5)
)
Rt = jnp.ones((10, 5))
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
gen_int = jnp.array([0.4, 0.25, 0.25, 0.1])

inf_feed_strength = DistributionalVariable(
"inf_feed_strength", dist.Beta(1, 1)
)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int)

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
infection_feedback_strength=inf_feed_strength,
infection_feedback_pmf=inf_feedback_pmf,
)

with numpyro.handlers.seed(rng_seed=0):
with numpyro.plate("test_plate", 5):
samp = InfectionsWithFeedback(
gen_int=gen_int,
Rt=Rt,
I0=I0,
)

assert samp.rt.shape == Rt.shape