Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stein Variational Gradient Descent #385

Closed
rlouf opened this issue Oct 22, 2022 · 12 comments
Closed

Add Stein Variational Gradient Descent #385

rlouf opened this issue Oct 22, 2022 · 12 comments
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed sampler Issue related to samplers vi Variational Inference

Comments

@rlouf
Copy link
Member

rlouf commented Oct 22, 2022

https://arxiv.org/abs/1608.04471

@rlouf rlouf added sampler Issue related to samplers vi Variational Inference help wanted Extra attention is needed good first issue Good for newcomers enhancement New feature or request labels Oct 22, 2022
@rlouf rlouf mentioned this issue Dec 11, 2022
5 tasks
@antotocar34
Copy link
Contributor

antotocar34 commented Dec 30, 2022

I've been having a go at implementing this.

The problem I'm encountering is that the VIAlgorithm API is not really suitable for this algorithm.
SVGD doesn't have an approximation and sampling phase. In fact no random sampling happens
at all, SVGD is a deterministic algorithm.

SVGD is really just a gradient descent type procedure, which means the user will probably want the following
exposed to them:

  1. A function that does one step of the optimization procedure
  2. A way to add an arbitrary optimizer (adam, adagrad, etc...)
  3. A way to take batch (in parallel if desired) gradients of the log posterior

Here are my thoughts:
1.
The existing mcmc kernel API is almost suitable for SVGD, except that there is no need for RNG keys
since the algorithm is deterministic.
2.
The optax library provides many state of the art optimizers, so it would be nice to have some compatibility with it.
We could define an Optimizer Protocol that implements the same methods as optax optimizers
(namely init and update), and have those users who want a custom optimizer adhere to this Protocol.
3.
I believe this can be done as suggested in #319.

Proposal Implementation

(To be clear, this would not be the final implementation, just to pin down ideas)

In ./blackjax/base.py

class SvgdStepFn(Protocol):
    def __call__(self, state: State) -> State:
        ...

class SVGDAlgorithm(NamedTuple):
    init: InitFn
    step: SvgdStepFn

In ./blackjax/svgd.py

class SVGDState(NamedTuple):
    position: PyTree
    kernel_parameters: PyTree
    opt_state: Any

def init(initial_position: PyTree, kernel_parameters: PyTree) -> SVGDState:
    opt_state = optimizer.init(initial_position)
    return SVGDState(initial_position, kernel_parameters, opt_state)

def step(state: SVGDState, log_p: Callable, kernel: Callable, optimizer: Optimizer) -> SVGDState:
    # Computation where we call kernel as kernel(x,y, **state.kernel_params)
    ...

In ./blackjax/kernels.py

class Optimizer(Protocol):
    init: Callable
    update: Callable 

class svgd:
    init = staticmethod(vi.svgd.init)
    step = staticmethod(vi.svgd.step)
    def __new__( cls, log_p: Callable, kernel: Callable, optimizer: Optimizer) -> SVGDAlgorithm:

        def init_fn(initial_position: PyTree, kernel_parameters: Dict):
            return cls.init(initial_position, state.kernel_parameters)

        def step_fn(
                state: SVGDState,
                ):
            return cls.step(state, log_p, kernel, optimizer)

    return SVGDAlgorithm(init_fn, step_fn)

An example from the user's point of view

svgd = svgd(log_p, kernel, optimizer) # example optimizer: optax.adam(0.1)
initial_state = svgd.init(initial_position)
step = jax.jit(svgd.step)

def update_kernel_parameters(state: SVGDState) -> SVGDState:
    position, kernel_parameters, opt_state = state
    # f could be the median heuristic for example
    return SVGDState(position, f(kernel_parameters, position), opt_state)


num_iterations = 5000
for _ in range(num_interations):
    state = svgd.step(state)
    state = update_kernel_parameters(state)

@junpenglao
Copy link
Member

Thanks for the detailed analysis!
In terms of high level API, it sounds like it might be useful to create a new meta API. How about a Approximation class? I am thinking algorithm like maximum a posteriori estimation (MAP), Laplace approximation and INLA, etc.

@rlouf
Copy link
Member Author

rlouf commented Jan 6, 2023

@antotocar34 Your proposal sounds about right. I think Pathfinder's current implementation, and the implementation of mean-field VI will need to be updated to come close to the approach you suggested.

We could maybe start with implementing the kernelized Stein discrepancy #384 in a first PR. I need it for another idea 😊

Any reason why update_kernel_parameters is called outside of the step function?

@antotocar34
Copy link
Contributor

@rlouf

Any reason why update_kernel_parameters is called outside of the step function?

If the f function (inside update_kernel_parameters) is not jit-able, then putting update_kernel_parameters inside step means that step is not jit-able anymore, right?
Meaning if the user wants to adapt their kernel parameters with a non jit-able function, then they can't jit the step function.

@rlouf
Copy link
Member Author

rlouf commented Feb 24, 2023

If f is not jittable then the users can wrap it in an host_callback (see this issue)

Assuming f returns a scalar:

import jax.experimental.host_callback as hcb

