diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index cf68949121f..8089cee2705 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -332,10 +332,12 @@ class Flat(Continuous): rv_op = flat @classmethod - def dist(cls, *, size=None, testval=None, **kwargs): - if testval is None: - testval = np.full(size, floatX(0.0)) - return super().dist([], size=size, testval=testval, **kwargs) + def dist(cls, *, size=None, initval=None, **kwargs): + if initval is None: + initval = np.full(size, floatX(0.0)) + res = super().dist([], size=size, **kwargs) + res.tag.test_value = initval + return res def logp(value): """ @@ -394,10 +396,12 @@ class HalfFlat(PositiveContinuous): rv_op = halfflat @classmethod - def dist(cls, *, size=None, testval=None, **kwargs): - if testval is None: - testval = np.full(size, floatX(1.0)) - return super().dist([], size=size, testval=testval, **kwargs) + def dist(cls, *, size=None, initval=None, **kwargs): + if initval is None: + initval = np.full(size, floatX(1.0)) + res = super().dist([], size=size, **kwargs) + res.tag.test_value = initval + return res def logp(value): """ diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 24b667df1d1..781fe19699f 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -143,22 +143,33 @@ def __new__(cls, name, *args, **kwargs): if "shape" in kwargs: raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.") + testval = kwargs.pop("testval", None) + + if testval: + warnings.warn( + "The `testval` argument is deprecated; use `initval`.", + DeprecationWarning, + stacklevel=2, + ) + + initval = kwargs.pop("initval", testval) + transform = kwargs.pop("transform", UNSET) rv_out = cls.dist(*args, rng=rng, **kwargs) - return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform) + if testval is not None: + rv_out.tag.test_value = testval + + return model.register_rv( + rv_out, name, data, total_size, dims=dims, transform=transform, initval=initval + ) @classmethod def dist(cls, dist_params, rng=None, **kwargs): - testval = kwargs.pop("testval", None) - rv_var = cls.rv_op(*dist_params, rng=rng, **kwargs) - if testval is not None: - rv_var.tag.test_value = testval - if ( rv_var.owner and isinstance(rv_var.owner.op, RandomVariable) diff --git a/pymc3/model.py b/pymc3/model.py index 36a46d5a610..0da38c7537f 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -40,7 +40,7 @@ from aesara.compile.sharedvalue import SharedVariable from aesara.gradient import grad from aesara.graph.basic import Constant, Variable, graph_inputs -from aesara.graph.fg import FunctionGraph, MissingInputError +from aesara.graph.fg import FunctionGraph from aesara.tensor.random.opt import local_subtensor_rv_lift from aesara.tensor.random.var import RandomStateSharedVariable from aesara.tensor.sharedvar import ScalarSharedVariable @@ -572,7 +572,7 @@ def __init__(self, mean=0, sigma=1, name='', model=None): Normal('v2', mu=mean, sigma=sd) # something more complex is allowed, too - half_cauchy = HalfCauchy('sd', beta=10, testval=1.) + half_cauchy = HalfCauchy('sd', beta=10, initval=1.) Normal('v3', mu=mean, sigma=half_cauchy) # Deterministic variables can be used in usual way @@ -649,6 +649,7 @@ def __init__( # The sequence of model-generated RNGs self.rng_seq = [] + self.initial_values = {} if self.parent is not None: self.named_vars = treedict(parent=self.parent.named_vars) @@ -914,35 +915,7 @@ def test_point(self): @property def initial_point(self): - points = [] - for rv_var in self.free_RVs: - value_var = rv_var.tag.value_var - var_value = getattr(value_var.tag, "test_value", None) - - if var_value is None: - - rv_var_value = getattr(rv_var.tag, "test_value", None) - - if rv_var_value is None: - try: - rv_var_value = rv_var.eval() - except MissingInputError: - raise MissingInputError(f"Couldn't generate an initial value for {rv_var}") - - transform = getattr(value_var.tag, "transform", None) - - if transform: - try: - rv_var_value = transform.forward(rv_var, rv_var_value).eval() - except MissingInputError: - raise MissingInputError(f"Couldn't generate an initial value for {rv_var}") - - var_value = rv_var_value - value_var.tag.test_value = var_value - - points.append((value_var, var_value)) - - return Point(points, model=self) + return Point(list(self.initial_values.items()), model=self) @property def disc_vars(self): @@ -954,6 +927,33 @@ def cont_vars(self): """All the continuous variables in the model""" return list(typefilter(self.value_vars, continuous_types)) + def set_initval(self, rv_var, initval): + initval = initval or getattr(rv_var.tag, "test_value", None) + + rv_value_var = self.rvs_to_values[rv_var] + transform = getattr(rv_value_var.tag, "transform", None) + + if initval is None or transform: + # Sample/evaluate this using the existing initial values, and + # with the least amount of affect on the RNGs involved (i.e. no + # in-placing) + from aesara.compile.mode import Mode, get_mode + + mode = get_mode(None) + opt_qry = mode.provided_optimizer.excluding("random_make_inplace") + mode = Mode(linker=mode.linker, optimizer=opt_qry) + + if transform: + value = initval if initval is not None else rv_var + rv_var = transform.forward(rv_var, value) + + initval_fn = aesara.function( + [], rv_var, mode=mode, givens=self.initial_values, on_unused_input="ignore" + ) + initval = initval_fn() + + self.initial_values[rv_value_var] = initval + def next_rng(self) -> RandomStateSharedVariable: """Generate a new ``RandomStateSharedVariable``. @@ -1116,7 +1116,9 @@ def set_data( shared_object.set_value(values) - def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET): + def register_rv( + self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET, initval=None + ): """Register an (un)observed random variable with the model. Parameters @@ -1132,6 +1134,8 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans Dimension names for the variable. transform A transform for the random variable in log-likelihood space. + initval + The initial value of the random variable. Returns ------- @@ -1145,6 +1149,7 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans self.free_RVs.append(rv_var) self.create_value_var(rv_var, transform) self.add_random_variable(rv_var, dims) + self.set_initval(rv_var, initval) else: if ( isinstance(data, Variable) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 832e61381e1..12f08eab323 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -594,6 +594,7 @@ def check_logp( n_samples=100, extra_args=None, scipy_args=None, + skip_params_fn=lambda x: False, ): """ Generic test for PyMC3 logp methods @@ -625,6 +626,9 @@ def check_logp( the pymc3 distribution logp is calculated scipy_args : Dictionary with extra arguments needed to call scipy logp method Usually the same as extra_args + skip_params_fn: Callable + A function that takes a ``dict`` of the test points and returns a + boolean indicating whether or not to perform the test. """ if decimal is None: decimal = select_by_precision(float64=6, float32=3) @@ -646,6 +650,8 @@ def logp_reference(args): domains["value"] = domain for pt in product(domains, n_samples=n_samples): pt = dict(pt) + if skip_params_fn(pt): + continue pt_d = self._model_input_dict(model, param_vars, pt) pt_logp = Point(pt_d, model=model) pt_ref = Point(pt, filter_model_vars=False, model=model) @@ -690,6 +696,7 @@ def check_logcdf( n_samples=100, skip_paramdomain_inside_edge_test=False, skip_paramdomain_outside_edge_test=False, + skip_params_fn=lambda x: False, ): """ Generic test for PyMC3 logcdf methods @@ -730,6 +737,9 @@ def check_logcdf( skip_paramdomain_outside_edge_test : Bool Whether to run test 2., which checks that pymc3 distribution logcdf returns -inf for invalid parameter values outside the supported domain edge + skip_params_fn: Callable + A function that takes a ``dict`` of the test points and returns a + boolean indicating whether or not to perform the test. Returns ------- @@ -745,6 +755,8 @@ def check_logcdf( for pt in product(domains, n_samples=n_samples): params = dict(pt) + if skip_params_fn(params): + continue scipy_cdf = scipy_logcdf(**params) value = params.pop("value") with Model() as m: @@ -825,7 +837,13 @@ def check_logcdf( ) def check_selfconsistency_discrete_logcdf( - self, distribution, domain, paramdomains, decimal=None, n_samples=100 + self, + distribution, + domain, + paramdomains, + decimal=None, + n_samples=100, + skip_params_fn=lambda x: False, ): """ Check that logcdf of discrete distributions matches sum of logps up to value @@ -836,6 +854,8 @@ def check_selfconsistency_discrete_logcdf( decimal = select_by_precision(float64=6, float32=3) for pt in product(domains, n_samples=n_samples): params = dict(pt) + if skip_params_fn(params): + continue value = params.pop("value") values = np.arange(domain.lower, value + 1) dist = distribution.dist(**params) @@ -1187,17 +1207,20 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n): Nat, {"N": NatSmall, "k": NatSmall, "n": NatSmall}, modified_scipy_hypergeom_logpmf, + skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"], ) self.check_logcdf( HyperGeometric, Nat, {"N": NatSmall, "k": NatSmall, "n": NatSmall}, modified_scipy_hypergeom_logcdf, + skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"], ) self.check_selfconsistency_discrete_logcdf( HyperGeometric, Nat, {"N": NatSmall, "k": NatSmall, "n": NatSmall}, + skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"], ) def test_negative_binomial(self): diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 88e56aa4806..18a864cb11c 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -327,17 +327,13 @@ def test_distribution(self): def _instantiate_pymc_rv(self, dist_params=None): params = dist_params if dist_params else self.pymc_dist_params - with pm.Model(): - self.pymc_rv = self.pymc_dist( - **params, - size=self.size, - rng=aesara.shared(self.get_random_state(reset=True)), - name=f"{self.pymc_dist.rv_op.name}_test", - ) + self.pymc_rv = self.pymc_dist.dist( + **params, size=self.size, rng=aesara.shared(self.get_random_state(reset=True)) + ) def check_pymc_draws_match_reference(self): # need to re-instantiate it to make sure that the order of drawings match the reference distribution one - self._instantiate_pymc_rv() + # self._instantiate_pymc_rv() assert_array_almost_equal( self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal ) diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index d479c98f320..e13c9cf4ac9 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -498,13 +498,31 @@ def test_initial_point(): with pm.Model() as model: a = pm.Uniform("a") - pm.Normal("x", a) + x = pm.Normal("x", a) with pytest.warns(DeprecationWarning): initial_point = model.test_point assert all(var.name in initial_point for var in model.value_vars) + b_initval = np.array(0.3) + + with pytest.warns(DeprecationWarning), model: + b = pm.Uniform("b", testval=b_initval) + + b_value_var = model.rvs_to_values[b] + b_initval_trans = b_value_var.tag.transform.forward(b, b_initval).eval() + + y_initval = np.array(-2.4) + + with model: + y = pm.Normal("y", initval=y_initval) + + assert model.rvs_to_values[a] in model.initial_values + assert model.rvs_to_values[x] in model.initial_values + assert model.initial_values[b_value_var] == b_initval_trans + assert model.initial_values[model.rvs_to_values[y]] == y_initval + def test_point_logps():