Skip to content

Commit

Permalink
Change Order Of The name Argument To Match Numpyro (#324)
Browse files Browse the repository at this point in the history
* fix name order

* fix test and remaining renamings

* Update model/src/pyrenew/deterministic/deterministic.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Update model/src/pyrenew/deterministic/deterministicpmf.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Update model/src/pyrenew/deterministic/deterministicpmf.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* Update model/src/pyrenew/deterministic/deterministicpmf.py

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

---------

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>
  • Loading branch information
AFg6K7h4fhy2 and dylanhmorris authored Jul 25, 2024
1 parent e703749 commit 3d77967
Show file tree
Hide file tree
Showing 25 changed files with 233 additions and 162 deletions.
11 changes: 6 additions & 5 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ To initialize these five components within the renewal modeling framework, we es
# | label: creating-elements
# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
gen_int = DeterministicPMF(name="gen_int", vars=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(2.5, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)
Expand All @@ -147,12 +147,13 @@ class MyRt(RandomVariable):
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
dist.Normal(0, sd_rt),
"rw_step_rv",
name="rw_step_rv",
dist=dist.Normal(0, sd_rt),
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
dist.Normal(jnp.log(1), jnp.log(1.2)), "init_log_Rt_rv"
name="init_log_Rt_rv",
dist=dist.Normal(jnp.log(1), jnp.log(1.2)),
),
),
transforms=t.ExpTransform(),
Expand Down
16 changes: 10 additions & 6 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ The following code-chunk defines the model components. Notice that for both the
```{python}
# | label: model-components
gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])
gen_int = DeterministicPMF(gen_int_array, name="gen_int")
feedback_strength = DeterministicVariable(0.05, name="feedback_strength")
gen_int = DeterministicPMF(name="gen_int", vars=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", vars=0.05)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(0.5, name="rate"),
DeterministicVariable(name="rate", vars=0.5),
),
t_unit=1,
)
Expand All @@ -64,8 +64,12 @@ rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"),
init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"),
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
Expand Down
22 changes: 12 additions & 10 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,11 @@ import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
inf_hosp_int, name="inf_hosp_int"
name="inf_hosp_int", vars=inf_hosp_int
)
hosp_rate = metaclass.DistributionalRV(
dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)),
name="IHR",
name="IHR", dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)
latent_hosp = latent.HospitalAdmissions(
Expand All @@ -172,17 +171,17 @@ latent_inf = latent.Infections()
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), name="I0"
name="I0", dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75))
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
deterministic.DeterministicVariable(0.05, name="rate"),
deterministic.DeterministicVariable(name="rate", vars=0.05),
),
t_unit=1,
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int")
gen_int = deterministic.DeterministicPMF(name="gen_int", vars=gen_int)
class MyRt(metaclass.RandomVariable):
Expand All @@ -200,10 +199,10 @@ class MyRt(metaclass.RandomVariable):
base_rv=process.SimpleRandomWalkProcess(
name="log_rt",
step_rv=metaclass.DistributionalRV(
dist.Normal(0, sd_rt), "rw_step_rv"
name="rw_step_rv", dist=dist.Normal(0, sd_rt)
),
init_rv=metaclass.DistributionalRV(
dist.Normal(0, 0.2), "init_log_Rt_rv"
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=transformation.ExpTransform(),
Expand All @@ -213,7 +212,9 @@ class MyRt(metaclass.RandomVariable):
rtproc = MyRt(
metaclass.DistributionalRV(dist.HalfNormal(0.025), "Rt_random_walk_sd")
metaclass.DistributionalRV(
name="Rt_random_walk_sd", dist=dist.HalfNormal(0.025)
)
)
# The observation model
Expand All @@ -223,7 +224,8 @@ rtproc = MyRt(
nb_conc_rv = metaclass.TransformedRandomVariable(
"concentration",
metaclass.DistributionalRV(
dist.TruncatedNormal(loc=0, scale=1, low=0.01), "concentration_raw"
name="concentration_raw",
dist=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
transformation.PowerTransform(-2),
)
Expand Down
10 changes: 6 additions & 4 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ rt_proc = process.RtWeeklyDiffProcess(
name="rt_weekly_diff",
offset=0,
log_rt_prior=deterministic.DeterministicVariable(
jnp.array([0.1, 0.2]), name="log_rt_prior"
name="log_rt_prior", vars=jnp.array([0.1, 0.2])
),
autoreg=deterministic.DeterministicVariable(
jnp.array([0.7]), name="autoreg"
name="autoreg", vars=jnp.array([0.7])
),
periodic_diff_sd=deterministic.DeterministicVariable(
jnp.array([0.1]), name="periodic_diff_sd"
name="periodic_diff_sd", vars=jnp.array([0.1])
),
)
```
Expand Down Expand Up @@ -76,7 +76,9 @@ mysimplex = dist.TransformedDistribution(
# Constructing the day of week effect
dayofweek = process.DayOfWeekEffect(
offset=0,
quantity_to_broadcast=metaclass.DistributionalRV(mysimplex, "simp"),
quantity_to_broadcast=metaclass.DistributionalRV(
name="simp", dist=mysimplex
),
t_start=0,
)
```
Expand Down
10 changes: 5 additions & 5 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@ class DeterministicVariable(RandomVariable):

def __init__(
self,
vars: ArrayLike,
name: str,
vars: ArrayLike,
) -> None:
"""Default constructor
Parameters
----------
vars : ArrayLike
A tuple with arraylike objects.
name : str
A name to assign to the process.
vars : ArrayLike
An ArrayLike object.
Returns
-------
None
"""

self.validate(vars)
self.vars = jnp.atleast_1d(vars)
self.name = name
self.vars = jnp.atleast_1d(vars)
self.validate(vars)

return None

Expand Down
10 changes: 5 additions & 5 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class DeterministicPMF(RandomVariable):

def __init__(
self,
vars: ArrayLike,
name: str,
vars: ArrayLike,
tol: float = 1e-5,
) -> None:
"""
Expand All @@ -29,10 +29,10 @@ def __init__(
Parameters
----------
vars : tuple
A tuple with arraylike objects.
name : str
A name to assign to the process.
A name to assign to the variable.
vars : tuple
An ArrayLike object.
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
to 1e-5.
Expand All @@ -46,7 +46,7 @@ def __init__(
tol=tol,
)

self.basevar = DeterministicVariable(vars, name)
self.basevar = DeterministicVariable(name=name, vars=vars)

return None

Expand Down
8 changes: 6 additions & 2 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,13 @@ def __init__(
"""

if day_of_week_effect_rv is None:
day_of_week_effect_rv = DeterministicVariable(1, "weekday_effect")
day_of_week_effect_rv = DeterministicVariable(
name="weekday_effect", vars=1
)
if hosp_report_prob_rv is None:
hosp_report_prob_rv = DeterministicVariable(1, "hosp_report_prob")
hosp_report_prob_rv = DeterministicVariable(
name="hosp_report_prob", vars=1
)

HospitalAdmissions.validate(
infect_hosp_rate_rv,
Expand Down
20 changes: 7 additions & 13 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def set_timeseries(
model time. It could be negative, indicating
that the `sample()` method returns timepoints
that occur prior to the model t = 0.
t_unit : int
The unit of the time series relative
to the model's fundamental (smallest)
Expand Down Expand Up @@ -219,20 +218,19 @@ class DistributionalRV(RandomVariable):

def __init__(
self,
dist: numpyro.distributions.Distribution,
name: str,
dist: numpyro.distributions.Distribution,
reparam: Reparam = None,
) -> None:
"""
Default constructor for DistributionalRV.
Parameters
----------
dist : numpyro.distributions.Distribution
Distribution of the random variable.
name : str
Name of the random variable.
dist : numpyro.distributions.Distribution
Distribution of the random variable.
reparam : numpyro.infer.reparam.Reparam
If not None, reparameterize sampling
from the distribution according to the
Expand All @@ -243,10 +241,9 @@ def __init__(
None
"""

self.name = name
self.validate(dist)

self.dist = dist
self.name = name
if reparam is not None:
self.reparam_dict = {self.name: reparam}
else:
Expand Down Expand Up @@ -603,19 +600,16 @@ def __init__(
Parameters
----------
name : str
A name for the random variable instance
A name for the random variable instance.
base_rv : RandomVariable
The underlying (untransformed) RandomVariable
The underlying (untransformed) RandomVariable.
transforms : Transform
Transformation or tuple of transformations
to apply to the output of
`base_rv.sample()`; single values will be coerced to
a length-one tuple. If a tuple, should be the same
length as the tuple returned by `base_rv.sample()`
length as the tuple returned by `base_rv.sample()`.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/firstdifferencear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def sample(
With a single array of shape (duration,).
"""
rates_of_change, *_ = self.rate_of_change_proc.sample(
name=self.name + "_rate_of_change",
duration=duration,
inits=jnp.atleast_1d(init_rate_of_change),
name=self.name + "_rate_of_change",
)
return (init_val + jnp.cumsum(rates_of_change.flatten()),)

Expand Down
10 changes: 6 additions & 4 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ def test_deterministic():
"""

var1 = DeterministicVariable(
jnp.array(
name="var1",
vars=jnp.array(
[
1,
]
),
name="var1",
)
var2 = DeterministicPMF(jnp.array([0.25, 0.25, 0.2, 0.3]), name="var2")
var3 = DeterministicProcess(jnp.array([1, 2, 3, 4]), name="var3")
var2 = DeterministicPMF(
name="var2", vars=jnp.array([0.25, 0.25, 0.2, 0.3])
)
var3 = DeterministicProcess(name="var3", vars=jnp.array([1, 2, 3, 4]))
var4 = NullVariable()
var5 = NullProcess()

Expand Down
16 changes: 10 additions & 6 deletions model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,25 @@
def test_forecast():
"""Check that forecasts are the right length and match the posterior up until forecast begins."""
pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
gen_int = DeterministicPMF(name="gen_int", vars=pmf_array)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsZeroPad(n_timepoints=gen_int.size()),
t_unit=1,
)
latent_infections = Infections()
observed_infections = PoissonObservation("poisson_rv")
observed_infections = PoissonObservation(name="poisson_rv")
rt = TransformedRandomVariable(
"Rt_rv",
name="Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"),
init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"),
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
Expand Down
Loading

0 comments on commit 3d77967

Please sign in to comment.