Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Order Of The name Argument To Match Numpyro #324

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ To initialize these five components within the renewal modeling framework, we es
# | label: creating-elements
# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
gen_int = DeterministicPMF(name="gen_int", vars=pmf_array)

# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(2.5, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)
Expand All @@ -144,12 +144,13 @@ class MyRt(RandomVariable):
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
dist.Normal(0, sd_rt),
"rw_step_rv",
name="rw_step_rv",
dist=dist.Normal(0, sd_rt),
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
dist.Normal(jnp.log(1), jnp.log(1.2)), "init_log_Rt_rv"
name="init_log_Rt_rv",
dist=dist.Normal(jnp.log(1), jnp.log(1.2)),
),
),
transforms=t.ExpTransform(),
Expand Down
16 changes: 10 additions & 6 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ The following code-chunk defines the model components. Notice that for both the
```{python}
# | label: model-components
gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])
gen_int = DeterministicPMF(gen_int_array, name="gen_int")
feedback_strength = DeterministicVariable(0.05, name="feedback_strength")
gen_int = DeterministicPMF(name="gen_int", vars=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", vars=0.05)

I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(0.5, name="rate"),
DeterministicVariable(name="rate", vars=0.5),
),
t_unit=1,
)
Expand All @@ -64,8 +64,12 @@ rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"),
init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"),
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
Expand Down
22 changes: 12 additions & 10 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,11 @@ import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
inf_hosp_int, name="inf_hosp_int"
name="inf_hosp_int", vars=inf_hosp_int
)

hosp_rate = metaclass.DistributionalRV(
dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)),
name="IHR",
name="IHR", dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)

latent_hosp = latent.HospitalAdmissions(
Expand All @@ -172,17 +171,17 @@ latent_inf = latent.Infections()
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), name="I0"
name="I0", dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75))
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
deterministic.DeterministicVariable(0.05, name="rate"),
deterministic.DeterministicVariable(name="rate", vars=0.05),
),
t_unit=1,
)

# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int")
gen_int = deterministic.DeterministicPMF(name="gen_int", vars=gen_int)


class MyRt(metaclass.RandomVariable):
Expand All @@ -200,10 +199,10 @@ class MyRt(metaclass.RandomVariable):
base_rv=process.SimpleRandomWalkProcess(
name="log_rt",
step_rv=metaclass.DistributionalRV(
dist.Normal(0, sd_rt), "rw_step_rv"
name="rw_step_rv", dist=dist.Normal(0, sd_rt)
),
init_rv=metaclass.DistributionalRV(
dist.Normal(0, 0.2), "init_log_Rt_rv"
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=transformation.ExpTransform(),
Expand All @@ -213,7 +212,9 @@ class MyRt(metaclass.RandomVariable):


rtproc = MyRt(
metaclass.DistributionalRV(dist.HalfNormal(0.025), "Rt_random_walk_sd")
metaclass.DistributionalRV(
name="Rt_random_walk_sd", dist=dist.HalfNormal(0.025)
)
)

# The observation model
Expand All @@ -223,7 +224,8 @@ rtproc = MyRt(
nb_conc_rv = metaclass.TransformedRandomVariable(
"concentration",
metaclass.DistributionalRV(
dist.TruncatedNormal(loc=0, scale=1, low=0.01), "concentration_raw"
name="concentration_raw",
dist=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
transformation.PowerTransform(-2),
)
Expand Down
10 changes: 6 additions & 4 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ rt_proc = process.RtWeeklyDiffProcess(
name="rt_weekly_diff",
offset=0,
log_rt_prior=deterministic.DeterministicVariable(
jnp.array([0.1, 0.2]), name="log_rt_prior"
name="log_rt_prior", vars=jnp.array([0.1, 0.2])
),
autoreg=deterministic.DeterministicVariable(
jnp.array([0.7]), name="autoreg"
name="autoreg", vars=jnp.array([0.7])
),
periodic_diff_sd=deterministic.DeterministicVariable(
jnp.array([0.1]), name="periodic_diff_sd"
name="periodic_diff_sd", vars=jnp.array([0.1])
),
)
```
Expand Down Expand Up @@ -76,7 +76,9 @@ mysimplex = dist.TransformedDistribution(
# Constructing the day of week effect
dayofweek = process.DayOfWeekEffect(
offset=0,
quantity_to_broadcast=metaclass.DistributionalRV(mysimplex, "simp"),
quantity_to_broadcast=metaclass.DistributionalRV(
name="simp", dist=mysimplex
),
t_start=0,
)
```
Expand Down
10 changes: 5 additions & 5 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@ class DeterministicVariable(RandomVariable):

def __init__(
self,
vars: ArrayLike,
name: str,
vars: ArrayLike,
) -> None:
"""Default constructor

Parameters
----------
vars : ArrayLike
A tuple with arraylike objects.
name : str
A name to assign to the process.
vars : ArrayLike
An ArrayLike object.

Returns
-------
None
"""

self.validate(vars)
self.vars = jnp.atleast_1d(vars)
self.name = name
self.vars = jnp.atleast_1d(vars)
self.validate(vars)

return None

Expand Down
10 changes: 5 additions & 5 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class DeterministicPMF(RandomVariable):

def __init__(
self,
vars: ArrayLike,
name: str,
vars: ArrayLike,
tol: float = 1e-5,
) -> None:
"""
Expand All @@ -29,10 +29,10 @@ def __init__(

Parameters
----------
vars : tuple
A tuple with arraylike objects.
name : str
A name to assign to the process.
A name to assign to the variable.
vars : tuple
An ArrayLike object.
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
to 1e-5.
Expand All @@ -46,7 +46,7 @@ def __init__(
tol=tol,
)

self.basevar = DeterministicVariable(vars, name)
self.basevar = DeterministicVariable(name=name, vars=vars)

return None

Expand Down
8 changes: 6 additions & 2 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,13 @@ def __init__(
"""

if day_of_week_effect_rv is None:
day_of_week_effect_rv = DeterministicVariable(1, "weekday_effect")
day_of_week_effect_rv = DeterministicVariable(
name="weekday_effect", vars=1
)
if hosp_report_prob_rv is None:
hosp_report_prob_rv = DeterministicVariable(1, "hosp_report_prob")
hosp_report_prob_rv = DeterministicVariable(
name="hosp_report_prob", vars=1
)

HospitalAdmissions.validate(
infect_hosp_rate_rv,
Expand Down
20 changes: 7 additions & 13 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def set_timeseries(
model time. It could be negative, indicating
that the `sample()` method returns timepoints
that occur prior to the model t = 0.

t_unit : int
The unit of the time series relative
to the model's fundamental (smallest)
Expand Down Expand Up @@ -219,20 +218,19 @@ class DistributionalRV(RandomVariable):

def __init__(
self,
dist: numpyro.distributions.Distribution,
name: str,
dist: numpyro.distributions.Distribution,
reparam: Reparam = None,
) -> None:
"""
Default constructor for DistributionalRV.

Parameters
----------
dist : numpyro.distributions.Distribution
Distribution of the random variable.
name : str
Name of the random variable.

dist : numpyro.distributions.Distribution
Distribution of the random variable.
reparam : numpyro.infer.reparam.Reparam
If not None, reparameterize sampling
from the distribution according to the
Expand All @@ -243,10 +241,9 @@ def __init__(
None
"""

self.name = name
self.validate(dist)

self.dist = dist
self.name = name
if reparam is not None:
self.reparam_dict = {self.name: reparam}
else:
Expand Down Expand Up @@ -603,19 +600,16 @@ def __init__(

Parameters
----------

name : str
A name for the random variable instance

A name for the random variable instance.
base_rv : RandomVariable
The underlying (untransformed) RandomVariable

The underlying (untransformed) RandomVariable.
transforms : Transform
Transformation or tuple of transformations
to apply to the output of
`base_rv.sample()`; single values will be coerced to
a length-one tuple. If a tuple, should be the same
length as the tuple returned by `base_rv.sample()`
length as the tuple returned by `base_rv.sample()`.

Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/firstdifferencear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def sample(
With a single array of shape (duration,).
"""
rates_of_change, *_ = self.rate_of_change_proc.sample(
name=self.name + "_rate_of_change",
duration=duration,
inits=jnp.atleast_1d(init_rate_of_change),
name=self.name + "_rate_of_change",
)
return (init_val + jnp.cumsum(rates_of_change.flatten()),)

Expand Down
10 changes: 6 additions & 4 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ def test_deterministic():
"""

var1 = DeterministicVariable(
jnp.array(
name="var1",
vars=jnp.array(
[
1,
]
),
name="var1",
)
var2 = DeterministicPMF(jnp.array([0.25, 0.25, 0.2, 0.3]), name="var2")
var3 = DeterministicProcess(jnp.array([1, 2, 3, 4]), name="var3")
var2 = DeterministicPMF(
name="var2", vars=jnp.array([0.25, 0.25, 0.2, 0.3])
)
var3 = DeterministicProcess(name="var3", vars=jnp.array([1, 2, 3, 4]))
var4 = NullVariable()
var5 = NullProcess()

Expand Down
16 changes: 10 additions & 6 deletions model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,25 @@
def test_forecast():
"""Check that forecasts are the right length and match the posterior up until forecast begins."""
pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
gen_int = DeterministicPMF(name="gen_int", vars=pmf_array)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsZeroPad(n_timepoints=gen_int.size()),
t_unit=1,
)
latent_infections = Infections()
observed_infections = PoissonObservation("poisson_rv")
observed_infections = PoissonObservation(name="poisson_rv")
rt = TransformedRandomVariable(
"Rt_rv",
name="Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"),
init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"),
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
Expand Down
Loading