From 6088c26b771bf37fa61e714c40d64230a0252976 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 30 Sep 2022 10:31:53 +0200 Subject: [PATCH] Pass size to specialized truncated dispatch --- pymc/distributions/truncated.py | 6 +++--- pymc/tests/distributions/test_truncated.py | 23 ++++++++++++++++++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 004b4f3c0a6..d187597a89b 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -53,7 +53,7 @@ def update(self, node: Node): @singledispatch -def _truncated(op: Op, lower, upper, *params): +def _truncated(op: Op, lower, upper, size, *params): """Return the truncated equivalent of another `RandomVariable`.""" raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented") @@ -150,7 +150,7 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None): # Try to use specialized Op try: - return _truncated(dist.owner.op, lower, upper, *dist.owner.inputs) + return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs) except NotImplementedError: pass @@ -339,7 +339,7 @@ def truncated_logprob(op, values, *inputs, **kwargs): @_truncated.register(NormalRV) -def _truncated_normal(op, lower, upper, rng, size, dtype, mu, sigma): +def _truncated_normal(op, lower, upper, size, rng, old_size, dtype, mu, sigma): return TruncatedNormal.dist( mu=mu, sigma=sigma, diff --git a/pymc/tests/distributions/test_truncated.py b/pymc/tests/distributions/test_truncated.py index dc589d8255c..f0048b674d6 100644 --- a/pymc/tests/distributions/test_truncated.py +++ b/pymc/tests/distributions/test_truncated.py @@ -53,12 +53,31 @@ def _icdf_not_implemented(*args, **kwargs): raise NotImplementedError() -def test_truncation_specialized_op(): +@pytest.mark.parametrize("shape_info", ("shape", "dims", "observed")) +def test_truncation_specialized_op(shape_info): rng = aesara.shared(np.random.default_rng()) x = at.random.normal(0, 10, rng=rng, name="x") - xt = Truncated.dist(x, lower=5, upper=15, shape=(100,)) + with Model(coords={"dim": range(100)}) as m: + if shape_info == "shape": + xt = Truncated("xt", dist=x, lower=5, upper=15, shape=(100,)) + elif shape_info == "dims": + xt = Truncated("xt", dist=x, lower=5, upper=15, dims=("dim",)) + elif shape_info == "observed": + xt = Truncated( + "xt", + dist=x, + lower=5, + upper=15, + observed=np.empty( + 100, + ), + ) + else: + raise ValueError(f"Not a valid shape_info parametrization: {shape_info}") + assert isinstance(xt.owner.op, TruncatedNormalRV) + assert xt.shape.eval() == (100,) # Test RNG is not reused assert xt.owner.inputs[0] is not rng