diff --git a/pymc3/model.py b/pymc3/model.py index d68705148c1..95140d30af4 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -67,10 +67,10 @@ from pymc3.util import ( UNSET, WithMemoization, + get_transformed_name, get_var_name, treedict, treelist, - get_transformed_name, ) from pymc3.vartypes import continuous_types, discrete_types, typefilter @@ -943,7 +943,8 @@ def test_point(self) -> Dict[str, np.ndarray]: @property def initial_point(self) -> Dict[str, np.ndarray]: """Maps free variable names to transformed, numeric initial values.""" - return self.recompute_initial_point() + seed = self.rng_seeder.randint(2 ** 30, dtype=np.int64) + return self.recompute_initial_point(seed) def recompute_initial_point(self, seed) -> Dict[str, np.ndarray]: """Recomputes the initial point of the model. @@ -959,9 +960,8 @@ def recompute_initial_point(self, seed) -> Dict[str, np.ndarray]: def _make_initial_point_expression( self, *, - rng: Optional[Union[int, np.random.SeedSequence]] = None, jitter_rvs: Set[TensorVariable] = None, - default_strategy: str = "moment", + default_strategy: str = "prior", return_transformed: bool = False, ) -> List[TensorVariable]: """Recomputes numeric initial values for all free model variables. @@ -999,33 +999,23 @@ def _make_initial_point_expression( elif strategy == "prior": value = variable else: - value = strategy + value = at.as_tensor_variable(strategy) transform = getattr(self.rvs_to_values[variable].tag, "transform", None) - print("-----") - if transform is not None: value = transform.forward(variable, value) - print("transform", transform) - print("value", value) - if variable in jitter_rvs: jitter = at.random.uniform(at.zeros_like(variable) - 1, at.zeros_like(variable) + 1) jitter.name = "jitter" value = value + jitter - print("after_jitter", value) - initial_values_transformed.append(value) if transform is not None: value = transform.backward(variable, value) - print("final", value) - print() - initial_values.append(value) # Copy the outputs in a way that we still know what corresponds to the free_RVS @@ -1065,9 +1055,9 @@ def make_initial_point_fn( self, *, jitter_rvs: Set[TensorVariable] = None, - default_strategy: str = "moment", + default_strategy: str = "prior", return_transformed: bool, - ) -> Callable[Union[int, np.random.SeedSequence], Dict[str, np.ndarray]]: + ) -> Callable[[Union[int, np.random.SeedSequence]], Dict[str, np.ndarray]]: """Recomputes numeric initial values for all free model variables. Parameters diff --git a/pymc3/tests/test_initvals.py b/pymc3/tests/test_initvals.py index 9ff8edd3291..41c0fe7b36c 100644 --- a/pymc3/tests/test_initvals.py +++ b/pymc3/tests/test_initvals.py @@ -124,7 +124,6 @@ def test_adds_jitter(self): assert iv["A"] == 0 # Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default # so the transformed initial value with jitter will be - print(iv) b_transformed = iv["B_log__"] b_untransformed = transform_back(B, b_transformed) assert b_transformed != 0 diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index 0537ada9537..5e30dccbecd 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -517,7 +517,7 @@ def test_initial_point(): assert a in model.initial_values assert x in model.initial_values assert model.initial_values[b] == b_initval - assert model.recompute_initial_point()["b_interval__"] == b_initval_trans + assert model.recompute_initial_point(0)["b_interval__"] == b_initval_trans assert model.initial_values[y] == y_initval