Skip to content

Commit

Permalink
Add gufunc_signature to SymbolicRandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 18, 2024
1 parent 0541065 commit 1a3744c
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 32 deletions.
3 changes: 2 additions & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class CensoredRV(SymbolicRandomVariable):
"""Censored random variable"""

inline_logprob = True
signature = "(),(),()->()"
ndim_supp = 0
_print_name = ("Censored", "\\operatorname{Censored}")


Expand Down Expand Up @@ -115,7 +117,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None):
return CensoredRV(
inputs=[dist_, lower_, upper_],
outputs=[censored_rv_],
ndim_supp=0,
)(dist, lower, upper)


Expand Down
122 changes: 100 additions & 22 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
from pytensor.graph.utils import MetaType
from pytensor.scan.op import Scan
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.blockwise import safe_signature
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -249,6 +251,12 @@ class SymbolicRandomVariable(OpFromGraph):
(0 for scalar, 1 for vector, ...)
"""

ndims_params: Optional[Sequence[int]] = None
"""Number of core dimensions of the distribution's parameters."""

signature: str = None
"""Numpy-like vectorized signature of the distribution."""

inline_logprob: bool = False
"""Specifies whether the logprob function is derived automatically by introspection
of the inner graph.
Expand All @@ -259,9 +267,25 @@ class SymbolicRandomVariable(OpFromGraph):
_print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""

def __init__(self, *args, ndim_supp, **kwargs):
"""Initialitze a SymbolicRandomVariable class."""
self.ndim_supp = ndim_supp
def __init__(
self,
*args,
**kwargs,
):
"""Initialize a SymbolicRandomVariable class."""
if self.signature is None:
self.signature = kwargs.get("signature", None)

if self.signature is not None:
inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(sig) for sig in inputs_sig]
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig)

if self.ndim_supp is None:
self.ndim_supp = kwargs.get("ndim_supp", None)
if self.ndim_supp is None:
raise ValueError("ndim_supp or gufunc_signature must be provided")

Check warning on line 287 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L287

Added line #L287 was not covered by tests

kwargs.setdefault("inline", True)
super().__init__(*args, **kwargs)

Expand All @@ -274,6 +298,11 @@ def update(self, node: Node):
"""
return {}

def batch_ndim(self, node: Node) -> int:
"""Number of dimensions of the distribution's batch shape."""
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
return out_ndim - self.ndim_supp

Check warning on line 304 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L303-L304

Added lines #L303 - L304 were not covered by tests


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""
Expand Down Expand Up @@ -538,23 +567,29 @@ def dist(
logcdf: Optional[Callable] = None,
random: Optional[Callable] = None,
moment: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
):
if ndim_supp is None or ndims_params is None:
if signature is None:
ndim_supp = 0
ndims_params = [0] * len(dist_params)
else:
inputs, outputs = _parse_gufunc_signature(signature)
ndim_supp = max(len(out) for out in outputs)
ndims_params = [len(inp) for inp in inputs]

Check warning on line 584 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L582-L584

Added lines #L582 - L584 were not covered by tests

if ndim_supp > 0:
raise NotImplementedError(
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
)

dist_params = [as_tensor_variable(param) for param in dist_params]

# Assume scalar ndims_params
if ndims_params is None:
ndims_params = [0] * len(dist_params)

if logp is None:
logp = default_not_implemented(class_name, "logp")

