Skip to content

Commit

Permalink
Add utility function to evaluate log density for individual sites. (#…
Browse files Browse the repository at this point in the history
…1932)

* Add utility function to evaluate log density for individual sites.

* Rename `log_densities` to `compute_log_probs`.

* Mark haiku and flax dropout as `xfail`.
  • Loading branch information
tillahoffmann authored Dec 6, 2024
1 parent 5838a84 commit bf9c715
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
4 changes: 4 additions & 0 deletions docs/source/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ log_density
-----------
.. autofunction:: numpyro.infer.util.log_density

compute_log_probs
-----------------
.. autofunction:: numpyro.infer.util.compute_log_probs

get_transforms
--------------
.. autofunction:: numpyro.infer.util.get_transforms
Expand Down
44 changes: 34 additions & 10 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,28 @@ def process_message(self, msg):
msg["value"] = random.PRNGKey(0)


def log_density(model, model_args, model_kwargs, params):
def compute_log_probs(
model,
model_args: tuple,
model_kwargs: dict,
params: dict,
sum_log_prob: bool = True,
):
"""
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
(EXPERIMENTAL INTERFACE) Computes log of density for each site of the model given
latent values ``params``.
:param model: Python callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
:param model_args: args provided to the model.
:param model_kwargs: kwargs provided to the model.
:param params: Dictionary of current parameter values keyed by site name.
:param sum_log_prob: sum log probability over batch dimensions.
:return: Dictionary mapping site names to log of density and a corresponding model
trace.
"""
model = substitute(model, data=params)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
log_joint = jnp.zeros(())
log_joint = {}
for site in model_trace.values():
if site["type"] == "sample":
value = site["value"]
Expand All @@ -94,11 +101,28 @@ def log_density(model, model_args, model_kwargs, params):
if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob

log_prob = jnp.sum(log_prob)
log_joint = log_joint + log_prob
log_joint[site["name"]] = jnp.sum(log_prob) if sum_log_prob else log_prob
return log_joint, model_trace


def log_density(model, model_args: tuple, model_kwargs: dict, params: dict):
"""
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent
values ``params``.
:param model: Python callable containing NumPyro primitives.
:param model_args: args provided to the model.
:param model_kwargs: kwargs provided to the model.
:param params: Dictionary of current parameter values keyed by site name.
:return: Log of joint density and a corresponding model trace.
"""
log_joint, model_trace = compute_log_probs(model, model_args, model_kwargs, params)
# We need to start with 0.0 instead of 0 because log_joint may be empty or only
# contain integers, but log_density must be a floating point value to be
# differentiable by jax.
return sum(log_joint.values(), start=0.0), model_trace


class _without_rsample_stop_gradient(numpyro.primitives.Messenger):
"""
Stop gradient for samples at latent sample sites for which has_rsample=False.
Expand Down
2 changes: 2 additions & 0 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def model(data, labels):
)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_haiku_state_dropout_smoke(dropout, batchnorm):
Expand Down Expand Up @@ -263,6 +264,7 @@ def model():
svi.run(random.PRNGKey(100), 10)


@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36")
@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
def test_flax_state_dropout_smoke(dropout, batchnorm):
Expand Down
19 changes: 19 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
from numpyro.infer.reparam import TransformReparam
from numpyro.infer.util import (
Predictive,
compute_log_probs,
constrain_fn,
initialize_model,
log_density,
log_likelihood,
potential_energy,
transform_fn,
Expand Down Expand Up @@ -266,6 +268,23 @@ def test_log_likelihood(batch_shape):
)


def test_compute_log_probs():
model, data, _ = beta_bernoulli()
samples = Predictive(model, return_sites=["beta"], num_samples=1)(random.key(7))
samples = {key: value[0] for key, value in samples.items()}

logden, _ = log_density(model, (data,), {}, samples)
assert logden.shape == ()

logdens, _ = compute_log_probs(model, (data,), {}, samples)
assert set(logdens) == {"beta", "obs"}
assert all(x.shape == () for x in logdens.values())

logdens, _ = compute_log_probs(model, (data,), {}, samples, False)
assert logdens["beta"].shape == (2,)
assert logdens["obs"].shape == (800, 2)


def test_model_with_transformed_distribution():
x_prior = dist.HalfNormal(2)
y_prior = dist.LogNormal(scale=3.0) # transformed distribution
Expand Down

0 comments on commit bf9c715

Please sign in to comment.