-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend function signature for InitFn #627
Conversation
@AdrienCorenflos for changes in blackjax/mcmc/marginal_latent_gaussian.py |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #627 +/- ##
==========================================
+ Coverage 99.00% 99.05% +0.04%
==========================================
Files 58 58
Lines 2615 2632 +17
==========================================
+ Hits 2589 2607 +18
+ Misses 26 25 -1 ☔ View full report in Codecov by Sentry. |
Why did the delta parameter get out of the step function? This seems like an unrelated change. |
I am not familiar with |
The step is mostly used for calibration purposes, but creating a new kernel
is very expensive so you don't want the user to do that every time they
want to change it.
…On Mon, 11 Dec 2023, 14:24 Junpeng Lao, ***@***.***> wrote:
Why did the delta parameter get out of the step function? This seems like
an unrelated change.
I am not familiar with marginal_latent_gaussian, from the test it doesnt
seems it is needed to be in step, but if the algorithm is intend to have
different scale even for top level user API, I will revert those changes.
—
Reply to this email directly, view it on GitHub
<#627 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEYGFZYCN7KRC6ERQT6NP33YI4CQFAVCNFSM6AAAAABAPWU5U6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNJQGA3TGMZWG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Could user does this instead (which are in general how Blackjax sees top level API): from blackjax.mcmc import marginal_latent_gaussian as mgrad
init, kernel = mgrad.init_and_kernel(...)
# does tuning for delta
...
# Actual sampling, delta is now fixed.
algorithm = mgrad.mgrad_gaussian(logdensity_fn, cov, delta)
state = algorithm.init(...)
for ...
state, info = algorithm.step(rnd_key, state) |
I suppose, but this would be a waste of computational resources. The init_and_kernel does some handling of the covariance matrix (SVD) which is easily the most expensive part. Why would you want to do it again post calibration? |
To maintain the same API for what we consider "top-level". This is done to be easier to compare different sampler. I suppose your concern here is that the 2nd call to |
* Extend function signature for InitFn * Fix formatting
Close #619.
This PR introduce
rng_key
as optional input toinitFn
protocal.For sampler like
dynamic_hmc
andghmc
, theinit_fn
of the top level API does not follow the old patter, as it needs an rng_key to generate part of thestate
. While it is possible to set a default rng_key in the class__new__
, we actually wants to keep the rng_key as input toinit_fn
as we usually want to vmap init so that it takes a vector of PRNG_key to initialized parallel chains (see eg: in meads_adaptationblackjax/blackjax/adaptation/meads_adaptation.py
Line 209 in 08e0d75
blackjax/blackjax/adaptation/meads_adaptation.py
Lines 239 to 240 in 08e0d75
Thus, in this PR we introduce rng_key to the
InitFn
, with some minor refactoring to other top level API follows the same contract and easier to plug into utility sampling functionrun_inference_algorithm