Skip to content

Commit

Permalink
WIP Automatic marginalization finite discrete variables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 14, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 307913d commit a21e1f8
Showing 2 changed files with 398 additions and 0 deletions.
249 changes: 249 additions & 0 deletions pymc_experimental/marginal_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from typing import Sequence, Tuple, Union

import aesara.tensor as at
import numpy as np
from aeppl import factorized_joint_logprob
from aeppl.logprob import _logprob
from aesara import clone_replace
from aesara.compile import SharedVariable
from aesara.compile.builders import OpFromGraph
from aesara.graph import Constant, FunctionGraph, ancestors
from aesara.tensor import TensorVariable
from aesara.tensor.elemwise import Elemwise
from pymc import SymbolicRandomVariable
from pymc.aesaraf import inputvars
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.model import Model


def _replace_marginalized_subgraph(fgraph, rv_to_marginalize):
# Check if it's even valid
temp_fgraph = FunctionGraph(inputs=rv_to_marginalize, outputs=fgraph.outputs, clone=False)


class MarginalModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.parent is not None:
self.marginalized_rvs = self.parent.marginalized_rvs
else:
self.marginalized_rvs = []

def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
# TODO: this does not need to be a property of a Model
if not isinstance(rvs_to_marginalize, Sequence):
rvs_to_marginalize = (rvs_to_marginalize,)

supported_dists = (Bernoulli, Categorical, DiscreteUniform)
for rv_to_marginalize in rvs_to_marginalize:
if rv_to_marginalize not in self.free_RVs:
raise ValueError(
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
)
if not isinstance(rv_to_marginalize.owner.op, supported_dists):
raise NotImplementedError(
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. "
f"Supported distribution include {supported_dists}"
)

if self.deterministics:
# TODO: This should be fine if deterministics do not depend on marginalized RVs
raise NotImplementedError("Models with deterministics cannot be marginalized")

if self.potentials:
raise NotImplementedError("Models with potentials cannot be marginalized")

# Replaced with subgraph that need to be marginalized for each RV
fg = FunctionGraph(outputs=self.basic_RVs, clone=False)
toposort = fg.toposort()
replacements = {}
for rv_to_marginalize in sorted(
rvs_to_marginalize, key=lambda rv: toposort.index(rv.owner)
):
old_rvs, new_rvs = _replace_finite_discrete_marginal_subgraph(
fg, rv_to_marginalize, self.rvs_to_values
)
# Update old mappings
for old_rv, new_rv in zip(old_rvs, new_rvs):
replacements[old_rv] = new_rv
if old_rv in self.free_RVs:
index = self.free_RVs.index(old_rv)
self.free_RVs.pop(index)
self.free_RVs.insert(index, new_rv)
else:
index = self.observed_RVs.index(old_rv)
self.observed_RVs.pop(index)
self.observed_RVs.insert(index, new_rv)
self.rvs_to_values[new_rv] = value = self.rvs_to_values.pop(old_rv)
self.values_to_rvs[value] = new_rv
self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
# TODO: Automatic imputation RV does not seem to have total_size mapping
self.rvs_to_total_sizes[new_rv] = self.rvs_to_total_sizes.pop(old_rv, None)

# This RV can now be safely ignored in the logp graph
self.free_RVs.remove(rv_to_marginalize)
value = self.rvs_to_values.pop(rv_to_marginalize)
self.values_to_rvs.pop(value)
self.rvs_to_transforms.pop(rv_to_marginalize)
self.rvs_to_total_sizes.pop(rv_to_marginalize)

return replacements


def _find_dependent_rvs(dependable_rv, all_rvs):
# Find rvs than depend on dependable
dependent_rvs = []
for rv in all_rvs:
if rv is dependable_rv:
continue
blockers = [other_rv for other_rv in all_rvs if other_rv is not rv]
if dependable_rv in ancestors([rv], blockers=blockers):
dependent_rvs.append(rv)
return dependent_rvs


def _find_input_rvs(output_rvs, all_rvs):
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
return [
var
for var in ancestors(output_rvs, blockers=blockers)
if var in blockers
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
]


def _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
# TODO: No need to consider apply nodes outside the subgraph...
fg = FunctionGraph(outputs=output_rvs, clone=False)

non_elemwise_blockers = [
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs
]
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
blockers = [var for var in blocker_candidates if var not in output_rvs]

# TODO: We could actually use these truncated inputs to
# generate a smaller Marginalized graph...
truncated_inputs = [
var
for var in ancestors(output_rvs, blockers=blockers)
if (
var in blockers
or (var.owner is None and not isinstance(var, (Constant, SharedVariable)))
)
]

# Check that we reach the marginalized rv following a pure elemwise graph
if rv_to_marginalize not in truncated_inputs:
return False

# Check that none of the truncated inputs depends on the marginalized_rv
other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize]
# TODO: We don't need to go all the way to the root variables
if rv_to_marginalize in ancestors(
other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs]
):
return False
return True


class FiniteDiscreteMarginalRV(SymbolicRandomVariable):
pass


def _replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, rvs_to_values):
# TODO: This should eventually be integrated in a more general routine that can
# identify other types of supported marginalization, of which finite discrete
# RVs is just one

dependent_rvs = _find_dependent_rvs(rv_to_marginalize, rvs_to_values)
input_rvs = _find_input_rvs(dependent_rvs, rvs_to_values)
other_input_rvs = [rv for rv in input_rvs if rv is not rv_to_marginalize]
# We don't need to worry about batched graphs if the RV is scalar.
# TODO: This eval is a bit hackish
if np.prod(rv_to_marginalize.shape.eval()) > 1:
if not _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, dependent_rvs):
raise NotImplementedError(
"The subgraph between a marginalized RV and its dependents includes non Elemwise operations. "
"This is currently not supported",
)

marginalization_op = FiniteDiscreteMarginalRV(
inputs=[rv_to_marginalize, *other_input_rvs],
outputs=dependent_rvs,
ndim_supp=None,
)
# Marginalized_RV logp is accounted by in the logp, so it can be safely ignored
# rv_to_marginalize = ignore_logprob(rv_to_marginalize)
marginalized_rvs = marginalization_op(rv_to_marginalize, *other_input_rvs)
if not isinstance(marginalized_rvs, Sequence):
marginalized_rvs = (marginalized_rvs,)
fgraph.replace_all(tuple(zip(dependent_rvs, marginalized_rvs)))
return dependent_rvs, marginalized_rvs


def _get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
op = rv.owner.op
if isinstance(op, Bernoulli):
return (0, 1)
elif isinstance(op, Categorical):
p_param = rv.owner.inputs[3]
return tuple(range(at.get_vector_length(p_param)))
elif isinstance(op, DiscreteUniform):
lower, upper = rv.owner.inputs[3:]
return tuple(
range(
at.get_scalar_constant_value(lower),
at.get_scalar_constant_value(upper),
)
)

raise NotImplementedError(f"Cannot compute domain for op {op}")


@_logprob.register(FiniteDiscreteMarginalRV)
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):

marginalized_rvs_node = op.make_node(*inputs)
marginalized_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)

marginalized_rv, *other_inputs = inputs
other_inputs = list(inputvars(other_inputs))

rvs_to_values = {}
dummy_marginalized_value = marginalized_rv.clone()
rvs_to_values[marginalized_rv] = dummy_marginalized_value

rvs_to_values.update(zip(marginalized_rvs, values))
_logp = at.sum(
[
at.sum(factor)
for factor in factorized_joint_logprob(
rv_values=rvs_to_values, warn_missing_rvs=False, **kwargs
).values()
]
)
# OpFromGraph does not accept constant inputs...
_values = [
value
for value in rvs_to_values.values()
if not isinstance(value, (Constant, SharedVariable))
]
# TODO: If we inline the logp graph, optimization becomes incredibly painful for
# large domains... Would be great to find a way to vectorize the graph across
# the domain values (when possible)
logp_op = OpFromGraph([*_values, *other_inputs], [_logp], inline=False)

# PyMC does not allow RVs in the logp graph... Even if we are just using the shape
# TODO: Get better work-around
marginalized_rv_shape = marginalized_rv.shape.eval()
values = [value for value in values if not isinstance(value, (Constant, SharedVariable))]
return at.logsumexp(
[
logp_op(np.full(marginalized_rv_shape, marginalized_rv_const), *values, *other_inputs)
for marginalized_rv_const in _get_domain_of_finite_discrete_rv(marginalized_rv)
]
)
149 changes: 149 additions & 0 deletions pymc_experimental/tests/test_marginal_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import aesara.tensor as at
import numpy as np
import pandas as pd
import pymc as pm
import pytest
from aeppl.logprob import _logprob
from aesara.graph import ancestors

