Skip to content

Commit

Permalink
Add logjac to logdensity_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Oct 30, 2024
1 parent b107f9f commit 8611f8f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/examples/howto_sample_multiple_chains.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ observed = np.random.normal(loc, scale, size=1_000)
def logdensity_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logjac = log_scale
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)
return logjac + jnp.sum(logpdf)
def logdensity(x):
Expand Down
3 changes: 2 additions & 1 deletion docs/examples/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ observed = np.random.normal(loc, scale, size=1_000)
def logdensity_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logjac = log_scale
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)
return logjac + jnp.sum(logpdf)
logdensity = lambda x: logdensity_fn(**x)
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def logdensity_by_observation(self, log_scale, coefs, preds, x):
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return logpdf
return log_scale + logpdf

def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def logdensity_fn(self, log_scale, coefs, preds, x):
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return jnp.sum(logpdf)
return log_scale + jnp.sum(logpdf)

def test_smc_inner_kernel_adaptive_tempered(self):
self.smc_inner_kernel_tuning_test_case(
Expand Down

0 comments on commit 8611f8f

Please sign in to comment.