diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index cda0ff2bd202..78d52cef4789 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -57,13 +57,9 @@ from jax._src.interpreters.batching import RaggedAxis from jax._src.lax import slicing from jax._src.lax.utils import ( - _input_dtype, - dtype_to_string, - standard_abstract_eval, - standard_multi_result_abstract_eval, - standard_named_shape_rule, - standard_primitive, -) + _input_dtype, dtype_to_string, standard_abstract_eval, + standard_multi_result_abstract_eval, standard_named_shape_rule, + standard_primitive) from jax._src import xla_bridge from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -1515,7 +1511,8 @@ def zeros_like_array(x: ArrayLike) -> Array: ### primitives -_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype) +_fixed_dtype = \ + lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype) _complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype _strip_weak_type = lambda *args, **_: False @@ -1543,7 +1540,7 @@ def unop(result_dtype, accepted_dtypes, name): def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, - allow_extended_dtype=False, **kwargs): + require_same=True, allow_extended_dtype=False, **kwargs): del kwargs assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes) for i, aval in enumerate(avals): @@ -1566,7 +1563,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, typename = dtype_to_string(aval.dtype) typenames = ', '.join(t.__name__ for t in types) raise TypeError(msg.format(name, typename, i, i, typenames)) - check_same_dtypes(name, *avals) + if require_same: check_same_dtypes(name, *avals) return result_dtype(*avals) @@ -1609,9 +1606,11 @@ def _naryop_weak_type_rule(name, *avals, **kwargs): "taken a gradient with respect to an integer argument.") return all(aval.weak_type for aval in avals) -def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False): +def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, + require_same_dtypes=False): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, - allow_extended_dtype=allow_extended_dtype) + allow_extended_dtype=allow_extended_dtype, + require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) weak_type_rule = partial(_naryop_weak_type_rule, name) prim = standard_primitive(shape_rule, dtype_rule, name, @@ -1973,16 +1972,48 @@ def _abs_jvp_rule(g, ans, x): lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp)) -pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow') +def _pow_dtype_rule(x, y): + if (dtypes.issubdtype(x.dtype, np.inexact) and + dtypes.issubdtype(y.dtype, np.integer)): + return x.dtype + if x.dtype == y.dtype: + return x.dtype + raise TypeError("the first argument to pow must have an inexact dtype (float " + "or complex), and the second argument must have an inexact or" + " integer dtype, and two inexact dtypes must match, but got " + f"{x.dtype} and {y.dtype} respectively.") +pow_p = naryop(_pow_dtype_rule, [_float | _complex, _int | _float | _complex], + 'pow', require_same_dtypes=False) def _pow_jvp_lhs(g, ans, x, y): - return mul(g, mul(y, pow(x, sub(y, _ones(y))))) + y_dtype = dtypes.dtype(y) + x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this + if dtypes.issubdtype(y_dtype, np.integer): + jac = select(eq(y, _const(y, 0)), _ones(y), + mul(_replace_zero(y), pow(x, sub(y, _ones(y))))) + else: + jac = mul(y, pow(x, sub(y, _ones(y)))) + return mul(g, jac) def _pow_jvp_rhs(g, ans, x, y): - return mul(g, mul(log(_replace_zero(x)), ans)) - + y_dtype = dtypes.dtype(y) + assert dtypes.issubdtype(y_dtype, np.inexact) + return convert_element_type(mul(g, mul(log(_replace_zero(x)), ans)), y_dtype) ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs) -mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp)) + +def _pow_lower(ctx, x, y): + x_aval, y_aval = ctx.avals_in + out_aval, = ctx.avals_out + convert = mlir.lower_fun( + partial(convert_element_type, new_dtype=out_aval.dtype), False) + x_aval_ = x_aval.update(dtype=out_aval.dtype) + y_aval_ = y_aval.update(dtype=out_aval.dtype) + [(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x) + [(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y) + ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_]) + return _nary_lower_hlo(hlo.PowOp, ctx_, x_, y_) +mlir.register_lowering(pow_p, _pow_lower) + def _integer_pow_dtype_rule(x, *, y): diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 4a7c9ea81f1b..94c1e7d6bed5 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -18,7 +18,6 @@ from functools import partial import operator -from typing import Callable from jax._src import core from jax._src import dispatch @@ -30,7 +29,8 @@ xops = xla_client.ops -_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_extended_dtype=True) +def _input_dtype(x, *_, **__): + return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 699bae1dfe8c..b0dcd8a0db48 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -32,7 +32,7 @@ from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, - promote_shapes, _where, _wraps) + promote_shapes, _where, _wraps, check_no_float0s) _lax_const = lax._const @@ -305,16 +305,60 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod +@_wraps(np.power, module='numpy') +def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: + check_arraylike("power", x1, x2) + check_no_float0s("power", x1, x2) + + # We apply special cases, both for algorithmic and autodiff reasons: + # 1. for *concrete* integer scalar powers (and arbitrary bases), we use + # unrolled binary exponentiation specialized on the exponent, which is + # more precise for e.g. x ** 2 when x is a float (algorithmic reason!); + # 2. for integer bases and integer powers, use unrolled binary exponentiation + # where the number of steps is determined by a max bit width of 64 + # (algorithmic reason!); + # 3. for integer powers and float/complex bases, we apply the lax primitive + # without any promotion of input types because in this case we want the + # function to be differentiable wrt its first argument at 0; + # 3. for other cases, perform jnp dtype promotion on the arguments then apply + # lax.pow. + + # Case 1: concrete integer scalar powers: + if isinstance(core.get_aval(x2), core.ConcreteArray): + try: + x2 = operator.index(x2) # type: ignore[arg-type] + except TypeError: + pass + else: + x1, = promote_dtypes_numeric(x1) + return lax.integer_pow(x1, x2) + + # Handle cases #2 and #3 under a jit: + return _power(x1, x2) + @partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: - x1, x2 = promote_args_numeric("power", x1, x2) - dtype = dtypes.dtype(x1) - if not dtypes.issubdtype(dtype, np.integer): + x1, x2 = promote_shapes("power", x1, x2) # not dtypes + + # Case 2: bool/integer result + x1_, x2_ = promote_args_numeric("power", x1, x2) + if (dtypes.issubdtype(dtypes.dtype(x1_), np.integer) or + dtypes.issubdtype(dtypes.dtype(x1_), np.bool_)): + assert np.iinfo(dtypes.dtype(x1_)).bits <= 64 # _pow_int_int assumes <=64bit + return _pow_int_int(x1_, x2_) + + # Case 3: float/complex base with integer power (special autodiff behavior) + d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2) + if dtypes.issubdtype(d1, np.inexact) and dtypes.issubdtype(d2, np.integer): return lax.pow(x1, x2) - # Integer power => use binary exponentiation. - # TODO(phawkins): add integer pow support to XLA. + # Case 4: do promotion first + return lax.pow(x1_, x2_) + +# TODO(phawkins): add integer pow support to XLA. +def _pow_int_int(x1, x2): + # Integer power => use binary exponentiation. bits = 6 # Anything more would overflow for any x1 > 1 zero = _constant_like(x2, 0) one = _constant_like(x2, 1) @@ -327,24 +371,6 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array: return acc -@_wraps(np.power, module='numpy') -def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: - check_arraylike("power", x1, x2) - # Special case for concrete integer scalars: use binary exponentiation. - # Using lax.pow may be imprecise for floating-point values; the goal of this - # code path is to make sure we end up with a precise output for the common - # pattern ``x ** 2`` or similar. - if isinstance(core.get_aval(x2), core.ConcreteArray): - try: - x2 = operator.index(x2) # type: ignore[arg-type] - except TypeError: - pass - else: - x1, = promote_dtypes_numeric(x1) - return lax.integer_pow(x1, x2) - return _power(x1, x2) - - @custom_jvp @_wraps(np.logaddexp, module='numpy') @jit diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 10625b037c25..9946b942cb12 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -337,7 +337,7 @@ def check_arraylike_or_none(fun_name: str, *args: Any): raise TypeError(msg.format(fun_name, type(arg), pos)) -def _check_no_float0s(fun_name: str, *args: Any): +def check_no_float0s(fun_name: str, *args: Any): """Check if none of the args have dtype float0.""" if any(dtypes.dtype(arg) == dtypes.float0 for arg in args): raise TypeError( @@ -348,6 +348,7 @@ def _check_no_float0s(fun_name: str, *args: Any): "to cast a float0 array to a regular zeros array. \n" "If you didn't expect to get a float0 you might have accidentally " "taken a gradient with respect to an integer argument.") +_check_no_float0s = check_no_float0s def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 4e99d612a290..bd1495c1c92a 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1519,10 +1519,10 @@ def test_jumble_map_eltwise(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.jumble_axis)(ins) - p = jumble_map(jax.jit(lambda x: x ** 2))(p) + p = jumble_map(jax.jit(lambda x: x * 3))(p) self.assertIsInstance(p, batching.Jumble) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2 + data = jax.lax.broadcasted_iota('int32', (3, 5), 1) * 3 self.assertAllClose(p.data, data, check_dtypes=False) def test_jumble_map_vector_dot(self): diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index feaae69bc6c3..faa22afb1f3a 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -538,6 +538,11 @@ def testPowSecondDerivative(self): # self.assertEqual(result, 0.0) self.assertAllClose(result, np.nan) + def testPowIntPowerAtZero(self): + # https://github.com/google/jax/issues/14397 + ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0) + self.assertAllClose(ans, 1., check_dtypes=False) + @jtu.sample_product( [dict(arg_shape=arg_shape, pred_shape=pred_shape) for arg_shape in [(), (3,), (2, 3)]