from pymc_experimental.marginal_model import FiniteDiscreteMarginalRV, MarginalModel


def test_marginalized_bernoulli_logp():
"""Test logp of IR TestFiniteMarginalDiscreteRV directly"""
idx = pm.Bernoulli.dist(0.7, name="idx")
mu = at.constant([-1, 1])[idx]
y = pm.Normal.dist(mu=mu, sigma=1.0, name="y")
marginal_y = FiniteDiscreteMarginalRV([idx], [y], ndim_supp=None)(idx)

y_vv = y.clone()
marginal_y_logp = _logprob(
marginal_y.owner.op,
(y_vv,),
*marginal_y.owner.inputs,
)

ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=[-1, 1], sigma=1.0), y_vv).sum()
np.testing.assert_almost_equal(
marginal_y_logp.eval({y_vv: 2}),
ref_logp.eval({y_vv: 2}),
)


def test_marginalize():
data = [2] * 5

with pm.Model() as m_ref:
sigma = pm.HalfNormal("sigma")
y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
z = pm.Normal("z", y, observed=data)

with MarginalModel() as m:
sigma = pm.HalfNormal("sigma")
idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
mu = at.switch(
at.eq(idx, 0),
-1,
at.switch(
at.eq(idx, 1),
0,
1,
),
)
y = pm.Normal("y", mu=mu, sigma=sigma)
z = pm.Normal("z", y, observed=data)

replacements = m.marginalize([idx])
assert len(replacements) == 1

assert y not in m.free_RVs
assert idx not in m.free_RVs

new_y = replacements[y]
assert new_y in m.free_RVs
assert new_y in ancestors([z])

assert isinstance(new_y.owner.op, FiniteDiscreteMarginalRV)
# Ignore RNGs
assert new_y.owner.inputs[:2] == [idx, sigma]

test_point = m_ref.initial_point()
# TODO: Test we don't get warnings with missing RVs
np.testing.assert_almost_equal(
m.compile_logp()(test_point),
m_ref.compile_logp()(test_point),
)


def test_marginalize_nested():
raise NotImplementedError("Must write test")


def test_not_supported_marginalization():
"""Marginalized graphs with non-Elemwise Operations are not supported as they
would violate the batching logp assumption"""

mu = at.constant([-1, 1])

# Allowed, as only elemwise operations connect idx to y
with MarginalModel() as m:
p = pm.Beta("p", 1, 1)
idx = pm.Bernoulli("idx", p=p, size=2)
y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1))
assert m.marginalize([idx])

# ALlowed, as index operation does not connext idx to y
with MarginalModel() as m:
p = pm.Beta("p", 1, 1)
idx = pm.Bernoulli("idx", p=p, size=2)
y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1]))
assert m.marginalize([idx])

# Not allowed, as index operation connects idx to y
with MarginalModel() as m:
p = pm.Beta("p", 1, 1)
idx = pm.Bernoulli("idx", p=p, size=2)
# Not allowed
y = pm.Normal("y", mu=mu[idx])
with pytest.raises(NotImplementedError):
m.marginalize(idx)

# Not allowed, as index operation connects idx to y, even though there is a
# pure Elemwise connection between the two
with MarginalModel() as m:
p = pm.Beta("p", 1, 1)
idx = pm.Bernoulli("idx", p=p, size=2)
y = pm.Normal("y", mu=mu[idx] + idx)
with pytest.raises(NotImplementedError):
m.marginalize(idx)


def test_change_point_model():
# fmt: off
disaster_data = pd.Series(
[4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
)
# fmt: on
years = np.arange(1851, 1962)

with MarginalModel() as disaster_model:
switchpoint = pm.DiscreteUniform(
"switchpoint", lower=years.min(), upper=years.max(), size=1
)

early_rate = pm.Exponential("early_rate", 1.0)
late_rate = pm.Exponential("late_rate", 1.0)
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)

disasters = pm.Poisson("disasters", rate, observed=disaster_data)

disaster_model.marginalize([switchpoint])
disaster_model.compile_logp()(disaster_model.initial_point())

raise NotImplementedError("Test not finished")

0 comments on commit a21e1f8

Please sign in to comment.