Skip to content

Commit

Permalink
Merge pull request #17295 from jakevdp:lax-pow-jvp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560133324
  • Loading branch information
jax authors committed Aug 25, 2023
2 parents 3ea0a74 + 6cec5d4 commit c71eedf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,10 @@ def _pow_jvp_lhs(g, ans, x, y):
y_dtype = dtypes.dtype(y)
x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this
if dtypes.issubdtype(y_dtype, np.integer):
if x.shape != y.shape:
shape = broadcast_shapes(x.shape, y.shape)
x = _maybe_broadcast(shape, x)
y = _maybe_broadcast(shape, y)
jac = select(eq(y, _const(y, 0)), _ones(y),
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
else:
Expand Down
8 changes: 8 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,14 @@ def f(x):
with self.assertRaises(NotImplementedError):
jax.jacrev(f)(x)

def testPowShapeMismatch(self):
# Regression test for https://github.com/google/jax/issues/17294
x = lax.iota('float32', 4)
y = 2
actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error
expected = jax.numpy.diag(y * x ** (y - 1))
self.assertArraysEqual(actual, expected)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit c71eedf

Please sign in to comment.