From a64c092f2e630b99fc4833927e05292775519251 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Thu, 8 Feb 2024 15:41:55 +0100 Subject: [PATCH] fix wrapper order --- numpyro/optim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/optim.py b/numpyro/optim.py index 37d16dc4a..5ad451bab 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -36,10 +36,10 @@ def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: - def _wrapper(h, x): - out, aux = h(x) + def _wrapper(x): + out, aux = f(x) return out, (out, aux) - grads, (out, aux) = _wrapper(jacfwd(f, has_aux=True), x) + grads, (out, aux) = jacfwd(_wrapper, has_aux=True)(x) return (out, aux), grads else: return value_and_grad(f, has_aux=True)(x)