Skip to content

Commit

Permalink
Correction in logprob derivation of discrete distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Apr 18, 2023
1 parent 597e44e commit a5c48c2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
10 changes: 4 additions & 6 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,10 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
else:
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")

if base_rv.dtype.startswith("int"):
logp_point = _logprob_helper(base_rv, operand, **kwargs)
if isinstance(op.scalar_op, GT):
logprob = pt.switch(condn_exp, pt.logaddexp(logprob, logp_point), logprob)
elif isinstance(op.scalar_op, LT):
logprob = pt.switch(condn_exp, logprob, pt.logaddexp(logprob, logp_point))
if base_rv.dtype.startswith("int") and isinstance(op.scalar_op, LT):
logpmf = _logprob_helper(base_rv, operand, **kwargs)
logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs)
logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logprob, logpmf))

if base_rv_op.name:
logprob.name = f"{base_rv_op}_logprob"
Expand Down
4 changes: 2 additions & 2 deletions tests/logprob/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
[
(
pt.lt,
st.poisson(2).logcdf,
lambda x: st.poisson(2).logcdf(x - 1),
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
),
(
pt.gt,
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
st.poisson(2).logsf,
st.poisson(2).logcdf,
),
],
Expand Down

0 comments on commit a5c48c2

Please sign in to comment.