Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unadjusted HMC, and generalizing the MH code #747

Open
reubenharry opened this issue Oct 15, 2024 · 5 comments
Open

Unadjusted HMC, and generalizing the MH code #747

reubenharry opened this issue Oct 15, 2024 · 5 comments

Comments

@reubenharry
Copy link
Contributor

Current behavior

The current HMC implementation performs n steps of discretized Hamiltonian dynamics and then an MH accept/reject.

Desired behavior

It would be useful to also have unadjusted HMC (so, the same but without the MH adjustment). This just amounts to running the integrator, so is easy to implement - the question is just whether to have a wrapper around it and present it as a separate algorithm (along with a tuning algorithm in the vein of MCLMC).

Relatedly, if we did that, we could then write standard (i.e. adjusted) HMC as a function of the unadjusted kernel, i.e. in pseudocode:

adjusted_kernel(base_kernel):
    def f(state):
        new_state, info = run_n_steps_of(base_kernel)
        accept = some_function_of(info)
        if accept:
            return new_state
         else: return state 

Note that blackjax does something pretty similar already (see hmc_proposal).

@gil2rok
Copy link
Contributor

gil2rok commented Oct 15, 2024

As far as I can tell, the HMC code implements the default HMC algorithm which does involve an accept reject step. If you want an unadjusted HMC-like algorithm without an accept reject step, it should be easy to make in Blackjax. If I'm understanding correctly, this appears to be more of an API question (should we allow unadjusted HMC in the main HMC algorithm) than of a limitation of Blackjax code.

@junpenglao
Copy link
Member

Maybe the easiest is to refactor hmc_proposal

def hmc_proposal(

So user can control the behavior of sample_proposal:

  1. comparing initial state and last state (current)
  2. take the last state (unadjusted HMC)
  3. sample from the whole trajectory (Implement multinomial HMC #383)

@junpenglao
Copy link
Member

Actually probably not 3 because there are more efficient way to implement multinomial HMC without storing the full trajectory.

@reubenharry
Copy link
Contributor Author

Yes, this is more of an API question about whether to expose unadjusted HMC as an inference algorithm. Junpeng's suggestion seems reasonable. Ideally, it would be nice if sample_proposal was refactored to be general enough to work for HMC, adjusted MCHMC (on branch) and MALA, in my opinion, but this isn't critical.

@reubenharry
Copy link
Contributor Author

Something to note is that the code for MCLMC (https://github.com/blackjax-devs/blackjax/blob/main/blackjax/mcmc/mclmc.py) differs only very slightly from what you'd do for unadjusted HMC (the integrator is different), so perhaps this should be generalized appropriately

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants