Skip to content

Commit

Permalink
Introduce core shape in RandomVariable Ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 30, 2024
1 parent 5285723 commit b770467
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 366 deletions.
20 changes: 11 additions & 9 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
if None in out_size:
assert_size_argument_jax_compatible(node)

def sample_fn(rng, size, *parameters):
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
def sample_fn(rng, batch_shape, core_shape, *parameters):
return jax_sample_fn(op)(rng, batch_shape, out_dtype, *parameters)

else:

def sample_fn(rng, size, *parameters):
return jax_sample_fn(op)(rng, out_size, out_dtype, *parameters)
def sample_fn(rng, batch_shape, core_shape, *parameters):
return jax_sample_fn(op)(rng, batch_shape, out_dtype, *parameters)

return sample_fn

Expand Down Expand Up @@ -305,7 +305,7 @@ def jax_sample_fn_binomial(op):

from numpyro.distributions.util import binomial

def sample_fn(rng, size, dtype, n, p):
def sample_fn(rng, size, core_shape, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

Expand All @@ -328,11 +328,11 @@ def jax_sample_fn_multinomial(op):

from numpyro.distributions.util import multinomial

def sample_fn(rng, size, dtype, n, p):
def sample_fn(rng, batch_shape, core_shape, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
sample = multinomial(key=sampling_key, n=n, p=p, shape=batch_shape)

rng["jax_state"] = rng_key

Expand All @@ -351,12 +351,14 @@ def jax_sample_fn_vonmises(op):

from numpyro.distributions.util import von_mises_centered

def sample_fn(rng, size, dtype, mu, kappa):
dtype = op.dtype

def sample_fn(rng, batch_shape, core_shape, mu, kappa):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = von_mises_centered(
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
key=sampling_key, concentration=kappa, shape=batch_shape, dtype=dtype
)
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi

Expand Down
45 changes: 26 additions & 19 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytensor.link.utils import (
compile_function_src,
)
from pytensor.tensor import NoneConst, get_vector_length
from pytensor.tensor.random.op import RandomVariable


Expand Down Expand Up @@ -62,7 +63,6 @@ def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
@numba_core_rv_funcify.register(ptr.BinomialRV)
@numba_core_rv_funcify.register(ptr.NegativeBinomialRV)
@numba_core_rv_funcify.register(ptr.MultinomialRV)
@numba_core_rv_funcify.register(ptr.DirichletRV)
@numba_core_rv_funcify.register(ptr.ChoiceRV) # the `p` argument is not supported
@numba_core_rv_funcify.register(ptr.PermutationRV)
def numba_core_rv_default(op, node):
Expand Down Expand Up @@ -145,6 +145,18 @@ def random_fn(rng, mean, cov, out):
return random_fn


@numba_core_rv_funcify.register(ptr.DirichletRV)
def core_DirichletRV(op, node):
@numba_basic.numba_njit
def random_fn(rng, alpha):
y = np.empty_like(alpha)
for i in range(len(alpha)):
y[i] = rng.gamma(alpha[i], 1.0)
return y / y.sum()

return random_fn


@numba_core_rv_funcify.register(ptr.GumbelRV)
def core_GumbelRV(op, node):
"""Code adapted from Numpy Implementation
Expand Down Expand Up @@ -229,20 +241,15 @@ def random_fn(rng, mu, kappa):

@numba_funcify.register(ptr.RandomVariable)
def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
size = op.size_param(node)
batch_shape = op.batch_shape_param(node)
core_shape = op.core_shape_param(node)
dist_params = op.dist_params(node)

# None sizes are represented as empty tuple for the time being
# https://github.com/pymc-devs/pytensor/issues/568
[size_len] = size.type.shape
size_is_None = size_len == 0

batch_shape_len = (
None if NoneConst.equals(batch_shape) else get_vector_length(batch_shape)
)
core_shape_len = get_vector_length(core_shape)
inplace = op.inplace

# TODO: Add core_shape to node.inputs
if op.ndim_supp > 0:
raise NotImplementedError("Multivariate RandomVariable not implemented yet")

core_op_fn = numba_core_rv_funcify(op, node)
if not getattr(core_op_fn, "handles_out", False):
nin = 1 + len(dist_params) # rng + params
Expand All @@ -260,7 +267,7 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
output_dtypes = encode_literals((node.default_output().type.dtype,))
inplace_pattern = encode_literals(())

def random_wrapper(rng, size, *inputs):
def random_wrapper(rng, batch_shape, core_shape, *dist_params):
if not inplace:
rng = copy(rng)

Expand All @@ -271,19 +278,19 @@ def random_wrapper(rng, size, *inputs):
output_dtypes,
inplace_pattern,
(rng,),
inputs,
((),), # TODO: correct core_shapes
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None
if size_is_None
else numba_ndarray.to_fixed_tuple(size, size_len), # size
if batch_shape_len is None
else numba_ndarray.to_fixed_tuple(batch_shape, batch_shape_len),
)
return rng, draws

def random(rng, size, *inputs):
def random(rng, batch_shape, core_shape, *dist_params):
pass

@overload(random)
def ov_random(rng, size, *inputs):
def ov_random(rng, batch_shape, core_shape, *dist_params):
return random_wrapper

return random
63 changes: 34 additions & 29 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ class ScipyRandomVariable(RandomVariable):
"""

@classmethod
@abc.abstractmethod
def rng_fn_scipy(cls, rng, *args, **kwargs):
def rng_fn_scipy(cls, *args, **kwargs):
r"""
`RandomVariable`\s implementations that want to use SciPy-based samplers
Expand All @@ -46,24 +45,30 @@ def rng_fn_scipy(cls, rng, *args, **kwargs):
"""

@classmethod
def rng_fn(cls, *args, **kwargs):
size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs)
def rng_fn(self, *args):
rng, *params, size, _ = args
return self.rng_fn_scipy(rng, *params, size)

def perform(self, node, inputs, outputs):
super().perform(node, inputs, outputs)

_, batch_shape, _, *params = inputs
_, draws_container = outputs
[draws] = draws_container

if np.ndim(res) == 0:
if np.ndim(draws) == 0:
# The sample is an `np.number`, and is not writeable, or non-NumPy
# type, so we need to clone/create a usable NumPy result
res = np.asarray(res)
draws = np.asarray(draws)

if size is None:
if batch_shape is None:
# SciPy will sometimes drop broadcastable dimensions; we need to
# check and, if necessary, add them back
exp_shape = broadcast_shapes(*[np.shape(a) for a in args[1:-1]])
if res.shape != exp_shape:
return np.broadcast_to(res, exp_shape).copy()
missing_ndim = node.outputs[1].type.ndim - draws.ndim
if missing_ndim:
draws = np.expand_dims(draws, tuple(range(missing_ndim)))

return res
draws_container[0] = draws


class UniformRV(RandomVariable):
Expand Down Expand Up @@ -423,7 +428,7 @@ class GammaRV(RandomVariable):
dtype = "floatX"
_print_name = ("Gamma", "\\operatorname{Gamma}")

def __call__(self, shape, scale, size=None, **kwargs):
def __call__(self, shape_param, scale, size=None, **kwargs):
r"""Draw samples from a gamma distribution.
Signature
Expand All @@ -433,7 +438,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
Parameters
----------
shape
shape_param
The shape :math:`\alpha` of the gamma distribution. Must be positive.
scale
The scale :math:`1/\beta` of the gamma distribution. Must be positive.
Expand All @@ -444,7 +449,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
is returned.
"""
return super().__call__(shape, scale, size=size, **kwargs)
return super().__call__(shape_param, scale, size=size, **kwargs)


_gamma = GammaRV()
Expand Down Expand Up @@ -672,7 +677,7 @@ class WeibullRV(RandomVariable):
dtype = "floatX"
_print_name = ("Weibull", "\\operatorname{Weibull}")

def __call__(self, shape, size=None, **kwargs):
def __call__(self, shape_param, size=None, **kwargs):
r"""Draw samples from a weibull distribution.
Signature
Expand All @@ -682,7 +687,7 @@ def __call__(self, shape, size=None, **kwargs):
Parameters
----------
shape
shape_param
The shape :math:`k` of the distribution. Must be positive.
size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
Expand All @@ -691,7 +696,7 @@ def __call__(self, shape, size=None, **kwargs):
is returned.
"""
return super().__call__(shape, size=size, **kwargs)
return super().__call__(shape_param, size=size, **kwargs)


weibull = WeibullRV()
Expand Down Expand Up @@ -863,7 +868,7 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
return super().__call__(mean, cov, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, mean, cov, size):
def rng_fn(cls, rng, mean, cov, size, core_shape=None):
if mean.ndim > 1 or cov.ndim > 2:
# Neither SciPy nor NumPy implement parameter broadcasting for
# multivariate normals (or any other multivariate distributions),
Expand Down Expand Up @@ -932,7 +937,7 @@ def __call__(self, alphas, size=None, **kwargs):
return super().__call__(alphas, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, alphas, size):
def rng_fn(cls, rng, alphas, size, core_shape=None):
if alphas.ndim > 1:
if size is None:
size = ()
Expand Down Expand Up @@ -1213,7 +1218,7 @@ class InvGammaRV(ScipyRandomVariable):
dtype = "floatX"
_print_name = ("InverseGamma", "\\operatorname{InverseGamma}")

def __call__(self, shape, scale, size=None, **kwargs):
def __call__(self, shape_param, scale, size=None, **kwargs):
r"""Draw samples from an inverse-gamma distribution.
Signature
Expand All @@ -1223,7 +1228,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
Parameters
----------
shape
shape_param
Shape parameter :math:`\alpha` of the distribution. Must be positive.
scale
Scale parameter :math:`\beta` of the distribution. Must be
Expand All @@ -1234,7 +1239,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
`None`, in which case a single sample is returned.
"""
return super().__call__(shape, scale, size=size, **kwargs)
return super().__call__(shape_param, scale, size=size, **kwargs)

@classmethod
def rng_fn_scipy(cls, rng, shape, scale, size):
Expand Down Expand Up @@ -1748,7 +1753,7 @@ def __call__(self, n, p, size=None, **kwargs):
return super().__call__(n, p, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, n, p, size):
def rng_fn(cls, rng, n, p, size, core_shape=None):
if n.ndim > 0 or p.ndim > 1:
size = tuple(size or ())

Expand Down Expand Up @@ -1812,7 +1817,7 @@ def __call__(self, p, size=None, **kwargs):
return super().__call__(p, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, p, size):
def rng_fn(cls, rng, p, size, core_shape=None):
if size is None:
size = p.shape[:-1]
else:
Expand Down Expand Up @@ -1901,10 +1906,10 @@ def __init__(self, *args, ndim_supp: int, p_none: bool, signature=None, **kwargs
def rng_fn(self, *params):
# Should we split into two Ops depending on p_none or not?
if self.p_none:
rng, a, replace, size = params
rng, a, replace, size, core_shape = params
p = None
else:
rng, a, p, replace, size = params
rng, a, p, replace, size, core_shape = params

batch_ndim = a.ndim - self.ndims_params[0]

Expand Down Expand Up @@ -1982,7 +1987,7 @@ class PermutationRV(RandomVariable):
_print_name = ("permutation", "\\operatorname{permutation}")

@classmethod
def rng_fn(cls, rng, x, size):
def rng_fn(cls, rng, x, size, core_shape=None):
return rng.permutation(x)

def __call__(self, x, dtype=None, **kwargs):
Expand Down
Loading

0 comments on commit b770467

Please sign in to comment.