Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First Pass Documentation (docstrings) For Top Level Pyrenew Files #89

Merged
merged 13 commits into from
Apr 22, 2024
70 changes: 65 additions & 5 deletions model/src/pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,88 @@
jax.lax.scan. Factories generate functions
that can be passed to scan with an
appropriate array to scan.

Notes
-----
TODO: Look into adding blocks for Functions and Examples in this
docstring.
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
"""
from typing import Callable, Tuple

import jax.numpy as jnp
from jax.typing import ArrayLike


def new_convolve_scanner(discrete_dist_flipped: ArrayLike) -> Callable:
"""
Factory function to create a scanner function for convolving a discrete distribution
over a time series data subset.

Parameters
----------
discrete_dist_flipped : ArrayLike
A 1D jax array representing the discrete distribution flipped for convolution.

Returns
-------
Callable
A scanner function that can be used with jax.lax.scan for convolution.
This function takes a history subset and a multiplier, computes the dot product,
then updates and returns the new history subset and the convolution result.

def new_convolve_scanner(discrete_dist_flipped):
def _new_scanner(history_subset, multiplier):
Notes
-----
TODO: Add Example.
TODO: Clarification on Returns description.
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
"""

def _new_scanner(
history_subset: ArrayLike, multiplier: float
) -> Tuple[ArrayLike, float]:
new_val = multiplier * jnp.dot(discrete_dist_flipped, history_subset)
latest = jnp.hstack([history_subset[1:], new_val])
return latest, new_val

return _new_scanner


def new_double_scanner(dists, transforms):
def new_double_scanner(
dists: Tuple[ArrayLike, ArrayLike], transforms: Tuple[Callable, Callable]
) -> Callable:
"""
Factory function to create a scanner function that applies two sequential transformations
and convolutions using two discrete distributions.

Parameters
----------
dists : Tuple[ArrayLike, ArrayLike]
A tuple of two 1D jax arrays, each representing a discrete distribution for the
two stages of convolution.
transforms : Tuple[Callable, Callable]
A tuple of two functions, each transforming the output of the dot product at each
convolution stage.

Returns
-------
Callable
A scanner function that applies two sequential convolutions and transformations.
It takes a history subset and a tuple of multipliers, computes the transformations
and dot products, and returns the updated history subset and a tuple of new values.

Notes
-----
TODO: Add Example
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
"""
d1, d2 = dists
t1, t2 = transforms

def _new_scanner(history_subset, multipliers):
def _new_scanner(

Check warning on line 88 in model/src/pyrenew/convolve.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/convolve.py#L88

Added line #L88 was not covered by tests
history_subset: jnp.ndarray, multipliers: Tuple[float, float]
) -> (jnp.ndarray, Tuple[float, float]):
m1, m2 = multipliers
m_net1 = t1(m1 * jnp.dot(d1, history_subset))
new_val = t2(m2 * m_net1 * jnp.dot(d2, history_subset))
latest = jnp.hstack([history_subset[1:], new_val])
return (latest, (new_val, m_net1))
return latest, (new_val, m_net1)

Check warning on line 95 in model/src/pyrenew/convolve.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/convolve.py#L95

Added line #L95 was not covered by tests

return _new_scanner
38 changes: 35 additions & 3 deletions model/src/pyrenew/distutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,37 @@
such as discrete time-to-event distributions
"""
import jax.numpy as jnp
from jax.typing import ArrayLike


def validate_discrete_dist_vector(
discrete_dist: jnp.ndarray, tol: float = 1e-20
) -> bool:
discrete_dist: ArrayLike, tol: float = 1e-20
) -> ArrayLike:
"""
Validate that a vector represents a discrete
probability distribution to within a specified
tolerance, raising a ValueError if not.

Parameters
----------
discrete_dist : ArrayLike
An jax array containing non-negative values that
represent a discrete probability distribution. The values
must sum to 1 within the specified tolerance.
tol : float, optional
The tolerance within which the sum of the distribution must
be 1. Defaults to 1e-20.

Returns
-------
ArrayLike
The normalized distribution array if the input is valid.

Raises
------
ValueError
If any value in discrete_dist is negative or if the sum of the
distribution does not equal 1 within the specified tolerance.
"""
discrete_dist = discrete_dist.flatten()
if not jnp.all(discrete_dist >= 0):
Expand All @@ -39,10 +61,20 @@ def validate_discrete_dist_vector(
return discrete_dist / dist_norm


def reverse_discrete_dist_vector(dist):
def reverse_discrete_dist_vector(dist: ArrayLike) -> ArrayLike:
"""
Reverse a discrete distribution
vector (useful for discrete
time-to-event distributions).

Parameters
----------
dist : ArrayLike
A discrete distribution vector (likely discrete time-to-event distribution)

Returns
-------
ArrayLike
A reversed (jnp.flip) discrete distribution vector
"""
return jnp.flip(dist)
53 changes: 35 additions & 18 deletions model/src/pyrenew/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
a given renewal process.
"""


import jax.numpy as jnp
from jax.typing import ArrayLike
from pyrenew.distutil import validate_discrete_dist_vector


def get_leslie_matrix(R, generation_interval_pmf):
def get_leslie_matrix(R: float, generation_interval_pmf: ArrayLike) -> float:
"""
Create the Leslie matrix
corresponding to a basic
Expand All @@ -28,7 +30,7 @@ def get_leslie_matrix(R, generation_interval_pmf):
mass vector of the renewal process

Returns
--------
-------
The Leslie matrix for the
renewal process, as a jax array.
"""
Expand All @@ -44,14 +46,17 @@ def get_leslie_matrix(R, generation_interval_pmf):
return jnp.vstack([R * generation_interval_pmf, aging_matrix])


def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf):
def get_asymptotic_growth_rate_and_age_dist(
R: float, generation_interval_pmf: ArrayLike
) -> tuple[float, ArrayLike]:
"""
Get the asymptotic per-timestep growth
rate of the renewal process (the dominant
eigenvalue of its Leslie matrix) and the
associated stable age distribution
(a normalized eigenvector associated to
that eigenvalue).

Parameters
----------
R : float
Expand All @@ -61,11 +66,17 @@ def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf):
mass vector of the renewal process

