Skip to content

Commit

Permalink
Implement geometric logcdf and icdf
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 11, 2022
1 parent fc6c63e commit a330723
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
16 changes: 16 additions & 0 deletions aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,22 @@ def geometric_logprob(op, values, *inputs, **kwargs):
return res


@_logcdf.register(arb.GeometricRV)
def geometric_logcdf(op, value, *inputs, **kwargs):
(p,) = inputs[3:]
res = at.switch(at.le(value, 0), -np.inf, at.log1mexp(at.log1p(-p) * value))
res = CheckParameterValue("0 <= p <= 1")(
res, at.all(at.le(0.0, p)), at.all(at.ge(1.0, p))
)
return res


@_icdf.register(arb.GeometricRV)
def geometric_icdf(op, value, *inputs, **kwargs):
(p,) = inputs[3:]
return at.ceil(at.log1p(-value) / at.log1p(-p)).astype(op.dtype)


@_logprob.register(arb.HyperGeometricRV)
def hypergeometric_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand Down
55 changes: 55 additions & 0 deletions tests/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,61 @@ def scipy_logprob(obs, p):
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
((-1,), np.array([0, 1, 100, 10000], dtype=np.int64), (), True),
((0.1,), np.array([0, 1, 100, 10000], dtype=np.int64), (), False),
((1.0,), np.array([0, 1, 100, 10000], dtype=np.int64), (3, 2), False),
(
(np.array([0.01, 0.2, 0.8]),),
np.array([-1, 1, 84], dtype=np.int64),
(),
False,
),
],
)
def test_geometric_logcdf(dist_params, obs, size, error):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.geometric(*dist_params_at, size=size_at)

cm = contextlib.suppress() if not error else pytest.raises(ParameterValueError)

with cm:
scipy_logprob_tester(
x, obs, dist_params, test_fn=stats.geom.logcdf, test="logcdf"
)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0.1,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), ()),
((0.5,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), (3, 2)),
(
(np.array([0.0, 0.2, 0.5, 1.0]),),
np.array([0.7, 0.7, 0.7, 0.7], dtype=np.int64),
(),
),
],
)
def test_geometric_icdf(dist_params, obs, size):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.geometric(*dist_params_at, size=size_at)

def scipy_geom_icdf(value, p):
# Scipy ppf returns floats
return stats.geom.ppf(value, p).astype(value.dtype)

scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_geom_icdf, test="icdf")


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
Expand Down

0 comments on commit a330723

Please sign in to comment.