Skip to content

Commit

Permalink
Add Uniform pareto conjugates
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing Xie committed Mar 23, 2023
1 parent 64b0e50 commit 18a1c16
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
86 changes: 85 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV, UniformRV
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var
Expand Down Expand Up @@ -268,13 +268,97 @@ def local_beta_negative_binomial_posterior(fgraph, node):
return rv_var.owner.outputs


def uniform_pareto_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a pareto prior with a uniform with 0 as the lower bound observation model.
.. math::
Y \sim \operatorname{Uniform}\left(0, \theta\right)
Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.
"""
# beta-negative_binomial observation model
x_lv, k_lv = var(), var()
theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_et = etuple(
etuplize(at.random.pareto),
theta_rng_lv,
theta_size_lv,
theta_type_idx_lv,
k_lv,
x_lv,
)
Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), var(), theta_et)

new_x_et = etuple(at.math.max, observed_val)
new_k_et = etuple(etuplize(at.add), k_lv, 1)

theta_posterior_et = etuple(
etuplize(at.random.pareto),
new_k_et,
new_x_et,
rng=theta_rng_lv,
size=theta_size_lv,
dtype=theta_type_idx_lv,
)
return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, theta_posterior_et),
)


@node_rewriter([UniformRV])
def local_uniform_pareto_posterior(fgraph, node):
sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_beta_negative_binomial_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, uniform_pareto_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

pareto_rv = rv_et[-1].evaled_obj
pareto_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(pareto_rv, []).append(
("local_uniform_pareto_posterior", pareto_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
conjugates_db.register(
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
)
conjugates_db.register("uniform", local_uniform_pareto_posterior, "basic")


sampler_finder_db.register(
Expand Down
73 changes: 73 additions & 0 deletions tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import pytest
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random import RandomStream
from etuples import etuple, etuplize
from kanren import run
from unification import var

from aemcmc.conjugates import (
beta_binomial_conjugateo,
beta_negative_binomial_conjugateo,
gamma_poisson_conjugateo,
uniform_pareto_conjugateo,
)


Expand Down Expand Up @@ -157,3 +159,74 @@ def test_beta_negative_binomial_conjugate_expand():
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_uniform_pareto_conjugate_contract():
"""Produce the closed-form posterior for the uniform observation model with
a pareto prior.
"""
srng = RandomStream(0)

xm_tt = at.scalar("xm")
k_tt = at.scalar("k")
theta_rv = srng.pareto(k_tt, xm_tt, name="theta")

# zero = at.iscalar("zero")
Y_rv = srng.uniform(0, theta_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, uniform_pareto_conjugateo(y_vv, Y_rv, q_lv))
posterior = eval_if_etuple(posterior_expr)

assert isinstance(posterior.owner.op, type(at.random.pareto))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((xm_tt, k_tt, y_vv), posterior)
assert sample_fn(1.0, 1000, 1) == pytest.approx(1.0, abs=0.01) # k = 1000
assert sample_fn(1.0, 1, 0) == pytest.approx(0.0, abs=0.01) # all zeros


def test_uniform_pareto_binomial_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

k_tt = at.scalar("k")
y_vv = at.iscalar("y")
n_tt = at.scalar("n")

Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)
etuplize(Y_rv)

# e_lv = var()
# (expanded_expr,) = run(1, e_lv, uniform_pareto_conjugateo(e_lv, y_vv, Y_rv))
# expanded = eval_if_etuple(expanded_expr)

# assert isinstance(expanded.owner.op, type(at.random.pareto))
from aesara.tensor.math import MaxAndArgmax
from kanren import eq, run
from unification import var

observed_val = var()
axis_lv = var()
new_x_et = etuple(etuple(MaxAndArgmax, axis_lv), observed_val)

k_lv, n_lv = var(), var()
new_k_et = etuple(etuplize(at.add), k_lv, n_lv)

theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_posterior_et = etuple(
etuplize(at.random.pareto),
theta_rng_lv,
theta_size_lv,
theta_type_idx_lv,
new_x_et,
new_k_et,
)

run(0, (new_x_et, new_k_et), eq(Y_rv, theta_posterior_et))

0 comments on commit 18a1c16

Please sign in to comment.