Skip to content

Commit

Permalink
Introduce Model.initial_values and deprecate testval in favor of initval
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 1, 2021
1 parent 53f6f43 commit aa8a913
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 48 deletions.
23 changes: 17 additions & 6 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,33 @@ def __new__(cls, name, *args, **kwargs):
if "shape" in kwargs:
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")

testval = kwargs.pop("testval", None)

if testval:
warnings.warn(
"The `testval` argument is deprecated; use `initval`.",
DeprecationWarning,
stacklevel=2,
)

initval = kwargs.pop("initval", testval)

transform = kwargs.pop("transform", UNSET)

rv_out = cls.dist(*args, rng=rng, **kwargs)

return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
if testval is not None:
rv_out.tag.test_value = testval

return model.register_rv(
rv_out, name, data, total_size, dims=dims, transform=transform, initval=initval
)

@classmethod
def dist(cls, dist_params, rng=None, **kwargs):

testval = kwargs.pop("testval", None)

rv_var = cls.rv_op(*dist_params, rng=rng, **kwargs)

if testval is not None:
rv_var.tag.test_value = testval

if (
rv_var.owner
and isinstance(rv_var.owner.op, RandomVariable)
Expand Down
68 changes: 36 additions & 32 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from aesara.compile.sharedvalue import SharedVariable
from aesara.gradient import grad
from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph, MissingInputError
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.opt import local_subtensor_rv_lift
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.sharedvar import ScalarSharedVariable
Expand Down Expand Up @@ -572,7 +572,7 @@ def __init__(self, mean=0, sigma=1, name='', model=None):
Normal('v2', mu=mean, sigma=sd)
# something more complex is allowed, too
half_cauchy = HalfCauchy('sd', beta=10, testval=1.)
half_cauchy = HalfCauchy('sd', beta=10, initval=1.)
Normal('v3', mu=mean, sigma=half_cauchy)
# Deterministic variables can be used in usual way
Expand Down Expand Up @@ -649,6 +649,7 @@ def __init__(

# The sequence of model-generated RNGs
self.rng_seq = []
self.initial_values = {}

if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
Expand Down Expand Up @@ -914,35 +915,7 @@ def test_point(self):

@property
def initial_point(self):
points = []
for rv_var in self.free_RVs:
value_var = rv_var.tag.value_var
var_value = getattr(value_var.tag, "test_value", None)

if var_value is None:

rv_var_value = getattr(rv_var.tag, "test_value", None)

if rv_var_value is None:
try:
rv_var_value = rv_var.eval()
except MissingInputError:
raise MissingInputError(f"Couldn't generate an initial value for {rv_var}")

transform = getattr(value_var.tag, "transform", None)

if transform:
try:
rv_var_value = transform.forward(rv_var, rv_var_value).eval()
except MissingInputError:
raise MissingInputError(f"Couldn't generate an initial value for {rv_var}")

var_value = rv_var_value
value_var.tag.test_value = var_value

points.append((value_var, var_value))

return Point(points, model=self)
return Point(list(self.initial_values.items()), model=self)

@property
def disc_vars(self):
Expand All @@ -954,6 +927,32 @@ def cont_vars(self):
"""All the continuous variables in the model"""
return list(typefilter(self.value_vars, continuous_types))

def set_initval(self, rv_var, initval):
initval = initval or getattr(rv_var.tag, "test_value", None)

rv_value_var = self.rvs_to_values[rv_var]
transform = getattr(rv_value_var.tag, "transform", None)

if initval is None or transform:
# Sample/evaluate this using the existing initial values, and
# with the least amount of affect on the RNGs involved (i.e. no
# in-placing)
from aesara.compile.mode import Mode, get_mode

mode = get_mode(None)
opt_qry = mode.provided_optimizer.excluding("random_make_inplace")
mode = Mode(linker=mode.linker, optimizer=opt_qry)

if transform:
rv_var = transform.forward(rv_var, initval or rv_var)

initval_fn = aesara.function(
[], rv_var, mode=mode, givens=self.initial_values, on_unused_input="ignore"
)
initval = initval_fn()

self.initial_values[rv_value_var] = initval

def next_rng(self) -> RandomStateSharedVariable:
"""Generate a new ``RandomStateSharedVariable``.
Expand Down Expand Up @@ -1116,7 +1115,9 @@ def set_data(

shared_object.set_value(values)

def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET):
def register_rv(
self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET, initval=None
):
"""Register an (un)observed random variable with the model.
Parameters
Expand All @@ -1132,6 +1133,8 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans
Dimension names for the variable.
transform
A transform for the random variable in log-likelihood space.
initval
The initial value of the random variable.
Returns
-------
Expand All @@ -1145,6 +1148,7 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans
self.free_RVs.append(rv_var)
self.create_value_var(rv_var, transform)
self.add_random_variable(rv_var, dims)
self.set_initval(rv_var, initval)
else:
if (
isinstance(data, Variable)
Expand Down
25 changes: 24 additions & 1 deletion pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ def check_logp(
n_samples=100,
extra_args=None,
scipy_args=None,
skip_params_fn=lambda x: False,
):
"""
Generic test for PyMC3 logp methods
Expand Down Expand Up @@ -625,6 +626,9 @@ def check_logp(
the pymc3 distribution logp is calculated
scipy_args : Dictionary with extra arguments needed to call scipy logp method
Usually the same as extra_args
skip_params_fn: Callable
A function that takes a ``dict`` of the test points and returns a
boolean indicating whether or not to perform the test.
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)
Expand All @@ -646,6 +650,8 @@ def logp_reference(args):
domains["value"] = domain
for pt in product(domains, n_samples=n_samples):
pt = dict(pt)
if skip_params_fn(pt):
continue
pt_d = self._model_input_dict(model, param_vars, pt)
pt_logp = Point(pt_d, model=model)
pt_ref = Point(pt, filter_model_vars=False, model=model)
Expand Down Expand Up @@ -690,6 +696,7 @@ def check_logcdf(
n_samples=100,
skip_paramdomain_inside_edge_test=False,
skip_paramdomain_outside_edge_test=False,
skip_params_fn=lambda x: False,
):
"""
Generic test for PyMC3 logcdf methods
Expand Down Expand Up @@ -730,6 +737,9 @@ def check_logcdf(
skip_paramdomain_outside_edge_test : Bool
Whether to run test 2., which checks that pymc3 distribution logcdf
returns -inf for invalid parameter values outside the supported domain edge
skip_params_fn: Callable
A function that takes a ``dict`` of the test points and returns a
boolean indicating whether or not to perform the test.
Returns
-------
Expand All @@ -745,6 +755,8 @@ def check_logcdf(

for pt in product(domains, n_samples=n_samples):
params = dict(pt)
if skip_params_fn(params):
continue
scipy_cdf = scipy_logcdf(**params)
value = params.pop("value")
with Model() as m:
Expand Down Expand Up @@ -825,7 +837,13 @@ def check_logcdf(
)

def check_selfconsistency_discrete_logcdf(
self, distribution, domain, paramdomains, decimal=None, n_samples=100
self,
distribution,
domain,
paramdomains,
decimal=None,
n_samples=100,
skip_params_fn=lambda x: False,
):
"""
Check that logcdf of discrete distributions matches sum of logps up to value
Expand All @@ -836,6 +854,8 @@ def check_selfconsistency_discrete_logcdf(
decimal = select_by_precision(float64=6, float32=3)
for pt in product(domains, n_samples=n_samples):
params = dict(pt)
if skip_params_fn(params):
continue
value = params.pop("value")
values = np.arange(domain.lower, value + 1)
dist = distribution.dist(**params)
Expand Down Expand Up @@ -1189,17 +1209,20 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
Nat,
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
modified_scipy_hypergeom_logpmf,
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
)
self.check_logcdf(
HyperGeometric,
Nat,
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
modified_scipy_hypergeom_logcdf,
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
)
self.check_selfconsistency_discrete_logcdf(
HyperGeometric,
Nat,
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
)

def test_negative_binomial(self):
Expand Down
12 changes: 4 additions & 8 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,13 @@ def test_distribution(self):

def _instantiate_pymc_rv(self, dist_params=None):
params = dist_params if dist_params else self.pymc_dist_params
with pm.Model():
self.pymc_rv = self.pymc_dist(
**params,
size=self.size,
rng=aesara.shared(self.get_random_state(reset=True)),
name=f"{self.pymc_dist.rv_op.name}_test",
)
self.pymc_rv = self.pymc_dist.dist(
**params, size=self.size, rng=aesara.shared(self.get_random_state(reset=True))
)

def check_pymc_draws_match_reference(self):
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
self._instantiate_pymc_rv()
# self._instantiate_pymc_rv()
assert_array_almost_equal(
self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
)
Expand Down
20 changes: 19 additions & 1 deletion pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,31 @@ def test_initial_point():

with pm.Model() as model:
a = pm.Uniform("a")
pm.Normal("x", a)
x = pm.Normal("x", a)

with pytest.warns(DeprecationWarning):
initial_point = model.test_point

assert all(var.name in initial_point for var in model.value_vars)

b_initval = np.array(0.3)

with pytest.warns(DeprecationWarning), model:
b = pm.Uniform("b", testval=b_initval)

b_value_var = model.rvs_to_values[b]
b_initval_trans = b_value_var.tag.transform.forward(b, b_initval).eval()

y_initval = np.array(-2.4)

with model:
y = pm.Normal("y", initval=y_initval)

assert model.rvs_to_values[a] in model.initial_values
assert model.rvs_to_values[x] in model.initial_values
assert model.initial_values[b_value_var] == b_initval_trans
assert model.initial_values[model.rvs_to_values[y]] == y_initval


def test_point_logps():

Expand Down

0 comments on commit aa8a913

Please sign in to comment.