-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gufunc signature to SymbolicRandomVariables #7159
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,11 +33,13 @@ | |
from pytensor.graph.utils import MetaType | ||
from pytensor.scan.op import Scan | ||
from pytensor.tensor.basic import as_tensor_variable | ||
from pytensor.tensor.blockwise import safe_signature | ||
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift | ||
from pytensor.tensor.random.type import RandomGeneratorType, RandomType | ||
from pytensor.tensor.random.utils import normalize_size_param | ||
from pytensor.tensor.rewriting.shape import ShapeFeature | ||
from pytensor.tensor.utils import _parse_gufunc_signature | ||
from pytensor.tensor.variable import TensorVariable | ||
from typing_extensions import TypeAlias | ||
|
||
|
@@ -261,6 +263,12 @@ | |
(0 for scalar, 1 for vector, ...) | ||
""" | ||
|
||
ndims_params: Optional[Sequence[int]] = None | ||
"""Number of core dimensions of the distribution's parameters.""" | ||
|
||
signature: str = None | ||
"""Numpy-like vectorized signature of the distribution.""" | ||
|
||
inline_logprob: bool = False | ||
"""Specifies whether the logprob function is derived automatically by introspection | ||
of the inner graph. | ||
|
@@ -271,9 +279,25 @@ | |
_print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}") | ||
"""Tuple of (name, latex name) used for for pretty-printing variables of this type""" | ||
|
||
def __init__(self, *args, ndim_supp, **kwargs): | ||
"""Initialitze a SymbolicRandomVariable class.""" | ||
self.ndim_supp = ndim_supp | ||
def __init__( | ||
self, | ||
*args, | ||
**kwargs, | ||
): | ||
"""Initialize a SymbolicRandomVariable class.""" | ||
if self.signature is None: | ||
self.signature = kwargs.get("signature", None) | ||
|
||
if self.signature is not None: | ||
inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature) | ||
self.ndims_params = [len(sig) for sig in inputs_sig] | ||
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig) | ||
|
||
if self.ndim_supp is None: | ||
self.ndim_supp = kwargs.get("ndim_supp", None) | ||
if self.ndim_supp is None: | ||
raise ValueError("ndim_supp or gufunc_signature must be provided") | ||
|
||
kwargs.setdefault("inline", True) | ||
super().__init__(*args, **kwargs) | ||
|
||
|
@@ -286,6 +310,11 @@ | |
""" | ||
return {} | ||
|
||
def batch_ndim(self, node: Node) -> int: | ||
"""Number of dimensions of the distribution's batch shape.""" | ||
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs) | ||
return out_ndim - self.ndim_supp | ||
|
||
|
||
class Distribution(metaclass=DistributionMeta): | ||
"""Statistical distribution""" | ||
|
@@ -558,23 +587,29 @@ | |
logcdf: Optional[Callable] = None, | ||
random: Optional[Callable] = None, | ||
support_point: Optional[Callable] = None, | ||
ndim_supp: int = 0, | ||
ndim_supp: Optional[int] = None, | ||
ndims_params: Optional[Sequence[int]] = None, | ||
signature: Optional[str] = None, | ||
dtype: str = "floatX", | ||
class_name: str = "CustomDist", | ||
**kwargs, | ||
): | ||
if ndim_supp is None or ndims_params is None: | ||
if signature is None: | ||
ndim_supp = 0 | ||
ndims_params = [0] * len(dist_params) | ||
else: | ||
inputs, outputs = _parse_gufunc_signature(signature) | ||
ndim_supp = max(len(out) for out in outputs) | ||
ndims_params = [len(inp) for inp in inputs] | ||
|
||
if ndim_supp > 0: | ||
raise NotImplementedError( | ||
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported." | ||
) | ||
|
||
dist_params = [as_tensor_variable(param) for param in dist_params] | ||
|
||
# Assume scalar ndims_params | ||
if ndims_params is None: | ||
ndims_params = [0] * len(dist_params) | ||
|
||
if logp is None: | ||
logp = default_not_implemented(class_name, "logp") | ||
|
||
|
@@ -614,7 +649,7 @@ | |
random: Optional[Callable], | ||
support_point: Optional[Callable], | ||
ndim_supp: int, | ||
ndims_params: Optional[Sequence[int]], | ||
ndims_params: Sequence[int], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this no longer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users don't call this method, the user facing one |
||
dtype: str, | ||
class_name: str, | ||
**kwargs, | ||
|
@@ -702,7 +737,9 @@ | |
logp: Optional[Callable] = None, | ||
logcdf: Optional[Callable] = None, | ||
support_point: Optional[Callable] = None, | ||
ndim_supp: int = 0, | ||
ndim_supp: Optional[int] = None, | ||
ndims_params: Optional[Sequence[int]] = None, | ||
signature: Optional[str] = None, | ||
dtype: str = "floatX", | ||
class_name: str = "CustomDist", | ||
**kwargs, | ||
|
@@ -712,14 +749,24 @@ | |
if logcdf is None: | ||
logcdf = default_not_implemented(class_name, "logcdf") | ||
|
||
if signature is None: | ||
if ndim_supp is None: | ||
ndim_supp = 0 | ||
if ndims_params is None: | ||
ndims_params = [0] * len(dist_params) | ||
signature = safe_signature( | ||
core_inputs=[pt.tensor(shape=(None,) * ndim_param) for ndim_param in ndims_params], | ||
core_outputs=[pt.tensor(shape=(None,) * ndim_supp)], | ||
) | ||
|
||
return super().dist( | ||
dist_params, | ||
class_name=class_name, | ||
logp=logp, | ||
logcdf=logcdf, | ||
dist=dist, | ||
support_point=support_point, | ||
ndim_supp=ndim_supp, | ||
signature=signature, | ||
**kwargs, | ||
) | ||
|
||
|
@@ -732,7 +779,7 @@ | |
logcdf: Optional[Callable], | ||
support_point: Optional[Callable], | ||
size=None, | ||
ndim_supp: int, | ||
signature: str, | ||
class_name: str, | ||
): | ||
size = normalize_size_param(size) | ||
|
@@ -745,6 +792,10 @@ | |
dummy_params = [dummy_size_param, *dummy_dist_params] | ||
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) | ||
|
||
signature = cls._infer_final_signature( | ||
signature, len(dummy_params), len(dummy_updates_dict) | ||
) | ||
|
||
rv_type = type( | ||
class_name, | ||
(CustomSymbolicDistRV,), | ||
|
@@ -802,7 +853,7 @@ | |
new_rv_op = rv_type( | ||
inputs=dummy_params, | ||
outputs=[*dummy_updates_dict.values(), dummy_rv], | ||
ndim_supp=ndim_supp, | ||
signature=signature, | ||
) | ||
new_rv = new_rv_op(new_size, *dist_params) | ||
|
||
|
@@ -811,10 +862,30 @@ | |
rv_op = rv_type( | ||
inputs=dummy_params, | ||
outputs=[*dummy_updates_dict.values(), dummy_rv], | ||
ndim_supp=ndim_supp, | ||
signature=signature, | ||
) | ||
return rv_op(size, *dist_params) | ||
|
||
@staticmethod | ||
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str: | ||
"""Add size and updates to user provided gufunc signature if they are missing.""" | ||
input_sig, output_sig = signature.split("->") | ||
# Numpy parser does not accept (constant) functions without inputs like "->()" | ||
# We work around as this makes sense for distributions like Flat that have no inputs | ||
if input_sig.strip() == "": | ||
inputs = () | ||
_, outputs = _parse_gufunc_signature("()" + signature) | ||
else: | ||
inputs, outputs = _parse_gufunc_signature(signature) | ||
if len(inputs) == n_inputs - 1: | ||
# Assume size is missing | ||
input_sig = ("()," if input_sig else "()") + input_sig | ||
if len(outputs) == 1: | ||
# Assume updates are missing | ||
output_sig = "()," * n_updates + output_sig | ||
signature = "->".join((input_sig, output_sig)) | ||
return signature | ||
|
||
|
||
class CustomDist: | ||
"""A helper class to create custom distributions | ||
|
@@ -828,12 +899,12 @@ | |
when not provided by the user. | ||
|
||
Alternatively, a user can provide a `random` function that returns numerical | ||
draws (e.g., via NumPy routines), and a `logp` function that must return an | ||
Python graph that represents the logp graph when evaluated. This is used for | ||
draws (e.g., via NumPy routines), and a `logp` function that must return a | ||
PyTensor graph that represents the logp graph when evaluated. This is used for | ||
mcmc sampling. | ||
|
||
Additionally, a user can provide a `logcdf` and `support_point` functions that must return | ||
an PyTensor graph that computes those quantities. These may be used by other PyMC | ||
PyTensor graphs that computes those quantities. These may be used by other PyMC | ||
routines. | ||
|
||
Parameters | ||
|
@@ -894,14 +965,18 @@ | |
distribution parameters, in the same order as they were supplied when the | ||
CustomDist was created. If ``None``, a default ``support_point`` function will be | ||
assigned that will always return 0, or an array of zeros. | ||
ndim_supp : int | ||
The number of dimensions in the support of the distribution. Defaults to assuming | ||
a scalar distribution, i.e. ``ndim_supp = 0``. | ||
ndim_supp : Optional[int] | ||
The number of dimensions in the support of the distribution. | ||
Inferred from signature, if provided. Defaults to assuming | ||
a scalar distribution, i.e. ``ndim_supp = 0`` | ||
ndims_params : Optional[Sequence[int]] | ||
The list of number of dimensions in the support of each of the distribution's | ||
parameters. If ``None``, it is assumed that all parameters are scalars, hence | ||
the number of dimensions of their support will be 0. This is not needed if an | ||
PyTensor dist function is provided. | ||
parameters. Inferred from signature, if provided. Defaults to assuming | ||
all parameters are scalars, i.e. ``ndims_params=[0, ...]``. | ||
signature : Optional[str] | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
A numpy vectorize-like signature that indicates the number and core dimensionality | ||
of the input parameters and sample outputs of the CustomDist. | ||
When specified, `ndim_supp` and `ndims_params` are not needed. See examples below. | ||
dtype : str | ||
The dtype of the distribution. All draws and observations passed into the | ||
distribution will be cast onto this dtype. This is not needed if an PyTensor | ||
|
@@ -939,6 +1014,7 @@ | |
|
||
Provide a random function that return numerical draws. This allows one to use a | ||
CustomDist in prior and posterior predictive sampling. | ||
A gufunc signature was also provided, which may be used by other routines. | ||
|
||
.. code-block:: python | ||
|
||
|
@@ -965,6 +1041,7 @@ | |
mu, | ||
logp=logp, | ||
random=random, | ||
signature="()->()", | ||
observed=np.random.randn(100, 3), | ||
size=(100, 3), | ||
) | ||
|
@@ -973,6 +1050,7 @@ | |
Provide a dist function that creates a PyTensor graph built from other | ||
PyMC distributions. PyMC can automatically infer that the logp of this | ||
variable corresponds to a shifted Exponential distribution. | ||
A gufunc signature was also provided, which may be used by other routines. | ||
|
||
.. code-block:: python | ||
|
||
|
@@ -994,6 +1072,7 @@ | |
lam, | ||
shift, | ||
dist=dist, | ||
signature="(),()->()", | ||
observed=[-1, -1, 0], | ||
) | ||
|
||
|
@@ -1040,10 +1119,11 @@ | |
random: Optional[Callable] = None, | ||
logp: Optional[Callable] = None, | ||
logcdf: Optional[Callable] = None, | ||
moment: Optional[Callable] = None, | ||
support_point: Optional[Callable] = None, | ||
ndim_supp: int = 0, | ||
# TODO: Deprecate ndim_supp / ndims_params in favor of signature? | ||
ndim_supp: Optional[int] = None, | ||
ndims_params: Optional[Sequence[int]] = None, | ||
signature: Optional[str] = None, | ||
dtype: str = "floatX", | ||
**kwargs, | ||
): | ||
|
@@ -1057,6 +1137,7 @@ | |
) | ||
dist_params = cls.parse_dist_params(dist_params) | ||
cls.check_valid_dist_random(dist, random, dist_params) | ||
moment = kwargs.pop("moment", None) | ||
if moment is not None: | ||
warnings.warn( | ||
"`moment` argument is deprecated. Use `support_point` instead.", | ||
|
@@ -1073,6 +1154,8 @@ | |
logcdf=logcdf, | ||
support_point=support_point, | ||
ndim_supp=ndim_supp, | ||
ndims_params=ndims_params, | ||
signature=signature, | ||
**kwargs, | ||
) | ||
else: | ||
|
@@ -1086,6 +1169,7 @@ | |
support_point=support_point, | ||
ndim_supp=ndim_supp, | ||
ndims_params=ndims_params, | ||
signature=signature, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
@@ -1099,8 +1183,9 @@ | |
logp: Optional[Callable] = None, | ||
logcdf: Optional[Callable] = None, | ||
support_point: Optional[Callable] = None, | ||
ndim_supp: int = 0, | ||
ndim_supp: Optional[int] = None, | ||
ndims_params: Optional[Sequence[int]] = None, | ||
signature: Optional[str] = None, | ||
dtype: str = "floatX", | ||
**kwargs, | ||
): | ||
|
@@ -1114,6 +1199,8 @@ | |
logcdf=logcdf, | ||
support_point=support_point, | ||
ndim_supp=ndim_supp, | ||
ndims_params=ndims_params, | ||
signature=signature, | ||
**kwargs, | ||
) | ||
else: | ||
|
@@ -1125,6 +1212,7 @@ | |
support_point=support_point, | ||
ndim_supp=ndim_supp, | ||
ndims_params=ndims_params, | ||
signature=signature, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
ndim_supp
going to be deprecated? If so should we include a warning?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be deprecated but I didn't want to decide yet on that. Won't be deprecated without a warning beforehand