Skip to content

Commit

Permalink
Rename Model initial_values to rvs_to_initial_values
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and wrongu committed Dec 1, 2022
1 parent 42d07d2 commit 291c761
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def make_initial_point_fn(

sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
**model.initial_values,
**model.rvs_to_initial_values,
**sdict_overrides,
}

Expand Down
11 changes: 7 additions & 4 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,15 +550,14 @@ def __init__(
self.name = self._validate_name(name)
self.check_bounds = check_bounds

self._initial_values: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]] = {}

if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
self.named_vars_to_dims = treedict(parent=self.parent.named_vars_to_dims)
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
self.free_RVs = treelist(parent=self.parent.free_RVs)
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
Expand All @@ -573,6 +572,7 @@ def __init__(
self.rvs_to_values = treedict()
self.rvs_to_transforms = treedict()
self.rvs_to_total_sizes = treedict()
self.rvs_to_initial_values = treedict()
self.free_RVs = treelist()
self.observed_RVs = treelist()
self.auto_deterministics = treelist()
Expand Down Expand Up @@ -1128,15 +1128,18 @@ def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Vari
Keys are the random variables (as returned by e.g. ``pm.Uniform()``) and
values are the numeric/symbolic initial values, strings denoting the strategy to get them, or None.
"""
return self._initial_values
warnings.warn(
"Model.initial_values is deprecated. Use Model.rvs_to_initial_values instead."
)
return self.rvs_to_initial_values

def set_initval(self, rv_var, initval):
"""Sets an initial value (strategy) for a random variable."""
if initval is not None and not isinstance(initval, (Variable, str)):
# Convert scalars or array-like inputs to ndarrays
initval = rv_var.type.filter(initval)

self.initial_values[rv_var] = initval
self.rvs_to_initial_values[rv_var] = initval

def set_data(
self,
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_dependent_initvals(self):
assert ip["B2_interval__"] == 0

# Modify initval of L and re-evaluate
pmodel.initial_values[U] = 9.9
pmodel.rvs_to_initial_values[U] = 9.9
ip = pmodel.initial_point(random_seed=0)
assert ip["B1_interval__"] < 0
assert ip["B2_interval__"] == 0
Expand All @@ -108,7 +108,7 @@ def test_nested_initvals(self):
ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values())
assert np.allclose(ip_vals, [1, 2, 4, 8, 16, 32], rtol=1e-3)

pmodel.initial_values[four] = 1
pmodel.rvs_to_initial_values[four] = 1

ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values())
assert np.allclose(np.exp(ip_vals), [1, 2, 4, 1, 2, 4], rtol=1e-3)
Expand Down
8 changes: 4 additions & 4 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,11 @@ def test_initial_point():
with model:
y = pm.Normal("y", initval=y_initval)

assert a in model.initial_values
assert x in model.initial_values
assert model.initial_values[b] == b_initval
assert a in model.rvs_to_initial_values
assert x in model.rvs_to_initial_values
assert model.rvs_to_initial_values[b] == b_initval
assert model.initial_point(0)["b_interval__"] == b_initval_trans
assert model.initial_values[y] == y_initval
assert model.rvs_to_initial_values[y] == y_initval


def test_point_logps():
Expand Down

0 comments on commit 291c761

Please sign in to comment.