Skip to content

Commit

Permalink
Rename Deterministic vars Attribute To value (#331)
Browse files Browse the repository at this point in the history
fix vars to value
  • Loading branch information
AFg6K7h4fhy2 authored Jul 25, 2024
1 parent b11f270 commit 7fd138d
Show file tree
Hide file tree
Showing 21 changed files with 91 additions and 91 deletions.
2 changes: 1 addition & 1 deletion docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ To initialize these five components within the renewal modeling framework, we es
# | label: creating-elements
# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(name="gen_int", vars=pmf_array)
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ The following code-chunk defines the model components. Notice that for both the
```{python}
# | label: model-components
gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])
gen_int = DeterministicPMF(name="gen_int", vars=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", vars=0.05)
gen_int = DeterministicPMF(name="gen_int", value=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", value=0.05)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(name="rate", vars=0.5),
DeterministicVariable(name="rate", value=0.5),
),
t_unit=1,
)
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", vars=inf_hosp_int
name="inf_hosp_int", value=inf_hosp_int
)
hosp_rate = metaclass.DistributionalRV(
Expand Down Expand Up @@ -175,13 +175,13 @@ I0 = InfectionInitializationProcess(
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
deterministic.DeterministicVariable(name="rate", vars=0.05),
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(name="gen_int", vars=gen_int)
gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)
class MyRt(metaclass.RandomVariable):
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ rt_proc = process.RtWeeklyDiffProcess(
name="rt_weekly_diff",
offset=0,
log_rt_prior=deterministic.DeterministicVariable(
name="log_rt_prior", vars=jnp.array([0.1, 0.2])
name="log_rt_prior", value=jnp.array([0.1, 0.2])
),
autoreg=deterministic.DeterministicVariable(
name="autoreg", vars=jnp.array([0.7])
name="autoreg", value=jnp.array([0.7])
),
periodic_diff_sd=deterministic.DeterministicVariable(
name="periodic_diff_sd", vars=jnp.array([0.1])
name="periodic_diff_sd", value=jnp.array([0.1])
),
)
```
Expand Down
22 changes: 11 additions & 11 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ class DeterministicVariable(RandomVariable):
def __init__(
self,
name: str,
vars: ArrayLike,
value: ArrayLike,
) -> None:
"""Default constructor
Parameters
----------
name : str
A name to assign to the process.
vars : ArrayLike
value : ArrayLike
An ArrayLike object.
Returns
Expand All @@ -35,19 +35,19 @@ def __init__(
"""

self.name = name
self.vars = jnp.atleast_1d(vars)
self.validate(vars)
self.value = jnp.atleast_1d(value)
self.validate(value)

return None

@staticmethod
def validate(vars: ArrayLike) -> None:
def validate(value: ArrayLike) -> None:
"""
Validates input to DeterministicPMF
Parameters
----------
vars : ArrayLike
value : ArrayLike
An ArrayLike object.
Returns
Expand All @@ -57,10 +57,10 @@ def validate(vars: ArrayLike) -> None:
Raises
------
Exception
If the input vars object is not a ArrayLike.
If the input value object is not a ArrayLike.
"""
if not isinstance(vars, ArrayLike):
raise Exception("vars is not a ArrayLike")
if not isinstance(value, ArrayLike):
raise Exception("value is not a ArrayLike")

return None

Expand All @@ -86,5 +86,5 @@ def sample(
Containing the stored values during construction.
"""
if record:
numpyro.deterministic(self.name, self.vars)
return (self.vars,)
numpyro.deterministic(self.name, self.value)
return (self.value,)
18 changes: 9 additions & 9 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class DeterministicPMF(RandomVariable):
def __init__(
self,
name: str,
vars: ArrayLike,
value: ArrayLike,
tol: float = 1e-5,
) -> None:
"""
Default constructor
Automatically checks that the elements in `vars` can be indeed
Automatically checks that the elements in `value` can be indeed
considered to be a PMF by calling
pyrenew.distutil.validate_discrete_dist_vector on each one of its
entries.
Expand All @@ -31,7 +31,7 @@ def __init__(
----------
name : str
A name to assign to the variable.
vars : tuple
value : tuple
An ArrayLike object.
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
Expand All @@ -41,23 +41,23 @@ def __init__(
-------
None
"""
vars = validate_discrete_dist_vector(
discrete_dist=vars,
value = validate_discrete_dist_vector(
discrete_dist=value,
tol=tol,
)

self.basevar = DeterministicVariable(name=name, vars=vars)
self.basevar = DeterministicVariable(name=name, value=value)

return None

@staticmethod
def validate(vars: ArrayLike) -> None:
def validate(value: ArrayLike) -> None:
"""
Validates input to DeterministicPMF
Parameters
----------
vars : ArrayLike
value : ArrayLike
An ArrayLike object.
Returns
Expand Down Expand Up @@ -97,4 +97,4 @@ def size(self) -> int:
The size of the PMF
"""

return self.basevar.vars.size
return self.basevar.value.size
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def __init__(

if day_of_week_effect_rv is None:
day_of_week_effect_rv = DeterministicVariable(
name="weekday_effect", vars=1
name="weekday_effect", value=1
)
if hosp_report_prob_rv is None:
hosp_report_prob_rv = DeterministicVariable(
name="hosp_report_prob", vars=1
name="hosp_report_prob", value=1
)

HospitalAdmissions.validate(
Expand Down
6 changes: 3 additions & 3 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ def test_deterministic():

var1 = DeterministicVariable(
name="var1",
vars=jnp.array(
value=jnp.array(
[
1,
]
),
)
var2 = DeterministicPMF(
name="var2", vars=jnp.array([0.25, 0.25, 0.2, 0.3])
name="var2", value=jnp.array([0.25, 0.25, 0.2, 0.3])
)
var3 = DeterministicProcess(name="var3", vars=jnp.array([1, 2, 3, 4]))
var3 = DeterministicProcess(name="var3", value=jnp.array([1, 2, 3, 4]))
var4 = NullVariable()
var5 = NullProcess()

Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def test_forecast():
"""Check that forecasts are the right length and match the posterior up until forecast begins."""
pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(name="gen_int", vars=pmf_array)
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
Expand Down
12 changes: 6 additions & 6 deletions model/src/test/test_infection_seeding_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
def test_initialize_infections_exponential():
"""Check that the InitializeInfectionsExponentialGrowth class generates the correct number of infections at each time point."""
n_timepoints = 10
rate_RV = DeterministicVariable(name="rate_RV", vars=0.5)
I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", vars=10.0)
rate_RV = DeterministicVariable(name="rate_RV", value=0.5)
I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", value=10.0)
default_t_pre_init = n_timepoints - 1

(I_pre_init,) = I_pre_init_RV()
Expand All @@ -36,7 +36,7 @@ def test_initialize_infections_exponential():

# test for failure with non-scalar rate or I_pre_init
rate_RV_2 = DeterministicVariable(
name="rate_RV", vars=np.array([0.5, 0.5])
name="rate_RV", value=np.array([0.5, 0.5])
)
with pytest.raises(ValueError):
InitializeInfectionsExponentialGrowth(
Expand All @@ -45,7 +45,7 @@ def test_initialize_infections_exponential():

I_pre_init_RV_2 = DeterministicVariable(
name="I_pre_init_RV",
vars=np.array([10.0, 10.0]),
value=np.array([10.0, 10.0]),
)
(I_pre_init_2,) = I_pre_init_RV_2()

Expand Down Expand Up @@ -75,7 +75,7 @@ def test_initialize_infections_zero_pad():
"""Check that the InitializeInfectionsZeroPad class generates the correct number of infections at each time point."""

n_timepoints = 10
I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", vars=10.0)
I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", value=10.0)
(I_pre_init,) = I_pre_init_RV()

infections = InitializeInfectionsZeroPad(
Expand All @@ -86,7 +86,7 @@ def test_initialize_infections_zero_pad():
)

I_pre_init_RV_2 = DeterministicVariable(
name="I_pre_init_RV", vars=np.array([10.0, 10.0])
name="I_pre_init_RV", value=np.array([10.0, 10.0])
)
(I_pre_init_2,) = I_pre_init_RV_2()

Expand Down
6 changes: 3 additions & 3 deletions model/src/test/test_infection_seeding_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def test_infection_initialization_process():
"exp_model",
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
n_timepoints, DeterministicVariable(name="rate", vars=0.5)
n_timepoints, DeterministicVariable(name="rate", value=0.5)
),
t_unit=1,
)

vec_model = InfectionInitializationProcess(
"vec_model",
DeterministicVariable(name="I0", vars=jnp.arange(n_timepoints)),
DeterministicVariable(name="I0", value=jnp.arange(n_timepoints)),
InitializeInfectionsFromVec(n_timepoints),
t_unit=1,
)
Expand All @@ -56,7 +56,7 @@ def test_infection_initialization_process():
with pytest.raises(TypeError):
InfectionInitializationProcess(
"vec_model",
DeterministicVariable(name="I0", vars=jnp.arange(n_timepoints)),
DeterministicVariable(name="I0", value=jnp.arange(n_timepoints)),
3,
t_unit=1,
)
8 changes: 4 additions & 4 deletions model/src/test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def test_infectionsrtfeedback():
# By doing the infection feedback strength 0, Rt = Rt_adjusted
# So infection should be equal in both
inf_feed_strength = DeterministicVariable(
name="inf_feed_strength", vars=jnp.zeros_like(Rt)
name="inf_feed_strength", value=jnp.zeros_like(Rt)
)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", vars=gen_int)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int)

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
Expand Down Expand Up @@ -113,9 +113,9 @@ def test_infectionsrtfeedback_feedback():
gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0])

inf_feed_strength = DeterministicVariable(
name="inf_feed_strength", vars=jnp.repeat(0.5, len(Rt))
name="inf_feed_strength", value=jnp.repeat(0.5, len(Rt))
)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", vars=gen_int)
inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int)

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_admissions_sample():
# Testing the hospital admissions
inf_hosp = DeterministicPMF(
name="inf_hosp",
vars=jnp.array(
value=jnp.array(
[
0,
0,
Expand Down
10 changes: 5 additions & 5 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_model_basicrenewal_no_timepoints_or_observations():
"""

gen_int = DeterministicPMF(
name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25])
name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25])
)

I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1))
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_model_basicrenewal_both_timepoints_and_observations():

gen_int = DeterministicPMF(
name="gen_int",
vars=jnp.array([0.25, 0.25, 0.25, 0.25]),
value=jnp.array([0.25, 0.25, 0.25, 0.25]),
)

I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1))
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_model_basicrenewal_no_obs_model():

gen_int = DeterministicPMF(
name="gen_int",
vars=jnp.array([0.25, 0.25, 0.25, 0.25]),
value=jnp.array([0.25, 0.25, 0.25, 0.25]),
)

with pytest.raises(ValueError):
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_model_basicrenewal_with_obs_model():
"""

gen_int = DeterministicPMF(
name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25])
name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25])
)

I0 = InfectionInitializationProcess(
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_model_basicrenewal_with_obs_model():

def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08
gen_int = DeterministicPMF(
name="gen_int", vars=jnp.array([0.25, 0.25, 0.25, 0.25])
name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25])
)

I0 = InfectionInitializationProcess(
Expand Down
Loading

0 comments on commit 7fd138d

Please sign in to comment.