Skip to content

Commit

Permalink
Add gufunc_signature to SymbolicRandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 28, 2024
1 parent a2988c7 commit d81ec05
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 36 deletions.
3 changes: 2 additions & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class CensoredRV(SymbolicRandomVariable):
"""Censored random variable"""

inline_logprob = True
signature = "(),(),()->()"
ndim_supp = 0
_print_name = ("Censored", "\\operatorname{Censored}")


Expand Down Expand Up @@ -115,7 +117,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None):
return CensoredRV(
inputs=[dist_, lower_, upper_],
outputs=[censored_rv_],
ndim_supp=0,
)(dist, lower, upper)


Expand Down
140 changes: 114 additions & 26 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
from pytensor.graph.utils import MetaType
from pytensor.scan.op import Scan
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.blockwise import safe_signature
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -261,6 +263,12 @@ class SymbolicRandomVariable(OpFromGraph):
(0 for scalar, 1 for vector, ...)
"""

ndims_params: Optional[Sequence[int]] = None
"""Number of core dimensions of the distribution's parameters."""

signature: str = None
"""Numpy-like vectorized signature of the distribution."""

inline_logprob: bool = False
"""Specifies whether the logprob function is derived automatically by introspection
of the inner graph.
Expand All @@ -271,9 +279,25 @@ class SymbolicRandomVariable(OpFromGraph):
_print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""

def __init__(self, *args, ndim_supp, **kwargs):
"""Initialitze a SymbolicRandomVariable class."""
self.ndim_supp = ndim_supp
def __init__(
self,
*args,
**kwargs,
):
"""Initialize a SymbolicRandomVariable class."""
if self.signature is None:
self.signature = kwargs.get("signature", None)

if self.signature is not None:
inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(sig) for sig in inputs_sig]
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig)

if self.ndim_supp is None:
self.ndim_supp = kwargs.get("ndim_supp", None)
if self.ndim_supp is None:
raise ValueError("ndim_supp or gufunc_signature must be provided")

Check warning on line 299 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L299

Added line #L299 was not covered by tests

kwargs.setdefault("inline", True)
super().__init__(*args, **kwargs)

Expand All @@ -286,6 +310,11 @@ def update(self, node: Node):
"""
return {}

def batch_ndim(self, node: Node) -> int:
"""Number of dimensions of the distribution's batch shape."""
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
return out_ndim - self.ndim_supp

Check warning on line 316 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L315-L316

Added lines #L315 - L316 were not covered by tests


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""
Expand Down Expand Up @@ -558,23 +587,29 @@ def dist(
logcdf: Optional[Callable] = None,
random: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
):
if ndim_supp is None or ndims_params is None:
if signature is None:
ndim_supp = 0
ndims_params = [0] * len(dist_params)
else:
inputs, outputs = _parse_gufunc_signature(signature)
ndim_supp = max(len(out) for out in outputs)
ndims_params = [len(inp) for inp in inputs]

Check warning on line 604 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L602-L604

Added lines #L602 - L604 were not covered by tests

if ndim_supp > 0:
raise NotImplementedError(
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
)

dist_params = [as_tensor_variable(param) for param in dist_params]

# Assume scalar ndims_params
if ndims_params is None:
ndims_params = [0] * len(dist_params)

if logp is None:
logp = default_not_implemented(class_name, "logp")

