Skip to content

Commit

Permalink
Add gufunc signature to pre-build CustomSymbolicDistributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 16, 2024
1 parent 0d8ddba commit 217c3c4
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def rv_op(cls, dist, lower=None, upper=None, size=None):
return CensoredRV(
inputs=[dist_, lower_, upper_],
outputs=[censored_rv_],
ndim_supp=0,
gufunc_signature="(),(),()->()",
)(dist, lower, upper)


Expand Down
26 changes: 23 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -256,11 +257,28 @@ class SymbolicRandomVariable(OpFromGraph):
If `False`, a logprob function must be dispatched directly to the subclass type.
"""

signature: str = None
"""Numpy-like vectorized signature of the distribution."""

_print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""

def __init__(self, *args, ndim_supp, **kwargs):
"""Initialitze a SymbolicRandomVariable class."""
def __init__(
self,
*args,
ndim_supp: Optional[int] = None,
gufunc_signature: Optional[str] = None,
**kwargs,
):
"""Initialize a SymbolicRandomVariable class."""
if gufunc_signature is not None:
self.gufunc_signature = gufunc_signature
if ndim_supp is None:
if gufunc_signature is not None:
_, outputs_sig = _parse_gufunc_signature(gufunc_signature)
ndim_supp = max(len(out_sig) for out_sig in outputs_sig)
else:
raise ValueError("ndim_supp must be specified if gufunc_signature is not.")

Check warning on line 281 in pymc/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/distribution.py#L281

Added line #L281 was not covered by tests
self.ndim_supp = ndim_supp
kwargs.setdefault("inline", True)
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -682,7 +700,8 @@ def dist(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
gufunc_signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
Expand All @@ -700,6 +719,7 @@ def dist(
dist=dist,
moment=moment,
ndim_supp=ndim_supp,
gufunc_signature=gufunc_signature,
**kwargs,
)

Expand Down
9 changes: 8 additions & 1 deletion pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,17 @@ def rv_op(cls, weights, *components, size=None):
# Output mix_indexes rng update so that it can be updated in place
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]

s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
if len(components) == 1:
comp_s = ",".join((*s, "w"))
gufunc_signature = f"(),(w),({comp_s})->({s})"
else:
comps_s = ",".join(f"({s})" for _ in components)
gufunc_signature = f"(),(w),{comps_s}->({s})"
mix_op = MarginalMixtureRV(
inputs=[mix_indexes_rng_, weights_, *components_],
outputs=[mix_indexes_rng_next_, mix_out_],
ndim_supp=components[0].owner.op.ndim_supp,
gufunc_signature=gufunc_signature,
)

# Create the actual MarginalMixture variable
Expand Down
8 changes: 5 additions & 3 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,7 @@ def rv_op(cls, n, eta, sd_dist, size=None):
return _LKJCholeskyCovRV(
inputs=[rng_, n_, eta_, sd_dist_],
outputs=[next_rng_, lkjcov_],
ndim_supp=1,
gufunc_signature="(),(),(),(n)->(),(n)",
)(rng, n, eta, sd_dist)


Expand Down Expand Up @@ -2790,10 +2790,12 @@ def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
for axis in range(n_zerosum_axes):
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)

support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
gufunc_signature = f"({support_str}),(),(s)->({support_str})"
return ZeroSumNormalRV(
inputs=[normal_dist_, sigma_, support_shape_],
outputs=[zerosum_rv_, support_shape_],
ndim_supp=n_zerosum_axes,
outputs=[zerosum_rv_],
gufunc_signature=gufunc_signature,
)(normal_dist, sigma, support_shape)


Expand Down
13 changes: 10 additions & 3 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,17 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None):
# shape = (B, T, S)
grw_ = pt.concatenate([init_dist_dimswapped_, innovation_dist_dimswapped_], axis=-ndim_supp)
grw_ = pt.cumsum(grw_, axis=-ndim_supp)

innov_supp_dims = [f"d{i}" for i in range(ndim_supp)]
innov_supp_str = ",".join(innov_supp_dims)
out_supp_str = ",".join(["t", *innov_supp_dims])
gufunc_signature = f"({innov_supp_str}),({innov_supp_str}),(s)->({out_supp_str})"

Check warning on line 202 in pymc/distributions/timeseries.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/timeseries.py#L199-L202

Added lines #L199 - L202 were not covered by tests
return RandomWalkRV(
[init_dist_, innovation_dist_, steps_],
# We pass steps_ through just so we can keep a reference to it, even though
# it's no longer needed at this point
[grw_, steps_],
ndim_supp=ndim_supp,
[grw_],
gufunc_signature=gufunc_signature,
)(init_dist, innovation_dist, steps)


Expand Down Expand Up @@ -655,6 +660,7 @@ def step(*args):
outputs=[noise_next_rng, ar_],
ar_order=ar_order,
constant_term=constant_term,
gufunc_signature="(o),(),(o),(s)->(),(t)",
ndim_supp=1,
)

Expand Down Expand Up @@ -825,7 +831,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
garch11_op = GARCH11RV(
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_],
outputs=[noise_next_rng, garch11_],
ndim_supp=1,
gufunc_signature="(),(),(),(),(),(s)->(),(t)",
)

garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps)
Expand Down Expand Up @@ -1006,6 +1012,7 @@ def step(*prev_args):
outputs=[noise_next_rng, sde_out_],
dt=dt,
sde_fn=sde_fn,
gufunc_signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)",
ndim_supp=1,
)

Expand Down

0 comments on commit 217c3c4

Please sign in to comment.