Skip to content

Commit

Permalink
define jax jvp new_tapes before args tuple definition
Browse files Browse the repository at this point in the history
  • Loading branch information
timmysilv committed Apr 17, 2023
1 parent d651950 commit eb176e3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pennylane/interfaces/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,9 @@ def execute_wrapper_jvp(primals, tangents):
"""Primals[0] are parameters as Jax tracers and tangents[0] is a list of tangent vectors as Jax tracers."""
if isinstance(gradient_fn, qml.gradients.gradient_transform):
at_max_diff = _n == max_diff
new_tapes = set_parameters_on_copy_and_unwrap(tapes, primals[0], unwrap=at_max_diff)
_args = (
set_parameters_on_copy_and_unwrap(tapes, primals[0], unwrap=at_max_diff),
new_tapes,
tangents[0],
gradient_fn,
device.shot_vector,
Expand Down

0 comments on commit eb176e3

Please sign in to comment.