Skip to content

Commit

Permalink
Allow truncation of self-contained SymbolicRandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 24, 2024
1 parent 3cf860d commit 4c0f4b7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
20 changes: 13 additions & 7 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pytensor.tensor.random.type import RandomType

from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
from pymc.distributions.custom import CustomSymbolicDistRV
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
Distribution,
Expand Down Expand Up @@ -302,17 +301,24 @@ class Truncated(Distribution):
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
if not (
isinstance(dist, TensorVariable)
and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV)
and dist.owner is not None
and isinstance(dist.owner.op, RandomVariable | SymbolicRandomVariable)
):
if isinstance(dist.owner.op, SymbolicRandomVariable):
raise NotImplementedError(
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n"
f"You can try wrapping the distribution inside a CustomDist instead."
)
raise ValueError(
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)

if (
isinstance(dist.owner.op, SymbolicRandomVariable)
and "[size]" not in dist.owner.op.extended_signature
):
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
# distribution factories like `Censored` and `Mixture` which would have a very complex signature if they
# encapsulated the random components instead of taking them as inputs like they do now.
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}")

if dist.owner.op.ndim_supp > 0:
raise NotImplementedError("Truncation not implemented for multivariate distributions")

Expand Down
19 changes: 18 additions & 1 deletion tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytensor.tensor.random.basic import GeometricRV, NormalRV
from pytensor.tensor.random.type import RandomType

from pymc import Model, draw, find_MAP
from pymc import ExGaussian, Model, Normal, draw, find_MAP
from pymc.distributions import (
Censored,
ChiSquared,
Expand Down Expand Up @@ -599,3 +599,20 @@ def dist(scale, size):
rv_out = Truncated.dist(latent, upper=7)

assert np.ptp(draw(rv_out, draws=100)) < 7


@pytest.mark.parametrize(
"dist_fn",
[
lambda: ExGaussian.dist(nu=3),
pytest.param(
lambda: Censored.dist(Normal.dist(), lower=1),
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
)
def test_truncated_symbolic_rv(dist_fn):
dist = dist_fn()
trunc_dist = Truncated.dist(dist, lower=1, upper=3)
assert 1 <= draw(trunc_dist) <= 3
assert (logp(trunc_dist, 2.5) > logp(dist, 2.5)).eval()

0 comments on commit 4c0f4b7

Please sign in to comment.