From 8611f8fc30ab43f9bcd8646682b9ecce9eb3ff69 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Wed, 30 Oct 2024 07:51:20 +0100 Subject: [PATCH] Add logjac to logdensity_fn --- docs/examples/howto_sample_multiple_chains.md | 3 ++- docs/examples/quickstart.md | 3 ++- tests/smc/__init__.py | 2 +- tests/smc/test_inner_kernel_tuning.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/examples/howto_sample_multiple_chains.md b/docs/examples/howto_sample_multiple_chains.md index a5b6566f8..c2947e29f 100644 --- a/docs/examples/howto_sample_multiple_chains.md +++ b/docs/examples/howto_sample_multiple_chains.md @@ -57,8 +57,9 @@ observed = np.random.normal(loc, scale, size=1_000) def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) + logjac = log_scale logpdf = stats.norm.logpdf(observed, loc, scale) - return jnp.sum(logpdf) + return logjac + jnp.sum(logpdf) def logdensity(x): diff --git a/docs/examples/quickstart.md b/docs/examples/quickstart.md index 870e5df9a..a290bfdad 100644 --- a/docs/examples/quickstart.md +++ b/docs/examples/quickstart.md @@ -48,8 +48,9 @@ observed = np.random.normal(loc, scale, size=1_000) def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) + logjac = log_scale logpdf = stats.norm.logpdf(observed, loc, scale) - return jnp.sum(logpdf) + return logjac + jnp.sum(logpdf) logdensity = lambda x: logdensity_fn(**x) diff --git a/tests/smc/__init__.py b/tests/smc/__init__.py index 006d7ba38..71d59e529 100644 --- a/tests/smc/__init__.py +++ b/tests/smc/__init__.py @@ -9,7 +9,7 @@ def logdensity_by_observation(self, log_scale, coefs, preds, x): scale = jnp.exp(log_scale) y = jnp.dot(x, coefs) logpdf = stats.norm.logpdf(preds, y, scale) - return logpdf + return log_scale + logpdf def logdensity_fn(self, log_scale, coefs, preds, x): """Linear regression""" diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 7d6190af5..080a02749 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -69,7 +69,7 @@ def logdensity_fn(self, log_scale, coefs, preds, x): scale = jnp.exp(log_scale) y = jnp.dot(x, coefs) logpdf = stats.norm.logpdf(preds, y, scale) - return jnp.sum(logpdf) + return log_scale + jnp.sum(logpdf) def test_smc_inner_kernel_adaptive_tempered(self): self.smc_inner_kernel_tuning_test_case(