diff --git a/model/src/pyrenew/convolve.py b/model/src/pyrenew/convolve.py index ae838d2d..44797805 100755 --- a/model/src/pyrenew/convolve.py +++ b/model/src/pyrenew/convolve.py @@ -7,9 +7,11 @@ calculating convolutions of timeseries with discrete distributions of times-to-event using -jax.lax.scan. Factories generate functions -that can be passed to scan with an -appropriate array to scan. +:py:func:`jax.lax.scan`. +Factories generate functions +that can be passed to +:py:func:`jax.lax.scan` with an +appropriate array to scan along. """ from __future__ import annotations @@ -19,59 +21,137 @@ from jax.typing import ArrayLike -def new_convolve_scanner(discrete_dist_flipped: ArrayLike) -> Callable: - """ - Factory function to create a scanner function for convolving a discrete distribution - over a time series data subset. +def new_convolve_scanner( + array_to_convolve: ArrayLike, + transform: Callable, +) -> Callable: + r""" + Factory function to create a "scanner" function + that can be used with :py:func:`jax.lax.scan` to + construct an array via backward-looking iterative + convolution. Parameters ---------- - discrete_dist_flipped : ArrayLike - A 1D jax array representing the discrete distribution flipped for convolution. + array_to_convolve : ArrayLike + A 1D jax array to convolve with subsets of the + iteratively constructed history array. + + transform : Callable + A transformation to apply to the result + of the dot product and multiplication. Returns ------- Callable - A scanner function that can be used with jax.lax.scan for convolution. - This function takes a history subset and a multiplier, computes the dot product, - then updates and returns the new history subset and the convolution result. + A scanner function that can be used with + :py:func:`jax.lax.scan` for convolution. + This function takes a history subset array and + a scalar, computes the dot product of + the supplied convolution array with the history + subset array, multiplies by the scalar, and + returns the resulting value and a new history subset + array formed by the 2nd-through-last entries + of the old history subset array followed by that same + resulting value. + + Notes + ----- + The following iterative operation is found often + in renewal processes: + + .. math:: + X(t) = f\left(m(t) \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ + \vdots{} \\ X(t - 1)\end{bmatrix} \cdot{} \mathbf{d} \right) + + Where :math:`\mathbf{d}` is a vector of length :math:`n`, + :math:`m(t)` is a scalar for each value of time :math:`t`, + and :math:`f` is a scalar-valued function. + + Given :math:`\mathbf{d}`, and optionally :math:`f`, + this factory function returns a new function that + peforms one step of this process while scanning along + an array of multipliers (i.e. an array + giving the values of :math:`m(t)`) using :py:func:`jax.lax.scan`. """ def _new_scanner( history_subset: ArrayLike, multiplier: float ) -> tuple[ArrayLike, float]: # numpydoc ignore=GL08 - new_val = multiplier * jnp.dot(discrete_dist_flipped, history_subset) + new_val = transform( + multiplier * jnp.dot(array_to_convolve, history_subset) + ) latest = jnp.hstack([history_subset[1:], new_val]) return latest, new_val return _new_scanner -def new_double_scanner( - dists: tuple[ArrayLike, ArrayLike], +def new_double_convolve_scanner( + arrays_to_convolve: tuple[ArrayLike, ArrayLike], transforms: tuple[Callable, Callable], ) -> Callable: - """ - Factory function to create a scanner function that applies two sequential transformations - and convolutions using two discrete distributions. + r""" + Factory function to create a scanner function + that iteratively constructs arrays by applying + the dot-product/multiply/transform operation + twice per history subset, with the first yielding + operation yielding an additional scalar multiplier + for the second. Parameters ---------- - dists : tuple[ArrayLike, ArrayLike] - A tuple of two 1D jax arrays, each representing a discrete distribution for the - two stages of convolution. + arrays_to_convolve : tuple[ArrayLike, ArrayLike] + A tuple of two 1D jax arrays, one for + each of the two stages of convolution. + The first entry in the arrays_to_convolve + tuple will be convolved with the + current history subset array first, the + the second entry will be convolved with + it second. transforms : tuple[Callable, Callable] - A tuple of two functions, each transforming the output of the dot product at each - convolution stage. + A tuple of two functions, each transforming the + output of the dot product at each + convolution stage. The first entry in the transforms + tuple will be applied first, then the second will + be applied. Returns ------- Callable - A scanner function that applies two sequential convolutions and transformations. - It takes a history subset and a tuple of multipliers, computes the transformations - and dot products, and returns the updated history subset and a tuple of new values. + A scanner function that applies two sets of + convolution, multiply, and transform operations + in sequence to construct a new array by scanning + along a pair of input arrays that are equal in + length to each other. + + Notes + ----- + Using the same notation as in the documentation for + :func:`new_convolve_scanner`, this function aids in + applying the iterative operation: + + .. math:: + \begin{aligned} + Y(t) &= f_1 \left(m_1(t) + \begin{bmatrix} + X(t - n) \\ + X(t - n + 1) \\ + \vdots{} \\ + X(t - 1) + \end{bmatrix} \cdot{} \mathbf{d}_1 \right) \\ \\ + X(t) &= f_2 \left( + m_2(t) Y(t) + \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ + \vdots{} \\ X(t - 1)\end{bmatrix} \cdot{} \mathbf{d}_2 \right) + \end{aligned} + + Where :math:`\mathbf{d}_1` and :math:`\mathbf{d}_2` are vectors of + length :math:`n`, :math:`m_1(t)` and :math:`m_2(t)` are scalars + for each value of time :math:`t`, and :math:`f_1` and :math:`f_2` + are scalar-valued functions. """ - d1, d2 = dists + arr1, arr2 = arrays_to_convolve t1, t2 = transforms def _new_scanner( @@ -79,8 +159,8 @@ def _new_scanner( multipliers: tuple[float, float], ) -> tuple[ArrayLike, tuple[float, float]]: # numpydoc ignore=GL08 m1, m2 = multipliers - m_net1 = t1(m1 * jnp.dot(d1, history_subset)) - new_val = t2(m2 * m_net1 * jnp.dot(d2, history_subset)) + m_net1 = t1(m1 * jnp.dot(arr1, history_subset)) + new_val = t2(m2 * m_net1 * jnp.dot(arr2, history_subset)) latest = jnp.hstack([history_subset[1:], new_val]) return latest, (new_val, m_net1) diff --git a/model/src/pyrenew/latent/infection_functions.py b/model/src/pyrenew/latent/infection_functions.py index 565e2d54..91622360 100755 --- a/model/src/pyrenew/latent/infection_functions.py +++ b/model/src/pyrenew/latent/infection_functions.py @@ -6,7 +6,8 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike -from pyrenew.convolve import new_convolve_scanner, new_double_scanner +from pyrenew.convolve import new_convolve_scanner, new_double_convolve_scanner +from pyrenew.transformation import ExpTransform, IdentityTransform def compute_infections_from_rt( @@ -40,7 +41,9 @@ def compute_infections_from_rt( ArrayLike The timeseries of infections, as a JAX array """ - incidence_func = new_convolve_scanner(reversed_generation_interval_pmf) + incidence_func = new_convolve_scanner( + reversed_generation_interval_pmf, IdentityTransform() + ) latest, all_infections = jax.lax.scan(f=incidence_func, init=I0, xs=Rt) @@ -75,6 +78,8 @@ def logistic_susceptibility_adjustment( float The adjusted value of I(t) + References + ---------- .. [1] Bhatt, Samir, et al. "Semi-mechanistic Bayesian modelling of COVID-19 with renewal processes." @@ -169,12 +174,12 @@ def compute_infections_from_rt_with_feedback( reductions in contact rate due to awareness of high incidence, et cetera. """ - feedback_scanner = new_double_scanner( - dists=( + feedback_scanner = new_double_convolve_scanner( + arrays_to_convolve=( reversed_infection_feedback_pmf, reversed_generation_interval_pmf, ), - transforms=(jnp.exp, lambda x: x), + transforms=(ExpTransform(), IdentityTransform()), ) latest, infs_and_R_adj = jax.lax.scan( f=feedback_scanner, diff --git a/model/src/test/test_convolve_scanners.py b/model/src/test/test_convolve_scanners.py new file mode 100644 index 00000000..068b165a --- /dev/null +++ b/model/src/test/test_convolve_scanners.py @@ -0,0 +1,54 @@ +""" +Unit tests for the iterative convolution +scanner function factories found in pyrenew.convolve +""" + +import jax +import jax.numpy as jnp +import numpy as np +import pyrenew.convolve as pc +from numpy.testing import assert_array_equal + + +def test_double_scanner_reduces_to_single(): + """ + Test that new_double_scanner() yields a function + that is equivalent to a single scanner if the first + scan is chosen appropriately + """ + inits = jnp.array([0.352, 5.2, -3]) + to_scan_a = jnp.array([0.5, 0.3, 0.2]) + + multipliers = jnp.array(np.random.normal(0, 0.5, size=500)) + + def transform_a(x: any): + """ + transformation associated with + array to_scan_a + + Parameters + ---------- + x: any + input value + + Returns + ------- + The result of 4 * x + 0.025, where x is the input + value + """ + return 4 * x + 0.025 + + scanner_a = pc.new_convolve_scanner(to_scan_a, transform_a) + + double_scanner_a = pc.new_double_convolve_scanner( + (jnp.array([523, 2, -0.5233]), to_scan_a), (lambda x: 1, transform_a) + ) + + _, result_a = jax.lax.scan(f=scanner_a, init=inits, xs=multipliers) + + _, result_a_double = jax.lax.scan( + f=double_scanner_a, init=inits, xs=(multipliers * 0.2352, multipliers) + ) + + assert_array_equal(result_a_double[1], jnp.ones_like(multipliers)) + assert_array_equal(result_a_double[0], result_a)