Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make infer_discrete work with scan #991

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
106 changes: 40 additions & 66 deletions numpyro/contrib/funsor/discrete.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from collections import OrderedDict, defaultdict
import functools

from jax import random
import jax.numpy as jnp

import funsor
from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate
from numpyro.distributions.util import is_identically_one
from numpyro.handlers import block, replay, seed, trace
from numpyro.contrib.funsor.enum_messenger import enum
from numpyro.contrib.funsor.infer_util import _enum_log_density, _get_shift, _shift_name
from numpyro.handlers import block, seed, substitute, trace
from numpyro.infer.util import _guess_max_plate_nesting


Expand Down Expand Up @@ -38,46 +38,6 @@ def _get_support_value_delta(funsor_dist, name, **kwargs):
return OrderedDict(funsor_dist.terms)[name][0]


def terms_from_trace(tr):
"""Helper function to extract elbo components from execution traces."""
log_factors = {}
log_measures = {}
sum_vars, prod_vars = frozenset(), frozenset()
for site in tr.values():
if site["type"] == "sample":
value = site["value"]
intermediates = site["intermediates"]
scale = site["scale"]
if intermediates:
log_prob = site["fn"].log_prob(value, intermediates)
else:
log_prob = site["fn"].log_prob(value)

if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob

dim_to_name = site["infer"]["dim_to_name"]
log_prob_factor = funsor.to_funsor(
log_prob, output=funsor.Real, dim_to_name=dim_to_name
)

if site["is_observed"]:
log_factors[site["name"]] = log_prob_factor
else:
log_measures[site["name"]] = log_prob_factor
sum_vars |= frozenset({site["name"]})
prod_vars |= frozenset(
f.name for f in site["cond_indep_stack"] if f.dim is not None
)

return {
"log_factors": log_factors,
"log_measures": log_measures,
"measure_vars": sum_vars,
"plate_vars": prod_vars,
}


def _sample_posterior(
model, first_available_dim, temperature, rng_key, *args, **kwargs
):
Expand All @@ -97,27 +57,14 @@ def _sample_posterior(
model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs)
first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

with block(), enum(first_available_dim=first_available_dim):
with plate_to_enum_plate():
model_tr = packed_trace(model).get_trace(*args, **kwargs)

terms = terms_from_trace(model_tr)
# terms["log_factors"] = [log p(x) for each observed or latent sample site x]
# terms["log_measures"] = [log p(z) or other Dice factor
# for each latent sample site z]

with funsor.interpretations.lazy:
log_prob = funsor.sum_product.sum_product(
sum_op,
prod_op,
list(terms["log_factors"].values()) + list(terms["log_measures"].values()),
eliminate=terms["measure_vars"] | terms["plate_vars"],
plates=terms["plate_vars"],
)
log_prob = funsor.optimizer.apply_optimizer(log_prob)
with funsor.adjoint.AdjointTape() as tape:
with block(), enum(first_available_dim=first_available_dim):
log_prob, model_tr, log_measures = _enum_log_density(
model, args, kwargs, {}, sum_op, prod_op
)

with approx:
approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the line where Scatter and Deltas are introduced


# construct a result trace to replay against the model
sample_tr = model_tr.copy()
Expand All @@ -138,13 +85,40 @@ def _sample_posterior(
value, name_to_dim=node["infer"]["name_to_dim"]
)
else:
log_measure = approx_factors[terms["log_measures"][name]]
log_measure = approx_factors[log_measures[name]]
sample_subs[name] = _get_support_value(log_measure, name)
node["value"] = funsor.to_data(
sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]
)

with replay(guide_trace=sample_tr):
data = {
name: site["value"]
for name, site in sample_tr.items()
if site["type"] == "sample"
}

# concatenate _PREV_foo to foo
time_vars = defaultdict(list)
for name in data:
if name.startswith("_PREV_"):
root_name = _shift_name(name, -_get_shift(name))
time_vars[root_name].append(name)
for name in time_vars:
if name in data:
time_vars[name].append(name)
time_vars[name] = sorted(time_vars[name], key=len, reverse=True)

for root_name, vars in time_vars.items():
prototype_shape = model_trace[root_name]["value"].shape
values = [data.pop(name) for name in vars]
if len(values) == 1:
data[root_name] = values[0].reshape(prototype_shape)
else:
assert len(prototype_shape) >= 1
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values]
data[root_name] = jnp.concatenate(values)

with substitute(data=data):
return model(*args, **kwargs)