Returns
--------
A tuple consisting of the asymptotic growth rate of
the process, as jax float, and the stable age distribution
of the process, as a jax array probability vector of the
same shape as the generation interval probability vector.
-------
tuple[float, ArrayLike]
A tuple consisting of the asymptotic growth rate of
the process, as jax float, and the stable age distribution
of the process, as a jax array probability vector of the
same shape as the generation interval probability vector.

Raises
------
ValueError
If an age distribution vector with non-zero imaginary part is produced.
"""
L = get_leslie_matrix(R, generation_interval_pmf)
eigenvals, eigenvecs = jnp.linalg.eig(L)
Expand All @@ -92,7 +103,9 @@ def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf):
return d_val_real, d_vec_norm


def get_stable_age_distribution(R, generation_interval_pmf):
def get_stable_age_distribution(
R: float, generation_interval_pmf: ArrayLike
) -> ArrayLike:
"""
Get the stable age distribution for a
renewal process with a given value of
Expand All @@ -114,18 +127,21 @@ def get_stable_age_distribution(R, generation_interval_pmf):
mass vector of the renewal process

Returns
--------
The stable age distribution for the
process, as a jax array probability vector of
the same shape as the generation interval
probability vector.
-------
ArrayLike
The stable age distribution for the
process, as a jax array probability vector of
the same shape as the generation interval
probability vector.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
1
]


def get_asymptotic_growth_rate(R, generation_interval_pmf):
def get_asymptotic_growth_rate(
R: float, generation_interval_pmf: ArrayLike
) -> float:
"""
Get the asymptotic per timestep growth rate
for a renewal process with a given value of
Expand All @@ -145,9 +161,10 @@ def get_asymptotic_growth_rate(R, generation_interval_pmf):
mass vector of the renewal process

Returns
--------
The asymptotic growth rate of the renewal process,
as a jax float.
-------
float
The asymptotic growth rate of the renewal process,
as a jax float.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
0
Expand Down
5 changes: 3 additions & 2 deletions model/src/pyrenew/mcmcutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def spread_draws(
posteriors: dict,
variables_names: list,
variables_names: list[str] | list[tuple],
) -> pl.DataFrame:
"""Get nicely shaped draws from the posterior

Expand All @@ -29,7 +29,8 @@ def spread_draws(

Returns
-------
polars.DataFrame
pl.DataFrame
A dataframe of draw-indexed
"""

for i_var, v in enumerate(variables_names):
Expand Down
49 changes: 42 additions & 7 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@ def _assert_sample_and_rtype(
) -> None:
"""Return type-checking for RandomVariable's sample function

Objects passed as `RandomVariable` should (a) have a `sample()` method that
Objects passed as `RandomVariable` should (a) have a sample() method that
(b) returns either a tuple or a named tuple.

Parameters
----------
rp : RandomVariable
Random variable to check.
skip_if_none: bool
When `True` it returns if `rp` is None.
skip_if_none: bool, optional
When `True` it returns if `rp` is None. Defaults to True

Returns
-------
None

Raises
------
Exception
If rp is not a RandomVariable, does not have a sample function, or
does not return a tuple. Also occurs if rettype does not initialized
properly.
"""

# Addressing the None case
Expand Down Expand Up @@ -101,7 +108,7 @@ def sample(
----------
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
calls, if any
calls, should there be any.

Notes
-----
Expand All @@ -117,6 +124,9 @@ def sample(
@staticmethod
@abstractmethod
def validate(**kwargs) -> None:
"""
Validation of kwargs to be implemented in subclasses.
"""
pass


Expand Down Expand Up @@ -149,7 +159,7 @@ def sample(
----------
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
calls, if any
calls, should there be any.

Notes
-----
Expand Down Expand Up @@ -244,9 +254,34 @@ def print_summary(
prob: float = 0.9,
exclude_deterministic: bool = True,
) -> None:
"""A wrapper of MCMC.print_summary"""
"""
A wrapper of MCMC.print_summary

Parameters
----------
prob : float, optional
The acceptance probability of print_summary. Defaults to 0.9
exclude_deterministic : bool, optional.
Whether to print deterministic variables in the summary.
Defaults to True.

Returns
-------
None
"""
return self.mcmc.print_summary(prob, exclude_deterministic)

def spread_draws(self, variables_names: list) -> pl.DataFrame:
"""A wrapper of mcmcutils.spread_draws"""
"""A wrapper of mcmcutils.spread_draws

Parameters
----------
variables_names : list
A list of variable names to create a table of samples.

Returns
-------
pl.DataFrame
"""

return spread_draws(self.mcmc.get_samples(), variables_names)
Loading