Skip to content

Commit

Permalink
Fix some tests for initial_point
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Sep 21, 2021
1 parent a10245d commit f5aa61a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 19 deletions.
24 changes: 7 additions & 17 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@
from pymc3.util import (
UNSET,
WithMemoization,
get_transformed_name,
get_var_name,
treedict,
treelist,
get_transformed_name,
)
from pymc3.vartypes import continuous_types, discrete_types, typefilter

Expand Down Expand Up @@ -943,7 +943,8 @@ def test_point(self) -> Dict[str, np.ndarray]:
@property
def initial_point(self) -> Dict[str, np.ndarray]:
"""Maps free variable names to transformed, numeric initial values."""
return self.recompute_initial_point()
seed = self.rng_seeder.randint(2 ** 30, dtype=np.int64)
return self.recompute_initial_point(seed)

def recompute_initial_point(self, seed) -> Dict[str, np.ndarray]:
"""Recomputes the initial point of the model.
Expand All @@ -959,9 +960,8 @@ def recompute_initial_point(self, seed) -> Dict[str, np.ndarray]:
def _make_initial_point_expression(
self,
*,
rng: Optional[Union[int, np.random.SeedSequence]] = None,
jitter_rvs: Set[TensorVariable] = None,
default_strategy: str = "moment",
default_strategy: str = "prior",
return_transformed: bool = False,
) -> List[TensorVariable]:
"""Recomputes numeric initial values for all free model variables.
Expand Down Expand Up @@ -999,33 +999,23 @@ def _make_initial_point_expression(
elif strategy == "prior":
value = variable
else:
value = strategy
value = at.as_tensor_variable(strategy)

transform = getattr(self.rvs_to_values[variable].tag, "transform", None)

print("-----")

if transform is not None:
value = transform.forward(variable, value)

print("transform", transform)
print("value", value)

if variable in jitter_rvs:
jitter = at.random.uniform(at.zeros_like(variable) - 1, at.zeros_like(variable) + 1)
jitter.name = "jitter"
value = value + jitter

print("after_jitter", value)

initial_values_transformed.append(value)

if transform is not None:
value = transform.backward(variable, value)

print("final", value)
print()

initial_values.append(value)

# Copy the outputs in a way that we still know what corresponds to the free_RVS
Expand Down Expand Up @@ -1065,9 +1055,9 @@ def make_initial_point_fn(
self,
*,
jitter_rvs: Set[TensorVariable] = None,
default_strategy: str = "moment",
default_strategy: str = "prior",
return_transformed: bool,
) -> Callable[Union[int, np.random.SeedSequence], Dict[str, np.ndarray]]:
) -> Callable[[Union[int, np.random.SeedSequence]], Dict[str, np.ndarray]]:
"""Recomputes numeric initial values for all free model variables.
Parameters
Expand Down
1 change: 0 additions & 1 deletion pymc3/tests/test_initvals.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def test_adds_jitter(self):
assert iv["A"] == 0
# Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default
# so the transformed initial value with jitter will be
print(iv)
b_transformed = iv["B_log__"]
b_untransformed = transform_back(B, b_transformed)
assert b_transformed != 0
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def test_initial_point():
assert a in model.initial_values
assert x in model.initial_values
assert model.initial_values[b] == b_initval
assert model.recompute_initial_point()["b_interval__"] == b_initval_trans
assert model.recompute_initial_point(0)["b_interval__"] == b_initval_trans
assert model.initial_values[y] == y_initval


Expand Down

0 comments on commit f5aa61a

Please sign in to comment.