From 69ad4df9a53ee6bc8d7e7f9cf589562cbffc2bd6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 14 Jun 2023 18:30:52 -0700 Subject: [PATCH 1/2] fix pow_p jvp rule at x=0. y=0 fixes #14397 For autodiff purposes (and eventually for evaluation implementation purposes) we need to distinguish between pow :: inexact -> int -> inexact (which is differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290. Instead of making a new primitive, we made the old one polymorphic and switch its behavior on the element type of its second argument. There were also some other cases with special handling for algorithmic reasons (e.g. doing binary exponentiation), so these autodiff cases had to be merged with those algorithmic cases. Co-authored-by: Roy Frostig --- jax/_src/lax/lax.py | 61 ++++++++++++++++++++++--------- jax/_src/lax/utils.py | 4 +-- jax/_src/numpy/ufuncs.py | 74 +++++++++++++++++++++++++------------- jax/_src/numpy/util.py | 3 +- tests/dynamic_api_test.py | 4 +-- tests/lax_autodiff_test.py | 5 +++ 6 files changed, 105 insertions(+), 46 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 5e9da108cae5..8984539c3fbd 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -55,13 +55,9 @@ from jax._src.interpreters.batching import RaggedAxis from jax._src.lax import slicing from jax._src.lax.utils import ( - _input_dtype, - dtype_to_string, - standard_abstract_eval, - standard_multi_result_abstract_eval, - standard_named_shape_rule, - standard_primitive, -) + _input_dtype, dtype_to_string, standard_abstract_eval, + standard_multi_result_abstract_eval, standard_named_shape_rule, + standard_primitive) from jax._src import xla_bridge from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -1508,7 +1504,8 @@ def zeros_like_array(x: ArrayLike) -> Array: ### primitives -_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype) +_fixed_dtype = \ + lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype) _complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype _strip_weak_type = lambda *args, **_: False @@ -1536,7 +1533,7 @@ def unop(result_dtype, accepted_dtypes, name): def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, - allow_extended_dtype=False, **kwargs): + require_same=True, allow_extended_dtype=False, **kwargs): del kwargs assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes) for i, aval in enumerate(avals): @@ -1559,7 +1556,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, typename = dtype_to_string(aval.dtype) typenames = ', '.join(t.__name__ for t in types) raise TypeError(msg.format(name, typename, i, i, typenames)) - check_same_dtypes(name, *avals) + if require_same: check_same_dtypes(name, *avals) return result_dtype(*avals) @@ -1602,9 +1599,11 @@ def _naryop_weak_type_rule(name, *avals, **kwargs): "taken a gradient with respect to an integer argument.") return all(aval.weak_type for aval in avals) -def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False): +def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, + require_same_dtypes=False): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, - allow_extended_dtype=allow_extended_dtype) + allow_extended_dtype=allow_extended_dtype, + require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) weak_type_rule = partial(_naryop_weak_type_rule, name) prim = standard_primitive(shape_rule, dtype_rule, name, @@ -1959,16 +1958,44 @@ def _abs_jvp_rule(g, ans, x): lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp)) -pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow') +def _pow_dtype_rule(x, y): + if (dtypes.issubdtype(x.dtype, np.inexact) and + dtypes.issubdtype(y.dtype, np.integer)): + return x.dtype + if x.dtype == y.dtype: + return x.dtype + raise TypeError("the first argument to pow must have an inexact dtype (float " + "or complex), and the second argument must have an inexact or" + " integer dtype, and two inexact dtypes must match, but got " + f"{x.dtype} and {y.dtype} respectively.") +pow_p = naryop(_pow_dtype_rule, [_float | _complex, _int | _float | _complex], + 'pow', require_same_dtypes=False) def _pow_jvp_lhs(g, ans, x, y): - return mul(g, mul(y, pow(x, sub(y, _ones(y))))) + y_dtype = dtypes.dtype(y) + x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this + if dtypes.issubdtype(y_dtype, np.integer): + jac = select(eq(y, _const(y, 0)), _ones(y), + mul(_replace_zero(y), pow(x, sub(y, _ones(y))))) + else: + jac = mul(y, pow(x, sub(y, _ones(y)))) + return mul(g, jac) def _pow_jvp_rhs(g, ans, x, y): - return mul(g, mul(log(_replace_zero(x)), ans)) - + y_dtype = dtypes.dtype(y) + assert dtypes.issubdtype(y_dtype, np.inexact) + return convert_element_type(mul(g, mul(log(_replace_zero(x)), ans)), y_dtype) ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs) -mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp)) + +def _pow_lower(ctx, x, y): + x_aval, y_aval = ctx.avals_in + out_aval, = ctx.avals_out + dt = mlir.dtype_to_ir_type(out_aval.dtype) + x = hlo.ConvertOp(ir.RankedTensorType.get(x_aval.shape, dt), x).result + y = hlo.ConvertOp(ir.RankedTensorType.get(y_aval.shape, dt), y).result + ctx_ = ctx.replace(avals_in=[x_aval, y_aval.update(dtype=out_aval.dtype)]) + return _nary_lower_hlo(hlo.PowOp, ctx_, x, y) +mlir.register_lowering(pow_p, _pow_lower) def _integer_pow_dtype_rule(x, *, y): diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index d92379eaac44..9f7663e3d266 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -18,7 +18,6 @@ from functools import partial import operator -from typing import Callable from jax._src import core from jax._src import dispatch @@ -31,7 +30,8 @@ xops = xla_client.ops -_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_extended_dtype=True) +def _input_dtype(x, *_, **__): + return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 65d09ebc7e26..1f37568e2a06 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -32,7 +32,7 @@ from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, - promote_shapes, _where, _wraps) + promote_shapes, _where, _wraps, check_no_float0s) _lax_const = lax._const @@ -305,16 +305,60 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod +@_wraps(np.power, module='numpy') +def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: + check_arraylike("power", x1, x2) + check_no_float0s("power", x1, x2) + + # We apply special cases, both for algorithmic and autodiff reasons: + # 1. for *concrete* integer scalar powers (and arbitrary bases), we use + # unrolled binary exponentiation specialized on the exponent, which is + # more precise for e.g. x ** 2 when x is a float (algorithmic reason!); + # 2. for integer bases and integer powers, use unrolled binary exponentiation + # where the number of steps is determined by a max bit width of 64 + # (algorithmic reason!); + # 3. for integer powers and float/complex bases, we apply the lax primitive + # without any promotion of input types because in this case we want the + # function to be differentiable wrt its first argument at 0; + # 3. for other cases, perform jnp dtype promotion on the arguments then apply + # lax.pow. + + # Case 1: concrete integer scalar powers: + if isinstance(core.get_aval(x2), core.ConcreteArray): + try: + x2 = operator.index(x2) # type: ignore[arg-type] + except TypeError: + pass + else: + x1, = promote_dtypes_numeric(x1) + return lax.integer_pow(x1, x2) + + # Handle cases #2 and #3 under a jit: + return _power(x1, x2) + @partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: - x1, x2 = promote_args_numeric("power", x1, x2) - dtype = dtypes.dtype(x1) - if not dtypes.issubdtype(dtype, np.integer): + x1, x2 = promote_shapes("power", x1, x2) # not dtypes + + # Case 2: bool/integer result + x1_, x2_ = promote_args_numeric("power", x1, x2) + if (dtypes.issubdtype(dtypes.dtype(x1_), np.integer) or + dtypes.issubdtype(dtypes.dtype(x1_), np.bool_)): + assert np.iinfo(dtypes.dtype(x1_)).bits <= 64 # _pow_int_int assumes <=64bit + return _pow_int_int(x1_, x2_) + + # Case 3: float/complex base with integer power (special autodiff behavior) + d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2) + if dtypes.issubdtype(d1, np.inexact) and dtypes.issubdtype(d2, np.integer): return lax.pow(x1, x2) - # Integer power => use binary exponentiation. - # TODO(phawkins): add integer pow support to XLA. + # Case 4: do promotion first + return lax.pow(x1_, x2_) + +# TODO(phawkins): add integer pow support to XLA. +def _pow_int_int(x1, x2): + # Integer power => use binary exponentiation. bits = 6 # Anything more would overflow for any x1 > 1 zero = _constant_like(x2, 0) one = _constant_like(x2, 1) @@ -327,24 +371,6 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array: return acc -@_wraps(np.power, module='numpy') -def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: - check_arraylike("power", x1, x2) - # Special case for concrete integer scalars: use binary exponentiation. - # Using lax.pow may be imprecise for floating-point values; the goal of this - # code path is to make sure we end up with a precise output for the common - # pattern ``x ** 2`` or similar. - if isinstance(core.get_aval(x2), core.ConcreteArray): - try: - x2 = operator.index(x2) # type: ignore[arg-type] - except TypeError: - pass - else: - x1, = promote_dtypes_numeric(x1) - return lax.integer_pow(x1, x2) - return _power(x1, x2) - - @custom_jvp @_wraps(np.logaddexp, module='numpy') @jit diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 10625b037c25..9946b942cb12 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -337,7 +337,7 @@ def check_arraylike_or_none(fun_name: str, *args: Any): raise TypeError(msg.format(fun_name, type(arg), pos)) -def _check_no_float0s(fun_name: str, *args: Any): +def check_no_float0s(fun_name: str, *args: Any): """Check if none of the args have dtype float0.""" if any(dtypes.dtype(arg) == dtypes.float0 for arg in args): raise TypeError( @@ -348,6 +348,7 @@ def _check_no_float0s(fun_name: str, *args: Any): "to cast a float0 array to a regular zeros array. \n" "If you didn't expect to get a float0 you might have accidentally " "taken a gradient with respect to an integer argument.") +_check_no_float0s = check_no_float0s def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 4e99d612a290..bd1495c1c92a 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1519,10 +1519,10 @@ def test_jumble_map_eltwise(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.jumble_axis)(ins) - p = jumble_map(jax.jit(lambda x: x ** 2))(p) + p = jumble_map(jax.jit(lambda x: x * 3))(p) self.assertIsInstance(p, batching.Jumble) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2 + data = jax.lax.broadcasted_iota('int32', (3, 5), 1) * 3 self.assertAllClose(p.data, data, check_dtypes=False) def test_jumble_map_vector_dot(self): diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 3ece195b0535..bdbef80830ae 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -533,6 +533,11 @@ def testPowSecondDerivative(self): # self.assertEqual(result, 0.0) self.assertAllClose(result, np.nan) + def testPowIntPowerAtZero(self): + # https://github.com/google/jax/issues/14397 + ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0) + self.assertAllClose(ans, 1., check_dtypes=False) + @jtu.sample_product( [dict(arg_shape=arg_shape, pred_shape=pred_shape) for arg_shape in [(), (3,), (2, 3)] From 1f8fb2c8bd62a47b10682a1ff94b0927957f6efa Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 22 Aug 2023 16:48:11 -0700 Subject: [PATCH 2/2] change lowering rule to satisfy jax2tf --- jax/_src/lax/lax.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8984539c3fbd..bdc51425d815 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1990,14 +1990,18 @@ def _pow_jvp_rhs(g, ans, x, y): def _pow_lower(ctx, x, y): x_aval, y_aval = ctx.avals_in out_aval, = ctx.avals_out - dt = mlir.dtype_to_ir_type(out_aval.dtype) - x = hlo.ConvertOp(ir.RankedTensorType.get(x_aval.shape, dt), x).result - y = hlo.ConvertOp(ir.RankedTensorType.get(y_aval.shape, dt), y).result - ctx_ = ctx.replace(avals_in=[x_aval, y_aval.update(dtype=out_aval.dtype)]) - return _nary_lower_hlo(hlo.PowOp, ctx_, x, y) + convert = mlir.lower_fun( + partial(convert_element_type, new_dtype=out_aval.dtype), False) + x_aval_ = x_aval.update(dtype=out_aval.dtype) + y_aval_ = y_aval.update(dtype=out_aval.dtype) + [(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x) + [(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y) + ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_]) + return _nary_lower_hlo(hlo.PowOp, ctx_, x_, y_) mlir.register_lowering(pow_p, _pow_lower) + def _integer_pow_dtype_rule(x, *, y): dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x) if y < 0 and dtypes.issubdtype(dtype, np.integer):