-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JEP: Allow finite difference method for non-jax differentiable functions #15425
Comments
Thanks for the proposal! We've discussed something along these lines amongst ourselves at various points. It's a good idea! Do you have an idea for the API you'd propose? When this has come up before, our thought was to make it a convenience wrapper around |
I actually implemented it using full-on custom primitives following this guide. Since it seems like there's interest, I'll make the PR. |
@mbmccoy any reason not to use a |
In particular |
Hello, for the context of API design, I have a small WIP library for finite difference, I implemented a finite difference version of import jax
from jax import numpy as jnp
import numpy as old_np # Not jax-traceable
import finitediffx as fdx
import functools as ft
from jax.experimental import enable_x64
with enable_x64():
@fdx.fgrad
@fdx.fgrad
def np_rosenbach2(x, y):
"""Compute the Rosenbach function for two variables."""
return old_np.power(1-x, 2) + 100*old_np.power(y-old_np.power(x, 2), 2)
@ft.partial(fdx.fgrad, derivative=2)
def np2_rosenbach2(x, y):
"""Compute the Rosenbach function for two variables."""
return old_np.power(1-x, 2) + 100*old_np.power(y-old_np.power(x, 2), 2)
@jax.grad
@jax.grad
def jnp_rosenbach2(x, y):
"""Compute the Rosenbach function for two variables."""
return jnp.power(1-x, 2) + 100*jnp.power(y-jnp.power(x, 2), 2)
print(np_rosenbach2(1.,2.))
print(np_rosenbach2(1.,2.))
print(jnp_rosenbach2(1., 2.))
402.0000951997936
402.0000951997936
402.0 |
As far as I understand we can also use import functools as ft
import jax
import jax.numpy as jnp
import numpy as onp
import finitediffx as fdx
def wrap_pure_callback(func):
@ft.wraps(func)
def wrapper(*args, **kwargs):
args = [jnp.asarray(arg) for arg in args]
func_ = lambda *args, **kwargs: func(*args, **kwargs).astype(args[0].dtype)
result_shape_dtype = jax.ShapeDtypeStruct(
shape=jnp.broadcast_shapes(*[arg.shape for arg in args]),
dtype=args[0].dtype,
)
return jax.pure_callback(
func_, result_shape_dtype, *args, **kwargs, vectorized=True
)
return wrapper
def define_finitdiff_jvp(func):
func = jax.custom_jvp(func)
@func.defjvp
def func_jvp(primals, tangents):
primal_out = func(*primals)
tangent_out = sum(
fdx.fgrad(func, argnums=i)(*primals) * dot for i, dot in enumerate(tangents)
)
return jnp.array(primal_out), jnp.array(tangent_out)
return func
@jax.jit
@define_finitdiff_jvp
@wrap_pure_callback
def np_rosenbach2(x, y):
"""Compute the Rosenbach function for two variables."""
return onp.power(1 - x, 2) + 100 * onp.power(y - onp.power(x, 2), 2)
@jax.jit
def jnp_rosenbach2(x, y):
"""Compute the Rosenbach function for two variables."""
return jnp.power(1 - x, 2) + 100 * jnp.power(y - jnp.power(x, 2), 2)
print(jax.value_and_grad(np_rosenbach2, argnums=0)(1.0, 2.0))
print(jax.value_and_grad(jnp_rosenbach2, argnums=0)(1.0, 2.0))
print(jax.value_and_grad(np_rosenbach2, argnums=1)(1.0, 2.0))
print(jax.value_and_grad(jnp_rosenbach2, argnums=1)(1.0, 2.0))
print(jax.vmap(jax.grad(np_rosenbach2), in_axes=(0, None))(jnp.array([1.0, 2.0, 3.0, 0.2]), 2.0))
print(jax.vmap(jax.grad(jnp_rosenbach2), in_axes=(0, None))(jnp.array([1.0, 2.0, 3.0, 0.2]), 2.0))
(Array(100., dtype=float32), Array(-399.9948, dtype=float32, weak_type=True))
(Array(100., dtype=float32, weak_type=True), Array(-400., dtype=float32, weak_type=True))
(Array(100., dtype=float32), Array(199.97772, dtype=float32, weak_type=True))
(Array(100., dtype=float32, weak_type=True), Array(200., dtype=float32, weak_type=True))
[-399.9948 1601.8411 8403.304 -158.45016]
[-400. 1602. 8404. -158.40001] |
So understated! Seems like you've written a package that does most of the work here. :) Given this, is an implementation of finite-differences in the core JAX package (e.g., under FYI, I'm happy to collaborate on a PR @ASEM000 if that's of interest. I'll try to get something up in the next day or so for more comment---I've had some busy days in the last week. @mattjj I'll look into using |
To be concrete about the API @froystig, here are some guiding principles that I have in mind:
API ExamplesThe "80%" use caseMost use cases should start with a simple decorator: import jax
from jax.experimental.jaxify import jaxify
from jax.experimental import enable_x64
@jaxify
def my_func(array1, array2):
return some_complex_math(array1, array2)
print(jax.value_and_grad(my_func)(x, y)) # Warn about not using 64-bit math
with enable_x64():
jax.value_and_grad(my_func)(x, y) # No warning Power use cases# The user wants control over the step size
@jaxify(step_size=1e-9, mode="forward")
def my_func(array1, array2):
return some_complex_math(array1, array2)
# The user wants per-argument control over the step size
@jaxify(step_size=(1e-9, 1e-3), mode=("forward", "center"))
def my_func(array1, array2):
return some_complex_math(array1, array2)
|
@mbmccoy – Why bundle together (a) setting up a derivative rule based on FD with (b) setting up a |
There are two reasons I can think of to support FD for the same class of functions that Equivalence: The class of functions theoretically supportable by a generic FD technique are precisely those that are theoretically supportable by a Think about it: almost by definition, the functions we'd want to finite-difference are not supported directly within JAX's JIT, so computing their values during a generic finite-difference routine requires the use of some mechanism like Conversely, it's pretty easy to see that—at least in principle—we can write a wrapper that will apply finite differences to pure functions that accept and return numpy arrays. That's, of course, the challenge I've set out for myself here. Consistency: Given the close link between the sets of functions supportable by Example documentation"""Make a function differentiable with JAX using finite differences.
[...]
The function must be pure (that it, side-effect free and deterministic based on its inputs),
and both its inputs and outputs must be ``numpy`` arrays. These requirements are the
same as for the ``jax.pure_callback`` function.
[...]
""" Note: edited for clarity and added an example docstring. |
I'd also be very interested in helping with this. I've done a similar thing using If I define a Using |
@mbmccoy I added a new functionality, |
Motivation: Make JAX relevant to many more users
JAX provides a powerful general-purpose tool for automatic differentiation, but it usually requires that users write code that is JAX-traceable end-to-end.
Significant numbers of scientific and industrial applications involve large, legacy codebases where the lift to transfer the system to end-to-end JAX is prohibitively high. In other cases, users are tied into proprietary software, or the underlying software is not written in python, and also find themselves unable to readily convert the underlying code to JAX.
Without JAX, the standard method for performing optimization involves computing derivatives using finite difference methods. While these can be integrated into JAX using custom functions, the process is cumbersome, which significantly limits the set of users able to integrate JAX into their work.
This JEP proposes a simple method for computing numerical derivatives in JAX. I expect that this change would expand the potential user base of JAX substantially, and could drive adoption of JAX across both academia and industry.
Proposal: A decorator that computes JAX derivatives using finite differences
Let's start with an example.
By wrapping the function
rosenbach2
injax_finite_difference
, it will become completely compatible with JAX's automatic differentiation tooling, and works with other JAX primitives such asvmap
.Additional options will be available for power users who may want to specify the step size, or forward vs center vs backward mode.
This is feasible.
I have working, tested code that does the above for any function that accepts and returns JAX
Array
s. If there is interest in this JEP, I will happily make a PR.Limits of this JEP
This proposal will not support XLA out of the box
The initial proposal does not aim to support XLA for finite differences. While it should be possible to overcome this limitation using a JAX Foreign Function Interface (FFI) [Issue #12632, PR], it would be best to wait until the FFI is finalized before implementing XLA for finite differences.
The downsides of
float32
increase with finite differencesUsing single-precision (32-bit) floating point numbers in finite differences may lead to unacceptably large errors in many cases. While this is not a foregone conclusion—many functions can be differentiated just fine with 32-bit floating point—we probably want to plan for mitigation strategies, e.g.,
The second strategy would likely be more important when using FFI in conjunction with XLA in later work. At this stage a warning may be all that's needed.
Related JEPs
The proposed Foreign Function Interface [Issue #12632, PR] will provide a method that allows JAX code to call out to external code code in the course of derivative computation. However, it does not create a method for computing derivatives—those must still be defined by the user.
However, we expect that the FFI combined with our finite-difference method would enable "the dream": nearly-arbitrary user code fully-integrated with JAX using a single decorator.
The text was updated successfully, but these errors were encountered: