From c885c321111aa171cb7e9db056d0824b70b35d66 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 7 Nov 2023 18:14:45 +0530 Subject: [PATCH] Fix minimum discrete formula and discrete cdf/icdf transforms --- pymc/logprob/order.py | 15 ++++++++++---- pymc/logprob/transforms.py | 18 ++++++++++++++--- tests/logprob/test_order.py | 34 +++++++++++++++++++++++--------- tests/logprob/test_transforms.py | 13 ++++++++++++ 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 46586305bef..0433dece0a7 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -270,13 +270,20 @@ def max_neg_logprob_discrete(op, values, base_rv, **kwargs): \ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n) where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables. """ + (value,) = values - logcdf = _logcdf_helper(base_rv, value) - logcdf_prev = _logcdf_helper(base_rv, value - 1) + + # The cdf of a negative variable is the survival at the negated value + logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value)) + logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1))) [n] = constant_fold([base_rv.size]) - # logprob = logdiffexp(1-n * logcdf_prev, n * logcdf) - logprob = pt.log((1 - pt.exp(logcdf_prev)) ** n - (1 - pt.exp(logcdf)) ** n) + # Now we can use the same expression as the discrete max + logprob = pt.where( + pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)), + -pt.inf, + logdiffexp(n * logcdf_prev, n * logcdf), + ) return logprob diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index e9092ea6d8f..c73095fa2ba 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -127,7 +127,11 @@ cleanup_ir_rewrites_db, measurable_ir_rewrites_db, ) -from pymc.logprob.utils import CheckParameterValue, check_negation, check_potential_measurability +from pymc.logprob.utils import ( + CheckParameterValue, + check_negation, + check_potential_measurability, +) class TransformedVariable(Op): @@ -469,6 +473,10 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg other_inputs = list(inputs) measurable_input = other_inputs.pop(op.measurable_input_idx) + # Do not apply rewrite to discrete variables + if measurable_input.type.dtype.startswith("int"): + return NotImplementedError + backward_value = op.transform_elemwise.backward(value, *other_inputs) # Fail if transformation is not injective @@ -513,6 +521,10 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) other_inputs = list(inputs) measurable_input = other_inputs.pop(op.measurable_input_idx) + # Do not apply rewrite to discrete variables + if measurable_input.type.dtype.startswith("int"): + return NotImplementedError + if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS): pass elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS): @@ -672,8 +684,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li # Do not apply rewrite to discrete variables if measurable_input.type.dtype.startswith("int"): - if check_negation(node.op.scalar_op, node.inputs[0]) is False and not isinstance( - node.op.scalar_op, Add + if not ( + check_negation(node.op.scalar_op, node.inputs[0]) or isinstance(node.op.scalar_op, Add) ): return None diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 4937405ac30..a7bbc97f053 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -37,6 +37,7 @@ import re import numpy as np +import pytensor import pytensor.tensor as pt import pytest import scipy.stats as sp @@ -257,23 +258,38 @@ def test_max_discrete(mu, size, value, axis): @pytest.mark.parametrize( - "mu, size, value, axis", + "mu, n, test_value, axis", [(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], ) -def test_min_discrete(mu, size, value, axis): - x = pm.Poisson.dist(name="x", mu=mu, size=(size)) +def test_min_discrete(mu, n, test_value, axis): + x = pm.Poisson.dist(name="x", mu=mu, size=(n,)) x_min = pt.min(x, axis=axis) x_min_value = pt.scalar("x_min_value") x_min_logprob = logp(x_min, x_min_value) - test_value = value + sf_before = 1 - sp.poisson(mu).cdf(test_value - 1) + sf = 1 - sp.poisson(mu).cdf(test_value) - n = size - exp_rv = (1 - sp.poisson(mu).cdf(test_value)) ** n - exp_rv_prev = (1 - sp.poisson(mu).cdf(test_value - 1)) ** n + expected_logp = np.log(sf_before**n - sf**n) np.testing.assert_allclose( - (np.log(exp_rv_prev - exp_rv)), - (x_min_logprob.eval({x_min_value: (test_value)})), + x_min_logprob.eval({x_min_value: test_value}), + expected_logp, rtol=1e-06, ) + + +def test_min_max_bernoulli(): + p = 0.7 + q = 1 - p + n = 3 + x = pm.Bernoulli.dist(p=p, shape=(n,)) + value = pt.scalar("value", dtype=int) + + max_logp_fn = pytensor.function([value], pm.logp(pt.max(x), value)) + np.testing.assert_allclose(max_logp_fn(0), np.log(q**n)) + np.testing.assert_allclose(max_logp_fn(1), np.log(1 - q**n)) + + min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value)) + np.testing.assert_allclose(min_logp_fn(1), np.log(p**n)) + np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n)) diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 32924a37d20..ea6a0f4d1f3 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -47,6 +47,8 @@ from pytensor.graph.fg import FunctionGraph from pytensor.scan import scan +import pymc as pm + from pymc.distributions.continuous import Cauchy from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob.abstract import MeasurableVariable, _logprob @@ -1262,3 +1264,14 @@ def test_invalid_broadcasted_transform_rv_fails(): # This logp derivation should fail or count only once the values that are broadcasted logprob = logp(y_rv, y_vv) assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == () + + +def test_discrete_measurable_cdf_icdf(): + p = 0.7 + rv = -pm.Bernoulli.dist(p=p) + + # A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise} + assert pm.logp(rv, -2).eval() == -np.inf # Correct + assert pm.logp(rv, -1).eval() == np.log(p) # Correct + assert pm.logp(rv, 0).eval() == np.log(1 - p) # Correct + assert pm.logp(rv, 1).eval() == -np.inf # Correct