Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT-ed calculation of Hessian [grad(grad)] fails with JAX #2163

Open
1 task done
quantshah opened this issue Feb 3, 2022 · 15 comments
Open
1 task done

JIT-ed calculation of Hessian [grad(grad)] fails with JAX #2163

quantshah opened this issue Feb 3, 2022 · 15 comments

Comments

@quantshah
Copy link
Contributor

quantshah commented Feb 3, 2022

Expected behavior

I was trying to compute the Hessian and saw that the Jax interface breaks down if we have the JIT on. Without JIT, it works fine. The error seems to be due to the non-availability of JVPs in the host_callback bridge between PL and Jax. To make it work, just remove the @jax.jit from the definition of the circuit.

@josh146 and I discussed this over slack and it seems a bit strange to have something that works with the JIT off not work when we simply JIT things.

Actual behavior

can't apply forward-mode autodiff (jvp) to a custom_vjp function.
JVP rule is implemented only for id_tap, not for call.

Additional information

No response

Source code

import jax
import pennylane as qml

dev = qml.device("default.qubit.jax", wires=1, shots=100)


@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(a):
    qml.RX(a, wires=0)
    return qml.expval(qml.PauliZ(0))

hess = jax.grad(jax.grad(circuit))
print("hessian", hess(0.5))

