Skip to content

Commit

Permalink
Speed up Schrodinger Follmer test (#741)
Browse files Browse the repository at this point in the history
* Plotting BlackJAX with BlackJAX

* Plotting BlackJAX with BlackJAX

* Update blackjax/mcmc/metrics.py

Co-authored-by: Junpeng Lao <junpenglao@gmail.com>

* Update blackjax/mcmc/metrics.py

Co-authored-by: Junpeng Lao <junpenglao@gmail.com>

* Merged comments from Junpeng

* Speed up Follmer

---------

Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
  • Loading branch information
AdrienCorenflos and junpenglao authored Sep 23, 2024
1 parent e1d816a commit 51625a8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/vi/test_schrodinger_follmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov):

# Simulate the data
observed = jax.random.multivariate_normal(
rng_key_observed, true_mu, true_cov, shape=(10_000,)
rng_key_observed, true_mu, true_cov, shape=(25,)
)

logp_model = functools.partial(
Expand Down

0 comments on commit 51625a8

Please sign in to comment.