Expand Down Expand Up @@ -594,7 +629,7 @@ def rv_op(
random: Optional[Callable],
moment: Optional[Callable],
ndim_supp: int,
ndims_params: Optional[Sequence[int]],
ndims_params: Sequence[int],
dtype: str,
class_name: str,
**kwargs,
Expand Down Expand Up @@ -682,7 +717,9 @@ def dist(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
Expand All @@ -692,14 +729,24 @@ def dist(
if logcdf is None:
logcdf = default_not_implemented(class_name, "logcdf")

if signature is None:
if ndim_supp is None:
ndim_supp = 0
if ndims_params is None:
ndims_params = [0] * len(dist_params)
signature = safe_signature(
core_inputs=[pt.tensor(shape=(None,) * ndim_param) for ndim_param in ndims_params],
core_outputs=[pt.tensor(shape=(None,) * ndim_supp)],
)

return super().dist(
dist_params,
class_name=class_name,
logp=logp,
logcdf=logcdf,
dist=dist,
moment=moment,
ndim_supp=ndim_supp,
signature=signature,
**kwargs,
)

Expand All @@ -712,7 +759,7 @@ def rv_op(
logcdf: Optional[Callable],
moment: Optional[Callable],
size=None,
ndim_supp: int,
signature: str,
class_name: str,
):
size = normalize_size_param(size)
Expand All @@ -725,6 +772,24 @@ def rv_op(
dummy_params = [dummy_size_param, *dummy_dist_params]
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

# Add size and updates in gufunc signature if they are missing
input_sig, output_sig = signature.split("->")
# Numpy parser does not accept (constant) functions without inputs like "->()"
# We work around as this makes sense for distributions like Flat that have no inputs
if input_sig.strip() == "":
inputs = ()
_, outputs = _parse_gufunc_signature("()" + signature)
else:
inputs, outputs = _parse_gufunc_signature(signature)
if len(inputs) == len(dummy_params) - 1:
# Assume size is missing
input_sig = ("()," if input_sig else "()") + input_sig
n_updates = len(dummy_updates_dict)
if len(outputs) == 1:
# Assume updates are missing
output_sig = "()," * n_updates + output_sig
signature = "->".join((input_sig, output_sig))

rv_type = type(
class_name,
(CustomSymbolicDistRV,),
Expand Down Expand Up @@ -782,7 +847,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
new_rv = new_rv_op(new_size, *dist_params)

Expand All @@ -791,7 +856,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
return rv_op(size, *dist_params)

Expand Down Expand Up @@ -874,14 +939,18 @@ class CustomDist:
distribution parameters, in the same order as they were supplied when the
CustomDist was created. If ``None``, a default ``moment`` function will be
assigned that will always return 0, or an array of zeros.
ndim_supp : int
The number of dimensions in the support of the distribution. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``.
ndim_supp : Optional[int]
The number of dimensions in the support of the distribution.
Inferred from signature, if provided. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``
ndims_params : Optional[Sequence[int]]
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0. This is not needed if an
PyTensor dist function is provided.
parameters. Inferred from signature, if provided. Defaults to assuming
all parameters are scalars, i.e. ``ndims_params=[0, ...]``.
signature : Optional[str]
A numpy vectorize-like signature that indicates the number and core dimensionality
of the inputs and outputs of the CustomDist. When specified `ndim_supp` and `ndims_params`
are not needed.
dtype : str
The dtype of the distribution. All draws and observations passed into the
distribution will be cast onto this dtype. This is not needed if an PyTensor
Expand Down Expand Up @@ -1021,8 +1090,10 @@ def __new__(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
ndim_supp: int = 0,
# TODO: Deprecate ndim_supp / ndims_params in favor of signature?
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1046,6 +1117,8 @@ def __new__(
logcdf=logcdf,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1059,6 +1132,7 @@ def __new__(
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand All @@ -1072,8 +1146,9 @@ def dist(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1087,6 +1162,8 @@ def dist(
logcdf=logcdf,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1098,6 +1175,7 @@ def dist(
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand Down
9 changes: 8 additions & 1 deletion pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,17 @@ def rv_op(cls, weights, *components, size=None):
# Output mix_indexes rng update so that it can be updated in place
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]

s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
if len(components) == 1:
comp_s = ",".join((*s, "w"))
signature = f"(),(w),({comp_s})->({s})"
else:
comps_s = ",".join(f"({s})" for _ in components)
signature = f"(),(w),{comps_s}->({s})"
mix_op = MarginalMixtureRV(
inputs=[mix_indexes_rng_, weights_, *components_],
outputs=[mix_indexes_rng_next_, mix_out_],
ndim_supp=components[0].owner.op.ndim_supp,
signature=signature,
)

# Create the actual MarginalMixture variable
Expand Down
9 changes: 6 additions & 3 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,8 @@ def rng_fn(self, rng, n, eta, D, size):
# be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper
class _LKJCholeskyCovRV(SymbolicRandomVariable):
default_output = 1
signature = "(),(),(),(n)->(),(n)"
ndim_supp = 1
_print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}")

def update(self, node):
Expand Down Expand Up @@ -1218,7 +1220,6 @@ def rv_op(cls, n, eta, sd_dist, size=None):
return _LKJCholeskyCovRV(
inputs=[rng_, n_, eta_, sd_dist_],
outputs=[next_rng_, lkjcov_],
ndim_supp=1,
)(rng, n, eta, sd_dist)


Expand Down Expand Up @@ -2787,10 +2788,12 @@ def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
for axis in range(n_zerosum_axes):
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)

support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
signature = f"({support_str}),(),(s)->({support_str})"
return ZeroSumNormalRV(
inputs=[normal_dist_, sigma_, support_shape_],
outputs=[zerosum_rv_, support_shape_],
ndim_supp=n_zerosum_axes,
outputs=[zerosum_rv_],
signature=signature,
)(normal_dist, sigma, support_shape)


Expand Down
Loading

0 comments on commit 1a3744c

Please sign in to comment.