def f_non_jit(x):
    return hcb(f, x, result_shape=())

Do you have any other blockers on this? If so you can always open a draft PR!

@albcab
Copy link
Member

albcab commented Feb 24, 2023

It seems to me like the main distinction here is that SVGD is doing gradient descent on a bunch of particles, not really approximating a distribution in itself but just samples from that distribution. In that sense SVGD is closer to SMC than VI, even if theoretically it's minimizing a distance between the target distribution and (samples from) an approximate distribution.

@antotocar34's proposal to code it as a sampling algorithm rather than a approximate inference algorithm is best. But other VI algorithms should be implemented like meanfield VI currently is because there is value in having access to the approximate distribution and not only to samples from it.

@rlouf
Copy link
Member Author

rlouf commented Feb 25, 2023

I think it also makes sense to give fine-control over the optimisation for VI algorithms, where the step function perform a one-step update of the approximation by sampling the ELBO and performing an optimisation step. Among other things, it allows you to run the procedure for a few more steps if you realise the approximation hasn't quite converged.

@antotocar34
Copy link
Contributor

antotocar34 commented Mar 7, 2023

My blocker is the following:

So I wrote the algorithm initially assuming I could destructure the Pytree of particles into an array of dimension (num_particles, parameter_dimension). That implementation seems to work.

The issue is that getting the particles array back into the original PyTree structure is not straightforward.

For example

import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten

def pytree_to_array(particles_pytree):
    _particle_array, pytree_def = tree_flatten(particles_pytree)
    particle_array = jnp.stack(_particle_array).squeeze().T
    return particle_array, pytree_def

def array_to_pytree(particle_array, pytree_def):
    arr = particle_array.T     # Works for 1, not for 2 and 3
    # arr = [particle_array.T]   # Works for 2 and 3, not for 1
    # arr = arr                  # Works for 3, not for 1 and 2
    return tree_unflatten(
            pytree_def,
            arr # What to put here?
            )

particles_1 = {
        "mu": jnp.array([1,2,-3,-1/2]),
        "sigma": jnp.array([1.1,2.3,0.1,2])
        }

particles_2 = {
        "theta": jnp.array([0.6,.54,0.3,0.1]),
        }

particles_3 = jnp.array([ [0.6,.54,0.3,0.1], [0.4,.1,.3,.4] ])

particle_array, pytree_def = pytree_to_array(particles_1)
pytree = array_to_pytree(particle_array, pytree_def)
print(pytree)

Playing with the above code should reveal the problem I'm facing.

I tried checking what the pytree looks like in the array_to_pytree function and assigning arr conditional on that but this is not possible as any conditional will be eagerly evaluated when jitted.

Possible solutions:

  1. There is a better way to go from particles Pytree to array and back?
  2. Rewrite the algorithm in keeping the particles as a PyTree. This doesn't seem like the right approach however, as I assume indexing PyTrees will be substantially slower, and not being to use vmap is quite inconvenient.
  3. Maybe I'm missing something obvious, please tell me :)

@antotocar34
Copy link
Contributor

For reference, here is a current version of the algorithm.

https://gist.github.com/antotocar34/3e6a762df1427a7db6105cfe72f66185

@albcab
Copy link
Member

albcab commented Mar 8, 2023

This works for (pretty sure) all pytrees that have the same dimension 0, i.e. number of particles:

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

def pytree_to_array(particles_pytree):
    particle_array = jax.vmap(lambda p: ravel_pytree(p)[0])(particles_pytree)
    example_particle = jax.tree_util.tree_map(lambda p: p[0], particles_pytree)
    _, unravel_fn = ravel_pytree(example_particle)
    return particle_array, unravel_fn

def array_to_pytree(particle_array, unravel_fn):
    return jax.vmap(lambda p: unravel_fn(p))(particle_array)

particles_1 = {
        "mu": jnp.array([1,2,-3,-1/2]),
        "sigma": jnp.array([1.1,2.3,0.1,2])
        }

particles_2 = {
        "theta": jnp.array([0.6,.54,0.3,0.1]),
        }

particles_3 = jnp.array([ [0.6,.54,0.3,0.1], [0.4,.1,.3,.4] ])

particles_4 = {
        "mu": jnp.zeros((4, 4)),
        "diag_cov": jnp.ones((4, 4))
        }
        
particles_5 = {
        "mu": jnp.zeros((4, 5)),
        "dense_cov": jnp.ones((4, 5, 5))
        }

particle_array, unravel_fn = pytree_to_array(particles_5)
pytree = array_to_pytree(particle_array, unravel_fn)
print(pytree)

@antotocar34
Copy link
Contributor

This is much better, thanks a million Alberto.

@albcab
Copy link
Member

albcab commented Mar 9, 2023

No problem Antoine. Let us know here if you have any other blockers (I'm interested in having SVGD implemented on blackjax) 🚀

@antotocar34 antotocar34 mentioned this issue Mar 24, 2023
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed sampler Issue related to samplers vi Variational Inference
Projects
None yet
Development

No branches or pull requests

4 participants