Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SPSA optimization method #357

Open
ankit27kh opened this issue Jun 9, 2022 · 8 comments
Open

Add SPSA optimization method #357

ankit27kh opened this issue Jun 9, 2022 · 8 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@ankit27kh
Copy link

The Simultaneous Perturbation Stochastic Approximation (SPSA) optimisation method is a faster optimisation method.

If the number of terms being optimized is p, then the finite-difference method takes 2p measurements of the objective function at each iteration (to form one gradient approximation), while SPSA takes only two measurements

It is also naturally suited for noisy measurements. Thus, it will be useful when simulating noisy systems.

The theory (and implementation) for SPSA is:

Furthermore, it is implemented:

More information:
https://www.jhuapl.edu/SPSA/

@mtthss
Copy link
Collaborator

mtthss commented Jul 14, 2022

Sounds like a nice contribution, do you want to take a stab at it?

@mtthss mtthss added help wanted Extra attention is needed enhancement New feature or request labels Aug 23, 2022
@ankit27kh
Copy link
Author

@mtthss if no one is working on it, I would like to try. Can you provide me with some general points before I start, like which files to update, what to take care of etc., as I haven't contributed to optax before.

@lockwo
Copy link

lockwo commented Feb 1, 2023

Is there any updates on this? I have previously worked with SPSA in TF (tensorflow/quantum#653) and would be interested in working on this but don't want to do redundant labor.

@ankit27kh
Copy link
Author

Hi @lockwo, are you still interested? If you can implement SPSA, it'll be of great help!

@fabianp
Copy link
Member

fabianp commented Feb 8, 2024

@ankit27kh : since there hasn't been activity for this in a year, I think it's safe for you to take over.

if you end up contributing this example, please do so to the contrib/ directory. Thanks!

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Sep 19, 2024

@fabianp I've created the following implementation of a pseudo-gradient estimator:

https://gist.github.com/carlosgmartin/0ee29182a17b35baf7d402ebdc797486

As noted in the function's docstring:

  • SPSA corresponds to the case where sampler=random.rademacher.
  • Gaussian smoothing corresponds to the case where sampler=random.normal.

I'd be happy to contribute this implementation to Optax. I could put it under optax.tree_utils with the other tree functions, if desired.

I also welcome any feedback on the code.

Note that this pseudo-gradient can be used in combination with any existing Optax optimizer: Its only role is to determine the gradient that is fed into the optimizer. Thus it acts as a replacement or analogue for jax.grad(f)(x, key).

@carlosgmartin
Copy link
Contributor

It would also be nice to have helper utility functions for estimation of the gradient via forward and central finite differences:

https://gist.github.com/carlosgmartin/a147b43f39633dcb0a985b51a5b1af0c

I'd be happy to contribute these as well.

@vroulet
Copy link
Collaborator

vroulet commented Sep 23, 2024

Thanks for looking into this @carlosgmartin.
@q-berthet is contributing to a similar approach in #827 (which should be merged soon).
Also the whole stochastic gradient estimator part of the codebase may be relevant (https://optax.readthedocs.io/en/latest/api/stochastic_gradient_estimators.html). Unfortunately the original authors of that part of the codebase have left and we are not sure that we will keep maintaining it given its poor support and adoption. But you may find interesting links there.
It may be great to discuss between both of you @q-berthet and @carlosgmartin, how to integrate such effort.

About the forward/central difference schemes, similar discussions have happened in JAX (see e.g. jax-ml/jax#15425). It seems that other users have expressed similar needs. If some libraries already propose such tools maybe it would be better to use those rather than reinventing them. (and maybe also check whether JAX ended up having such module).

Thanks again @carlosgmartin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

6 participants