Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gufunc signature to SymbolicRandomVariables #7159

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class CensoredRV(SymbolicRandomVariable):
"""Censored random variable"""

inline_logprob = True
signature = "(),(),()->()"
ndim_supp = 0
_print_name = ("Censored", "\\operatorname{Censored}")


Expand Down Expand Up @@ -115,7 +117,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None):
return CensoredRV(
inputs=[dist_, lower_, upper_],
outputs=[censored_rv_],
ndim_supp=0,
)(dist, lower, upper)


Expand Down
140 changes: 114 additions & 26 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
from pytensor.graph.utils import MetaType
from pytensor.scan.op import Scan
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.blockwise import safe_signature
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -261,6 +263,12 @@
(0 for scalar, 1 for vector, ...)
"""

ndims_params: Optional[Sequence[int]] = None
"""Number of core dimensions of the distribution's parameters."""

signature: str = None
"""Numpy-like vectorized signature of the distribution."""

inline_logprob: bool = False
"""Specifies whether the logprob function is derived automatically by introspection
of the inner graph.
Expand All @@ -271,9 +279,25 @@
_print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""

def __init__(self, *args, ndim_supp, **kwargs):
"""Initialitze a SymbolicRandomVariable class."""
self.ndim_supp = ndim_supp
def __init__(
self,
*args,
**kwargs,
):
"""Initialize a SymbolicRandomVariable class."""
if self.signature is None:
self.signature = kwargs.get("signature", None)

if self.signature is not None:
inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(sig) for sig in inputs_sig]
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig)

if self.ndim_supp is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ndim_supp going to be deprecated? If so should we include a warning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be deprecated but I didn't want to decide yet on that. Won't be deprecated without a warning beforehand

self.ndim_supp = kwargs.get("ndim_supp", None)
if self.ndim_supp is None:
raise ValueError("ndim_supp or gufunc_signature must be provided")

Check warning on line 299 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L299

Added line #L299 was not covered by tests

kwargs.setdefault("inline", True)
super().__init__(*args, **kwargs)

Expand All @@ -286,6 +310,11 @@
"""
return {}

def batch_ndim(self, node: Node) -> int:
"""Number of dimensions of the distribution's batch shape."""
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
return out_ndim - self.ndim_supp

Check warning on line 316 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L315-L316

Added lines #L315 - L316 were not covered by tests


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""
Expand Down Expand Up @@ -558,23 +587,29 @@
logcdf: Optional[Callable] = None,
random: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
):
if ndim_supp is None or ndims_params is None:
if signature is None:
ndim_supp = 0
ndims_params = [0] * len(dist_params)
else:
inputs, outputs = _parse_gufunc_signature(signature)
ndim_supp = max(len(out) for out in outputs)
ndims_params = [len(inp) for inp in inputs]

Check warning on line 604 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L602-L604

Added lines #L602 - L604 were not covered by tests

if ndim_supp > 0:
raise NotImplementedError(
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
)

dist_params = [as_tensor_variable(param) for param in dist_params]

# Assume scalar ndims_params
if ndims_params is None:
ndims_params = [0] * len(dist_params)

if logp is None:
logp = default_not_implemented(class_name, "logp")

Expand Down Expand Up @@ -614,7 +649,7 @@
random: Optional[Callable],
support_point: Optional[Callable],
ndim_supp: int,
ndims_params: Optional[Sequence[int]],
ndims_params: Sequence[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this no longer Optional while everywhere else ndim becomes optional

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users don't call this method, the user facing one dist and __new__ always defines ndim from the optional arguments and then passes it here

dtype: str,
class_name: str,
**kwargs,
Expand Down Expand Up @@ -702,7 +737,9 @@
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
Expand All @@ -712,14 +749,24 @@
if logcdf is None:
logcdf = default_not_implemented(class_name, "logcdf")

if signature is None:
if ndim_supp is None:
ndim_supp = 0
if ndims_params is None:
ndims_params = [0] * len(dist_params)
signature = safe_signature(
core_inputs=[pt.tensor(shape=(None,) * ndim_param) for ndim_param in ndims_params],
core_outputs=[pt.tensor(shape=(None,) * ndim_supp)],
)

return super().dist(
dist_params,
class_name=class_name,
logp=logp,
logcdf=logcdf,
dist=dist,
support_point=support_point,
ndim_supp=ndim_supp,
signature=signature,
**kwargs,
)

Expand All @@ -732,7 +779,7 @@
logcdf: Optional[Callable],
support_point: Optional[Callable],
size=None,
ndim_supp: int,
signature: str,
class_name: str,
):
size = normalize_size_param(size)
Expand All @@ -745,6 +792,10 @@
dummy_params = [dummy_size_param, *dummy_dist_params]
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

signature = cls._infer_final_signature(
signature, len(dummy_params), len(dummy_updates_dict)
)

