Skip to content

Commit d442b45

Browse files
authored
[torchax] Remove safe_zip (#9525)
1 parent 9ccd4e0 commit d442b45

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torchax/torchax/interop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ def j2t_autograd(fn, call_jax=call_jax):
225225

226226
@wraps(fn)
227227
def inner(*args, **kwargs):
228-
from jax.tree_util import tree_flatten, tree_unflatten
229-
from jax.util import safe_zip
228+
from jax.tree_util import tree_flatten
230229

231230
class JaxFun(torch.autograd.Function):
232231

@@ -261,8 +260,8 @@ def backward(ctx, *grad_out):
261260
# The subsequent gradients correspond to flat_inputs.
262261
# We need to put a None for inputs that did not require gradients.
263262
final_grads = [None]
264-
for needs_grad, grad in safe_zip(ctx.needs_input_grad[1:],
265-
input_grads_structured):
263+
for needs_grad, grad in zip(
264+
ctx.needs_input_grad[1:], input_grads_structured, strict=True):
266265
final_grads.append(grad if needs_grad else None)
267266

268267
return tuple(final_grads)

0 commit comments

Comments
 (0)