Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an incrementally updated RNG state in Model #4729

Merged
merged 4 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- ⚠ Theano-PyMC has been replaced with Aesara, so all external references to `theano`, `tt`, and `pymc3.theanof` need to be replaced with `aesara`, `at`, and `pymc3.aesaraf` (see [4471](https://github.com/pymc-devs/pymc3/pull/4471)).
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc3/pull/4471) and `3.11.2` release notes).
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
- ...

### New Features
Expand Down
14 changes: 14 additions & 0 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import scipy.sparse as sps

from aesara import config, scalar
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
from aesara.graph.basic import (
Apply,
Expand Down Expand Up @@ -861,3 +862,16 @@ def take_along_axis(arr, indices, axis=0):

# use the fancy index
return arr[_make_along_axis_idx(arr_shape, indices, _axis)]


def compile_rv_inplace(inputs, outputs, mode=None, **kwargs):
"""Use ``aesara.function`` with the random_make_inplace optimization always enabled.

Using this function ensures that compiled functions containing random
variables will produce new samples on each call.
"""
mode = get_mode(mode)
opt_qry = mode.provided_optimizer.including("random_make_inplace")
mode = Mode(linker=mode.linker, optimizer=opt_qry)
aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs)
return aesara_function
2 changes: 1 addition & 1 deletion pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):

self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)

super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)
super().__init__(shape=X.shape[0], dtype="float64", initval=0, *args, **kwargs)

if self.X.ndim != 2:
raise ValueError("The design matrix X must have two dimensions")
Expand Down
8 changes: 4 additions & 4 deletions pymc3/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, distribution, lower, upper, default, *args, **kwargs):
super().__init__(
shape=self._wrapped.shape,
dtype=self._wrapped.dtype,
testval=self._wrapped.testval,
initval=self._wrapped.initval,
defaults=defaults,
transform=self._wrapped.transform,
)
Expand Down Expand Up @@ -252,15 +252,15 @@ class Bound:

with pm.Model():
NegativeNormal = pm.Bound(pm.Normal, upper=0.0)
par1 = NegativeNormal('par`', mu=0.0, sigma=1.0, testval=-0.5)
par1 = NegativeNormal('par`', mu=0.0, sigma=1.0, initval=-0.5)
# you can use the Bound object multiple times to
# create multiple bounded random variables
par1_1 = NegativeNormal('par1_1', mu=-1.0, sigma=1.0, testval=-1.5)
par1_1 = NegativeNormal('par1_1', mu=-1.0, sigma=1.0, initval=-1.5)

# you can also define a Bound implicitly, while applying
# it to a random variable
par2 = pm.Bound(pm.Normal, lower=-1.0, upper=1.0)(
'par2', mu=0.0, sigma=1.0, testval=1.0)
'par2', mu=0.0, sigma=1.0, initval=1.0)
"""

def __init__(self, distribution, lower=None, upper=None):
Expand Down
20 changes: 12 additions & 8 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,12 @@ class Flat(Continuous):
rv_op = flat

@classmethod
def dist(cls, *, size=None, testval=None, **kwargs):
if testval is None:
testval = np.full(size, floatX(0.0))
return super().dist([], size=size, testval=testval, **kwargs)
def dist(cls, *, size=None, initval=None, **kwargs):
if initval is None:
initval = np.full(size, floatX(0.0))
res = super().dist([], size=size, **kwargs)
res.tag.test_value = initval
return res

def logp(value):
"""
Expand Down Expand Up @@ -394,10 +396,12 @@ class HalfFlat(PositiveContinuous):
rv_op = halfflat

@classmethod
def dist(cls, *, size=None, testval=None, **kwargs):
if testval is None:
testval = np.full(size, floatX(1.0))
return super().dist([], size=size, testval=testval, **kwargs)
def dist(cls, *, size=None, initval=None, **kwargs):
if initval is None:
initval = np.full(size, floatX(1.0))
res = super().dist([], size=size, **kwargs)
res.tag.test_value = initval
return res

