File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -225,8 +225,7 @@ def j2t_autograd(fn, call_jax=call_jax):
225225
226226 @wraps (fn )
227227 def inner (* args , ** kwargs ):
228- from jax .tree_util import tree_flatten , tree_unflatten
229- from jax .util import safe_zip
228+ from jax .tree_util import tree_flatten
230229
231230 class JaxFun (torch .autograd .Function ):
232231
@@ -261,8 +260,8 @@ def backward(ctx, *grad_out):
261260 # The subsequent gradients correspond to flat_inputs.
262261 # We need to put a None for inputs that did not require gradients.
263262 final_grads = [None ]
264- for needs_grad , grad in safe_zip ( ctx . needs_input_grad [ 1 :],
265- input_grads_structured ):
263+ for needs_grad , grad in zip (
264+ ctx . needs_input_grad [ 1 :], input_grads_structured , strict = True ):
266265 final_grads .append (grad if needs_grad else None )
267266
268267 return tuple (final_grads )
You can’t perform that action at this time.
0 commit comments