Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.grad on fax.constrained.cga_ecp solutions does not work #6

Closed
lukasheinrich opened this issue Jan 28, 2020 · 7 comments
Closed
Assignees

Comments

@lukasheinrich
Copy link

Hello,

I'm trying to see whether I can use fax in order to find gradients of a fixed point function (an optimization problem) wrt to problem parameters

consider f(x,y) = -(x**2 + (y[0]-a[0])**2 + (y[1]-a[1])**2 and h(x,y) = x-2

def func(a):
    return fax.constrained.cga_ecp(lambda x,y: -(x**2 + (y[0]-a[0])**2 + (y[1]-a[1])**2),lambda x,y: x-2, (2.,[3.,3.]),lr_func = 0.4).value
>>> jax.jit(func)([0.2,0.3])                                                                                                                                                                              
(DeviceArray(1.9999994, dtype=float32),
 [DeviceArray(0.2, dtype=float32), DeviceArray(0.3, dtype=float32)])```   

but when I try to compute gradients, I get the following error

>>> jax.jit(jax.grad(func))([0.2,0.3])  
...
NotImplementedError: Forward-mode differentiation rule for 'cond' not implemented

is there any way around this?

@gehring gehring self-assigned this Jan 28, 2020
@gehring
Copy link
Owner

gehring commented Jan 28, 2020

Can you tell me a bit more about what it is that you are trying to do? There might be another way to compute what you need (after all, this is kind of the core idea behind most implementations in this package).

For context, the functional control flow used in fax.constrained.cga_ecp for applying cga is the cause of this issue. jax doesn't support differentiating through some types of control flow "ops".

The current implementation doesn't explicitly define any special differentiation rules so jax.grad will simply try to do reverse-mode differentiation. In this context, this would result in back-propagating through each solving step. This is far from ideal (imo) but if you really wanted to do that, you might have success with using unroll=True here. Just make sure the number of iterations isn't too big since it will always run for the max number of steps.

Alternatively, you might be able to leverage some of the tools already implemented in fax to compute the derivatives you want more efficiently; something like fax.implicit.two_phase_solver might work.

Once I understand why you want to get the gradient of fax.constrained.cga_ecp, I might be able to help you find a more efficient way to do this. An implicit approach would first solve the constraint optimization then solve for the gradient much like a standard backprop approach. You might even be able to use a "competitive differentiation" approach which would allow you to solve the high-level optimization concurrently with the low-level subproblem/optimization without having to fully compute the gradient of fax.constrained.cga_ecp.

@lukasheinrich
Copy link
Author

thanks @gehring for the prompt reply. I was hoping that using jax.grad on cga_ecp would give me exactly an implicit formulation, but it seems I'm mistaken.

I'm investigating fixed-point / implicit models in order to find derivatives of solutions of optimization problems wrt problem parameters without differentiating through an optimization loop in the context of particle physics inference https://github.com/scikit-hep/pyhf

The main setup is that a likelihood function L(θ,x), where x is the observed data and θ are the likelihood parameters is further parametrized by some model (hyper-)parameters m: L(θ,x;m). The maximum likelihood estimate for a given observation vector x0 is a function of those hyper-parameters θ*(m) = argmax_θ L(θ,x0,m).

I'd like to use fax in order to find derivatives ∂θ*/∂m

in pyhf currently the θ* are found using the SLSQP method in scipy.optimize.minimize, since we have both parameter bounds and equality constraints, but both can be removed as a first start.

If there is a way to express the above implicitly that would be great.

@lukasheinrich
Copy link
Author

lukasheinrich commented Feb 3, 2020

hi @gehring, do you think something like the above would be feasible w/ fax? Do you need any more info?

edit:

seems like this gives me a good start

screenshot

@gehring
Copy link
Owner

gehring commented Feb 3, 2020

Sorry for the delays in responding.

Keep in mind that when we started fax, its primary purpose was to provide tools to compute derivatives of fixed point. Since many of the things we've implemented are more general than that, we've allowed ourselves to structure the code such that it could be used in more general settings but it is useful to keep in mind when trying to understand our implementations.

One way to transform these subproblems into something fax can handle is by formulating constraints which implicitly define a solution to the sub-problem, e.g., kkt conditions. Unfortunately, for now, we only handle equality constraints, but we are looking into supporting inequalities.

If you can formulate the solution to your constrained optimization as a (possibly multi-dimensional) fixed point problem, fax can help you tackle it using a twophase method. This involves first solving the sub-problem (however you like) then using the fixed point's equality constraint to solve for the derivatives. We've implemented a general method which handles arbitrary (contractive) fixed points so as long as you can formulate the solution of your sub-problem as a fixed point, you can use our implementation. This is an example of implicit differentation.

However, this general way of handling implicit differentation might not always be the best thing to do. You ought to leverage domain knowledge when you can, e.g., re-using factorizations from the "forward" pass in the "backwards" pass (think of solving for 'x' in inv(A)x = y and its derivatives). I can point you to some examples/references if this is a direction you'd like to pursue, but, this is a bit outside the scope of fax.

Otherwise, assuming you can formulate the solution to your constrained optimization as an equality constraint, the second approach would be what we call competitive differentiation which solves incrementally both the outer-problem and the sub-problem at the same time using the competitive descent method applied to the Lagrangian. This would be the cga approach. In this instance, the langrange multipliers can be thought of as approximating the derivative you would solve for when using a two phase implicit method.

The advantage of this second approach is that you can start improving the "meta" parameters before having ever fully solved the sub-problem. This can lead to big improvements in some cases but it isn't quite clear yet when it is or isn't favorable over a two-phase method. It is probably a good idea to explore both and see which approach seems most promising for your particular use case.

We've implemented some wrappers to help formulate the Lagrangian the way our implementation of cga expects it and to help apply our twophase implementation, but in order to use it to compute implicit derivatives, you'll need to specify the correct constraints or fixed point. Since these wrappers are meant to help you compute your gradients for you, they were never intended to play nice with jax.grad. Additionally, jax.grad is asking to compute the adjoint using reverse-mode differentation which fits nicely with a two-phase formulation but doesn't really make sense with a competitive formulation.

You might find helpful to look through how we used it in a control/"reinforcement learning" setting. You can find the repo here.

@f-t-s @pierrelux if you have a minute, I would appreciate if you could read through my blurb to make sure I didn't misrepresent anything or accidentally said something false or misleading!

@gehring
Copy link
Owner

gehring commented Feb 3, 2020

hi @gehring, do you think something like the above would be feasible w/ fax? Do you need any more info?

edit:

seems like this gives me a good start

screenshot

Yes! That looks like correct approach for using implicit differentation on the contractive fixed point case. This is the type of problem two_phase_solver is meant to solve.

@lukasheinrich
Copy link
Author

i'm not sure I caught the right twitter handles but this is what came out of this

https://twitter.com/lukasheinrich_/status/1235622557295849473

thanks again for this library!

we have some JIT'ing issues that might be interesting to the team:

jax-ml/jax#2346

@gehring
Copy link
Owner

gehring commented Mar 5, 2020

@lukasheinrich Happy we could help!

Thanks for the mention! You have (now had?) the right handle but seeing it compelled me to update it to something more meaningful (now @ClementGehring) and that I'd be happy to use. I don't really use twitter but would there be a way to update the mention to my new handle. No worries if there isn't or if it's a hassle!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants