diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index b9391b5b43f..0dc78d0b0d5 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -41,10 +41,7 @@ from pytensor.graph.basic import Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Max from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable @@ -56,6 +53,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import find_negated_var from pymc.math import logdiffexp from pymc.pytensorf import constant_fold @@ -168,6 +166,13 @@ class MeasurableMaxNeg(Max): MeasurableVariable.register(MeasurableMaxNeg) +class MeasurableDiscreteMaxNeg(Max): + """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables""" + + +MeasurableVariable.register(MeasurableDiscreteMaxNeg) + + @node_rewriter(tracks=[Max]) def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) @@ -180,37 +185,20 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[ base_var = node.inputs[0] - if base_var.owner is None: - return None - - if not rv_map_feature.request_measurable(node.inputs): - return None - # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise - if not isinstance(base_var.owner.op, Elemwise): + if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)): return None + base_rv = find_negated_var(base_var) + # negation is rv * (-1). Hence the scalar_op must be Mul - try: - if not ( - isinstance(base_var.owner.op.scalar_op, Mul) - and len(base_var.owner.inputs) == 2 - and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1 - ): - return None - except NotScalarConstantError: + if base_rv is None: return None - base_rv = base_var.owner.inputs[0] - # Non-univariate distributions and non-RVs must be rejected if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0): return None - # TODO: We are currently only supporting continuous rvs - if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"): - return None - # univariate i.i.d. test which also rules out other distributions for params in base_rv.owner.inputs[3:]: if params.type.ndim != 0: @@ -222,11 +210,16 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[ if axis != base_var_dims: return None - measurable_min = MeasurableMaxNeg(list(axis)) - min_rv_node = measurable_min.make_node(base_var) - min_rv = min_rv_node.outputs + if not rv_map_feature.request_measurable([base_rv]): + return None - return min_rv + # distinguish measurable discrete and continuous (because logprob is different) + if base_rv.owner.op.dtype.startswith("int"): + measurable_min = MeasurableDiscreteMaxNeg(list(axis)) + else: + measurable_min = MeasurableMaxNeg(list(axis)) + + return measurable_min.make_node(base_rv).outputs measurable_ir_rewrites_db.register( @@ -238,14 +231,13 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[ @_logprob.register(MeasurableMaxNeg) -def max_neg_logprob(op, values, base_var, **kwargs): +def max_neg_logprob(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. The formula that we use here is : \ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x)) where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively. """ (value,) = values - base_rv = base_var.owner.inputs[0] logprob = _logprob_helper(base_rv, -value) logcdf = _logcdf_helper(base_rv, -value) @@ -254,3 +246,31 @@ def max_neg_logprob(op, values, base_var, **kwargs): logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n) return logprob + + +@_logprob.register(MeasurableDiscreteMaxNeg) +def discrete_max_neg_logprob(op, values, base_rv, **kwargs): + r"""Compute the log-likelihood graph for the `Max` operation. + + The formula that we use here is : + .. math:: + \ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n) + where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables. + """ + + (value,) = values + + # The cdf of a negative variable is the survival at the negated value + logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value)) + logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1))) + + [n] = constant_fold([base_rv.size]) + + # Now we can use the same expression as the discrete max + logprob = pt.where( + pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)), + -pt.inf, + logdiffexp(n * logcdf_prev, n * logcdf), + ) + + return logprob diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 5796c9443c7..49827f7a618 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -49,6 +49,10 @@ from pytensor.graph.op import HasInnerGraph from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise +from pytensor.scalar.basic import Mul +from pytensor.tensor.basic import get_underlying_scalar_constant_value +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable @@ -296,3 +300,26 @@ def diracdelta_logprob(op, values, *inputs, **kwargs): (const_value,) = inputs values, const_value = pt.broadcast_arrays(values, const_value) return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf) + + +def find_negated_var(var): + """Return a variable that is being multiplied by -1 or None otherwise.""" + + if ( + not (var.owner) + and isinstance(var.owner.op, Elemwise) + and isinstance(var.owner.op.scalar_op, Mul) + ): + return None + if len(var.owner.inputs) != 2: + return None + + inputs = var.owner.inputs + for mul_var, mul_const in (inputs, reversed(inputs)): + try: + if get_underlying_scalar_constant_value(mul_const) == -1: + return mul_var + except NotScalarConstantError: + continue + + return None diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 78bba2a0bb0..4d15240375a 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -37,6 +37,7 @@ import re import numpy as np +import pytensor import pytensor.tensor as pt import pytest import scipy.stats as sp @@ -254,3 +255,41 @@ def test_max_discrete(mu, size, value, axis): (x_max_logprob.eval({x_max_value: test_value})), rtol=1e-06, ) + + +@pytest.mark.parametrize( + "mu, n, test_value, axis", + [(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], +) +def test_min_discrete(mu, n, test_value, axis): + x = pm.Poisson.dist(name="x", mu=mu, size=(n,)) + x_min = pt.min(x, axis=axis) + x_min_value = pt.scalar("x_min_value") + x_min_logprob = logp(x_min, x_min_value) + + sf_before = 1 - sp.poisson(mu).cdf(test_value - 1) + sf = 1 - sp.poisson(mu).cdf(test_value) + + expected_logp = np.log(sf_before**n - sf**n) + + np.testing.assert_allclose( + x_min_logprob.eval({x_min_value: test_value}), + expected_logp, + rtol=1e-06, + ) + + +def test_min_max_bernoulli(): + p = 0.7 + q = 1 - p + n = 3 + x = pm.Bernoulli.dist(name="x", p=p, shape=(n,)) + value = pt.scalar("value", dtype=int) + + max_logp_fn = pytensor.function([value], pm.logp(pt.max(x), value)) + np.testing.assert_allclose(max_logp_fn(0), np.log(q**n)) + np.testing.assert_allclose(max_logp_fn(1), np.log(1 - q**n)) + + min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value)) + np.testing.assert_allclose(min_logp_fn(1), np.log(p**n)) + np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n))