Expand Down Expand Up @@ -614,7 +649,7 @@ def rv_op(
random: Optional[Callable],
support_point: Optional[Callable],
ndim_supp: int,
ndims_params: Optional[Sequence[int]],
ndims_params: Sequence[int],
dtype: str,
class_name: str,
**kwargs,
Expand Down Expand Up @@ -702,7 +737,9 @@ def dist(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
Expand All @@ -712,14 +749,24 @@ def dist(
if logcdf is None:
logcdf = default_not_implemented(class_name, "logcdf")

if signature is None:
if ndim_supp is None:
ndim_supp = 0
if ndims_params is None:
ndims_params = [0] * len(dist_params)
signature = safe_signature(

Check warning on line 757 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L752-L757

Added lines #L752 - L757 were not covered by tests
core_inputs=[pt.tensor(shape=(None,) * ndim_param) for ndim_param in ndims_params],
core_outputs=[pt.tensor(shape=(None,) * ndim_supp)],
)

return super().dist(
dist_params,
class_name=class_name,
logp=logp,
logcdf=logcdf,
dist=dist,
support_point=support_point,
ndim_supp=ndim_supp,
signature=signature,
**kwargs,
)

Expand All @@ -732,7 +779,7 @@ def rv_op(
logcdf: Optional[Callable],
support_point: Optional[Callable],
size=None,
ndim_supp: int,
signature: str,
class_name: str,
):
size = normalize_size_param(size)
Expand All @@ -745,6 +792,10 @@ def rv_op(
dummy_params = [dummy_size_param, *dummy_dist_params]
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

signature = cls._infer_final_signature(

Check warning on line 795 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L795

Added line #L795 was not covered by tests
signature, len(dummy_params), len(dummy_updates_dict)
)

rv_type = type(
class_name,
(CustomSymbolicDistRV,),
Expand Down Expand Up @@ -802,7 +853,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
new_rv = new_rv_op(new_size, *dist_params)

Expand All @@ -811,10 +862,30 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
return rv_op(size, *dist_params)

@staticmethod
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str:
"""Add size and updates to user provided gufunc signature if they are missing."""
input_sig, output_sig = signature.split("->")

Check warning on line 872 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L872

Added line #L872 was not covered by tests
# Numpy parser does not accept (constant) functions without inputs like "->()"
# We work around as this makes sense for distributions like Flat that have no inputs
if input_sig.strip() == "":
inputs = ()
_, outputs = _parse_gufunc_signature("()" + signature)

Check warning on line 877 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L875-L877

Added lines #L875 - L877 were not covered by tests
else:
inputs, outputs = _parse_gufunc_signature(signature)
if len(inputs) == n_inputs - 1:

Check warning on line 880 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L879-L880

Added lines #L879 - L880 were not covered by tests
# Assume size is missing
input_sig = ("()," if input_sig else "()") + input_sig
if len(outputs) == 1:

Check warning on line 883 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L882-L883

Added lines #L882 - L883 were not covered by tests
# Assume updates are missing
output_sig = "()," * n_updates + output_sig
signature = "->".join((input_sig, output_sig))
return signature

Check warning on line 887 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L885-L887

Added lines #L885 - L887 were not covered by tests


class CustomDist:
"""A helper class to create custom distributions
Expand All @@ -828,12 +899,12 @@ class CustomDist:
when not provided by the user.
Alternatively, a user can provide a `random` function that returns numerical
draws (e.g., via NumPy routines), and a `logp` function that must return an
Python graph that represents the logp graph when evaluated. This is used for
draws (e.g., via NumPy routines), and a `logp` function that must return a
PyTensor graph that represents the logp graph when evaluated. This is used for
mcmc sampling.
Additionally, a user can provide a `logcdf` and `support_point` functions that must return
an PyTensor graph that computes those quantities. These may be used by other PyMC
PyTensor graphs that computes those quantities. These may be used by other PyMC
routines.
Parameters
Expand Down Expand Up @@ -894,14 +965,18 @@ class CustomDist:
distribution parameters, in the same order as they were supplied when the
CustomDist was created. If ``None``, a default ``support_point`` function will be
assigned that will always return 0, or an array of zeros.
ndim_supp : int
The number of dimensions in the support of the distribution. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``.
ndim_supp : Optional[int]
The number of dimensions in the support of the distribution.
Inferred from signature, if provided. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``
ndims_params : Optional[Sequence[int]]
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0. This is not needed if an
PyTensor dist function is provided.
parameters. Inferred from signature, if provided. Defaults to assuming
all parameters are scalars, i.e. ``ndims_params=[0, ...]``.
signature : Optional[str]
A numpy vectorize-like signature that indicates the number and core dimensionality
of the input parameters and sample outputs of the CustomDist.
When specified, `ndim_supp` and `ndims_params` are not needed. See examples below.
dtype : str
The dtype of the distribution. All draws and observations passed into the
distribution will be cast onto this dtype. This is not needed if an PyTensor
Expand Down Expand Up @@ -939,6 +1014,7 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable:
Provide a random function that return numerical draws. This allows one to use a
CustomDist in prior and posterior predictive sampling.
A gufunc signature was also provided, which may be used by other routines.
.. code-block:: python
Expand All @@ -965,6 +1041,7 @@ def random(
mu,
logp=logp,
random=random,
signature="()->()",
observed=np.random.randn(100, 3),
size=(100, 3),
)
Expand All @@ -973,6 +1050,7 @@ def random(
Provide a dist function that creates a PyTensor graph built from other
PyMC distributions. PyMC can automatically infer that the logp of this
variable corresponds to a shifted Exponential distribution.
A gufunc signature was also provided, which may be used by other routines.
.. code-block:: python
Expand All @@ -994,6 +1072,7 @@ def dist(
lam,
shift,
dist=dist,
signature="(),()->()",
observed=[-1, -1, 0],
)
Expand Down Expand Up @@ -1040,10 +1119,11 @@ def __new__(
random: Optional[Callable] = None,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
# TODO: Deprecate ndim_supp / ndims_params in favor of signature?
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1057,6 +1137,7 @@ def __new__(
)
dist_params = cls.parse_dist_params(dist_params)
cls.check_valid_dist_random(dist, random, dist_params)
moment = kwargs.get("moment", None)
if moment is not None:
warnings.warn(
"`moment` argument is deprecated. Use `support_point` instead.",
Expand All @@ -1073,6 +1154,8 @@ def __new__(
logcdf=logcdf,
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1086,6 +1169,7 @@ def __new__(
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand All @@ -1099,8 +1183,9 @@ def dist(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1114,6 +1199,8 @@ def dist(
logcdf=logcdf,
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1125,6 +1212,7 @@ def dist(
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand Down
9 changes: 8 additions & 1 deletion pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,17 @@ def rv_op(cls, weights, *components, size=None):
# Output mix_indexes rng update so that it can be updated in place
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]

s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
if len(components) == 1:
comp_s = ",".join((*s, "w"))
signature = f"(),(w),({comp_s})->({s})"
else:
comps_s = ",".join(f"({s})" for _ in components)
signature = f"(),(w),{comps_s}->({s})"
mix_op = MarginalMixtureRV(
inputs=[mix_indexes_rng_, weights_, *components_],
outputs=[mix_indexes_rng_next_, mix_out_],
ndim_supp=components[0].owner.op.ndim_supp,
signature=signature,
)

# Create the actual MarginalMixture variable
Expand Down
Loading

0 comments on commit d81ec05

Please sign in to comment.