Skip to content

Commit

Permalink
fix wrapper order
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Feb 8, 2024
1 parent ef3a3f7 commit a64c092
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@

def _value_and_grad(f, x, forward_mode_differentiation=False):
if forward_mode_differentiation:
def _wrapper(h, x):
out, aux = h(x)
def _wrapper(x):
out, aux = f(x)
return out, (out, aux)
grads, (out, aux) = _wrapper(jacfwd(f, has_aux=True), x)
grads, (out, aux) = jacfwd(_wrapper, has_aux=True)(x)
return (out, aux), grads
else:
return value_and_grad(f, has_aux=True)(x)
Expand Down

0 comments on commit a64c092

Please sign in to comment.