Skip to content

Commit

Permalink
fix pow_p jvp rule at x=0. y=0
Browse files Browse the repository at this point in the history
fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see #14397 (comment).

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
  • Loading branch information
mattjj and froystig committed Jul 28, 2023
1 parent 9a21ff0 commit 77f1610
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 44 deletions.
55 changes: 40 additions & 15 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@
from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype,
dtype_to_string,
standard_abstract_eval,
standard_multi_result_abstract_eval,
standard_named_shape_rule,
standard_primitive,
)
_input_dtype, dtype_to_string, standard_abstract_eval,
standard_multi_result_abstract_eval, standard_named_shape_rule,
standard_primitive)
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -1508,7 +1504,8 @@ def zeros_like_array(x: ArrayLike) -> Array:
### primitives


_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
_fixed_dtype = \
lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype

_strip_weak_type = lambda *args, **_: False
Expand Down Expand Up @@ -1536,7 +1533,7 @@ def unop(result_dtype, accepted_dtypes, name):


def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
allow_extended_dtype=False, **kwargs):
require_same=True, allow_extended_dtype=False, **kwargs):
del kwargs
assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes)
for i, aval in enumerate(avals):
Expand All @@ -1559,7 +1556,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
typename = dtype_to_string(aval.dtype)
typenames = ', '.join(t.__name__ for t in types)
raise TypeError(msg.format(name, typename, i, i, typenames))
check_same_dtypes(name, *avals)
if require_same: check_same_dtypes(name, *avals)
return result_dtype(*avals)


Expand Down Expand Up @@ -1602,9 +1599,11 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):
"taken a gradient with respect to an integer argument.")
return all(aval.weak_type for aval in avals)

def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False):
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
require_same_dtypes=False):
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
allow_extended_dtype=allow_extended_dtype)
allow_extended_dtype=allow_extended_dtype,
require_same=require_same_dtypes)
shape_rule = partial(broadcasting_shape_rule, name)
weak_type_rule = partial(_naryop_weak_type_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name,
Expand Down Expand Up @@ -1959,16 +1958,42 @@ def _abs_jvp_rule(g, ans, x):
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))

pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
def _pow_dtype_rule(x, y):
if (dtypes.issubdtype(x.dtype, np.inexact) and
(dtypes.issubdtype(y.dtype, np.integer) or
dtypes.issubdtype(y.dtype, np.inexact))):
return x.dtype
raise TypeError("the first argument to pow must have an inexact dtype (float "
"or complex), and the second argument must have an inexact or"
f" integer dtype, but got {x.dtype} and {y.dtype}.")
pow_p = naryop(_pow_dtype_rule, [_float | _complex, _int | _float | _complex],
'pow', require_same_dtypes=False)

def _pow_jvp_lhs(g, ans, x, y):
return mul(g, mul(y, pow(x, sub(y, _ones(y)))))
y_dtype = dtypes.dtype(y)
x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this
if dtypes.issubdtype(y_dtype, np.integer):
jac = select(eq(y, _const(y, 0)), _ones(y),
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
else:
jac = mul(y, pow(x, sub(y, _ones(y))))
return mul(g, jac)

def _pow_jvp_rhs(g, ans, x, y):
# TODO cast result to y dtype?
return mul(g, mul(log(_replace_zero(x)), ans))

ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp))

def _pow_lower(ctx, x, y):
x_aval, y_aval = ctx.avals_in
out_aval, = ctx.avals_out
dt = mlir.dtype_to_ir_type(out_aval.dtype)
x = hlo.ConvertOp(ir.RankedTensorType.get(x_aval.shape, dt), x).result
y = hlo.ConvertOp(ir.RankedTensorType.get(y_aval.shape, dt), y).result
ctx_ = ctx.replace(avals_in=[x_aval, y_aval.update(dtype=out_aval.dtype)])
return _nary_lower_hlo(hlo.PowOp, ctx_, x, y)
mlir.register_lowering(pow_p, _pow_lower)


def _integer_pow_dtype_rule(x, *, y):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from functools import partial
import operator
from typing import Callable

from jax._src import core
from jax._src import dispatch
Expand All @@ -31,7 +30,8 @@

xops = xla_client.ops

_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_extended_dtype=True)
def _input_dtype(x, *_, **__):
return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)

def _argnum_weak_type(*argnums):
return lambda *args, **_: all(args[i].weak_type for i in argnums)
Expand Down
74 changes: 50 additions & 24 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, _wraps)
promote_shapes, _where, _wraps, check_no_float0s)

_lax_const = lax._const

Expand Down Expand Up @@ -305,16 +305,60 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
return lax.round(div), mod


@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
check_no_float0s("power", x1, x2)
x1, x2 = promote_shapes("power", x1, x2) # not dtypes, see next comment

# We apply special cases, both for algorithmic and autodiff reasons:
# 1. for *concrete* integer scalar powers (and arbitrary bases), we use
# unrolled binary exponentiation specialized on the exponent, which is
# more precise for e.g. x ** 2 when x is a float (algorithmic reason!);
# 2. for integer bases and integer powers, use unrolled binary exponentiation
# where the number of steps is determined by a max bit width of 64
# (algorithmic reason!);
# 3. for integer powers and float/complex bases, we apply the lax primitive
# without any promotion of input types because in this case we want the
# function to be differentiable wrt its first argument at 0;
# 3. for other cases, perform jnp dtype promotion on the arguments then apply
# lax.pow.

# Case 1: concrete integer scalar powers:
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:
pass
else:
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)

# Handle cases #2 and #3 under a jit:
return _power(x1, x2)

@partial(jit, inline=True)
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
x1, x2 = promote_args_numeric("power", x1, x2)
dtype = dtypes.dtype(x1)
if not dtypes.issubdtype(dtype, np.integer):
# Case 2: bool/integer base and bool/integer power:
d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2)
if ((dtypes.issubdtype(d1, np.integer) or dtypes.issubdtype(d1, np.bool_)) and
(dtypes.issubdtype(d2, np.integer) or dtypes.issubdtype(d2, np.bool_))):
x1, x2 = promote_args_numeric("power", x1, x2)
assert np.iinfo(dtypes.dtype(x1)).bits <= 64 # _pow_int_int assumes <=64bit
return _pow_int_int(x1, x2)

# Case 3: float/complex base with integer power (special autodiff behavior)
if dtypes.issubdtype(d1, np.inexact) and dtypes.issubdtype(d2, np.integer):
return lax.pow(x1, x2)

# Integer power => use binary exponentiation.

# TODO(phawkins): add integer pow support to XLA.
# Case 4: do promotion first
x1, x2 = promote_args_numeric("power", x1, x2)
return lax.pow(x1, x2)

# TODO(phawkins): add integer pow support to XLA.
def _pow_int_int(x1, x2):
# Integer power => use binary exponentiation.
bits = 6 # Anything more would overflow for any x1 > 1
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
Expand All @@ -327,24 +371,6 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
return acc


@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:
pass
else:
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)
return _power(x1, x2)


@custom_jvp
@_wraps(np.logaddexp, module='numpy')
@jit
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def check_arraylike_or_none(fun_name: str, *args: Any):
raise TypeError(msg.format(fun_name, type(arg), pos))


def _check_no_float0s(fun_name: str, *args: Any):
def check_no_float0s(fun_name: str, *args: Any):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
raise TypeError(
Expand All @@ -348,6 +348,7 @@ def _check_no_float0s(fun_name: str, *args: Any):
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
_check_no_float0s = check_no_float0s


def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]:
Expand Down
4 changes: 2 additions & 2 deletions tests/dynamic_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,10 +1519,10 @@ def test_jumble_map_eltwise(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.jumble_axis)(ins)
p = jumble_map(jax.jit(lambda x: x ** 2))(p)
p = jumble_map(jax.jit(lambda x: x * 3))(p)
self.assertIsInstance(p, batching.Jumble)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) * 3
self.assertAllClose(p.data, data, check_dtypes=False)

def test_jumble_map_vector_dot(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,11 @@ def testPowSecondDerivative(self):
# self.assertEqual(result, 0.0)
self.assertAllClose(result, np.nan)

def testPowIntPowerAtZero(self):
# https://github.com/google/jax/issues/14397
ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0)
self.assertAllClose(ans, 1., check_dtypes=False)

@jtu.sample_product(
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
for arg_shape in [(), (3,), (2, 3)]
Expand Down

0 comments on commit 77f1610

Please sign in to comment.