From 5e149b120dd2f6b3379ca9eaaa85fd68843f6741 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 10 Sep 2024 11:28:16 -0500 Subject: [PATCH] Add default observation rate to `compute_delay_ascertained_incidence` (#437) * reorder arguments and add new test * add test description --- pyrenew/convolve.py | 3 ++- pyrenew/latent/hospitaladmissions.py | 2 +- test/test_incidence_observed_with_delay.py | 19 ++++++++++++++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 44d98e4c..55b2227f 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -11,6 +11,7 @@ :py:func:`jax.lax.scan` with an appropriate array to scan along. """ + from __future__ import annotations from typing import Callable @@ -166,9 +167,9 @@ def _new_scanner( def compute_delay_ascertained_incidence( - p_observed_given_incident: ArrayLike, latent_incidence: ArrayLike, delay_incidence_to_observation_pmf: ArrayLike, + p_observed_given_incident: ArrayLike = 1, ) -> ArrayLike: """ Computes incidences observed according diff --git a/pyrenew/latent/hospitaladmissions.py b/pyrenew/latent/hospitaladmissions.py index 57090528..1fcfc581 100644 --- a/pyrenew/latent/hospitaladmissions.py +++ b/pyrenew/latent/hospitaladmissions.py @@ -211,9 +211,9 @@ def sample( ) = self.infection_to_admission_interval_rv(**kwargs) latent_hospital_admissions = compute_delay_ascertained_incidence( - infection_hosp_rate.value, latent_infections.value, infection_to_admission_interval.value, + infection_hosp_rate.value, ) # Applying the day of the week effect. For this we need to: diff --git a/test/test_incidence_observed_with_delay.py b/test/test_incidence_observed_with_delay.py index 9bff64ad..bcb1ff66 100644 --- a/test/test_incidence_observed_with_delay.py +++ b/test/test_incidence_observed_with_delay.py @@ -34,6 +34,12 @@ jnp.array([0.25, 0.5, 0.25]), jnp.array([2]), ], + [ + jnp.array([1.0]), + jnp.array([0, 2.0, 4.0]), + jnp.array([0.25, 0.5, 0.25]), + jnp.array([2]), + ], ], ) def test(obs_rate, latent_incidence, delay_interval, expected_output): @@ -42,8 +48,19 @@ def test(obs_rate, latent_incidence, delay_interval, expected_output): incidence observed with a delay """ result = compute_delay_ascertained_incidence( - obs_rate, latent_incidence, delay_interval, + obs_rate, ) assert_array_equal(result, expected_output) + + +def test_default_obs_rate(): + """ + Compute incidence observed with a delay and default observation rate + """ + result = compute_delay_ascertained_incidence( + jnp.array([1.0, 2.0, 3.0]), + jnp.array([1.0]), + ) + assert_array_equal(result, jnp.array([1.0, 2.0, 3.0]))