diff --git a/model/src/pyrenew/convolve.py b/model/src/pyrenew/convolve.py index 6c894c20..2dc869c9 100755 --- a/model/src/pyrenew/convolve.py +++ b/model/src/pyrenew/convolve.py @@ -10,12 +10,44 @@ jax.lax.scan. Factories generate functions that can be passed to scan with an appropriate array to scan. + +Notes +----- +TODO: Look into adding blocks for Functions and Examples in this +docstring. """ +from typing import Callable, Tuple + import jax.numpy as jnp +from jax.typing import ArrayLike + + +def new_convolve_scanner(discrete_dist_flipped: ArrayLike) -> Callable: + """ + Factory function to create a scanner function for convolving a discrete distribution + over a time series data subset. + + Parameters + ---------- + discrete_dist_flipped : ArrayLike + A 1D jax array representing the discrete distribution flipped for convolution. + Returns + ------- + Callable + A scanner function that can be used with jax.lax.scan for convolution. + This function takes a history subset and a multiplier, computes the dot product, + then updates and returns the new history subset and the convolution result. -def new_convolve_scanner(discrete_dist_flipped): - def _new_scanner(history_subset, multiplier): + Notes + ----- + TODO: Add Example. + TODO: Clarification on Returns description. + """ + + def _new_scanner( + history_subset: ArrayLike, multiplier: float + ) -> Tuple[ArrayLike, float]: new_val = multiplier * jnp.dot(discrete_dist_flipped, history_subset) latest = jnp.hstack([history_subset[1:], new_val]) return latest, new_val @@ -23,15 +55,43 @@ def _new_scanner(history_subset, multiplier): return _new_scanner -def new_double_scanner(dists, transforms): +def new_double_scanner( + dists: Tuple[ArrayLike, ArrayLike], transforms: Tuple[Callable, Callable] +) -> Callable: + """ + Factory function to create a scanner function that applies two sequential transformations + and convolutions using two discrete distributions. + + Parameters + ---------- + dists : Tuple[ArrayLike, ArrayLike] + A tuple of two 1D jax arrays, each representing a discrete distribution for the + two stages of convolution. + transforms : Tuple[Callable, Callable] + A tuple of two functions, each transforming the output of the dot product at each + convolution stage. + + Returns + ------- + Callable + A scanner function that applies two sequential convolutions and transformations. + It takes a history subset and a tuple of multipliers, computes the transformations + and dot products, and returns the updated history subset and a tuple of new values. + + Notes + ----- + TODO: Add Example + """ d1, d2 = dists t1, t2 = transforms - def _new_scanner(history_subset, multipliers): + def _new_scanner( + history_subset: jnp.ndarray, multipliers: Tuple[float, float] + ) -> (jnp.ndarray, Tuple[float, float]): m1, m2 = multipliers m_net1 = t1(m1 * jnp.dot(d1, history_subset)) new_val = t2(m2 * m_net1 * jnp.dot(d2, history_subset)) latest = jnp.hstack([history_subset[1:], new_val]) - return (latest, (new_val, m_net1)) + return latest, (new_val, m_net1) return _new_scanner diff --git a/model/src/pyrenew/distutil.py b/model/src/pyrenew/distutil.py index a718b56b..315b3405 100755 --- a/model/src/pyrenew/distutil.py +++ b/model/src/pyrenew/distutil.py @@ -9,15 +9,37 @@ such as discrete time-to-event distributions """ import jax.numpy as jnp +from jax.typing import ArrayLike def validate_discrete_dist_vector( - discrete_dist: jnp.ndarray, tol: float = 1e-20 -) -> bool: + discrete_dist: ArrayLike, tol: float = 1e-20 +) -> ArrayLike: """ Validate that a vector represents a discrete probability distribution to within a specified tolerance, raising a ValueError if not. + + Parameters + ---------- + discrete_dist : ArrayLike + An jax array containing non-negative values that + represent a discrete probability distribution. The values + must sum to 1 within the specified tolerance. + tol : float, optional + The tolerance within which the sum of the distribution must + be 1. Defaults to 1e-20. + + Returns + ------- + ArrayLike + The normalized distribution array if the input is valid. + + Raises + ------ + ValueError + If any value in discrete_dist is negative or if the sum of the + distribution does not equal 1 within the specified tolerance. """ discrete_dist = discrete_dist.flatten() if not jnp.all(discrete_dist >= 0): @@ -39,10 +61,20 @@ def validate_discrete_dist_vector( return discrete_dist / dist_norm -def reverse_discrete_dist_vector(dist): +def reverse_discrete_dist_vector(dist: ArrayLike) -> ArrayLike: """ Reverse a discrete distribution vector (useful for discrete time-to-event distributions). + + Parameters + ---------- + dist : ArrayLike + A discrete distribution vector (likely discrete time-to-event distribution) + + Returns + ------- + ArrayLike + A reversed (jnp.flip) discrete distribution vector """ return jnp.flip(dist) diff --git a/model/src/pyrenew/math.py b/model/src/pyrenew/math.py index 4a4d59de..91b35284 100755 --- a/model/src/pyrenew/math.py +++ b/model/src/pyrenew/math.py @@ -6,11 +6,13 @@ a given renewal process. """ + import jax.numpy as jnp +from jax.typing import ArrayLike from pyrenew.distutil import validate_discrete_dist_vector -def get_leslie_matrix(R, generation_interval_pmf): +def get_leslie_matrix(R: float, generation_interval_pmf: ArrayLike) -> float: """ Create the Leslie matrix corresponding to a basic @@ -28,7 +30,7 @@ def get_leslie_matrix(R, generation_interval_pmf): mass vector of the renewal process Returns - -------- + ------- The Leslie matrix for the renewal process, as a jax array. """ @@ -44,7 +46,9 @@ def get_leslie_matrix(R, generation_interval_pmf): return jnp.vstack([R * generation_interval_pmf, aging_matrix]) -def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf): +def get_asymptotic_growth_rate_and_age_dist( + R: float, generation_interval_pmf: ArrayLike +) -> tuple[float, ArrayLike]: """ Get the asymptotic per-timestep growth rate of the renewal process (the dominant @@ -52,6 +56,7 @@ def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf): associated stable age distribution (a normalized eigenvector associated to that eigenvalue). + Parameters ---------- R : float @@ -61,11 +66,17 @@ def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf): mass vector of the renewal process Returns - -------- - A tuple consisting of the asymptotic growth rate of - the process, as jax float, and the stable age distribution - of the process, as a jax array probability vector of the - same shape as the generation interval probability vector. + ------- + tuple[float, ArrayLike] + A tuple consisting of the asymptotic growth rate of + the process, as jax float, and the stable age distribution + of the process, as a jax array probability vector of the + same shape as the generation interval probability vector. + + Raises + ------ + ValueError + If an age distribution vector with non-zero imaginary part is produced. """ L = get_leslie_matrix(R, generation_interval_pmf) eigenvals, eigenvecs = jnp.linalg.eig(L) @@ -92,7 +103,9 @@ def get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf): return d_val_real, d_vec_norm -def get_stable_age_distribution(R, generation_interval_pmf): +def get_stable_age_distribution( + R: float, generation_interval_pmf: ArrayLike +) -> ArrayLike: """ Get the stable age distribution for a renewal process with a given value of @@ -114,18 +127,21 @@ def get_stable_age_distribution(R, generation_interval_pmf): mass vector of the renewal process Returns - -------- - The stable age distribution for the - process, as a jax array probability vector of - the same shape as the generation interval - probability vector. + ------- + ArrayLike + The stable age distribution for the + process, as a jax array probability vector of + the same shape as the generation interval + probability vector. """ return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[ 1 ] -def get_asymptotic_growth_rate(R, generation_interval_pmf): +def get_asymptotic_growth_rate( + R: float, generation_interval_pmf: ArrayLike +) -> float: """ Get the asymptotic per timestep growth rate for a renewal process with a given value of @@ -145,9 +161,10 @@ def get_asymptotic_growth_rate(R, generation_interval_pmf): mass vector of the renewal process Returns - -------- - The asymptotic growth rate of the renewal process, - as a jax float. + ------- + float + The asymptotic growth rate of the renewal process, + as a jax float. """ return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[ 0 diff --git a/model/src/pyrenew/mcmcutils.py b/model/src/pyrenew/mcmcutils.py index 3b081958..733d7c75 100644 --- a/model/src/pyrenew/mcmcutils.py +++ b/model/src/pyrenew/mcmcutils.py @@ -10,7 +10,7 @@ def spread_draws( posteriors: dict, - variables_names: list, + variables_names: list[str] | list[tuple], ) -> pl.DataFrame: """Get nicely shaped draws from the posterior @@ -29,7 +29,8 @@ def spread_draws( Returns ------- - polars.DataFrame + pl.DataFrame + A dataframe of draw-indexed """ for i_var, v in enumerate(variables_names): diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index d00c5ace..5b51584d 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -17,19 +17,26 @@ def _assert_sample_and_rtype( ) -> None: """Return type-checking for RandomVariable's sample function - Objects passed as `RandomVariable` should (a) have a `sample()` method that + Objects passed as `RandomVariable` should (a) have a sample() method that (b) returns either a tuple or a named tuple. Parameters ---------- rp : RandomVariable Random variable to check. - skip_if_none: bool - When `True` it returns if `rp` is None. + skip_if_none: bool, optional + When `True` it returns if `rp` is None. Defaults to True Returns ------- None + + Raises + ------ + Exception + If rp is not a RandomVariable, does not have a sample function, or + does not return a tuple. Also occurs if rettype does not initialized + properly. """ # Addressing the None case @@ -101,7 +108,7 @@ def sample( ---------- **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` - calls, if any + calls, should there be any. Notes ----- @@ -117,6 +124,9 @@ def sample( @staticmethod @abstractmethod def validate(**kwargs) -> None: + """ + Validation of kwargs to be implemented in subclasses. + """ pass @@ -149,7 +159,7 @@ def sample( ---------- **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` - calls, if any + calls, should there be any. Notes ----- @@ -244,9 +254,34 @@ def print_summary( prob: float = 0.9, exclude_deterministic: bool = True, ) -> None: - """A wrapper of MCMC.print_summary""" + """ + A wrapper of MCMC.print_summary + + Parameters + ---------- + prob : float, optional + The acceptance probability of print_summary. Defaults to 0.9 + exclude_deterministic : bool, optional. + Whether to print deterministic variables in the summary. + Defaults to True. + + Returns + ------- + None + """ return self.mcmc.print_summary(prob, exclude_deterministic) def spread_draws(self, variables_names: list) -> pl.DataFrame: - """A wrapper of mcmcutils.spread_draws""" + """A wrapper of mcmcutils.spread_draws + + Parameters + ---------- + variables_names : list + A list of variable names to create a table of samples. + + Returns + ------- + pl.DataFrame + """ + return spread_draws(self.mcmc.get_samples(), variables_names) diff --git a/model/src/pyrenew/regression.py b/model/src/pyrenew/regression.py index 67772bc6..d3e6789e 100755 --- a/model/src/pyrenew/regression.py +++ b/model/src/pyrenew/regression.py @@ -78,11 +78,11 @@ def __init__( If `None`, use an identity transform. Default `None`. - intercept_suffix : str + intercept_suffix : str, optional Suffix for naming the intercept random variable in class to numpyro.sample(). Default `"_intercept"`. - coefficient_suffix : str + coefficient_suffix : str, optional Suffix for naming the regression coefficient random variables in calls to numpyro.sample(). Default `"_coefficients"`. @@ -99,12 +99,38 @@ class to numpyro.sample(). Default `"_intercept"`. self.coefficient_suffix = coefficient_suffix def predict(self, intercept, coefficients): + """ + Generates a transformed prediction w/ intercept, coefficients, and + fixed predictor values + + Parameters + ---------- + intercept : ArrayLike + Sampled numpyro distribution generated from intercept priors. + coefficients : ArrayLike + Sampled prediction coefficients distribution generated + from coefficients priors. + + Returns + ------- + ArrayLike + Array of transformed predictions. + """ transformed_prediction = ( intercept + self.fixed_predictor_values @ coefficients ) return self.transform.inverse(transformed_prediction) def sample(self): + """ + Sample generalized linear model + + Returns + ------- + dict + A dictionary containing transformed predictions, and + the intercept and coefficients sample distributions. + """ intercept = numpyro.sample( self.name + self.intercept_suffix, self.intercept_prior ) diff --git a/model/src/pyrenew/transform.py b/model/src/pyrenew/transform.py index 0701eb2f..bca9a792 100755 --- a/model/src/pyrenew/transform.py +++ b/model/src/pyrenew/transform.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp +from jax.typing import ArrayLike class AbstractTransform(metaclass=ABCMeta): @@ -19,10 +20,16 @@ def __call__(self, x): @abstractmethod def transform(self, x): + """ + Transform generated predictions + """ pass @abstractmethod def inverse(self, x): + """ + Take the inverse of transformed predictions + """ pass @@ -35,10 +42,32 @@ class IdentityTransform(AbstractTransform): f^-1(x) = x """ - def transform(self, x): + def transform(self, x: any): + """ + Parameters + ---------- + x : any + Input, usually ArrayLike + + Returns + ------- + any + The same object that was inputted. + """ return x - def inverse(self, x): + def inverse(self, x: any): + """ + Parameters + ---------- + x : any + Input, usually ArrayLike + + Returns + ------- + any + The same object that was inputted. + """ return x @@ -51,10 +80,32 @@ class LogTransform(AbstractTransform): f^-1(x) = exp(x) """ - def transform(self, x): + def transform(self, x: ArrayLike): + """ + Parameters + ---------- + x : ArrayLike + Input, usually predictions array.. + + Returns + ------- + ArrayLike + Log-transformed input + """ return jnp.log(x) - def inverse(self, x): + def inverse(self, x: ArrayLike): + """ + Parameters + ---------- + x : ArrayLike + Input, usually log-scale predictions array. + + Returns + ------- + ArrayLike + Exponentiated input + """ return jnp.exp(x) @@ -68,10 +119,32 @@ class LogitTransform(AbstractTransform): f^-1(x) = 1 / (1 + exp(-x)) """ - def transform(self, x): + def transform(self, x: ArrayLike): + """ + Parameters + ---------- + x : ArrayLike + Input, usually predictions array. + + Returns + ------- + ArrayLike + Logit transformed input. + """ return jax.scipy.special.logit(x) - def inverse(self, x): + def inverse(self, x: ArrayLike): + """ + Parameters + ---------- + x : ArrayLike + Input, usually logit-transformed predictions array. + + Returns + ------- + ArrayLike + Inversed logit transformed input. + """ return jax.scipy.special.expit(x) @@ -97,8 +170,30 @@ def __init__(self, x_max: float): """ self.x_max = x_max - def transform(self, x): + def transform(self, x: ArrayLike): + """ + Parameters + ---------- + x : ArrayLike + Input, usually predictions array. + + Returns + ------- + ArrayLike + x_max scaled logit transformed input. + """ return jax.scipy.special.logit(x / self.x_max) - def inverse(self, x): + def inverse(self, x: ArrayLike): + """ + Parameters + ---------- + x : ArrayLike + Input, usually scaled logit predictions array. + + Returns + ------- + ArrayLike + Inverse of x_max scaled logit transformed input. + """ return self.x_max * jax.scipy.special.expit(x)