Tracebacks

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
    print("hessian", hess(0.5))
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/qnode.py", line 549, in __call__
    res = qml.execute(
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/__init__.py", line 412, in execute
    res = _execute(
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 83, in execute
    return _execute(
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 218, in _execute
    return wrapped_exec(params)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: JVP rule is implemented only for id_tap, not for call.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
    print("hessian", hess(0.5))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 918, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 993, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 2312, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 116, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 918, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 1000, in value_and_grad_f
    g = vjp_py(np.ones((), dtype=dtype))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/tree_util.py", line 326, in <lambda>
    func = lambda *args, **kw: original_func(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 2219, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/tree_util.py", line 326, in <lambda>
    func = lambda *args, **kw: original_func(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 222, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 558, in call_transpose
    out_flat = primitive.bind(fun, *all_args, **new_params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 687, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 263, in memoized_fun
    ans = call(fun, *args)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 771, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 228, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 690, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 180, in wrapped_exec_bwd
    vjps = host_callback.call(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 658, in call
    return _call(callback_func, arg, result_shape=result_shape,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 708, in _call
    flat_results = outside_call_p.bind(*flat_args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 288, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 1204, in _outside_call_jvp_rule
    raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: JVP rule is implemented only for id_tap, not for call.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
    print("hessian", hess(0.5))
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 180, in wrapped_exec_bwd
    vjps = host_callback.call(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 658, in call
    return _call(callback_func, arg, result_shape=result_shape,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 708, in _call
    flat_results = outside_call_p.bind(*flat_args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 1204, in _outside_call_jvp_rule
    raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
NotImplementedError: JVP rule is implemented only for id_tap, not for call.

System information

Name: PennyLane
Version: 0.21.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/shahnawaz/Dropbox/dev/pennylane
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic_version, toml
Required-by: PennyLane-Forest, PennyLane-Lightning
Platform info:           macOS-11.2.3-x86_64-i386-64bit
Python version:          3.9.7
Numpy version:           1.21.4
Scipy version:           1.7.3
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.20.2)
- forest.numpy_wavefunction (PennyLane-Forest-0.20.0)
- forest.qvm (PennyLane-Forest-0.20.0)
- forest.wavefunction (PennyLane-Forest-0.20.0)
- default.gaussian (PennyLane-0.21.0.dev0)
- default.mixed (PennyLane-0.21.0.dev0)
- default.qubit (PennyLane-0.21.0.dev0)
- default.qubit.autograd (PennyLane-0.21.0.dev0)
- default.qubit.jax (PennyLane-0.21.0.dev0)
- default.qubit.tf (PennyLane-0.21.0.dev0)
- default.qubit.torch (PennyLane-0.21.0.dev0)
None

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@quantshah quantshah added the bug 🐛 Something isn't working label Feb 3, 2022
@antalszava
Copy link
Contributor

antalszava commented Feb 3, 2022

Hi @quantshah

it seems a bit strange to have something that works with the JIT off not work when we simply JIT things.

This very much comes down to how the JAX JIT interface uses host_callback.call at the moment: (see newer comment here).

This very much comes down to what JAX offers at the moment:

Maybe there's a way of using host_callback.id_tap with jax.jit in the way we'd like to, but so far it doesn't seem to be the case.

@antalszava antalszava removed the bug 🐛 Something isn't working label Feb 3, 2022
@antalszava antalszava changed the title [BUG] JIT-ed calculation of Hessian [grad(grad)] fails JIT-ed calculation of Hessian [grad(grad)] fails with JAX Feb 3, 2022
@quantshah
Copy link
Contributor Author

Thanks @antalszava for the explanation. Feel free to leave this issue open till there is a resolution or close it since this is an issue with Jax and not PL.

@antalszava
Copy link
Contributor

Sure :) I might leave it open, just so that there's a point of reference if this becomes a question for others too.

@antalszava
Copy link
Contributor

Note: this issue could be potentially resolved by a refactor to the JAX JIT interface. We have this on our radar and would like to look into a resolution in the coming weeks.

In specific, at the moment the g cotangent value is being used as input parameters here. This g value instead should be applied to the result of host_callback.call.

@josh146
Copy link
Member

josh146 commented Apr 18, 2022

@antalszava just curious: if we move towards having the quantum device itself compute the VJP, this would require that the cotangent vector g must be passed through the host_callback? One example would be the adjoint method with lightning.qubit.

@antalszava
Copy link
Contributor

Likely not. JAX seems to assume that g is applied to the residuals to yield the jacobian.

Specifically for adjoint with mode="forward", the jacobian could be "passed" to the registered backward function as a residual using the following pattern (as suggested on this discussion thread here):

params = jnp.array([0.1, 0.2])

@jax.custom_vjp
def wrapped_exec(params):
    y = params ** 2, params ** 3
    # don't need compute jacs here
    return y

def wrapped_exec_fwd(params):
    y = wrapped_exec(params)
    jacs = jnp.diag(2 * params), jnp.diag(3 * params ** 2) # compute here
    return y, jacs # don't need params here

def wrapped_exec_bwd(res, g):
    jac1, jac2 = res
    g1, g2 = g
    return (g1 @ jac1) + (g2 @ jac2),

wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd)
jax.jacobian(wrapped_exec)(params)

@quantshah
Copy link
Contributor Author

Hi everyone, getting back to this thread as I saw that in Jax, there is a possibility to implement higher order gradients (VJPs) with host_callback using an outside implementation for the gradient computation (e.g., TensorFlow). See the discussion here: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-tensorflow-function-with-reverse-mode-autodiff-support

I had a look at the implementation here: https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py#L100 but haven't figured out completely what is happening in the custom backward pass that allows one to compute higher order gradients with host_callback. It feels like somehow they are just hooking up the TensorFlow autodiff mechanism to the Jax custom_vjp definitions and it works all the way up to the higher order derivative.

Just putting this out here for reference in the future in case we look into this again and it is helpful.

@PhilipVinc
Copy link
Contributor

I just stumbled on this issue, I think a way to avoid this issue would be to define a new jax primitive operation jax.core.Primitive("qml_expval"), defining the CPU and GPU implementations, and then also how to differentiate it.
Then qml_expval will be perfectly equivalent to any other jax native operation (so everything can be supported).

The main 'complication' is that I'm not sure if you can feed a host_callback as the implementation.
I think it could be possible, but should be tried.
I'm sure you could feed a C function (because that's what it natively supports), which then trampolines back into python code.

The interface is quite stable and has not changed in the last 2 years.

Ps: If there's any interest for that, I don't have the time to put in, but I can provide some guidance. I did that already for two different packages.

@quantshah
Copy link
Contributor Author

This is interesting, there is a nice example here of how this is done all the way upto JITing: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

But the problem still remains that we cannot use XLA operations to do the expval evaluation as the expval evaluation happens purely on the quantum device right? So there has to be a break in the computational graph somewhere (what host_callback does). So I don’t know if this would still work.

Unless there is a way to use the XLA custom call (https://www.tensorflow.org/xla/custom_call) to get the value of the expval.

@antalszava
Copy link
Contributor

Hi @PhilipVinc, thank you for the suggestion! 🎉

Personally, I'm new to creating a custom JAX primitive, any help and guidance here would definitely be appreciated. 🙂

Also wondering about the question that Shahnawaz mentioned: how could introducing the new primitive help with the specific error originally reported? It would seem that the issue is specific to the invocation of host_callback.call. Assuming that jax.core.Primitive is compatible with host_callback.call and jax.core.Primitive("qml_expval") is defined, wouldn't the the same issue arise with its custom gradient?

@PhilipVinc
Copy link
Contributor

But the problem still remains that we cannot use XLA operations to do the expval evaluation as the expval evaluation happens purely on the quantum device right? So there has to be a break in the computational graph somewhere (what host_callback does). So I don’t know if this would still work.

There's two graphs at play here. The one used during function transformations, pre-jax.jit, (which is more of a tape than a graph, but whatever). For this one, a primitive is a node in the graph. It must specify, much akin to host_callback, what's the input and output shape (primitive.def_abstract_eval), and how that primitive transforms under passes like vmap (batching-primitive_batching), vjp and jvp (ad.primitive_jvps, but does not need to specify anything else.

The issue with host_callback is that due to a bunch of issues, you cannot customise how host_callback transforms under vmap vjp and jvp. But you can do that for your custom primitive.

Then, you must also tell jax to what XLA operation the primitive corresponds to when he compiles (xla.backend_specific_translations["cpu"][primitive])
XLA Operations contain a reference to a C function that performs that operation.

You can always call ANY C-code.
For example, I used this mechanism to support MPI operations inside of jax-jitted functions with mpi4jax (a good, self-contained file is this here implementing the primitive ) or to support numba operations inside of jitted functions with numba4jax.

But, I guess, you could also enqueue an host_callback operation.

@PhilipVinc
Copy link
Contributor

PhilipVinc commented May 17, 2022

Assuming that jax.core.Primitive is compatible with host_callback.call and jax.core.Primitive("qml_expval") is defined, wouldn't the the same issue arise with its custom gradient?

What issue exactly? I'm not familiar with the depth of pennylane's source, so if you have a short example, even in pseudocode, that would help clarify

@antalszava
Copy link
Contributor

What issue exactly? I'm not familiar with the depth of pennylane's source, so if you have a short example, even in pseudocode, that would help clarify

Sure. 🙂 At the moment, the use of @jax.custom_vjp in PennyLane is not ideal because we are passing the cotangent vectors (g) along with the input parameters to the host_callback.call invocation:

args = tuple(params) + (g,)
vjps = host_callback.call(
    non_diff_wrapper,
    args,
    result_shape=jax.ShapeDtypeStruct((total_params,), dtype),
)

Passing g helped with using the qml.gradients.batch_vjp function we have internally, a function called in the non_diff_wrapper function and shared across other machine learning frameworks too (including TensorFlow and PyTorch).

At the same time, passing g along the other arguments creates issues because g may become a BatchTrace object when using certain transforms (e.g., jax.jacobian that uses jax.vmap) and this seems to be the culprit for the original error:

NotImplementedError: JVP rule is implemented only for id_tap, not for call.

It would seem that there may be two components to a solution here:

  • g should not be passed as an argument, but rather be used to mutate the output of host_callback.call (as per how custom_vjp should work in JAX);
  • The logic in PennyLane for computing batches of VJPs is adjusted (either updating qml.gradients.batch_vjp or implementing it to be aligned with the logic here).

With those changes we should have no BatchTrace objects flow through the host_callback.call's invocation and should be able to implement the logic for supporting the original Hessian computation.

@PhilipVinc
Copy link
Contributor

At the same time, passing g along the other arguments creates issues because g may become a BatchTrace object when using certain transforms (e.g., jax.jacobian that uses jax.vmap) and this seems to be the culprit for the original error:

Are you aware of jax.custom_batching.custom_vmap ? If you define a custom_vmap rule for your custom_vjp you might sidestep the issue entirely.

@antalszava
Copy link
Contributor

Wasn't aware of it! 😲 Will try this out thank you. 🙂 I see it's functionality in the works, but should be worthwhile to try because of the jax.jacobian support. 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants