Skip to content

Commit

Permalink
Add logprob derivation for >= and <= operations
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Apr 19, 2023
1 parent 9b712bf commit 40de354
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
27 changes: 15 additions & 12 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import GT, LT
from pytensor.tensor.math import gt, lt
from pytensor.scalar.basic import GE, GT, LE, LT
from pytensor.tensor.math import ge, gt, le, lt

from pymc.logprob.abstract import (
MeasurableElemwise,
Expand All @@ -36,10 +36,10 @@
class MeasurableComparison(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph."""

valid_scalar_types = (GT, LT)
valid_scalar_types = (GT, LT, GE, LE)


@node_rewriter(tracks=[gt, lt])
@node_rewriter(tracks=[gt, lt, ge, le])
def find_measurable_comparisons(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableComparison]]:
Expand Down Expand Up @@ -92,18 +92,21 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):

condn_exp = pt.eq(value, np.array(True))

if isinstance(op.scalar_op, GT):
if isinstance(op.scalar_op, (GT, GE)):
logprob = pt.switch(condn_exp, logccdf, logcdf)
elif isinstance(op.scalar_op, LT):
if base_rv.dtype.startswith("int"):
logpmf = _logprob_helper(base_rv, operand, **kwargs)
logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs)
logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf))
else:
logprob = pt.switch(condn_exp, logcdf, logccdf)
elif isinstance(op.scalar_op, (LT, LE)):
logprob = pt.switch(condn_exp, logcdf, logccdf)
else:
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")

if base_rv.dtype.startswith("int"):
logpmf = _logprob_helper(base_rv, operand, **kwargs)
logcdf_prev = _logcdf_helper(base_rv, operand - 1, **kwargs)
if isinstance(op.scalar_op, LT):
return pt.switch(condn_exp, logcdf_prev, pt.logaddexp(logccdf, logpmf))
elif isinstance(op.scalar_op, GE):
return pt.switch(condn_exp, pt.logaddexp(logccdf, logpmf), logcdf_prev)

if base_rv_op.name:
logprob.name = f"{base_rv_op}_logprob"
logcdf.name = f"{base_rv_op}_logcdf"
Expand Down
30 changes: 21 additions & 9 deletions tests/logprob/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,25 @@
@pytest.mark.parametrize(
"comparison_op, exp_logp_true, exp_logp_false",
[
(pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf),
(pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf),
((pt.lt, pt.le), "logcdf", "logsf"),
((pt.gt, pt.ge), "logsf", "logcdf"),
],
)
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
x_rv = pt.random.normal(0, 1)
comp_x_rv = comparison_op(x_rv, 0.5)
for op in comparison_op:
comp_x_rv = op(x_rv, 0.5)

comp_x_vv = comp_x_rv.clone()
comp_x_vv = comp_x_rv.clone()

logprob = logp(comp_x_rv, comp_x_vv)
assert_no_rvs(logprob)
logprob = logp(comp_x_rv, comp_x_vv)
assert_no_rvs(logprob)

logp_fn = pytensor.function([comp_x_vv], logprob)
logp_fn = pytensor.function([comp_x_vv], logprob)
ref_scipy = st.norm(0, 1)

assert np.isclose(logp_fn(0), exp_logp_false(0.5))
assert np.isclose(logp_fn(1), exp_logp_true(0.5))
assert np.isclose(logp_fn(0), getattr(ref_scipy, exp_logp_false)(0.5))
assert np.isclose(logp_fn(1), getattr(ref_scipy, exp_logp_true)(0.5))


@pytest.mark.parametrize(
Expand All @@ -54,11 +56,21 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
lambda x: st.poisson(2).logcdf(x - 1),
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
),
(
pt.ge,
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
lambda x: st.poisson(2).logcdf(x - 1),
),
(
pt.gt,
st.poisson(2).logsf,
st.poisson(2).logcdf,
),
(
pt.le,
st.poisson(2).logcdf,
st.poisson(2).logsf,
),
],
)
def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
Expand Down

0 comments on commit 40de354

Please sign in to comment.