Skip to content

Commit

Permalink
Fix JVP rule for lax.pow()
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 22, 2022
1 parent da4e79a commit f98c0b4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
3 changes: 1 addition & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,8 +2001,7 @@ def _abs_jvp_rule(g, ans, x):
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')

def _pow_jvp_lhs(g, ans, x, y):
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
return mul(g, jac)
return mul(g, mul(y, pow(x, sub(y, _ones(y)))))

def _pow_jvp_rhs(g, ans, x, y):
return mul(g, mul(log(_replace_zero(x)), ans))
Expand Down
11 changes: 11 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,17 @@ def testReverseGrad(self):
check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
rtol={np.float32: 3e-3})

def testPowSecondDerivative(self):
# https://github.com/google/jax/issues/12033
x, y = 4.0, 0.0
expected = ((0.0, 1/x), (1/x, np.log(x) ** 2))

result_fwd = jax.jacfwd(jax.jacfwd(lax.pow, (0, 1)), (0, 1))(x, y)
self.assertAllClose(result_fwd, expected)

result_rev = jax.jacrev(jax.jacrev(lax.pow, (0, 1)), (0, 1))(x, y)
self.assertAllClose(result_rev, expected)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_predshape={}_argshapes={}".format(
jtu.format_shape_dtype_string(pred_shape, np.bool_),
Expand Down

0 comments on commit f98c0b4

Please sign in to comment.