-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP Automatic marginalization finite discrete variables
- Loading branch information
1 parent
307913d
commit a21e1f8
Showing
2 changed files
with
398 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
from typing import Sequence, Tuple, Union | ||
|
||
import aesara.tensor as at | ||
import numpy as np | ||
from aeppl import factorized_joint_logprob | ||
from aeppl.logprob import _logprob | ||
from aesara import clone_replace | ||
from aesara.compile import SharedVariable | ||
from aesara.compile.builders import OpFromGraph | ||
from aesara.graph import Constant, FunctionGraph, ancestors | ||
from aesara.tensor import TensorVariable | ||
from aesara.tensor.elemwise import Elemwise | ||
from pymc import SymbolicRandomVariable | ||
from pymc.aesaraf import inputvars | ||
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform | ||
from pymc.model import Model | ||
|
||
|
||
def _replace_marginalized_subgraph(fgraph, rv_to_marginalize): | ||
# Check if it's even valid | ||
temp_fgraph = FunctionGraph(inputs=rv_to_marginalize, outputs=fgraph.outputs, clone=False) | ||
|
||
|
||
class MarginalModel(Model): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
if self.parent is not None: | ||
self.marginalized_rvs = self.parent.marginalized_rvs | ||
else: | ||
self.marginalized_rvs = [] | ||
|
||
def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]): | ||
# TODO: this does not need to be a property of a Model | ||
if not isinstance(rvs_to_marginalize, Sequence): | ||
rvs_to_marginalize = (rvs_to_marginalize,) | ||
|
||
supported_dists = (Bernoulli, Categorical, DiscreteUniform) | ||
for rv_to_marginalize in rvs_to_marginalize: | ||
if rv_to_marginalize not in self.free_RVs: | ||
raise ValueError( | ||
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model" | ||
) | ||
if not isinstance(rv_to_marginalize.owner.op, supported_dists): | ||
raise NotImplementedError( | ||
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. " | ||
f"Supported distribution include {supported_dists}" | ||
) | ||
|
||
if self.deterministics: | ||
# TODO: This should be fine if deterministics do not depend on marginalized RVs | ||
raise NotImplementedError("Models with deterministics cannot be marginalized") | ||
|
||
if self.potentials: | ||
raise NotImplementedError("Models with potentials cannot be marginalized") | ||
|
||
# Replaced with subgraph that need to be marginalized for each RV | ||
fg = FunctionGraph(outputs=self.basic_RVs, clone=False) | ||
toposort = fg.toposort() | ||
replacements = {} | ||
for rv_to_marginalize in sorted( | ||
rvs_to_marginalize, key=lambda rv: toposort.index(rv.owner) | ||
): | ||
old_rvs, new_rvs = _replace_finite_discrete_marginal_subgraph( | ||
fg, rv_to_marginalize, self.rvs_to_values | ||
) | ||
# Update old mappings | ||
for old_rv, new_rv in zip(old_rvs, new_rvs): | ||
replacements[old_rv] = new_rv | ||
if old_rv in self.free_RVs: | ||
index = self.free_RVs.index(old_rv) | ||
self.free_RVs.pop(index) | ||
self.free_RVs.insert(index, new_rv) | ||
else: | ||
index = self.observed_RVs.index(old_rv) | ||
self.observed_RVs.pop(index) | ||
self.observed_RVs.insert(index, new_rv) | ||
self.rvs_to_values[new_rv] = value = self.rvs_to_values.pop(old_rv) | ||
self.values_to_rvs[value] = new_rv | ||
self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv) | ||
# TODO: Automatic imputation RV does not seem to have total_size mapping | ||
self.rvs_to_total_sizes[new_rv] = self.rvs_to_total_sizes.pop(old_rv, None) | ||
|
||
# This RV can now be safely ignored in the logp graph | ||
self.free_RVs.remove(rv_to_marginalize) | ||
value = self.rvs_to_values.pop(rv_to_marginalize) | ||
self.values_to_rvs.pop(value) | ||
self.rvs_to_transforms.pop(rv_to_marginalize) | ||
self.rvs_to_total_sizes.pop(rv_to_marginalize) | ||
|
||
return replacements | ||
|
||
|
||
def _find_dependent_rvs(dependable_rv, all_rvs): | ||
# Find rvs than depend on dependable | ||
dependent_rvs = [] | ||
for rv in all_rvs: | ||
if rv is dependable_rv: | ||
continue | ||
blockers = [other_rv for other_rv in all_rvs if other_rv is not rv] | ||
if dependable_rv in ancestors([rv], blockers=blockers): | ||
dependent_rvs.append(rv) | ||
return dependent_rvs | ||
|
||
|
||
def _find_input_rvs(output_rvs, all_rvs): | ||
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] | ||
return [ | ||
var | ||
for var in ancestors(output_rvs, blockers=blockers) | ||
if var in blockers | ||
or (var.owner is None and not isinstance(var, (Constant, SharedVariable))) | ||
] | ||
|
||
|
||
def _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): | ||
# TODO: No need to consider apply nodes outside the subgraph... | ||
fg = FunctionGraph(outputs=output_rvs, clone=False) | ||
|
||
non_elemwise_blockers = [ | ||
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs | ||
] | ||
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers | ||
blockers = [var for var in blocker_candidates if var not in output_rvs] | ||
|
||
# TODO: We could actually use these truncated inputs to | ||
# generate a smaller Marginalized graph... | ||
truncated_inputs = [ | ||
var | ||
for var in ancestors(output_rvs, blockers=blockers) | ||
if ( | ||
var in blockers | ||
or (var.owner is None and not isinstance(var, (Constant, SharedVariable))) | ||
) | ||
] | ||
|
||
# Check that we reach the marginalized rv following a pure elemwise graph | ||
if rv_to_marginalize not in truncated_inputs: | ||
return False | ||
|
||
# Check that none of the truncated inputs depends on the marginalized_rv | ||
other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] | ||
# TODO: We don't need to go all the way to the root variables | ||
if rv_to_marginalize in ancestors( | ||
other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] | ||
): | ||
return False | ||
return True | ||
|
||
|
||
class FiniteDiscreteMarginalRV(SymbolicRandomVariable): | ||
pass | ||
|
||
|
||
def _replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, rvs_to_values): | ||
# TODO: This should eventually be integrated in a more general routine that can | ||
# identify other types of supported marginalization, of which finite discrete | ||
# RVs is just one | ||
|
||
dependent_rvs = _find_dependent_rvs(rv_to_marginalize, rvs_to_values) | ||
input_rvs = _find_input_rvs(dependent_rvs, rvs_to_values) | ||
other_input_rvs = [rv for rv in input_rvs if rv is not rv_to_marginalize] | ||
# We don't need to worry about batched graphs if the RV is scalar. | ||
# TODO: This eval is a bit hackish | ||
if np.prod(rv_to_marginalize.shape.eval()) > 1: | ||
if not _is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, dependent_rvs): | ||
raise NotImplementedError( | ||
"The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " | ||
"This is currently not supported", | ||
) | ||
|
||
marginalization_op = FiniteDiscreteMarginalRV( | ||
inputs=[rv_to_marginalize, *other_input_rvs], | ||
outputs=dependent_rvs, | ||
ndim_supp=None, | ||
) | ||
# Marginalized_RV logp is accounted by in the logp, so it can be safely ignored | ||
# rv_to_marginalize = ignore_logprob(rv_to_marginalize) | ||
marginalized_rvs = marginalization_op(rv_to_marginalize, *other_input_rvs) | ||
if not isinstance(marginalized_rvs, Sequence): | ||
marginalized_rvs = (marginalized_rvs,) | ||
fgraph.replace_all(tuple(zip(dependent_rvs, marginalized_rvs))) | ||
return dependent_rvs, marginalized_rvs | ||
|
||
|
||
def _get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: | ||
op = rv.owner.op | ||
if isinstance(op, Bernoulli): | ||
return (0, 1) | ||
elif isinstance(op, Categorical): | ||
p_param = rv.owner.inputs[3] | ||
return tuple(range(at.get_vector_length(p_param))) | ||
elif isinstance(op, DiscreteUniform): | ||
lower, upper = rv.owner.inputs[3:] | ||
return tuple( | ||
range( | ||
at.get_scalar_constant_value(lower), | ||
at.get_scalar_constant_value(upper), | ||
) | ||
) | ||
|
||
raise NotImplementedError(f"Cannot compute domain for op {op}") | ||
|
||
|
||
@_logprob.register(FiniteDiscreteMarginalRV) | ||
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): | ||
|
||
marginalized_rvs_node = op.make_node(*inputs) | ||
marginalized_rvs = clone_replace( | ||
op.inner_outputs, | ||
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, | ||
) | ||
|
||
marginalized_rv, *other_inputs = inputs | ||
other_inputs = list(inputvars(other_inputs)) | ||
|
||
rvs_to_values = {} | ||
dummy_marginalized_value = marginalized_rv.clone() | ||
rvs_to_values[marginalized_rv] = dummy_marginalized_value | ||
|
||
rvs_to_values.update(zip(marginalized_rvs, values)) | ||
_logp = at.sum( | ||
[ | ||
at.sum(factor) | ||
for factor in factorized_joint_logprob( | ||
rv_values=rvs_to_values, warn_missing_rvs=False, **kwargs | ||
).values() | ||
] | ||
) | ||
# OpFromGraph does not accept constant inputs... | ||
_values = [ | ||
value | ||
for value in rvs_to_values.values() | ||
if not isinstance(value, (Constant, SharedVariable)) | ||
] | ||
# TODO: If we inline the logp graph, optimization becomes incredibly painful for | ||
# large domains... Would be great to find a way to vectorize the graph across | ||
# the domain values (when possible) | ||
logp_op = OpFromGraph([*_values, *other_inputs], [_logp], inline=False) | ||
|
||
# PyMC does not allow RVs in the logp graph... Even if we are just using the shape | ||
# TODO: Get better work-around | ||
marginalized_rv_shape = marginalized_rv.shape.eval() | ||
values = [value for value in values if not isinstance(value, (Constant, SharedVariable))] | ||
return at.logsumexp( | ||
[ | ||
logp_op(np.full(marginalized_rv_shape, marginalized_rv_const), *values, *other_inputs) | ||
for marginalized_rv_const in _get_domain_of_finite_discrete_rv(marginalized_rv) | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import aesara.tensor as at | ||
import numpy as np | ||
import pandas as pd | ||
import pymc as pm | ||
import pytest | ||
from aeppl.logprob import _logprob | ||
from aesara.graph import ancestors | ||
|
||
from pymc_experimental.marginal_model import FiniteDiscreteMarginalRV, MarginalModel | ||
|
||
|
||
def test_marginalized_bernoulli_logp(): | ||
"""Test logp of IR TestFiniteMarginalDiscreteRV directly""" | ||
idx = pm.Bernoulli.dist(0.7, name="idx") | ||
mu = at.constant([-1, 1])[idx] | ||
y = pm.Normal.dist(mu=mu, sigma=1.0, name="y") | ||
marginal_y = FiniteDiscreteMarginalRV([idx], [y], ndim_supp=None)(idx) | ||
|
||
y_vv = y.clone() | ||
marginal_y_logp = _logprob( | ||
marginal_y.owner.op, | ||
(y_vv,), | ||
*marginal_y.owner.inputs, | ||
) | ||
|
||
ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=[-1, 1], sigma=1.0), y_vv).sum() | ||
np.testing.assert_almost_equal( | ||
marginal_y_logp.eval({y_vv: 2}), | ||
ref_logp.eval({y_vv: 2}), | ||
) | ||
|
||
|
||
def test_marginalize(): | ||
data = [2] * 5 | ||
|
||
with pm.Model() as m_ref: | ||
sigma = pm.HalfNormal("sigma") | ||
y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma) | ||
z = pm.Normal("z", y, observed=data) | ||
|
||
with MarginalModel() as m: | ||
sigma = pm.HalfNormal("sigma") | ||
idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6]) | ||
mu = at.switch( | ||
at.eq(idx, 0), | ||
-1, | ||
at.switch( | ||
at.eq(idx, 1), | ||
0, | ||
1, | ||
), | ||
) | ||
y = pm.Normal("y", mu=mu, sigma=sigma) | ||
z = pm.Normal("z", y, observed=data) | ||
|
||
replacements = m.marginalize([idx]) | ||
assert len(replacements) == 1 | ||
|
||
assert y not in m.free_RVs | ||
assert idx not in m.free_RVs | ||
|
||
new_y = replacements[y] | ||
assert new_y in m.free_RVs | ||
assert new_y in ancestors([z]) | ||
|
||
assert isinstance(new_y.owner.op, FiniteDiscreteMarginalRV) | ||
# Ignore RNGs | ||
assert new_y.owner.inputs[:2] == [idx, sigma] | ||
|
||
test_point = m_ref.initial_point() | ||
# TODO: Test we don't get warnings with missing RVs | ||
np.testing.assert_almost_equal( | ||
m.compile_logp()(test_point), | ||
m_ref.compile_logp()(test_point), | ||
) | ||
|
||
|
||
def test_marginalize_nested(): | ||
raise NotImplementedError("Must write test") | ||
|
||
|
||
def test_not_supported_marginalization(): | ||
"""Marginalized graphs with non-Elemwise Operations are not supported as they | ||
would violate the batching logp assumption""" | ||
|
||
mu = at.constant([-1, 1]) | ||
|
||
# Allowed, as only elemwise operations connect idx to y | ||
with MarginalModel() as m: | ||
p = pm.Beta("p", 1, 1) | ||
idx = pm.Bernoulli("idx", p=p, size=2) | ||
y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1)) | ||
assert m.marginalize([idx]) | ||
|
||
# ALlowed, as index operation does not connext idx to y | ||
with MarginalModel() as m: | ||
p = pm.Beta("p", 1, 1) | ||
idx = pm.Bernoulli("idx", p=p, size=2) | ||
y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1])) | ||
assert m.marginalize([idx]) | ||
|
||
# Not allowed, as index operation connects idx to y | ||
with MarginalModel() as m: | ||
p = pm.Beta("p", 1, 1) | ||
idx = pm.Bernoulli("idx", p=p, size=2) | ||
# Not allowed | ||
y = pm.Normal("y", mu=mu[idx]) | ||
with pytest.raises(NotImplementedError): | ||
m.marginalize(idx) | ||
|
||
# Not allowed, as index operation connects idx to y, even though there is a | ||
# pure Elemwise connection between the two | ||
with MarginalModel() as m: | ||
p = pm.Beta("p", 1, 1) | ||
idx = pm.Bernoulli("idx", p=p, size=2) | ||
y = pm.Normal("y", mu=mu[idx] + idx) | ||
with pytest.raises(NotImplementedError): | ||
m.marginalize(idx) | ||
|
||
|
||
def test_change_point_model(): | ||
# fmt: off | ||
disaster_data = pd.Series( | ||
[4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, | ||
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, | ||
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0, | ||
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, | ||
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, | ||
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, | ||
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1] | ||
) | ||
# fmt: on | ||
years = np.arange(1851, 1962) | ||
|
||
with MarginalModel() as disaster_model: | ||
switchpoint = pm.DiscreteUniform( | ||
"switchpoint", lower=years.min(), upper=years.max(), size=1 | ||
) | ||
|
||
early_rate = pm.Exponential("early_rate", 1.0) | ||
late_rate = pm.Exponential("late_rate", 1.0) | ||
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) | ||
|
||
disasters = pm.Poisson("disasters", rate, observed=disaster_data) | ||
|
||
disaster_model.marginalize([switchpoint]) | ||
disaster_model.compile_logp()(disaster_model.initial_point()) | ||
|
||
raise NotImplementedError("Test not finished") |