Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Inverse CDF dispatch functions #147

Merged
merged 5 commits into from
Aug 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions aeppl/logprob.py
Original file line number Diff line number Diff line change
@@ -68,6 +68,14 @@ def logcdf(rv_var, rv_value, **kwargs):
return logcdf


def icdf(rv, value, **kwargs):
"""Create a graph for the inverse CDF of a `RandomVariable`."""
rv_icdf = _icdf(rv.owner.op, value, *rv.owner.inputs, **kwargs)
if rv.name:
rv_icdf.name = f"{rv.name}_icdf"
return rv_icdf


@singledispatch
def _logprob(
op: Op,
@@ -101,6 +109,21 @@ def _logcdf(
raise NotImplementedError(f"Logcdf method not implemented for {op}")


@singledispatch
def _icdf(
op: Op,
value: TensorVariable,
*inputs: TensorVariable,
**kwargs,
):
"""Create a graph for the inverse CDF of a `RandomVariable`.

This function dispatches on the type of `op`, which should be a subclass
of `RandomVariable`.
"""
raise NotImplementedError(f"icdf not implemented for {op}")


@_logprob.register(arb.UniformRV)
def uniform_logprob(op, values, *inputs, **kwargs):
(value,) = values
@@ -158,6 +181,12 @@ def normal_logcdf(op, value, *inputs, **kwargs):
return res


@_icdf.register(arb.NormalRV)
def normal_icdf(op, value, *inputs, **kwargs):
loc, scale = inputs[3:]
return loc + scale * -np.sqrt(2.0) * at.erfcinv(2 * value)


@_logprob.register(arb.HalfNormalRV)
def halfnormal_logprob(op, values, *inputs, **kwargs):
(value,) = values
@@ -470,6 +499,22 @@ def geometric_logprob(op, values, *inputs, **kwargs):
return res


@_logcdf.register(arb.GeometricRV)
def geometric_logcdf(op, value, *inputs, **kwargs):
(p,) = inputs[3:]
res = at.switch(at.le(value, 0), -np.inf, at.log1mexp(at.log1p(-p) * value))
res = CheckParameterValue("0 <= p <= 1")(
res, at.all(at.le(0.0, p)), at.all(at.ge(1.0, p))
)
return res


@_icdf.register(arb.GeometricRV)
def geometric_icdf(op, value, *inputs, **kwargs):
(p,) = inputs[3:]
return at.ceil(at.log1p(-value) / at.log1p(-p)).astype(op.dtype)


@_logprob.register(arb.HyperGeometricRV)
def hypergeometric_logprob(op, values, *inputs, **kwargs):
(value,) = values
2 changes: 1 addition & 1 deletion aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
logccdf = at.log1mexp(logcdf)
# For right censored discrete RVs, we need to add an extra term
# corresponding to the pmf at the upper bound
if base_rv_op.dtype == "int64":
if base_rv_op.dtype.startswith("int"):
logccdf = at.logaddexp(logccdf, logprob)

logprob = at.switch(
96 changes: 85 additions & 11 deletions tests/test_logprob.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from aesara import function

from aeppl.dists import dirac_delta
from aeppl.logprob import ParameterValueError, logcdf, logprob
from aeppl.logprob import ParameterValueError, icdf, logcdf, logprob

# @pytest.fixture(scope="module", autouse=True)
# def set_aesara_flags():
@@ -48,7 +48,7 @@ def create_aesara_params(dist_params, obs, size):


def scipy_logprob_tester(
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True, test_logcdf=False
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True, test="logprob"
):
"""Test for correspondence between `RandomVariable` and NumPy shape and
broadcast dimensions.
@@ -61,10 +61,15 @@ def scipy_logprob_tester(

test_fn = getattr(stats, name)

if not test_logcdf:
if test == "logprob":
aesara_res = logprob(rv_var, at.as_tensor(obs))
else:
elif test == "logcdf":
aesara_res = logcdf(rv_var, at.as_tensor(obs))
elif test == "icdf":
aesara_res = icdf(rv_var, at.as_tensor(obs))
else:
raise ValueError(f"test must be one of (logprob, logcdf, icdf), got {test}")

aesara_res_val = aesara_res.eval(dist_params)

numpy_res = np.asarray(test_fn(obs, *dist_params.values()))
@@ -118,7 +123,7 @@ def test_uniform_logcdf(dist_params, obs, size):
def scipy_logcdf(obs, l, u):
return stats.uniform.logcdf(obs, loc=l, scale=u - l)

scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True)
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logcdf, test="logcdf")


@pytest.mark.parametrize(
@@ -154,9 +159,25 @@ def test_normal_logcdf(dist_params, obs, size):

x = at.random.normal(*dist_params_at, size=size_at)

scipy_logprob_tester(
x, obs, dist_params, test_fn=stats.norm.logcdf, test_logcdf=True
)
scipy_logprob_tester(x, obs, dist_params, test_fn=stats.norm.logcdf, test="logcdf")


@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0, 1), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), (2, 3)),
],
)
def test_normal_icdf(dist_params, obs, size):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.normal(*dist_params_at, size=size_at)

scipy_logprob_tester(x, obs, dist_params, test_fn=stats.norm.ppf, test="icdf")


@pytest.mark.parametrize(
@@ -705,9 +726,7 @@ def scipy_logcdf(obs, mu):
return stats.poisson.logcdf(obs, mu)

with cm:
scipy_logprob_tester(
x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True
)
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logcdf, test="logcdf")


@pytest.mark.parametrize(
@@ -776,6 +795,61 @@ def scipy_logprob(obs, p):
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
((-1,), np.array([0, 1, 100, 10000], dtype=np.int64), (), True),
((0.1,), np.array([0, 1, 100, 10000], dtype=np.int64), (), False),
((1.0,), np.array([0, 1, 100, 10000], dtype=np.int64), (3, 2), False),
(
(np.array([0.01, 0.2, 0.8]),),
np.array([-1, 1, 84], dtype=np.int64),
(),
False,
),
],
)
def test_geometric_logcdf(dist_params, obs, size, error):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.geometric(*dist_params_at, size=size_at)

cm = contextlib.suppress() if not error else pytest.raises(ParameterValueError)

with cm:
scipy_logprob_tester(
x, obs, dist_params, test_fn=stats.geom.logcdf, test="logcdf"
)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0.1,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), ()),
((0.5,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), (3, 2)),
(
(np.array([0.0, 0.2, 0.5, 1.0]),),
np.array([0.7, 0.7, 0.7, 0.7], dtype=np.int64),
(),
),
],
)
def test_geometric_icdf(dist_params, obs, size):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.geometric(*dist_params_at, size=size_at)

def scipy_geom_icdf(value, p):
# Scipy ppf returns floats
return stats.geom.ppf(value, p).astype(value.dtype)

scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_geom_icdf, test="icdf")


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
1 change: 1 addition & 0 deletions tests/test_scan.py
Original file line number Diff line number Diff line change
@@ -345,6 +345,7 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t):
assert np.allclose(y_logp_val, y_logp_ref_val)


@pytest.mark.xfail(reason="see #148")
@aesara.config.change_flags(compute_test_value="raise")
@pytest.mark.xfail(reason="see #148")
def test_initial_values():