Skip to content

Commit

Permalink
Correction in logprob derivation of discrete distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Apr 19, 2023
1 parent 597e44e commit 19ab8ad
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
14 changes: 6 additions & 8 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,15 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
if isinstance(op.scalar_op, GT):
logprob = pt.switch(condn_exp, logccdf, logcdf)
elif isinstance(op.scalar_op, LT):
logprob = pt.switch(condn_exp, logcdf, logccdf)
if base_rv.dtype.startswith("int"):
logpmf = _logprob_helper(base_rv, operand, **kwargs)
logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs)
logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf))
else:
logprob = pt.switch(condn_exp, logcdf, logccdf)
else:
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")

if base_rv.dtype.startswith("int"):
logp_point = _logprob_helper(base_rv, operand, **kwargs)
if isinstance(op.scalar_op, GT):
logprob = pt.switch(condn_exp, pt.logaddexp(logprob, logp_point), logprob)
elif isinstance(op.scalar_op, LT):
logprob = pt.switch(condn_exp, logprob, pt.logaddexp(logprob, logp_point))

if base_rv_op.name:
logprob.name = f"{base_rv_op}_logprob"
logcdf.name = f"{base_rv_op}_logcdf"
Expand Down
4 changes: 2 additions & 2 deletions tests/logprob/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
[
(
pt.lt,
st.poisson(2).logcdf,
lambda x: st.poisson(2).logcdf(x - 1),
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
),
(
pt.gt,
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
st.poisson(2).logsf,
st.poisson(2).logcdf,
),
],
Expand Down

0 comments on commit 19ab8ad

Please sign in to comment.