Skip to content

Commit

Permalink
fix pow_p jvp rule at x=0. y=0
Browse files Browse the repository at this point in the history
fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see #14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

Co-authored-by: Roy Frostig <frostig@google.com>
  • Loading branch information
mattjj and froystig committed Jul 10, 2023
1 parent f4eed78 commit 1ed8c06
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 41 deletions.
37 changes: 24 additions & 13 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@
from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype,
dtype_to_string,
standard_abstract_eval,
standard_multi_result_abstract_eval,
standard_named_shape_rule,
standard_primitive,
)
_input_dtype, dtype_to_string, standard_abstract_eval,
standard_multi_result_abstract_eval, standard_named_shape_rule,
standard_primitive,)
from jax._src.lib import pytree
from jax._src import xla_bridge
from jax._src.lib import xla_client
Expand Down Expand Up @@ -1532,7 +1528,7 @@ def unop(result_dtype, accepted_dtypes, name):


def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
allow_opaque_dtype=False, **kwargs):
require_same=True, allow_opaque_dtype=False, **kwargs):
del kwargs
assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes)
for i, aval in enumerate(avals):
Expand All @@ -1555,7 +1551,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
typename = dtype_to_string(aval.dtype)
typenames = ', '.join(t.__name__ for t in types)
raise TypeError(msg.format(name, typename, i, i, typenames))
check_same_dtypes(name, *avals)
if require_same: check_same_dtypes(name, *avals)
return result_dtype(*avals)


Expand Down Expand Up @@ -1598,8 +1594,10 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):
"taken a gradient with respect to an integer argument.")
return all(aval.weak_type for aval in avals)

def naryop(result_dtype, accepted_dtypes, name, allow_opaque_dtype=False):
def naryop(result_dtype, accepted_dtypes, name, allow_opaque_dtype=False,
require_same_dtypes=False):
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
require_same=require_same_dtypes,
allow_opaque_dtype=allow_opaque_dtype)
shape_rule = partial(broadcasting_shape_rule, name)
weak_type_rule = partial(_naryop_weak_type_rule, name)
Expand Down Expand Up @@ -1955,16 +1953,29 @@ def _abs_jvp_rule(g, ans, x):
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))

pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
pow_p = standard_naryop([_float | _complex, _int | _float | _complex], 'pow')

def _pow_jvp_lhs(g, ans, x, y):
return mul(g, mul(y, pow(x, sub(y, _ones(y)))))
y_dtype = dtypes.dtype(y)
x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y)
if dtypes.issubdtype(y_dtype, np.integer):
jac = select(eq(y, _const(y, 0)), _ones(y),
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
else:
jac = mul(y, pow(x, sub(y, _ones(y))))
return mul(g, jac)

def _pow_jvp_rhs(g, ans, x, y):
# TODO cast result to y dtype?
return mul(g, mul(log(_replace_zero(x)), ans))

ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp))

def _pow_lower(ctx, x, y):
(x,), (y,) = mlir.lower_fun(jax._src.numpy.util.promote_dtypes_numeric,
multiple_results=True)(ctx, x, y)
return _nary_lower_hlo(hlo.PowOp, ctx, x, y)
mlir.register_lowering(pow_p, _pow_lower)


def _integer_pow_dtype_rule(x, *, y):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from functools import partial
import operator
from typing import Callable

from jax._src import core
from jax._src import dispatch
Expand All @@ -31,7 +30,8 @@

xops = xla_client.ops

_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_opaque_dtype=True)
def _input_dtype(x, *_, **__):
return dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)

def _argnum_weak_type(*argnums):
return lambda *args, **_: all(args[i].weak_type for i in argnums)
Expand Down
63 changes: 38 additions & 25 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, _wraps)
promote_shapes, _where, _wraps, check_no_float0s)

_lax_const = lax._const

Expand Down Expand Up @@ -305,16 +305,47 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
return lax.round(div), mod


@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
check_no_float0s("power", x1, x2)
x1, x2 = promote_shapes("power", x1, x2) # not dtypes, see next comment

# We apply special cases, both for algorithmic and autodiff reasons:
# 1. for concrete integer scalar powers (and arbitrary bases), we use
# unrolled binary exponentiation specialized on the exponent, which is
# more precise for e.g. x ** 2 when x is a float;
# 2. for integer bases and integer powers, use unrolled binary exponentiation
# where the number of steps is determined by a max bit width of 64;
# 3. for other cases, call lax.pow.

# Case 1: concrete integer scalar powers:
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:
pass
else:
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)

# Handle cases #2 and #3 under a jit:
return _power(x1, x2)

@partial(jit, inline=True)
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
x1, x2 = promote_args_numeric("power", x1, x2)
dtype = dtypes.dtype(x1)
if not dtypes.issubdtype(dtype, np.integer):
return lax.pow(x1, x2)
# Case 2: integer base and integer power:
d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2)
if dtypes.issubdtype(d1, np.integer) and dtypes.issubdtype(d2, np.integer):
assert np.iinfo(d1).bits <= 64 # _pow_int_int assumes 64bit max
return _pow_int_int(*promote_args_numeric("power", x1, x2))

# Integer power => use binary exponentiation.
# Case 3: call lax.pow (promotion handled in lowering rule)
return lax.pow(x1, x2)

# TODO(phawkins): add integer pow support to XLA.
# TODO(phawkins): add integer pow support to XLA.
def _pow_int_int(x1, x2):
# Integer power => use binary exponentiation.
bits = 6 # Anything more would overflow for any x1 > 1
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
Expand All @@ -327,24 +358,6 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
return acc


@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:
pass
else:
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)
return _power(x1, x2)


@custom_jvp
@_wraps(np.logaddexp, module='numpy')
@jit
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def check_arraylike_or_none(fun_name: str, *args: Any):
raise TypeError(msg.format(fun_name, type(arg), pos))


def _check_no_float0s(fun_name: str, *args: Any):
def check_no_float0s(fun_name: str, *args: Any):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
raise TypeError(
Expand All @@ -347,6 +347,7 @@ def _check_no_float0s(fun_name: str, *args: Any):
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
_check_no_float0s = check_no_float0s


def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]:
Expand Down
5 changes: 5 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,11 @@ def testPowSecondDerivative(self):
# self.assertEqual(result, 0.0)
self.assertAllClose(result, np.nan)

def testPowIntPowerAtZero(self):
# https://github.com/google/jax/issues/14397
ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0)
self.assertAllClose(ans, 1., check_dtypes=False)

@jtu.sample_product(
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
for arg_shape in [(), (3,), (2, 3)]
Expand Down

0 comments on commit 1ed8c06

Please sign in to comment.