def logp(value):
"""
Expand Down
73 changes: 52 additions & 21 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import warnings

from abc import ABCMeta
from copy import copy
from typing import TYPE_CHECKING

import dill

from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable

from pymc3.distributions import _logcdf, _logp

Expand Down Expand Up @@ -77,14 +77,6 @@ def _random(*args, **kwargs):
rv_type = None

if isinstance(rv_op, RandomVariable):
if not rv_op.inplace:
# TODO: This is a temporary work-around.
# Remove this once we know what we want regarding RNG states
# and their propagation.
rv_op = copy(rv_op)
rv_op.inplace = True
clsdict["rv_op"] = rv_op

rv_type = type(rv_op)

new_cls = super().__new__(cls, name, bases, clsdict)
Expand Down Expand Up @@ -137,7 +129,7 @@ def __new__(cls, name, *args, **kwargs):
rng = kwargs.pop("rng", None)

if rng is None:
rng = model.default_rng
rng = model.next_rng()

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")
Expand All @@ -151,21 +143,60 @@ def __new__(cls, name, *args, **kwargs):
if "shape" in kwargs:
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")

testval = kwargs.pop("testval", None)

if testval is not None:
warnings.warn(
"The `testval` argument is deprecated; use `initval`.",
DeprecationWarning,
stacklevel=2,
)

initval = kwargs.pop("initval", testval)

transform = kwargs.pop("transform", UNSET)

rv_out = cls.dist(*args, rng=rng, **kwargs)

return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
if testval is not None:
rv_out.tag.test_value = testval

return model.register_rv(
rv_out, name, data, total_size, dims=dims, transform=transform, initval=initval
)

@classmethod
def dist(cls, dist_params, **kwargs):
def dist(cls, dist_params, rng=None, **kwargs):

testval = kwargs.pop("testval", None)

rv_var = cls.rv_op(*dist_params, **kwargs)

if testval is not None:
rv_var.tag.test_value = testval
warnings.warn(
"The `testval` argument is deprecated. "
"Use `initval` to set initial values for a `Model`; "
"otherwise, set test values on Aesara parameters explicitly "
"when attempting to use Aesara's test value debugging features.",
DeprecationWarning,
stacklevel=2,
)

rv_var = cls.rv_op(*dist_params, rng=rng, **kwargs)

if (
rv_var.owner
and isinstance(rv_var.owner.op, RandomVariable)
and isinstance(rng, RandomStateSharedVariable)
and not getattr(rng, "default_update", None)
):
# This tells `aesara.function` that the shared RNG variable
# is mutable, which--in turn--tells the `FunctionGraph`
# `Supervisor` feature to allow in-place updates on the variable.
# Without it, the `RandomVariable`s could not be optimized to allow
# in-place RNG updates, forcing all sample results from compiled
# functions to be the same on repeated evaluations.
new_rng = rv_var.owner.outputs[0]
rv_var.update = (rng, new_rng)
rng.default_update = new_rng

return rv_var

Expand Down Expand Up @@ -246,14 +277,14 @@ def __init__(
self,
shape,
dtype,
testval=None,
initval=None,
defaults=(),
parent_dist=None,
*args,
**kwargs,
):
super().__init__(
shape=shape, dtype=dtype, testval=testval, defaults=defaults, *args, **kwargs
shape=shape, dtype=dtype, initval=initval, defaults=defaults, *args, **kwargs
)
self.parent_dist = parent_dist

Expand Down Expand Up @@ -311,7 +342,7 @@ def __init__(
logp,
shape=(),
dtype=None,
testval=0,
initval=0,
random=None,
wrap_random_with_dist_shape=True,
check_shape_in_random=True,
Expand All @@ -332,8 +363,8 @@ def __init__(
a value here.
dtype: None, str (Optional)
The dtype of the distribution.
testval: number or array (Optional)
The ``testval`` of the RV's tensor that follow the ``DensityDist``
initval: number or array (Optional)
The ``initval`` of the RV's tensor that follow the ``DensityDist``
distribution.
args, kwargs: (Optional)
These are passed to the parent class' ``__init__``.
Expand Down Expand Up @@ -369,7 +400,7 @@ def __init__(
"""
if dtype is None:
dtype = aesara.config.floatX
super().__init__(shape, dtype, testval, *args, **kwargs)
super().__init__(shape, dtype, initval, *args, **kwargs)
self.logp = logp
if type(self.logp) == types.MethodType:
if PLATFORM != "linux":
Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ class NormalMixture(Mixture):
10,
shape=n_components,
transform=pm.transforms.ordered,
testval=[1, 2, 3],
initval=[1, 2, 3],
)
σ = pm.HalfNormal("σ", 10, shape=n_components)
weights = pm.Dirichlet("w", np.ones(n_components))
Expand Down Expand Up @@ -684,7 +684,7 @@ def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs):
self.mixture_axis = mixture_axis
kwargs.setdefault("dtype", self.comp_dists.dtype)

# Compute the mode so we don't always have to pass a testval
# Compute the mode so we don't always have to pass a initval
defaults = kwargs.pop("defaults", [])
event_shape = self.comp_dists.shape[mixture_axis + 1 :]
_w = at.shape_padleft(
Expand Down
18 changes: 9 additions & 9 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def logp(self, X):
)


def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testval=None):
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initval=None):
R"""
Bartlett decomposition of the Wishart distribution. As the Wishart
distribution requires the matrix to be symmetric positive semi-definite
Expand Down Expand Up @@ -875,7 +875,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testv
Input matrix S is already Cholesky decomposed as S.T * S
return_cholesky: bool (default=False)
Only return the Cholesky decomposed matrix.
testval: ndarray
initval: ndarray
p x p positive definite matrix used to initialize

Notes
Expand All @@ -894,21 +894,21 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testv
n_diag = len(diag_idx[0])
n_tril = len(tril_idx[0])

if testval is not None:
if initval is not None:
# Inverse transform
testval = np.dot(np.dot(np.linalg.inv(L), testval), np.linalg.inv(L.T))
testval = linalg.cholesky(testval, lower=True)
diag_testval = testval[diag_idx] ** 2
tril_testval = testval[tril_idx]
initval = np.dot(np.dot(np.linalg.inv(L), initval), np.linalg.inv(L.T))
initval = linalg.cholesky(initval, lower=True)
diag_testval = initval[diag_idx] ** 2
tril_testval = initval[tril_idx]
else:
diag_testval = None
tril_testval = None

c = at.sqrt(
ChiSquared("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, testval=diag_testval)
ChiSquared("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, initval=diag_testval)
)
pm._log.info("Added new variable %s_c to model diagonal of Wishart." % name)
z = Normal("%s_z" % name, 0.0, 1.0, shape=n_tril, testval=tril_testval)
z = Normal("%s_z" % name, 0.0, 1.0, shape=n_tril, initval=tril_testval)
pm._log.info("Added new variable %s_z to model off-diagonals of Wishart." % name)
# Construct A matrix
A = at.zeros(S.shape, dtype=np.float32)
Expand Down
Loading