Skip to content

Commit

Permalink
Add default observation rate to compute_delay_ascertained_incidence (
Browse files Browse the repository at this point in the history
…#437)

* reorder arguments and add new test

* add test description
  • Loading branch information
damonbayer authored Sep 10, 2024
1 parent 68c2b8d commit 5e149b1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
:py:func:`jax.lax.scan` with an
appropriate array to scan along.
"""

from __future__ import annotations

from typing import Callable
Expand Down Expand Up @@ -166,9 +167,9 @@ def _new_scanner(


def compute_delay_ascertained_incidence(
p_observed_given_incident: ArrayLike,
latent_incidence: ArrayLike,
delay_incidence_to_observation_pmf: ArrayLike,
p_observed_given_incident: ArrayLike = 1,
) -> ArrayLike:
"""
Computes incidences observed according
Expand Down
2 changes: 1 addition & 1 deletion pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def sample(
) = self.infection_to_admission_interval_rv(**kwargs)

latent_hospital_admissions = compute_delay_ascertained_incidence(
infection_hosp_rate.value,
latent_infections.value,
infection_to_admission_interval.value,
infection_hosp_rate.value,
)

# Applying the day of the week effect. For this we need to:
Expand Down
19 changes: 18 additions & 1 deletion test/test_incidence_observed_with_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
jnp.array([0.25, 0.5, 0.25]),
jnp.array([2]),
],
[
jnp.array([1.0]),
jnp.array([0, 2.0, 4.0]),
jnp.array([0.25, 0.5, 0.25]),
jnp.array([2]),
],
],
)
def test(obs_rate, latent_incidence, delay_interval, expected_output):
Expand All @@ -42,8 +48,19 @@ def test(obs_rate, latent_incidence, delay_interval, expected_output):
incidence observed with a delay
"""
result = compute_delay_ascertained_incidence(
obs_rate,
latent_incidence,
delay_interval,
obs_rate,
)
assert_array_equal(result, expected_output)


def test_default_obs_rate():
"""
Compute incidence observed with a delay and default observation rate
"""
result = compute_delay_ascertained_incidence(
jnp.array([1.0, 2.0, 3.0]),
jnp.array([1.0]),
)
assert_array_equal(result, jnp.array([1.0, 2.0, 3.0]))

0 comments on commit 5e149b1

Please sign in to comment.