Skip to content

Commit

Permalink
Add QMC marginalization
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 27, 2024
1 parent 83ebb80 commit 0a97875
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 5 deletions.
107 changes: 102 additions & 5 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Sequence, Union

import numpy as np
import pymc
import pytensor.tensor as pt
import scipy
from arviz import InferenceData, dict_to_dataset
from pymc import SymbolicRandomVariable
from pymc import SymbolicRandomVariable, icdf
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
from pymc.distributions.continuous import Continuous, Normal
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.transforms import Chain
from pymc.logprob.abstract import _logprob
Expand Down Expand Up @@ -159,7 +160,11 @@ def _marginalize(self, user_warnings=False):
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
)

old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
if isinstance(rv_to_marginalize.owner.op, Continuous):
subgraph_builder_fn = replace_continuous_marginal_subgraph
else:
subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
old_rvs, new_rvs = subgraph_builder_fn(
fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
)

Expand Down Expand Up @@ -267,7 +272,11 @@ def marginalize(
)

rv_op = rv_to_marginalize.owner.op
if isinstance(rv_op, DiscreteMarkovChain):

if isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
pass

elif isinstance(rv_op, DiscreteMarkovChain):
if rv_op.n_lags > 1:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
Expand All @@ -276,7 +285,11 @@ def marginalize(
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
)
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):

elif isinstance(rv_op, Normal):
pass

else:
raise NotImplementedError(
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
)
Expand Down Expand Up @@ -549,6 +562,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
"""Base class for Discrete Marginal Markov Chain RVs"""


class QMCMarginalNormalRV(MarginalRV):
"""Basec class for QMC Marginalized RVs"""

__props__ = ("qmc_order",)

def __init__(self, *args, qmc_order: int, **kwargs):
self.qmc_order = qmc_order
super().__init__(*args, **kwargs)


def static_shape_ancestors(vars):
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
return [
Expand Down Expand Up @@ -707,6 +730,36 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
return rvs_to_marginalize, marginalized_rvs


def replace_continuous_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
if not dependent_rvs:
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")

marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
dependent_rvs_input_rvs = [
rv
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
if rv is not rv_to_marginalize
]

input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs]
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]

outputs = rvs_to_marginalize
# We are strict about shared variables in SymbolicRandomVariables
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)

marginalized_rvs = QMCMarginalNormalRV(
inputs=inputs,
outputs=outputs,
ndim_supp=max([rv.owner.op.ndim_supp for rv in dependent_rvs]),
qmc_order=13,
)(*inputs)

fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
return rvs_to_marginalize, marginalized_rvs


def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
op = rv.owner.op
dist_params = rv.owner.op.dist_params(rv.owner)
Expand Down Expand Up @@ -870,3 +923,47 @@ def step_alpha(logp_emission, log_alpha, log_P):
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
dummy_logps = (pt.constant(0),) * (len(values) - 1)
return joint_logp, *dummy_logps


@_logprob.register(QMCMarginalNormalRV)
def qmc_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rvs_node = op.make_node(*inputs)
marginalized_rv, *inner_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)

marginalized_rv_node = marginalized_rv.owner
marginalized_rv_op = marginalized_rv_node.op

# GET QMC draws from the marginalized RV
# TODO: Make this an Op
rng = marginalized_rv_op.rng_param(marginalized_rv_node)
shape = constant_fold(tuple(marginalized_rv.shape))
size = np.prod(shape).astype(int)
n_draws = 2**op.qmc_order
qmc_engine = scipy.stats.qmc.Sobol(d=size, seed=rng.get_value(borrow=False))
uniform_draws = qmc_engine.random(n_draws).reshape((n_draws, *shape))
qmc_draws = icdf(marginalized_rv, uniform_draws)
qmc_draws.name = f"QMC_{op.name}_draws"

# Obtain the logp of the dependent variables
# We need to include the marginalized RV for correctness, we remove it later.
inner_rv_values = dict(zip(inner_rvs, values))
marginalized_vv = marginalized_rv.clone()
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
# Pop the logp term corresponding to the marginalized RV
# (it already got accounted for in the bias of the QMC draws)
logps_dict.pop(marginalized_vv)

# Vectorize across QMC draws and take the mean on log scale
core_marginalized_logps = list(logps_dict.values())
batched_marginalized_logps = vectorize_graph(
core_marginalized_logps, replace={marginalized_vv: qmc_draws}
)
return tuple(
pt.logsumexp(batched_marginalized_logp, axis=0) - pt.log(size)
for batched_marginalized_logp in batched_marginalized_logps
)
17 changes: 17 additions & 0 deletions pymc_experimental/tests/model/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pymc as pm
import pytensor.tensor as pt
import pytest
import scipy
from arviz import InferenceData, dict_to_dataset
from pymc.distributions import transforms
from pymc.logprob.abstract import _logprob
Expand Down Expand Up @@ -802,3 +803,19 @@ def create_model(model_class):
marginal_m.compile_logp()(ip),
reference_m.compile_logp()(ip),
)


def test_marginalize_normal_via_qmc():
with MarginalModel() as m:
SD = pm.HalfNormal("SD", default_transform=None)
X = pm.Normal("X", sigma=SD)
Y = pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])

m.marginalize([X]) # ideally method="qmc"

# P(Y=[1, 2, 3] | SD = 1) = int_x P(Y=[1, 2, 3] | SD=1, X=x) P(X=x | SD=1) = Norm([1, 2, 3], 0.5, sqrt(2))
[logp_eval] = m.compile_logp(vars=[Y], sum=False)({"SD": 1})
np.testing.assert_allclose(
logp_eval,
scipy.stats.norm.logpdf([1, 2, 3], 0.5, np.sqrt(2) / 2),
)

0 comments on commit 0a97875

Please sign in to comment.