From 3d7796721c247fa499c3da1da4a20646459839ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?upx3=E2=80=94TM=20=28CFA=29?= <127630341+AFg6K7h4fhy2@users.noreply.github.com> Date: Thu, 25 Jul 2024 17:57:05 -0400 Subject: [PATCH] Change Order Of The `name` Argument To Match Numpyro (#324) * fix name order * fix test and remaining renamings * Update model/src/pyrenew/deterministic/deterministic.py Co-authored-by: Dylan H. Morris * Update model/src/pyrenew/deterministic/deterministicpmf.py Co-authored-by: Dylan H. Morris * Update model/src/pyrenew/deterministic/deterministicpmf.py Co-authored-by: Dylan H. Morris * Update model/src/pyrenew/deterministic/deterministicpmf.py Co-authored-by: Dylan H. Morris --------- Co-authored-by: Dylan H. Morris --- docs/source/tutorials/basic_renewal_model.qmd | 11 +-- docs/source/tutorials/extending_pyrenew.qmd | 16 ++-- .../tutorials/hospital_admissions_model.qmd | 22 ++--- docs/source/tutorials/periodic_effects.qmd | 10 ++- .../pyrenew/deterministic/deterministic.py | 10 +-- .../pyrenew/deterministic/deterministicpmf.py | 10 +-- .../src/pyrenew/latent/hospitaladmissions.py | 8 +- model/src/pyrenew/metaclass.py | 20 ++--- .../src/pyrenew/process/firstdifferencear.py | 2 +- model/src/test/test_deterministic.py | 10 ++- model/src/test/test_forecast.py | 16 ++-- .../src/test/test_infection_seeding_method.py | 15 ++-- .../test/test_infection_seeding_process.py | 10 +-- model/src/test/test_infectionsrtfeedback.py | 8 +- model/src/test/test_latent_admissions.py | 16 ++-- model/src/test/test_latent_infections.py | 8 +- model/src/test/test_model_basic_renewal.py | 32 +++++--- model/src/test/test_model_hosp_admissions.py | 80 +++++++++++-------- .../test/test_observation_negativebinom.py | 4 +- model/src/test/test_periodiceffect.py | 4 +- model/src/test/test_predictive.py | 12 ++- model/src/test/test_random_key.py | 12 ++- model/src/test/test_random_walk.py | 18 +++-- model/src/test/test_rtperiodicdiff.py | 37 ++++++--- model/src/test/test_transformed_rv_class.py | 4 +- 25 files changed, 233 insertions(+), 162 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 7c5de48a..554f4a77 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -121,12 +121,12 @@ To initialize these five components within the renewal modeling framework, we es # | label: creating-elements # (1) The generation interval (deterministic) pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1]) -gen_int = DeterministicPMF(pmf_array, name="gen_int") +gen_int = DeterministicPMF(name="gen_int", vars=pmf_array) # (2) Initial infections (inferred with a prior) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(2.5, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(2.5, 1)), InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) @@ -147,12 +147,13 @@ class MyRt(RandomVariable): base_rv=SimpleRandomWalkProcess( name="log_rt", step_rv=DistributionalRV( - dist.Normal(0, sd_rt), - "rw_step_rv", + name="rw_step_rv", + dist=dist.Normal(0, sd_rt), reparam=LocScaleReparam(0), ), init_rv=DistributionalRV( - dist.Normal(jnp.log(1), jnp.log(1.2)), "init_log_Rt_rv" + name="init_log_Rt_rv", + dist=dist.Normal(jnp.log(1), jnp.log(1.2)), ), ), transforms=t.ExpTransform(), diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 5b0cf197..8a525ffd 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -42,15 +42,15 @@ The following code-chunk defines the model components. Notice that for both the ```{python} # | label: model-components gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1]) -gen_int = DeterministicPMF(gen_int_array, name="gen_int") -feedback_strength = DeterministicVariable(0.05, name="feedback_strength") +gen_int = DeterministicPMF(name="gen_int", vars=gen_int_array) +feedback_strength = DeterministicVariable(name="feedback_strength", vars=0.05) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( gen_int_array.size, - DeterministicVariable(0.5, name="rate"), + DeterministicVariable(name="rate", vars=0.5), ), t_unit=1, ) @@ -64,8 +64,12 @@ rt = TransformedRandomVariable( "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 1bee6d91..93b16477 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -142,12 +142,11 @@ import jax.numpy as jnp import numpyro.distributions as dist inf_hosp_int = deterministic.DeterministicPMF( - inf_hosp_int, name="inf_hosp_int" + name="inf_hosp_int", vars=inf_hosp_int ) hosp_rate = metaclass.DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)), - name="IHR", + name="IHR", dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) latent_hosp = latent.HospitalAdmissions( @@ -172,17 +171,17 @@ latent_inf = latent.Infections() I0 = InfectionInitializationProcess( "I0_initialization", metaclass.DistributionalRV( - dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), name="I0" + name="I0", dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)) ), InitializeInfectionsExponentialGrowth( gen_int_array.size, - deterministic.DeterministicVariable(0.05, name="rate"), + deterministic.DeterministicVariable(name="rate", vars=0.05), ), t_unit=1, ) # Generation interval and Rt -gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") +gen_int = deterministic.DeterministicPMF(name="gen_int", vars=gen_int) class MyRt(metaclass.RandomVariable): @@ -200,10 +199,10 @@ class MyRt(metaclass.RandomVariable): base_rv=process.SimpleRandomWalkProcess( name="log_rt", step_rv=metaclass.DistributionalRV( - dist.Normal(0, sd_rt), "rw_step_rv" + name="rw_step_rv", dist=dist.Normal(0, sd_rt) ), init_rv=metaclass.DistributionalRV( - dist.Normal(0, 0.2), "init_log_Rt_rv" + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) ), ), transforms=transformation.ExpTransform(), @@ -213,7 +212,9 @@ class MyRt(metaclass.RandomVariable): rtproc = MyRt( - metaclass.DistributionalRV(dist.HalfNormal(0.025), "Rt_random_walk_sd") + metaclass.DistributionalRV( + name="Rt_random_walk_sd", dist=dist.HalfNormal(0.025) + ) ) # The observation model @@ -223,7 +224,8 @@ rtproc = MyRt( nb_conc_rv = metaclass.TransformedRandomVariable( "concentration", metaclass.DistributionalRV( - dist.TruncatedNormal(loc=0, scale=1, low=0.01), "concentration_raw" + name="concentration_raw", + dist=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), transformation.PowerTransform(-2), ) diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 3c1f964f..8683e976 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -28,13 +28,13 @@ rt_proc = process.RtWeeklyDiffProcess( name="rt_weekly_diff", offset=0, log_rt_prior=deterministic.DeterministicVariable( - jnp.array([0.1, 0.2]), name="log_rt_prior" + name="log_rt_prior", vars=jnp.array([0.1, 0.2]) ), autoreg=deterministic.DeterministicVariable( - jnp.array([0.7]), name="autoreg" + name="autoreg", vars=jnp.array([0.7]) ), periodic_diff_sd=deterministic.DeterministicVariable( - jnp.array([0.1]), name="periodic_diff_sd" + name="periodic_diff_sd", vars=jnp.array([0.1]) ), ) ``` @@ -76,7 +76,9 @@ mysimplex = dist.TransformedDistribution( # Constructing the day of week effect dayofweek = process.DayOfWeekEffect( offset=0, - quantity_to_broadcast=metaclass.DistributionalRV(mysimplex, "simp"), + quantity_to_broadcast=metaclass.DistributionalRV( + name="simp", dist=mysimplex + ), t_start=0, ) ``` diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index b30d1903..66a77029 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -17,26 +17,26 @@ class DeterministicVariable(RandomVariable): def __init__( self, - vars: ArrayLike, name: str, + vars: ArrayLike, ) -> None: """Default constructor Parameters ---------- - vars : ArrayLike - A tuple with arraylike objects. name : str A name to assign to the process. + vars : ArrayLike + An ArrayLike object. Returns ------- None """ - self.validate(vars) - self.vars = jnp.atleast_1d(vars) self.name = name + self.vars = jnp.atleast_1d(vars) + self.validate(vars) return None diff --git a/model/src/pyrenew/deterministic/deterministicpmf.py b/model/src/pyrenew/deterministic/deterministicpmf.py index 3fd85420..c7342bdc 100644 --- a/model/src/pyrenew/deterministic/deterministicpmf.py +++ b/model/src/pyrenew/deterministic/deterministicpmf.py @@ -15,8 +15,8 @@ class DeterministicPMF(RandomVariable): def __init__( self, - vars: ArrayLike, name: str, + vars: ArrayLike, tol: float = 1e-5, ) -> None: """ @@ -29,10 +29,10 @@ def __init__( Parameters ---------- - vars : tuple - A tuple with arraylike objects. name : str - A name to assign to the process. + A name to assign to the variable. + vars : tuple + An ArrayLike object. tol : float, optional Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults to 1e-5. @@ -46,7 +46,7 @@ def __init__( tol=tol, ) - self.basevar = DeterministicVariable(vars, name) + self.basevar = DeterministicVariable(name=name, vars=vars) return None diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 3d5a84cd..f92c5c90 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -90,9 +90,13 @@ def __init__( """ if day_of_week_effect_rv is None: - day_of_week_effect_rv = DeterministicVariable(1, "weekday_effect") + day_of_week_effect_rv = DeterministicVariable( + name="weekday_effect", vars=1 + ) if hosp_report_prob_rv is None: - hosp_report_prob_rv = DeterministicVariable(1, "hosp_report_prob") + hosp_report_prob_rv = DeterministicVariable( + name="hosp_report_prob", vars=1 + ) HospitalAdmissions.validate( infect_hosp_rate_rv, diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 21d38312..fd058d24 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -142,7 +142,6 @@ def set_timeseries( model time. It could be negative, indicating that the `sample()` method returns timepoints that occur prior to the model t = 0. - t_unit : int The unit of the time series relative to the model's fundamental (smallest) @@ -219,8 +218,8 @@ class DistributionalRV(RandomVariable): def __init__( self, - dist: numpyro.distributions.Distribution, name: str, + dist: numpyro.distributions.Distribution, reparam: Reparam = None, ) -> None: """ @@ -228,11 +227,10 @@ def __init__( Parameters ---------- - dist : numpyro.distributions.Distribution - Distribution of the random variable. name : str Name of the random variable. - + dist : numpyro.distributions.Distribution + Distribution of the random variable. reparam : numpyro.infer.reparam.Reparam If not None, reparameterize sampling from the distribution according to the @@ -243,10 +241,9 @@ def __init__( None """ + self.name = name self.validate(dist) - self.dist = dist - self.name = name if reparam is not None: self.reparam_dict = {self.name: reparam} else: @@ -603,19 +600,16 @@ def __init__( Parameters ---------- - name : str - A name for the random variable instance - + A name for the random variable instance. base_rv : RandomVariable - The underlying (untransformed) RandomVariable - + The underlying (untransformed) RandomVariable. transforms : Transform Transformation or tuple of transformations to apply to the output of `base_rv.sample()`; single values will be coerced to a length-one tuple. If a tuple, should be the same - length as the tuple returned by `base_rv.sample()` + length as the tuple returned by `base_rv.sample()`. Returns ------- diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index afbf768e..e3e594cd 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -71,9 +71,9 @@ def sample( With a single array of shape (duration,). """ rates_of_change, *_ = self.rate_of_change_proc.sample( + name=self.name + "_rate_of_change", duration=duration, inits=jnp.atleast_1d(init_rate_of_change), - name=self.name + "_rate_of_change", ) return (init_val + jnp.cumsum(rates_of_change.flatten()),) diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index eae44f17..b20d7da4 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -18,15 +18,17 @@ def test_deterministic(): """ var1 = DeterministicVariable( - jnp.array( + name="var1", + vars=jnp.array( [ 1, ] ), - name="var1", ) - var2 = DeterministicPMF(jnp.array([0.25, 0.25, 0.2, 0.3]), name="var2") - var3 = DeterministicProcess(jnp.array([1, 2, 3, 4]), name="var3") + var2 = DeterministicPMF( + name="var2", vars=jnp.array([0.25, 0.25, 0.2, 0.3]) + ) + var3 = DeterministicProcess(name="var3", vars=jnp.array([1, 2, 3, 4])) var4 = NullVariable() var5 = NullProcess() diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 148627f2..259fdff7 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -21,21 +21,25 @@ def test_forecast(): """Check that forecasts are the right length and match the posterior up until forecast begins.""" pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) - gen_int = DeterministicPMF(pmf_array, name="gen_int") + gen_int = DeterministicPMF(name="gen_int", vars=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) latent_infections = Infections() - observed_infections = PoissonObservation("poisson_rv") + observed_infections = PoissonObservation(name="poisson_rv") rt = TransformedRandomVariable( - "Rt_rv", + name="Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py index 0b7efe7f..29abcce6 100644 --- a/model/src/test/test_infection_seeding_method.py +++ b/model/src/test/test_infection_seeding_method.py @@ -13,8 +13,8 @@ def test_initialize_infections_exponential(): """Check that the InitializeInfectionsExponentialGrowth class generates the correct number of infections at each time point.""" n_timepoints = 10 - rate_RV = DeterministicVariable(0.5, name="rate_RV") - I_pre_init_RV = DeterministicVariable(10.0, name="I_pre_init_RV") + rate_RV = DeterministicVariable(name="rate_RV", vars=0.5) + I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", vars=10.0) default_t_pre_init = n_timepoints - 1 (I_pre_init,) = I_pre_init_RV() @@ -35,14 +35,17 @@ def test_initialize_infections_exponential(): assert infections_default_t_pre_init[default_t_pre_init] == I_pre_init # test for failure with non-scalar rate or I_pre_init - rate_RV_2 = DeterministicVariable(np.array([0.5, 0.5]), name="rate_RV") + rate_RV_2 = DeterministicVariable( + name="rate_RV", vars=np.array([0.5, 0.5]) + ) with pytest.raises(ValueError): InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV_2 ).initialize_infections(I_pre_init) I_pre_init_RV_2 = DeterministicVariable( - np.array([10.0, 10.0]), name="I_pre_init_RV" + name="I_pre_init_RV", + vars=np.array([10.0, 10.0]), ) (I_pre_init_2,) = I_pre_init_RV_2() @@ -72,7 +75,7 @@ def test_initialize_infections_zero_pad(): """Check that the InitializeInfectionsZeroPad class generates the correct number of infections at each time point.""" n_timepoints = 10 - I_pre_init_RV = DeterministicVariable(10.0, name="I_pre_init_RV") + I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", vars=10.0) (I_pre_init,) = I_pre_init_RV() infections = InitializeInfectionsZeroPad( @@ -83,7 +86,7 @@ def test_initialize_infections_zero_pad(): ) I_pre_init_RV_2 = DeterministicVariable( - np.array([10.0, 10.0]), name="I_pre_init_RV" + name="I_pre_init_RV", vars=np.array([10.0, 10.0]) ) (I_pre_init_2,) = I_pre_init_RV_2() diff --git a/model/src/test/test_infection_seeding_process.py b/model/src/test/test_infection_seeding_process.py index f159fcf1..685ba61f 100644 --- a/model/src/test/test_infection_seeding_process.py +++ b/model/src/test/test_infection_seeding_process.py @@ -19,23 +19,23 @@ def test_infection_initialization_process(): zero_pad_model = InfectionInitializationProcess( "zero_pad_model", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints), t_unit=1, ) exp_model = InfectionInitializationProcess( "exp_model", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( - n_timepoints, DeterministicVariable(0.5, name="rate") + n_timepoints, DeterministicVariable(name="rate", vars=0.5) ), t_unit=1, ) vec_model = InfectionInitializationProcess( "vec_model", - DeterministicVariable(jnp.arange(n_timepoints), name="I0"), + DeterministicVariable(name="I0", vars=jnp.arange(n_timepoints)), InitializeInfectionsFromVec(n_timepoints), t_unit=1, ) @@ -56,7 +56,7 @@ def test_infection_initialization_process(): with pytest.raises(TypeError): InfectionInitializationProcess( "vec_model", - DeterministicVariable(jnp.arange(n_timepoints), name="I0"), + DeterministicVariable(name="I0", vars=jnp.arange(n_timepoints)), 3, t_unit=1, ) diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index f4b6d164..e602a3a3 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -69,9 +69,9 @@ def test_infectionsrtfeedback(): # By doing the infection feedback strength 0, Rt = Rt_adjusted # So infection should be equal in both inf_feed_strength = DeterministicVariable( - jnp.zeros_like(Rt), name="inf_feed_strength" + name="inf_feed_strength", vars=jnp.zeros_like(Rt) ) - inf_feedback_pmf = DeterministicPMF(gen_int, name="inf_feedback_pmf") + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", vars=gen_int) # Test the InfectionsWithFeedback class InfectionsWithFeedback = latent.InfectionsWithFeedback( @@ -113,9 +113,9 @@ def test_infectionsrtfeedback_feedback(): gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) inf_feed_strength = DeterministicVariable( - jnp.repeat(0.5, len(Rt)), name="inf_feed_strength" + name="inf_feed_strength", vars=jnp.repeat(0.5, len(Rt)) ) - inf_feedback_pmf = DeterministicPMF(gen_int, name="inf_feedback_pmf") + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", vars=gen_int) # Test the InfectionsWithFeedback class InfectionsWithFeedback = latent.InfectionsWithFeedback( diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index eb0c03e3..6b211b7c 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -21,11 +21,15 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions rt = TransformedRandomVariable( - "Rt_rv", + name="Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) @@ -43,7 +47,8 @@ def test_admissions_sample(): # Testing the hospital admissions inf_hosp = DeterministicPMF( - jnp.array( + name="inf_hosp", + vars=jnp.array( [ 0, 0, @@ -65,13 +70,12 @@ def test_admissions_sample(): 0.05, ] ), - name="inf_hosp", ) hosp1 = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) ), ) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 3d75bb1c..d472a1d7 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -22,8 +22,12 @@ def test_infections_as_deterministic(): "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index 98bf033d..d0330ae0 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -37,8 +37,12 @@ def get_default_rt(): "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) @@ -52,10 +56,10 @@ def test_model_basicrenewal_no_timepoints_or_observations(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) latent_infections = Infections() @@ -82,10 +86,11 @@ def test_model_basicrenewal_both_timepoints_and_observations(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", + vars=jnp.array([0.25, 0.25, 0.25, 0.25]), ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) latent_infections = Infections() @@ -116,15 +121,16 @@ def test_model_basicrenewal_no_obs_model(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", + vars=jnp.array([0.25, 0.25, 0.25, 0.25]), ) with pytest.raises(ValueError): - I0 = DistributionalRV(dist=1, name="I0") + I0 = DistributionalRV(name="I0", dist=1) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -189,12 +195,12 @@ def test_model_basicrenewal_with_obs_model(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25]) ) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -238,12 +244,12 @@ def test_model_basicrenewal_with_obs_model(): def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25]) ) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index 05d20183..632e5e9c 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -46,8 +46,12 @@ def get_default_rt(): "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) @@ -78,10 +82,10 @@ def test_model_hosp_no_timepoints_or_observations(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = get_default_rt() @@ -89,7 +93,8 @@ def test_model_hosp_no_timepoints_or_observations(): observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( - jnp.array( + name="inf_hosp", + vars=jnp.array( [ 0, 0, @@ -111,13 +116,12 @@ def test_model_hosp_no_timepoints_or_observations(): 0.05, ], ), - name="inf_hosp", ) latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -142,10 +146,11 @@ def test_model_hosp_both_timepoints_and_observations(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", + vars=jnp.array([0.25, 0.25, 0.25, 0.25]), ) - I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") + I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = get_default_rt() @@ -153,7 +158,8 @@ def test_model_hosp_both_timepoints_and_observations(): observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( - jnp.array( + name="inf_hosp", + vars=jnp.array( [ 0, 0, @@ -175,13 +181,12 @@ def test_model_hosp_both_timepoints_and_observations(): 0.05, ], ), - name="inf_hosp", ) latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -209,12 +214,13 @@ def test_model_hosp_no_obs_model(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", + vars=jnp.array([0.25, 0.25, 0.25, 0.25]), ) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -223,7 +229,8 @@ def test_model_hosp_no_obs_model(): Rt_process = get_default_rt() inf_hosp = DeterministicPMF( - jnp.array( + name="inf_hosp", + vars=jnp.array( [ 0, 0, @@ -245,13 +252,13 @@ def test_model_hosp_no_obs_model(): 0.05, ] ), - name="inf_hosp", ) latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + name="IHR", + dist=dist.LogNormal(jnp.log(0.05), 0.05), ), ) @@ -313,12 +320,12 @@ def test_model_hosp_with_obs_model(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25]) ) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -328,7 +335,8 @@ def test_model_hosp_with_obs_model(): observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( - jnp.array( + name="inf_hosp", + vars=jnp.array( [ 0, 0, @@ -350,13 +358,13 @@ def test_model_hosp_with_obs_model(): 0.05, ], ), - name="inf_hosp", ) latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + name="IHR", + dist=dist.LogNormal(jnp.log(0.05), 0.05), ), ) @@ -398,12 +406,13 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", + vars=jnp.array([0.25, 0.25, 0.25, 0.25]), ) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -413,7 +422,8 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( - jnp.array( + name="inf_hosp", + vars=jnp.array( [ 0, 0, @@ -435,7 +445,6 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): 0.05, ], ), - name="inf_hosp", ) # Other random components @@ -452,7 +461,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): day_of_week_effect_rv=weekday, hosp_report_prob_rv=hosp_report_prob_dist, infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + name="IHR", dist=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -494,14 +503,15 @@ def test_model_hosp_with_obs_model_weekday_phosp(): """ gen_int = DeterministicPMF( - jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + name="gen_int", + vars=jnp.array([0.25, 0.25, 0.25, 0.25]), ) n_obs_to_generate = 30 pad_size = 5 I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -512,7 +522,8 @@ def test_model_hosp_with_obs_model_weekday_phosp(): observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( - jnp.array( + name="inf_hosp", + vars=jnp.array( [ 0, 0, @@ -534,7 +545,6 @@ def test_model_hosp_with_obs_model_weekday_phosp(): 0.05, ], ), - name="inf_hosp", ) # Other random components @@ -544,7 +554,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): weekday = jnp.tile(weekday, 10) weekday = weekday[:total_length] - weekday = DeterministicVariable(weekday, name="weekday") + weekday = DeterministicVariable(name="weekday", vars=weekday) hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4]) hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10) @@ -552,7 +562,8 @@ def test_model_hosp_with_obs_model_weekday_phosp(): hosp_report_prob_dist = hosp_report_prob_dist / hosp_report_prob_dist.sum() hosp_report_prob_dist = DeterministicVariable( - vars=hosp_report_prob_dist, name="hosp_report_prob_dist" + name="hosp_report_prob_dist", + vars=hosp_report_prob_dist, ) latent_admissions = HospitalAdmissions( @@ -560,7 +571,8 @@ def test_model_hosp_with_obs_model_weekday_phosp(): day_of_week_effect_rv=weekday, hosp_report_prob_rv=hosp_report_prob_dist, infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" + name="IHR", + dist=dist.LogNormal(jnp.log(0.05), 0.05), ), ) diff --git a/model/src/test/test_observation_negativebinom.py b/model/src/test/test_observation_negativebinom.py index 845dac4a..9e17e73b 100644 --- a/model/src/test/test_observation_negativebinom.py +++ b/model/src/test/test_observation_negativebinom.py @@ -16,7 +16,7 @@ def test_negativebinom_deterministic_obs(): negb = NegativeBinomialObservation( "negbinom_rv", - concentration_rv=DeterministicVariable(10, name="concentration"), + concentration_rv=DeterministicVariable(name="concentration", vars=10), ) rates = np.random.randint(1, 5, size=10) @@ -42,7 +42,7 @@ def test_negativebinom_random_obs(): negb = NegativeBinomialObservation( "negbinom_rv", - concentration_rv=DeterministicVariable(10, "concentration"), + concentration_rv=DeterministicVariable(name="concentration", vars=10), ) rates = np.repeat(5, 20000) diff --git a/model/src/test/test_periodiceffect.py b/model/src/test/test_periodiceffect.py index fa541e38..9712911a 100644 --- a/model/src/test/test_periodiceffect.py +++ b/model/src/test/test_periodiceffect.py @@ -13,7 +13,7 @@ def test_periodiceffect() -> None: """Checks basic functionality of the process""" x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) - rv = DeterministicVariable(x, name="weekly-sample") + rv = DeterministicVariable(name="weekly-sample", vars=x) params = { "offset": 0, @@ -58,7 +58,7 @@ def test_weeklyeffect() -> None: """Checks basic functionality of the process""" x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) - rv = DeterministicVariable(x, name="weekly-sample") + rv = DeterministicVariable(name="weekly-sample", vars=x) params = { "offset": 2, diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 67f713cf..2bc1995c 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -20,10 +20,10 @@ from pyrenew.process import SimpleRandomWalkProcess pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) -gen_int = DeterministicPMF(pmf_array, name="gen_int") +gen_int = DeterministicPMF(name="gen_int", vars=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -33,8 +33,12 @@ "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index adeb34cc..b75b999a 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -25,10 +25,10 @@ def create_test_model(): # numpydoc ignore=GL08 pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) - gen_int = DeterministicPMF(pmf_array, name="gen_int") + gen_int = DeterministicPMF(name="gen_int", vars=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -38,8 +38,12 @@ def create_test_model(): # numpydoc ignore=GL08 "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 66be96db..1ca04bc5 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -14,10 +14,16 @@ def test_rw_can_be_sampled(): Check that a simple random walk can be initialized and sampled from """ - init_rv_rand = DistributionalRV(dist.Normal(1, 0.5), "init_rv_rand") - init_rv_fixed = DeterministicVariable(50.0, "init_rv_fixed") + init_rv_rand = DistributionalRV( + name="init_rv_rand", + dist=dist.Normal(1, 0.5), + ) + init_rv_fixed = DeterministicVariable(name="init_rv_fixed", vars=50.0) - step_rv = DistributionalRV(dist.Normal(0, 1), "rw_step") + step_rv = DistributionalRV( + name="rw_step", + dist=dist.Normal(0, 1), + ) rw_init_rand = SimpleRandomWalkProcess( "rw_rand_init", step_rv=step_rv, init_rv=init_rv_rand @@ -56,10 +62,12 @@ def test_rw_samples_correctly_distributed(): rw_normal = SimpleRandomWalkProcess( name="rw_normal_test", step_rv=DistributionalRV( - dist=dist.Normal(loc=step_mean, scale=step_sd), name="rw_normal_dist", + dist=dist.Normal(loc=step_mean, scale=step_sd), + ), + init_rv=DeterministicVariable( + name="init_rv_fixed", vars=rw_init_val ), - init_rv=DeterministicVariable(rw_init_val, "init_rv_fixed"), ) with numpyro.handlers.seed(rng_seed=62): diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index 3d4fb11e..4184a763 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -53,11 +53,13 @@ def test_rtweeklydiff() -> None: "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( - jnp.array([0.1, 0.2]), name="log_rt_prior" + name="log_rt_prior", vars=jnp.array([0.1, 0.2]) + ), + "autoreg": DeterministicVariable( + name="autoreg", vars=jnp.array([0.7]) ), - "autoreg": DeterministicVariable(jnp.array([0.7]), name="autoreg"), "periodic_diff_sd": DeterministicVariable( - jnp.array([0.1]), name="periodic_diff_sd" + name="periodic_diff_sd", vars=jnp.array([0.1]) ), } duration = 30 @@ -98,12 +100,15 @@ def test_rtweeklydiff_no_autoregressive() -> None: "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( - jnp.array([0.0, 0.0]), name="log_rt_prior" + name="log_rt_prior", vars=jnp.array([0.0, 0.0]) ), # No autoregression! - "autoreg": DeterministicVariable(jnp.array([0.0]), name="autoreg"), + "autoreg": DeterministicVariable( + name="autoreg", vars=jnp.array([0.0]) + ), "periodic_diff_sd": DeterministicVariable( - jnp.array([0.1]), name="periodic_diff_sd" + name="periodic_diff_sd", + vars=jnp.array([0.1]), ), } @@ -135,11 +140,15 @@ def test_rtweeklydiff_manual_reconstruction() -> None: "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( - jnp.array([0.1, 0.2]), name="log_rt_prior" + name="log_rt_prior", + vars=jnp.array([0.1, 0.2]), + ), + "autoreg": DeterministicVariable( + name="autoreg", vars=jnp.array([0.7]) ), - "autoreg": DeterministicVariable(jnp.array([0.7]), name="autoreg"), "periodic_diff_sd": DeterministicVariable( - jnp.array([0.1]), name="periodic_diff_sd" + name="periodic_diff_sd", + vars=jnp.array([0.1]), ), } @@ -170,11 +179,15 @@ def test_rtperiodicdiff_smallsample(): "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( - jnp.array([0.1, 0.2]), name="log_rt_prior" + name="log_rt_prior", + vars=jnp.array([0.1, 0.2]), + ), + "autoreg": DeterministicVariable( + name="autoreg", vars=jnp.array([0.7]) ), - "autoreg": DeterministicVariable(jnp.array([0.7]), name="autoreg"), "periodic_diff_sd": DeterministicVariable( - jnp.array([0.1]), name="periodic_diff_sd" + name="periodic_diff_sd", + vars=jnp.array([0.1]), ), } diff --git a/model/src/test/test_transformed_rv_class.py b/model/src/test/test_transformed_rv_class.py index cf52b487..134a9180 100644 --- a/model/src/test/test_transformed_rv_class.py +++ b/model/src/test/test_transformed_rv_class.py @@ -63,7 +63,7 @@ def test_transform_rv_validation(): works as expected. """ - base_rv = DistributionalRV(dist.Normal(0, 1), "test_normal") + base_rv = DistributionalRV(name="test_normal", dist=dist.Normal(0, 1)) base_rv.sample_length = lambda: 1 # numpydoc ignore=GL08 l2_rv = LengthTwoRV() @@ -105,7 +105,7 @@ def test_transforms_applied_at_sampling(): instances correctly apply their specified transformations at sampling """ - norm_rv = DistributionalRV(dist.Normal(0, 1), "test_normal") + norm_rv = DistributionalRV(name="test_normal", dist=dist.Normal(0, 1)) norm_rv.sample_length = lambda: 1 l2_rv = LengthTwoRV()