-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Patch for differentiable code with dynamic shapes #1303
Comments
It will probably be slow on GPU for large problems but this can also be considered, @jax.jit
def expensive_function(x):
# pretend this is some non-trivial operation
return jnp.dot(x, x) * x
@jax.jit
def compute_nonzero_only(i, args):
x, out = args
def falseFun(args):
_, out = args
return out
def trueFun(args):
xi, out = args
out = out.at[i].set(expensive_function(xi))
return out
out = cond(x[i] > 0, trueFun, falseFun, (x[i], out))
return (x, out)
@jax.jit
def fun(x):
out = jnp.zeros(x.shape)
_, out = fori_loop(0, x.shape[0], compute_nonzero_only, (x, out))
return out
x = jnp.array([1.0, 0.0, 3.0])
example_optimization_gradient = jit(jax.grad(lambda x: fun(x).sum()))
np.testing.assert_allclose(example_optimization_gradient(x), [3, 0, 27]) It is jittable and doesn't require custom derivative. You can try it with your problem and make some profiling. But this won't work properly if put in |
^ could use similar method to above plus a chunked vmap type vectorization to improve speed |
Some comments now that I analyzed the memory effect of this. Regarding bounce integral stuff, for In If we want more efficient automatically differentiable bounce integrals, then the options are options
Regarding 2 and 3, now that the infrastructure, api, and testing to do this in DESC is done, it's simple to rewrite in PyTorch; just replace calls to If there's a cleaner way to do option 1 with our AD stuff while not sacrificing gpu performance that may be better. I have opened a question in JAX's discussions. |
A fundamental limitation of our auto diff tool JAX is it is unable to handle dynamic size arrays or jagged tensor operations. Both of these operations are supported by PyTorch, which I reference to better document what these operations are:
These operations would improve the performance of bounce integrals.
Patch
So, unlike
c
code, in JAX we can't usejnp.nonzero
, which constrains the algorithms we write. Sometimes it is possible to work around this by using a shape independent algorithm. For example, performing least squares regression to a variable number of points through equation 12 instead of 13: least squares example. In other cases, it is not possible to write a shape independent algorithm without sacrificing performance.We can implement a patch in
desc.backend
to work around JAX not being able to do 1 and 2. For example, the code below is a differentiable version ofjnp.nonzero
where the intention is to have a function in
desc.backend
so that we can replace this logicwith
np.nonzero
has built-in numpy vectorization, so we could rely on that if we need to do multiple calls toexpensive_function
.However, we can't use JAX vectorization on
nonzero_where
because it uses a numpy function. Hence to vectorize, one is limited to python for loops or list comprehension (which uses a c loop but still not ideal) instead ofvmap
,map
, andscan
. If the loop size is not large, this may not be that bad. (We already use list comprehension in other parts of code that is called when we compute things).The text was updated successfully, but these errors were encountered: