Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor convolve scanner factory functions #161

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
69f59b6
Shorten line lengths in convolve.py docstrings
dylanhmorris Jun 5, 2024
73f18ce
Add transform to new_convolve_scanner, improve documentation
dylanhmorris Jun 5, 2024
e77361a
Default transforms for new_double_scanner
dylanhmorris Jun 5, 2024
bdc31ea
Correct new_double_scanner docstring
dylanhmorris Jun 5, 2024
6a94335
Merge branch 'main' into 147-add-transforms-to-new_convolve_scanner-m…
dylanhmorris Jun 5, 2024
9a7c825
Add unit test that double scanner reduces to single, use pyrenew.tran…
dylanhmorris Jun 5, 2024
5f61459
Rename new_double_scanner to new_double_convolve_scanner for clarity
dylanhmorris Jun 6, 2024
8783086
precommit autofixes
dylanhmorris Jun 6, 2024
bf59672
Satisfy numpydoc precommit
dylanhmorris Jun 6, 2024
78204db
convert new_convolve_scanner docstring to raw and remove escapes for …
dylanhmorris Jun 6, 2024
b3e8d98
fix math rendering in documentation
damonbayer Jun 7, 2024
e93f414
Rename arguments to convolve scanners, improve docs, add arraylike va…
dylanhmorris Jun 10, 2024
671fc5a
Merge main into 147
dylanhmorris Jun 10, 2024
547c2b4
Autoformat changed source files
dylanhmorris Jun 10, 2024
df3f29d
Minor docstring/sphinx tweaks
dylanhmorris Jun 10, 2024
b92cd2d
Fix displaymath rendering errors and improve visual
dylanhmorris Jun 10, 2024
8c4a098
Add math notes to new_double_convolve_scanner
dylanhmorris Jun 10, 2024
3e5bf40
Merge branch 'main' into 147-add-transforms-to-new_convolve_scanner-m…
dylanhmorris Jun 11, 2024
2ed2c68
Remove defaults for convolve scanner factories
dylanhmorris Jun 11, 2024
c9e2376
Remove arrayutils call
dylanhmorris Jun 11, 2024
879be18
Remove validate arraylike
dylanhmorris Jun 11, 2024
9558dda
merge fix branch into PR branch
dylanhmorris Jun 11, 2024
1b1d9cf
Revert bug-introducting change in arrayutils
dylanhmorris Jun 11, 2024
b5178a6
Autoformat files
dylanhmorris Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions model/src/pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,110 @@

import jax.numpy as jnp
from jax.typing import ArrayLike
from pyrenew.transformation import IdentityTransform
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved


def new_convolve_scanner(discrete_dist_flipped: ArrayLike) -> Callable:
"""
Factory function to create a scanner function for convolving a discrete distribution
over a time series data subset.
def new_convolve_scanner(
discrete_dist_flipped: ArrayLike, transform: Callable = None
) -> Callable:
r"""
Factory function to create a "scanner" function
that can be used with :py:func:`jax.lax.scan` to
construct an array via backward-looking iterative
convolution.

Parameters
----------
discrete_dist_flipped : ArrayLike
A 1D jax array representing the discrete distribution flipped for convolution.
A 1D jax array representing the discrete
distribution flipped for convolution.

transform : Callable
A transformation to apply to the result
of the dot product and multiplication.
If None, use the identity transformation.
Default None.

Returns
-------
Callable
A scanner function that can be used with jax.lax.scan for convolution.
This function takes a history subset and a multiplier, computes the dot product,
then updates and returns the new history subset and the convolution result.
A scanner function that can be used with
jax.lax.scan for convolution.
This function takes a history subset and
a multiplier, computes the dot product,
then updates and returns the new history
subset and the convolution result.

Notes
-----
The following iterative operation is found often
in renewal processes:

.. math::
X(t) = f\left(m(t) * \left[X(t - n),
X(t - n + 1), ... X(t - 1)\right] \dot \vec{d} \right)

Where `math`:\vec{d} is a vector of length `math`:n,
`math`:m(t) is a scalar for each value of time `math`:t,
and `math`:f is a scalar-valued function.

Given `math`:\vec{d}, and optionally `math`:f,
this factory function returns a new function that
peforms one step of this process while scanning along
an array of multipliers (i.e. an array
giving the values of `math`:m(t)) using :py:func:jax.lax.scan.
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
"""
if transform is None:
transform = IdentityTransform()

def _new_scanner(
history_subset: ArrayLike, multiplier: float
) -> tuple[ArrayLike, float]: # numpydoc ignore=GL08
new_val = multiplier * jnp.dot(discrete_dist_flipped, history_subset)
new_val = transform(
multiplier * jnp.dot(discrete_dist_flipped, history_subset)
)
latest = jnp.hstack([history_subset[1:], new_val])
return latest, new_val

return _new_scanner


def new_double_scanner(
def new_double_convolve_scanner(
dists: tuple[ArrayLike, ArrayLike],
transforms: tuple[Callable, Callable],
transforms: tuple[Callable, Callable] = (None, None),
) -> Callable:
"""
Factory function to create a scanner function that applies two sequential transformations
and convolutions using two discrete distributions.
Factory function to create a scanner function
that iteratively constructs arrays by applying
the dot-product/multiply/transform operation
twice per history subset, with the first yielding
operation yielding an additional scalar multiplier
for the second.

Parameters
----------
dists : tuple[ArrayLike, ArrayLike]
A tuple of two 1D jax arrays, each representing a discrete distribution for the
A tuple of two 1D jax arrays, each representing a
discrete distribution for the
two stages of convolution.
transforms : tuple[Callable, Callable]
A tuple of two functions, each transforming the output of the dot product at each
convolution stage.
A tuple of two functions, each transforming the
output of the dot product at each
convolution stage. If either is None,
the identity transformation will be used
at that step. Default (None, None)

Returns
-------
Callable
A scanner function that applies two sequential convolutions and transformations.
It takes a history subset and a tuple of multipliers, computes the transformations
and dot products, and returns the updated history subset and a tuple of new values.
A scanner function that applies two sets of
convolution, multiply, and transform operations
in sequence to construct a new array by scanning
along a pair of input arrays that are equal in
length to each other.
"""
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
d1, d2 = dists
t1, t2 = transforms
t1, t2 = [x if x is not None else IdentityTransform() for x in transforms]

def _new_scanner(
history_subset: ArrayLike,
Expand Down
7 changes: 4 additions & 3 deletions model/src/pyrenew/latent/infection_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
from pyrenew.convolve import new_convolve_scanner, new_double_scanner
from pyrenew.convolve import new_convolve_scanner, new_double_convolve_scanner
from pyrenew.transformation import ExpTransform, IdentityTransform
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved


def compute_infections_from_rt(
Expand Down Expand Up @@ -169,12 +170,12 @@ def compute_infections_from_rt_with_feedback(
reductions in contact rate due to awareness of high incidence,
et cetera.
"""
feedback_scanner = new_double_scanner(
feedback_scanner = new_double_convolve_scanner(
dists=(
reversed_infection_feedback_pmf,
reversed_generation_interval_pmf,
),
transforms=(jnp.exp, lambda x: x),
transforms=(ExpTransform(), IdentityTransform()),
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
)
latest, infs_and_R_adj = jax.lax.scan(
f=feedback_scanner,
Expand Down
54 changes: 54 additions & 0 deletions model/src/test/test_convolve_scanners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Unit tests for the iterative convolution
scanner function factories found in pyrenew.convolve
"""

import jax
import jax.numpy as jnp
import numpy as np
import pyrenew.convolve as pc
from numpy.testing import assert_array_equal


def test_double_scanner_reduces_to_single():
"""
Test that new_double_scanner() yields a function
that is equivalent to a single scanner if the first
scan is chosen appropriately
"""
inits = jnp.array([0.352, 5.2, -3])
to_scan_a = jnp.array([0.5, 0.3, 0.2])

multipliers = jnp.array(np.random.normal(0, 0.5, size=500))

def transform_a(x: any):
"""
transformation associated with
array to_scan_a

Parameters
----------
x: any
input value

Returns
-------
The result of 4 * x + 0.025, where x is the input
value
"""
return 4 * x + 0.025

scanner_a = pc.new_convolve_scanner(to_scan_a, transform_a)

double_scanner_a = pc.new_double_convolve_scanner(
(jnp.array([523, 2, -0.5233]), to_scan_a), (lambda x: 1, transform_a)
)

_, result_a = jax.lax.scan(f=scanner_a, init=inits, xs=multipliers)

_, result_a_double = jax.lax.scan(
f=double_scanner_a, init=inits, xs=(multipliers * 0.2352, multipliers)
)

assert_array_equal(result_a_double[1], jnp.ones_like(multipliers))
assert_array_equal(result_a_double[0], result_a)
Loading