rv_type = type(
class_name,
(CustomSymbolicDistRV,),
Expand Down Expand Up @@ -802,7 +853,7 @@
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
new_rv = new_rv_op(new_size, *dist_params)

Expand All @@ -811,10 +862,30 @@
rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
return rv_op(size, *dist_params)

@staticmethod
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str:
"""Add size and updates to user provided gufunc signature if they are missing."""
input_sig, output_sig = signature.split("->")
# Numpy parser does not accept (constant) functions without inputs like "->()"
# We work around as this makes sense for distributions like Flat that have no inputs
if input_sig.strip() == "":
inputs = ()
_, outputs = _parse_gufunc_signature("()" + signature)
else:
inputs, outputs = _parse_gufunc_signature(signature)
if len(inputs) == n_inputs - 1:
# Assume size is missing
input_sig = ("()," if input_sig else "()") + input_sig
if len(outputs) == 1:
# Assume updates are missing
output_sig = "()," * n_updates + output_sig
signature = "->".join((input_sig, output_sig))
return signature


class CustomDist:
"""A helper class to create custom distributions
Expand All @@ -828,12 +899,12 @@
when not provided by the user.

Alternatively, a user can provide a `random` function that returns numerical
draws (e.g., via NumPy routines), and a `logp` function that must return an
Python graph that represents the logp graph when evaluated. This is used for
draws (e.g., via NumPy routines), and a `logp` function that must return a
PyTensor graph that represents the logp graph when evaluated. This is used for
mcmc sampling.

Additionally, a user can provide a `logcdf` and `support_point` functions that must return
an PyTensor graph that computes those quantities. These may be used by other PyMC
PyTensor graphs that computes those quantities. These may be used by other PyMC
routines.

Parameters
Expand Down Expand Up @@ -894,14 +965,18 @@
distribution parameters, in the same order as they were supplied when the
CustomDist was created. If ``None``, a default ``support_point`` function will be
assigned that will always return 0, or an array of zeros.
ndim_supp : int
The number of dimensions in the support of the distribution. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``.
ndim_supp : Optional[int]
The number of dimensions in the support of the distribution.
Inferred from signature, if provided. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``
ndims_params : Optional[Sequence[int]]
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0. This is not needed if an
PyTensor dist function is provided.
parameters. Inferred from signature, if provided. Defaults to assuming
all parameters are scalars, i.e. ``ndims_params=[0, ...]``.
signature : Optional[str]
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
A numpy vectorize-like signature that indicates the number and core dimensionality
of the input parameters and sample outputs of the CustomDist.
When specified, `ndim_supp` and `ndims_params` are not needed. See examples below.
dtype : str
The dtype of the distribution. All draws and observations passed into the
distribution will be cast onto this dtype. This is not needed if an PyTensor
Expand Down Expand Up @@ -939,6 +1014,7 @@

Provide a random function that return numerical draws. This allows one to use a
CustomDist in prior and posterior predictive sampling.
A gufunc signature was also provided, which may be used by other routines.

.. code-block:: python

Expand All @@ -965,6 +1041,7 @@
mu,
logp=logp,
random=random,
signature="()->()",
observed=np.random.randn(100, 3),
size=(100, 3),
)
Expand All @@ -973,6 +1050,7 @@
Provide a dist function that creates a PyTensor graph built from other
PyMC distributions. PyMC can automatically infer that the logp of this
variable corresponds to a shifted Exponential distribution.
A gufunc signature was also provided, which may be used by other routines.

.. code-block:: python

Expand All @@ -994,6 +1072,7 @@
lam,
shift,
dist=dist,
signature="(),()->()",
observed=[-1, -1, 0],
)

Expand Down Expand Up @@ -1040,10 +1119,11 @@
random: Optional[Callable] = None,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
# TODO: Deprecate ndim_supp / ndims_params in favor of signature?
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1057,6 +1137,7 @@
)
dist_params = cls.parse_dist_params(dist_params)
cls.check_valid_dist_random(dist, random, dist_params)
moment = kwargs.pop("moment", None)
if moment is not None:
warnings.warn(
"`moment` argument is deprecated. Use `support_point` instead.",
Expand All @@ -1073,6 +1154,8 @@
logcdf=logcdf,
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1086,6 +1169,7 @@
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand All @@ -1099,8 +1183,9 @@
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1114,6 +1199,8 @@
logcdf=logcdf,
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1125,6 +1212,7 @@
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand Down
9 changes: 8 additions & 1 deletion pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,17 @@ def rv_op(cls, weights, *components, size=None):
# Output mix_indexes rng update so that it can be updated in place
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]

s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
if len(components) == 1:
comp_s = ",".join((*s, "w"))
signature = f"(),(w),({comp_s})->({s})"
else:
comps_s = ",".join(f"({s})" for _ in components)
signature = f"(),(w),{comps_s}->({s})"
mix_op = MarginalMixtureRV(
inputs=[mix_indexes_rng_, weights_, *components_],
outputs=[mix_indexes_rng_next_, mix_out_],
ndim_supp=components[0].owner.op.ndim_supp,
signature=signature,
)

# Create the actual MarginalMixture variable
Expand Down
Loading
Loading