From 40de354ed3f363920f130be4e2589bc49cd94a59 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 19 Apr 2023 16:07:00 +0530 Subject: [PATCH] Add logprob derivation for >= and <= operations --- pymc/logprob/binary.py | 27 +++++++++++++++------------ tests/logprob/test_binary.py | 30 +++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index a35673b4547..72224d394cf 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -19,8 +19,8 @@ from pytensor.graph.basic import Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.scalar.basic import GT, LT -from pytensor.tensor.math import gt, lt +from pytensor.scalar.basic import GE, GT, LE, LT +from pytensor.tensor.math import ge, gt, le, lt from pymc.logprob.abstract import ( MeasurableElemwise, @@ -36,10 +36,10 @@ class MeasurableComparison(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph.""" - valid_scalar_types = (GT, LT) + valid_scalar_types = (GT, LT, GE, LE) -@node_rewriter(tracks=[gt, lt]) +@node_rewriter(tracks=[gt, lt, ge, le]) def find_measurable_comparisons( fgraph: FunctionGraph, node: Node ) -> Optional[List[MeasurableComparison]]: @@ -92,18 +92,21 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs): condn_exp = pt.eq(value, np.array(True)) - if isinstance(op.scalar_op, GT): + if isinstance(op.scalar_op, (GT, GE)): logprob = pt.switch(condn_exp, logccdf, logcdf) - elif isinstance(op.scalar_op, LT): - if base_rv.dtype.startswith("int"): - logpmf = _logprob_helper(base_rv, operand, **kwargs) - logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs) - logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf)) - else: - logprob = pt.switch(condn_exp, logcdf, logccdf) + elif isinstance(op.scalar_op, (LT, LE)): + logprob = pt.switch(condn_exp, logcdf, logccdf) else: raise TypeError(f"Unsupported scalar_op {op.scalar_op}") + if base_rv.dtype.startswith("int"): + logpmf = _logprob_helper(base_rv, operand, **kwargs) + logcdf_prev = _logcdf_helper(base_rv, operand - 1, **kwargs) + if isinstance(op.scalar_op, LT): + return pt.switch(condn_exp, logcdf_prev, pt.logaddexp(logccdf, logpmf)) + elif isinstance(op.scalar_op, GE): + return pt.switch(condn_exp, pt.logaddexp(logccdf, logpmf), logcdf_prev) + if base_rv_op.name: logprob.name = f"{base_rv_op}_logprob" logcdf.name = f"{base_rv_op}_logcdf" diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index c2dc6926598..0780dcf0936 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -27,23 +27,25 @@ @pytest.mark.parametrize( "comparison_op, exp_logp_true, exp_logp_false", [ - (pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf), - (pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf), + ((pt.lt, pt.le), "logcdf", "logsf"), + ((pt.gt, pt.ge), "logsf", "logcdf"), ], ) def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): x_rv = pt.random.normal(0, 1) - comp_x_rv = comparison_op(x_rv, 0.5) + for op in comparison_op: + comp_x_rv = op(x_rv, 0.5) - comp_x_vv = comp_x_rv.clone() + comp_x_vv = comp_x_rv.clone() - logprob = logp(comp_x_rv, comp_x_vv) - assert_no_rvs(logprob) + logprob = logp(comp_x_rv, comp_x_vv) + assert_no_rvs(logprob) - logp_fn = pytensor.function([comp_x_vv], logprob) + logp_fn = pytensor.function([comp_x_vv], logprob) + ref_scipy = st.norm(0, 1) - assert np.isclose(logp_fn(0), exp_logp_false(0.5)) - assert np.isclose(logp_fn(1), exp_logp_true(0.5)) + assert np.isclose(logp_fn(0), getattr(ref_scipy, exp_logp_false)(0.5)) + assert np.isclose(logp_fn(1), getattr(ref_scipy, exp_logp_true)(0.5)) @pytest.mark.parametrize( @@ -54,11 +56,21 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): lambda x: st.poisson(2).logcdf(x - 1), lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), ), + ( + pt.ge, + lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + lambda x: st.poisson(2).logcdf(x - 1), + ), ( pt.gt, st.poisson(2).logsf, st.poisson(2).logcdf, ), + ( + pt.le, + st.poisson(2).logcdf, + st.poisson(2).logsf, + ), ], ) def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):