Expand Down
69 changes: 42 additions & 27 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def compute_markov_factors(
sum_vars,
prod_vars,
history,
sum_op,
prod_op,
):
"""
:param dict time_to_factors: a map from time variable to the log prob factors.
Expand All @@ -119,8 +121,8 @@ def compute_markov_factors(
eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var]
with funsor.interpretations.lazy:
lazy_result = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
sum_op,
prod_op,
log_factors,
eliminate=eliminate_vars,
plates=prod_vars,
Expand All @@ -136,41 +138,22 @@ def compute_markov_factors(
)
markov_factors.append(
funsor.sum_product.sarkka_bilmes_product(
funsor.ops.logaddexp, funsor.ops.add, trans, time_var, global_vars
sum_op, prod_op, trans, time_var, global_vars
)
)
else:
# remove `_PREV_` prefix to convert prev to curr
prev_to_curr = {k: _shift_name(k, -_get_shift(k)) for k in prev_vars}
markov_factors.append(
funsor.sum_product.sequential_sum_product(
funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr
sum_op, prod_op, trans, time_var, prev_to_curr
)
)
return markov_factors


def log_density(model, model_args, model_kwargs, params):
"""
Similar to :func:`numpyro.infer.util.log_density` but works for models
with discrete latent variables. Internally, this uses :mod:`funsor`
to marginalize discrete latent sites and evaluate the joint log probability.

:param model: Python callable containing NumPyro primitives. Typically,
the model has been enumerated by using
:class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

def model(*args, **kwargs):
...

log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
"""Helper function to compute elbo and extract its components from execution traces."""
model = substitute(model, data=params)
with plate_to_enum_plate():
model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
Expand All @@ -180,6 +163,7 @@ def model(*args, **kwargs):
time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites
sum_vars, prod_vars = frozenset(), frozenset()
history = 1
log_measures = {}
for site in model_trace.values():
if site["type"] == "sample":
value = site["value"]
Expand Down Expand Up @@ -214,7 +198,9 @@ def model(*args, **kwargs):
log_factors.append(log_prob_factor)

if not site["is_observed"]:
log_measures[site["name"]] = log_prob_factor
sum_vars |= frozenset({site["name"]})

prod_vars |= frozenset(
f.name for f in site["cond_indep_stack"] if f.dim is not None
)
Expand All @@ -236,13 +222,15 @@ def model(*args, **kwargs):
sum_vars,
prod_vars,
history,
sum_op,
prod_op,
)
log_factors = log_factors + markov_factors

with funsor.interpretations.lazy:
lazy_result = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
sum_op,
prod_op,
log_factors,
eliminate=sum_vars | prod_vars,
plates=prod_vars,
Expand All @@ -255,4 +243,31 @@ def model(*args, **kwargs):
result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
)
)
return result, model_trace, log_measures


def log_density(model, model_args, model_kwargs, params):
"""
Similar to :func:`numpyro.infer.util.log_density` but works for models
with discrete latent variables. Internally, this uses :mod:`funsor`
to marginalize discrete latent sites and evaluate the joint log probability.

:param model: Python callable containing NumPyro primitives. Typically,
the model has been enumerated by using
:class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

def model(*args, **kwargs):
...

log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
result, model_trace, _ = _enum_log_density(
model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
)
return result.data, model_trace
33 changes: 33 additions & 0 deletions test/contrib/test_infer_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpyro
from numpyro import handlers, infer
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
from numpyro.distributions.util import is_identically_one

Expand Down Expand Up @@ -81,6 +82,38 @@ def hmm(data, hidden_dim=10):
logger.info("inferred states: {}".format(list(map(int, inferred_states))))


@pytest.mark.parametrize("length", [1, 2, 10])
@pytest.mark.parametrize("temperature", [0, 1])
def test_scan_hmm_smoke(length, temperature):

# This should match the example in the infer_discrete docstring.
def hmm(data, hidden_dim=10):
transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
means = jnp.arange(float(hidden_dim))

def transition_fn(state, y):
state = numpyro.sample("states", dist.Categorical(transition[state]))
y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y)
return state, (state, y)

_, (states, data) = scan(transition_fn, 0, data, length=length)

return [0] + [s for s in states], data

true_states, data = handlers.seed(hmm, 0)(None)
assert len(data) == length
assert len(true_states) == 1 + len(data)

decoder = infer_discrete(
config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1)
)
inferred_states, _ = decoder(data)
assert len(inferred_states) == len(true_states)

logger.info("true states: {}".format(list(map(int, true_states))))
logger.info("inferred states: {}".format(list(map(int, inferred_states))))


def vectorize_model(model, size, dim):
def fn(*args, **kwargs):
with numpyro.plate("particles", size=size, dim=dim):
Expand Down