Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more distributions to be truncated #7476

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Censored(Distribution):
rv_op = CensoredRV.rv_op

@classmethod
def dist(cls, dist, lower, upper, **kwargs):
def dist(cls, dist, lower=-np.inf, upper=np.inf, **kwargs):
if not isinstance(dist, TensorVariable) or not isinstance(
dist.owner.op, RandomVariable | SymbolicRandomVariable
):
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2941,8 +2941,8 @@ class ExGaussian(Continuous):
rv_op = ExGaussianRV.rv_op

@classmethod
def dist(cls, mu=0.0, sigma=None, nu=None, *args, **kwargs):
return super().dist([mu, sigma, nu], *args, **kwargs)
def dist(cls, mu=0.0, sigma=1.0, *, nu, **kwargs):
return super().dist([mu, sigma, nu], **kwargs)

def support_point(rv, size, mu, sigma, nu):
mu, nu, _ = pt.broadcast_arrays(mu, nu, sigma)
Expand Down
20 changes: 13 additions & 7 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pytensor.tensor.random.type import RandomType

from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
from pymc.distributions.custom import CustomSymbolicDistRV
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
Distribution,
Expand Down Expand Up @@ -302,17 +301,24 @@ class Truncated(Distribution):
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
if not (
isinstance(dist, TensorVariable)
and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV)
and dist.owner is not None
and isinstance(dist.owner.op, RandomVariable | SymbolicRandomVariable)
):
if isinstance(dist.owner.op, SymbolicRandomVariable):
raise NotImplementedError(
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n"
f"You can try wrapping the distribution inside a CustomDist instead."
)
raise ValueError(
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)

if (
isinstance(dist.owner.op, SymbolicRandomVariable)
and "[size]" not in dist.owner.op.extended_signature
):
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
# distribution factories like `Censored` and `Mixture` which would have a very complex signature if they
# encapsulated the random components instead of taking them as inputs like they do now.
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}")

if dist.owner.op.ndim_supp > 0:
raise NotImplementedError("Truncation not implemented for multivariate distributions")

Expand Down
21 changes: 19 additions & 2 deletions tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytensor.tensor.random.basic import GeometricRV, NormalRV
from pytensor.tensor.random.type import RandomType

from pymc import Model, draw, find_MAP
from pymc import ExGaussian, Model, Normal, draw, find_MAP
from pymc.distributions import (
Censored,
ChiSquared,
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_truncation_exceptions():
# Truncation does not work with SymbolicRV inputs
with pytest.raises(
NotImplementedError,
match="Truncation not implemented for SymbolicRandomVariable CensoredRV",
match="Truncation not implemented for CensoredRV",
):
Truncated.dist(Censored.dist(pt.random.normal(), lower=-1, upper=1), -1, 1)

Expand Down Expand Up @@ -599,3 +599,20 @@ def dist(scale, size):
rv_out = Truncated.dist(latent, upper=7)

assert np.ptp(draw(rv_out, draws=100)) < 7


@pytest.mark.parametrize(
"dist_fn",
[
lambda: ExGaussian.dist(nu=3),
pytest.param(
lambda: Censored.dist(Normal.dist(), lower=1),
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
)
def test_truncated_symbolic_rv(dist_fn):
dist = dist_fn()
trunc_dist = Truncated.dist(dist, lower=1, upper=3)
assert 1 <= draw(trunc_dist) <= 3
assert (logp(trunc_dist, 2.5) > logp(dist, 2.5)).eval()
Loading