From 087467ef625f33d94ad39ca0b04a38bf32d9218b Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 16 Aug 2024 15:52:42 -0400 Subject: [PATCH 1/4] Split distributional RVs into static and dynamic --- model/src/pyrenew/metaclass.py | 198 ++++++++++++++++++++++++++++++--- 1 file changed, 182 insertions(+), 16 deletions(-) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index b4ac6529..6dd441fe 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -5,7 +5,7 @@ """ from abc import ABCMeta, abstractmethod -from typing import NamedTuple, get_type_hints +from typing import Callable, NamedTuple, get_type_hints import jax import jax.numpy as jnp @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt import numpy as np import numpyro +import numpyro.distributions as dist import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive @@ -126,7 +127,7 @@ def _assert_sample_and_rtype( class SampledValue(NamedTuple): """ - A container for a sampled value from a RandomVariable. + A container for a value sampled from a RandomVariable. Attributes ---------- @@ -135,7 +136,8 @@ class SampledValue(NamedTuple): t_start : int, optional The start time of the value. t_unit : int, optional - The unit of time relative to the model's fundamental (smallest) time unit. + The unit of time relative to the model's fundamental + (smallest) time unit. """ value: ArrayLike | None = None @@ -274,16 +276,127 @@ def __call__(self, **kwargs): return self.sample(**kwargs) -class DistributionalRV(RandomVariable): +class DynamicDistributionalRV(RandomVariable): """ Wrapper class for random variables that sample - from a single :class:`numpyro.distributions.Distribution`. + from a single :class:`numpyro.distributions.Distribution` + that is parameterized / instantiated at `sample()` time + (rather than at RandomVariable instantiation time). """ def __init__( self, name: str, - dist: numpyro.distributions.Distribution, + distribution_constructor: Callable, + reparam: Reparam = None, + ) -> None: + """ + Default constructor for DynamicDistributionalRV. + + Parameters + ---------- + name : str + Name of the random variable. + distribution_constructor : Callable + Callable that returns a concrete parametrized + numpyro.Distributions.distribution instance. + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + + Returns + ------- + None + """ + + self.name = name + self.validate(distribution_constructor) + self.distribution_constructor = distribution_constructor + if reparam is not None: + self.reparam_dict = {self.name: reparam} + else: + self.reparam_dict = {} + + return None + + @staticmethod + def validate(distribution_constructor: any) -> None: + """ + Confirm that the distribution_constructor is + callable. + + Parameters + ---------- + distribution_constructor : any + Putative distribution_constructor to validate. + + Returns + ------- + None or raises a ValueError + """ + if not callable(distribution_constructor): + raise ValueError( + "To instantiate a DynamicDistributionalRV, ", + "one must provide a Callable that returns a " + "numpyro.distributions.Distribution as the " + "distribution_constructor argument. " + f"Got {type(distribution_constructor)}, which " + "does not appear to be callable", + ) + return None + + def sample( + self, + *args, + obs: ArrayLike = None, + **kwargs, + ) -> tuple: + """ + Sample from the distributional rv. + + Parameters + ---------- + *args : + Positional arguments passed to self.distribution_constructor + obs : ArrayLike, optional + Observations passed as the `obs` argument to + :fun:`numpyro.sample()`. Default `None`. + **kwargs : dict, optional + Keyword arguments passed to self.distribution_constructor + + Returns + ------- + SampledValue + Containing a sample from the distribution. + """ + with numpyro.handlers.reparam(config=self.reparam_dict): + sample = numpyro.sample( + name=self.name, + fn=self.distribution_constructor(*args, **kwargs), + obs=obs, + ) + return ( + SampledValue( + jnp.atleast_1d(sample), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + +class StaticDistributionalRV(RandomVariable): + """ + Wrapper class for random variables that sample + from a single :class:`numpyro.distributions.Distribution` + that is parameterized / instantiated at RandomVariable + instantiation time (rather than at `sample()`-ing time). + """ + + def __init__( + self, + name: str, + distribution: numpyro.distributions.Distribution, reparam: Reparam = None, ) -> None: """ @@ -293,7 +406,7 @@ def __init__( ---------- name : str Name of the random variable. - dist : numpyro.distributions.Distribution + distribution : numpyro.distributions.Distribution Distribution of the random variable. reparam : numpyro.infer.reparam.Reparam If not None, reparameterize sampling @@ -306,8 +419,8 @@ def __init__( """ self.name = name - self.validate(dist) - self.dist = dist + self.validate(distribution) + self.distribution = distribution if reparam is not None: self.reparam_dict = {self.name: reparam} else: @@ -316,14 +429,15 @@ def __init__( return None @staticmethod - def validate(dist: any) -> None: + def validate(distribution: any) -> None: """ Validation of the distribution to be implemented in subclasses. """ - if not isinstance(dist, numpyro.distributions.Distribution): + if not isinstance(distribution, numpyro.distributions.Distribution): raise ValueError( - "dist should be an instance of " - f"numpyro.distributions.Distribution, got {dist}" + "distribution should be an instance of " + "numpyro.distributions.Distribution, got " + "{type(distribution)}" ) return None @@ -347,13 +461,13 @@ def sample( Returns ------- - tuple - Containing the sampled from the distribution. + SampledValue + Containing a sample from the distribution. """ with numpyro.handlers.reparam(config=self.reparam_dict): sample = numpyro.sample( name=self.name, - fn=self.dist, + fn=self.distribution, obs=obs, ) return ( @@ -365,6 +479,58 @@ def sample( ) +def DistributionalRV( + name: str, + distribution: numpyro.distributions.Distribution | Callable, + reparam: Reparam = None, +) -> RandomVariable: + """ + Factory function to generate Distributional RandomVariables, + either static or dynamic. + + Parameters + ---------- + name : str + Name of the random variable. + + distribution: numpyro.distributions.Distribution | Callable + Either numpyro.distributions.Distribution instance + given the static distribution of the random variable or + a callable that returns a parameterized + numpyro.distributions.Distribution when called, which + allows for dynamically-parameterized DistributionalRVs, + e.g. a Normal distribution with an inferred location and + scale. + + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + + Returns + ------- + DynamicDistributionalRV | StaticDistributionalRV or + raises a ValueError if a distribution cannot be constructed. + """ + if isinstance(distribution, dist.Distribution): + return StaticDistributionalRV( + name=name, distribution=distribution, reparam=reparam + ) + elif callable(distribution): + return DynamicDistributionalRV( + name=name, distribution_constructor=distribution, reparam=reparam + ) + else: + raise ValueError( + "distribution argument to DistributionalRV " + "must be either a numpyro.distributions.Distribution " + "(for instantiating a static DistributionalRV) " + "or a callable that returns a " + "numpyro.distributions.Distribution (for " + "a dynamic DistributionalRV" + ) + + class Model(metaclass=ABCMeta): """Abstract base class for models""" From 099450ab4877846be5bde2e452684df9ad66db0a Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 16 Aug 2024 16:01:17 -0400 Subject: [PATCH 2/4] Update DistributionalRV kwarg dist => distribution in all tests --- .../src/test/test_assert_sample_and_rtype.py | 2 +- model/src/test/test_assert_type.py | 2 +- model/src/test/test_forecast.py | 2 +- .../test/test_infection_seeding_process.py | 4 ++-- model/src/test/test_latent_admissions.py | 2 +- model/src/test/test_model_basic_renewal.py | 12 +++++----- model/src/test/test_model_hosp_admissions.py | 24 +++++++++---------- model/src/test/test_predictive.py | 2 +- model/src/test/test_random_key.py | 2 +- model/src/test/test_random_walk.py | 6 ++--- model/src/test/test_transformed_rv_class.py | 8 +++++-- model/src/test/utils.py | 4 ++-- 12 files changed, 37 insertions(+), 33 deletions(-) diff --git a/model/src/test/test_assert_sample_and_rtype.py b/model/src/test/test_assert_sample_and_rtype.py index 6fd2f774..fac5a41a 100644 --- a/model/src/test/test_assert_sample_and_rtype.py +++ b/model/src/test/test_assert_sample_and_rtype.py @@ -92,7 +92,7 @@ def test_input_rv(): # numpydoc ignore=GL08 valid_rv = [ NullObservation(), DeterministicVariable(name="rv1", value=jnp.array([1, 2, 3, 4])), - DistributionalRV(name="rv2", dist=dist.Normal(0, 1)), + DistributionalRV(name="rv2", distribution=dist.Normal(0, 1)), ] not_rv = jnp.array([1]) diff --git a/model/src/test/test_assert_type.py b/model/src/test/test_assert_type.py index ea66ddff..40385868 100644 --- a/model/src/test/test_assert_type.py +++ b/model/src/test/test_assert_type.py @@ -14,7 +14,7 @@ def test_valid_assertion_types(): 5, "Hello", (1,), - DistributionalRV(name="rv", dist=dist.Beta(1, 1)), + DistributionalRV(name="rv", distribution=dist.Beta(1, 1)), ] arg_names = ["input_int", "input_string", "input_tuple", "input_rv"] input_types = [int, str, tuple, RandomVariable] diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 3feaf373..be9b53aa 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -24,7 +24,7 @@ def test_forecast(): gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/model/src/test/test_infection_seeding_process.py b/model/src/test/test_infection_seeding_process.py index b7fa6ef7..84b21428 100644 --- a/model/src/test/test_infection_seeding_process.py +++ b/model/src/test/test_infection_seeding_process.py @@ -19,14 +19,14 @@ def test_infection_initialization_process(): zero_pad_model = InfectionInitializationProcess( "zero_pad_model", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints), t_unit=1, ) exp_model = InfectionInitializationProcess( "exp_model", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( n_timepoints, DeterministicVariable(name="rate", value=0.5) ), diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 7648d70f..36c90161 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -63,7 +63,7 @@ def test_admissions_sample(): hosp1 = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) + name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index 9b5bdf46..f6230d6b 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -33,7 +33,7 @@ def test_model_basicrenewal_no_timepoints_or_observations(): name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) + I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() @@ -64,7 +64,7 @@ def test_model_basicrenewal_both_timepoints_and_observations(): value=jnp.array([0.25, 0.25, 0.25, 0.25]), ) - I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) + I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() @@ -100,11 +100,11 @@ def test_model_basicrenewal_no_obs_model(): ) with pytest.raises(ValueError): - I0 = DistributionalRV(name="I0", dist=1) + I0 = DistributionalRV(name="I0", distribution=1) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -175,7 +175,7 @@ def test_model_basicrenewal_with_obs_model(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -224,7 +224,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index 7e7e129b..f0034fb2 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -57,7 +57,7 @@ def test_model_hosp_no_timepoints_or_observations(): name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) + I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = simple_rt() @@ -93,7 +93,7 @@ def test_model_hosp_no_timepoints_or_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) + name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -122,7 +122,7 @@ def test_model_hosp_both_timepoints_and_observations(): value=jnp.array([0.25, 0.25, 0.25, 0.25]), ) - I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) + I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = simple_rt() @@ -158,7 +158,7 @@ def test_model_hosp_both_timepoints_and_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) + name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -192,7 +192,7 @@ def test_model_hosp_no_obs_model(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -230,7 +230,7 @@ def test_model_hosp_no_obs_model(): infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( name="IHR", - dist=dist.LogNormal(jnp.log(0.05), 0.05), + distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), ) @@ -302,7 +302,7 @@ def test_model_hosp_with_obs_model(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -341,7 +341,7 @@ def test_model_hosp_with_obs_model(): infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( name="IHR", - dist=dist.LogNormal(jnp.log(0.05), 0.05), + distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), ) @@ -389,7 +389,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -438,7 +438,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): day_of_week_effect_rv=weekday, hosp_report_prob_rv=hosp_report_prob_dist, infect_hosp_rate_rv=DistributionalRV( - name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) + name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -488,7 +488,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -549,7 +549,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): hosp_report_prob_rv=hosp_report_prob_dist, infect_hosp_rate_rv=DistributionalRV( name="IHR", - dist=dist.LogNormal(jnp.log(0.05), 0.05), + distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), ) diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index c848ccda..b36ed658 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -23,7 +23,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index 247555ca..cd49504f 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -28,7 +28,7 @@ def create_test_model(): # numpydoc ignore=GL08 gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 242e0400..55c9e5f7 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -16,13 +16,13 @@ def test_rw_can_be_sampled(): """ init_rv_rand = DistributionalRV( name="init_rv_rand", - dist=dist.Normal(1, 0.5), + distribution=dist.Normal(1, 0.5), ) init_rv_fixed = DeterministicVariable(name="init_rv_fixed", value=50.0) step_rv = DistributionalRV( name="rw_step", - dist=dist.Normal(0, 1), + distribution=dist.Normal(0, 1), ) rw_init_rand = SimpleRandomWalkProcess( @@ -63,7 +63,7 @@ def test_rw_samples_correctly_distributed(): name="rw_normal_test", step_rv=DistributionalRV( name="rw_normal_dist", - dist=dist.Normal(loc=step_mean, scale=step_sd), + distribution=dist.Normal(loc=step_mean, scale=step_sd), ), init_rv=DeterministicVariable( name="init_rv_fixed", value=rw_init_val diff --git a/model/src/test/test_transformed_rv_class.py b/model/src/test/test_transformed_rv_class.py index 210041a5..9b009c42 100644 --- a/model/src/test/test_transformed_rv_class.py +++ b/model/src/test/test_transformed_rv_class.py @@ -67,7 +67,9 @@ def test_transform_rv_validation(): works as expected. """ - base_rv = DistributionalRV(name="test_normal", dist=dist.Normal(0, 1)) + base_rv = DistributionalRV( + name="test_normal", distribution=dist.Normal(0, 1) + ) base_rv.sample_length = lambda: 1 # numpydoc ignore=GL08 l2_rv = LengthTwoRV() @@ -109,7 +111,9 @@ def test_transforms_applied_at_sampling(): instances correctly apply their specified transformations at sampling """ - norm_rv = DistributionalRV(name="test_normal", dist=dist.Normal(0, 1)) + norm_rv = DistributionalRV( + name="test_normal", distribution=dist.Normal(0, 1) + ) norm_rv.sample_length = lambda: 1 l2_rv = LengthTwoRV() diff --git a/model/src/test/utils.py b/model/src/test/utils.py index 4bb05a52..76274e29 100644 --- a/model/src/test/utils.py +++ b/model/src/test/utils.py @@ -32,10 +32,10 @@ def simple_rt(arg_name: str = "Rt_rv"): base_rv=SimpleRandomWalkProcess( name="log_rt", step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) + name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), init_rv=DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) + name="init_log_rt", distribution=dist.Normal(0, 0.2) ), ), transforms=t.ExpTransform(), From 8ca2dd2a377dfa246ab6d01e93ab2bec1e494abe Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 16 Aug 2024 16:07:11 -0400 Subject: [PATCH 3/4] update dist => distribution kwarg in DistributionalRV in all tutorials --- docs/source/tutorials/basic_renewal_model.qmd | 6 +++--- docs/source/tutorials/extending_pyrenew.qmd | 8 +++++--- docs/source/tutorials/hospital_admissions_model.qmd | 13 +++++++------ docs/source/tutorials/periodic_effects.qmd | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 65f1729c..b06de44a 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -126,7 +126,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array) # (2) Initial infections (inferred with a prior) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(2.5, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(2.5, 1)), InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) @@ -148,12 +148,12 @@ class MyRt(RandomVariable): name="log_rt", step_rv=DistributionalRV( name="rw_step_rv", - dist=dist.Normal(0, sd_rt), + distribution=dist.Normal(0, sd_rt), reparam=LocScaleReparam(0), ), init_rv=DistributionalRV( name="init_log_rt", - dist=dist.Normal(jnp.log(1), jnp.log(1.2)), + distribution=dist.Normal(jnp.log(1), jnp.log(1.2)), ), ), transforms=t.ExpTransform(), diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index bcfc2749..5e854a4f 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -49,7 +49,7 @@ feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( gen_int_array.size, DeterministicVariable(name="rate", value=0.05), @@ -67,9 +67,11 @@ rt = TransformedRandomVariable( base_rv=SimpleRandomWalkProcess( name="log_rt", step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) + name="rw_step_rv", distribution=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_rt", distribution=dist.Normal(0, 0.2) ), - init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)), ), transforms=t.ExpTransform(), ) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 57a18b5c..cbd73e39 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -146,7 +146,7 @@ inf_hosp_int = deterministic.DeterministicPMF( ) hosp_rate = metaclass.DistributionalRV( - name="IHR", dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) + name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) latent_hosp = latent.HospitalAdmissions( @@ -171,7 +171,8 @@ latent_inf = latent.Infections() I0 = InfectionInitializationProcess( "I0_initialization", metaclass.DistributionalRV( - name="I0", dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)) + name="I0", + distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), InitializeInfectionsExponentialGrowth( gen_int_array.size, @@ -199,10 +200,10 @@ class MyRt(metaclass.RandomVariable): base_rv=process.SimpleRandomWalkProcess( name="log_rt", step_rv=metaclass.DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, sd_rt.value) + name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) ), init_rv=metaclass.DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) + name="init_log_rt", distribution=dist.Normal(0, 0.2) ), ), transforms=transformation.ExpTransform(), @@ -213,7 +214,7 @@ class MyRt(metaclass.RandomVariable): rtproc = MyRt( metaclass.DistributionalRV( - name="Rt_random_walk_sd", dist=dist.HalfNormal(0.025) + name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025) ) ) @@ -225,7 +226,7 @@ nb_conc_rv = metaclass.TransformedRandomVariable( "concentration", metaclass.DistributionalRV( name="concentration_raw", - dist=dist.TruncatedNormal(loc=0, scale=1, low=0.01), + distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), transformation.PowerTransform(-2), ) diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index d1c5d55b..7586776b 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -77,7 +77,7 @@ mysimplex = dist.TransformedDistribution( dayofweek = process.DayOfWeekEffect( offset=0, quantity_to_broadcast=metaclass.DistributionalRV( - name="simp", dist=mysimplex + name="simp", distribution=mysimplex ), t_start=0, ) From 25fa99a9d7a68125fd4067ce8c53e1ccb8c4b1fe Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 16 Aug 2024 16:36:33 -0400 Subject: [PATCH 4/4] Add tests for DistributionalRV factory and classes --- model/src/test/test_distributional_rv.py | 107 +++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 model/src/test/test_distributional_rv.py diff --git a/model/src/test/test_distributional_rv.py b/model/src/test/test_distributional_rv.py new file mode 100644 index 00000000..66cb263c --- /dev/null +++ b/model/src/test/test_distributional_rv.py @@ -0,0 +1,107 @@ +""" +Tests for the distributional RV classes +""" +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest +from numpy.testing import assert_array_equal +from pyrenew.metaclass import ( + DistributionalRV, + DynamicDistributionalRV, + StaticDistributionalRV, +) + + +class NonCallableTestClass: + """ + Generic non-callable object to test + callable checking for DynamicDistributionalRV. + """ + + def __init__(self): + """ + Initialization method for generic non-callable + object + """ + pass + + +@pytest.mark.parametrize("not_a_dist", [1, "test", NonCallableTestClass()]) +def test_invalid_constructor_args(not_a_dist): + """ + Test that the constructor errors + appropriately when given incorrect input + """ + + with pytest.raises( + ValueError, match="distribution argument to DistributionalRV" + ): + DistributionalRV(name="this should fail", distribution=not_a_dist) + with pytest.raises( + ValueError, + match=( + "distribution should be an instance of " + "numpyro.distributions.Distribution" + ), + ): + StaticDistributionalRV.validate(not_a_dist) + with pytest.raises(ValueError, match="must provide a Callable"): + DynamicDistributionalRV.validate(not_a_dist) + + +@pytest.mark.parametrize( + ["valid_static_dist_arg", "valid_dynamic_dist_arg"], + [ + [dist.Normal(0, 1), dist.Normal], + [dist.Cauchy(3.0, 5.0), dist.Cauchy], + [dist.Poisson(0.25), dist.Poisson], + ], +) +def test_factory_triage(valid_static_dist_arg, valid_dynamic_dist_arg): + """ + Test that passing a numpyro.distributions.Distribution + instance to the DistributionalRV factory instaniates + a StaticDistributionalRV, while passing a callable + instaniates a DynamicDistributionalRV + """ + static = DistributionalRV( + name="test static", distribution=valid_static_dist_arg + ) + assert isinstance(static, StaticDistributionalRV) + dynamic = DistributionalRV( + name="test dynamic", distribution=valid_dynamic_dist_arg + ) + assert isinstance(dynamic, DynamicDistributionalRV) + + +@pytest.mark.parametrize( + ["dist", "params"], + [ + [dist.Normal, {"loc": 0.0, "scale": 0.5}], + [dist.Poisson, {"rate": 0.35265}], + [ + dist.Cauchy, + { + "loc": jnp.array([1.0, 5.0, -0.25]), + "scale": jnp.array([0.02, 0.15, 2]), + }, + ], + ], +) +def test_sampling_equivalent(dist, params): + """ + Test that sampling a DynamicDistributionalRV + with a given parameterization is equivalent to + sampling a StaticDistributionalRV with the + same parameterization and the same random seed + """ + static = DistributionalRV(name="static", distribution=dist(**params)) + dynamic = DistributionalRV(name="dynamic", distribution=dist) + assert isinstance(static, StaticDistributionalRV) + assert isinstance(dynamic, DynamicDistributionalRV) + with numpyro.handlers.seed(rng_seed=5): + static_samp, *_ = static() + with numpyro.handlers.seed(rng_seed=5): + dynamic_samp, *_ = dynamic(**params) + assert_array_equal(static_samp.value, dynamic_samp.value)