Skip to content

Commit

Permalink
Merge pull request #16419 from mattjj:pow-jvp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559266945
  • Loading branch information
jax authors committed Aug 23, 2023
2 parents abff9d2 + 1f8fb2c commit af42359
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 46 deletions.
65 changes: 48 additions & 17 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,9 @@
from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype,
dtype_to_string,
standard_abstract_eval,
standard_multi_result_abstract_eval,
standard_named_shape_rule,
standard_primitive,
)
_input_dtype, dtype_to_string, standard_abstract_eval,
standard_multi_result_abstract_eval, standard_named_shape_rule,
standard_primitive)
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -1515,7 +1511,8 @@ def zeros_like_array(x: ArrayLike) -> Array:
### primitives


_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
_fixed_dtype = \
lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype

_strip_weak_type = lambda *args, **_: False
Expand Down Expand Up @@ -1543,7 +1540,7 @@ def unop(result_dtype, accepted_dtypes, name):


def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
allow_extended_dtype=False, **kwargs):
require_same=True, allow_extended_dtype=False, **kwargs):
del kwargs
assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes)
for i, aval in enumerate(avals):
Expand All @@ -1566,7 +1563,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
typename = dtype_to_string(aval.dtype)
typenames = ', '.join(t.__name__ for t in types)
raise TypeError(msg.format(name, typename, i, i, typenames))
check_same_dtypes(name, *avals)
if require_same: check_same_dtypes(name, *avals)
return result_dtype(*avals)


Expand Down Expand Up @@ -1609,9 +1606,11 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):
"taken a gradient with respect to an integer argument.")
return all(aval.weak_type for aval in avals)

def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False):
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
require_same_dtypes=False):
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
allow_extended_dtype=allow_extended_dtype)
allow_extended_dtype=allow_extended_dtype,
require_same=require_same_dtypes)
shape_rule = partial(broadcasting_shape_rule, name)
weak_type_rule = partial(_naryop_weak_type_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name,
Expand Down Expand Up @@ -1973,16 +1972,48 @@ def _abs_jvp_rule(g, ans, x):
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))

pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
def _pow_dtype_rule(x, y):
if (dtypes.issubdtype(x.dtype, np.inexact) and
dtypes.issubdtype(y.dtype, np.integer)):
return x.dtype
if x.dtype == y.dtype:
return x.dtype
raise TypeError("the first argument to pow must have an inexact dtype (float "
"or complex), and the second argument must have an inexact or"
" integer dtype, and two inexact dtypes must match, but got "
f"{x.dtype} and {y.dtype} respectively.")
pow_p = naryop(_pow_dtype_rule, [_float | _complex, _int | _float | _complex],
'pow', require_same_dtypes=False)

def _pow_jvp_lhs(g, ans, x, y):
return mul(g, mul(y, pow(x, sub(y, _ones(y)))))
y_dtype = dtypes.dtype(y)
x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this
if dtypes.issubdtype(y_dtype, np.integer):
jac = select(eq(y, _const(y, 0)), _ones(y),
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
else:
jac = mul(y, pow(x, sub(y, _ones(y))))
return mul(g, jac)

def _pow_jvp_rhs(g, ans, x, y):
return mul(g, mul(log(_replace_zero(x)), ans))

y_dtype = dtypes.dtype(y)
assert dtypes.issubdtype(y_dtype, np.inexact)
return convert_element_type(mul(g, mul(log(_replace_zero(x)), ans)), y_dtype)
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp))

def _pow_lower(ctx, x, y):
x_aval, y_aval = ctx.avals_in
out_aval, = ctx.avals_out
convert = mlir.lower_fun(
partial(convert_element_type, new_dtype=out_aval.dtype), False)
x_aval_ = x_aval.update(dtype=out_aval.dtype)
y_aval_ = y_aval.update(dtype=out_aval.dtype)
[(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
[(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
return _nary_lower_hlo(hlo.PowOp, ctx_, x_, y_)
mlir.register_lowering(pow_p, _pow_lower)



def _integer_pow_dtype_rule(x, *, y):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from functools import partial
import operator
from typing import Callable

from jax._src import core
from jax._src import dispatch
Expand All @@ -30,7 +29,8 @@

xops = xla_client.ops

_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_extended_dtype=True)
def _input_dtype(x, *_, **__):
return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)

def _argnum_weak_type(*argnums):
return lambda *args, **_: all(args[i].weak_type for i in argnums)
Expand Down
74 changes: 50 additions & 24 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, _wraps)
promote_shapes, _where, _wraps, check_no_float0s)

_lax_const = lax._const

Expand Down Expand Up @@ -305,16 +305,60 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
return lax.round(div), mod


@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
check_no_float0s("power", x1, x2)

# We apply special cases, both for algorithmic and autodiff reasons:
# 1. for *concrete* integer scalar powers (and arbitrary bases), we use
# unrolled binary exponentiation specialized on the exponent, which is
# more precise for e.g. x ** 2 when x is a float (algorithmic reason!);
# 2. for integer bases and integer powers, use unrolled binary exponentiation
# where the number of steps is determined by a max bit width of 64
# (algorithmic reason!);
# 3. for integer powers and float/complex bases, we apply the lax primitive
# without any promotion of input types because in this case we want the
# function to be differentiable wrt its first argument at 0;
# 3. for other cases, perform jnp dtype promotion on the arguments then apply
# lax.pow.

# Case 1: concrete integer scalar powers:
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:
pass
else:
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)

# Handle cases #2 and #3 under a jit:
return _power(x1, x2)

@partial(jit, inline=True)
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
x1, x2 = promote_args_numeric("power", x1, x2)
dtype = dtypes.dtype(x1)
if not dtypes.issubdtype(dtype, np.integer):
x1, x2 = promote_shapes("power", x1, x2) # not dtypes

# Case 2: bool/integer result
x1_, x2_ = promote_args_numeric("power", x1, x2)
if (dtypes.issubdtype(dtypes.dtype(x1_), np.integer) or
dtypes.issubdtype(dtypes.dtype(x1_), np.bool_)):
assert np.iinfo(dtypes.dtype(x1_)).bits <= 64 # _pow_int_int assumes <=64bit
return _pow_int_int(x1_, x2_)

# Case 3: float/complex base with integer power (special autodiff behavior)
d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2)
if dtypes.issubdtype(d1, np.inexact) and dtypes.issubdtype(d2, np.integer):
return lax.pow(x1, x2)

# Integer power => use binary exponentiation.

# TODO(phawkins): add integer pow support to XLA.
# Case 4: do promotion first
return lax.pow(x1_, x2_)

# TODO(phawkins): add integer pow support to XLA.
def _pow_int_int(x1, x2):
# Integer power => use binary exponentiation.
bits = 6 # Anything more would overflow for any x1 > 1
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
Expand All @@ -327,24 +371,6 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
return acc


@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("power", x1, x2)
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
if isinstance(core.get_aval(x2), core.ConcreteArray):
try:
x2 = operator.index(x2) # type: ignore[arg-type]
except TypeError:
pass
else:
x1, = promote_dtypes_numeric(x1)
return lax.integer_pow(x1, x2)
return _power(x1, x2)


@custom_jvp
@_wraps(np.logaddexp, module='numpy')
@jit
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def check_arraylike_or_none(fun_name: str, *args: Any):
raise TypeError(msg.format(fun_name, type(arg), pos))


def _check_no_float0s(fun_name: str, *args: Any):
def check_no_float0s(fun_name: str, *args: Any):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
raise TypeError(
Expand All @@ -348,6 +348,7 @@ def _check_no_float0s(fun_name: str, *args: Any):
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
_check_no_float0s = check_no_float0s


def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]:
Expand Down
4 changes: 2 additions & 2 deletions tests/dynamic_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,10 +1519,10 @@ def test_jumble_map_eltwise(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.jumble_axis)(ins)
p = jumble_map(jax.jit(lambda x: x ** 2))(p)
p = jumble_map(jax.jit(lambda x: x * 3))(p)
self.assertIsInstance(p, batching.Jumble)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) * 3
self.assertAllClose(p.data, data, check_dtypes=False)

def test_jumble_map_vector_dot(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,11 @@ def testPowSecondDerivative(self):
# self.assertEqual(result, 0.0)
self.assertAllClose(result, np.nan)

def testPowIntPowerAtZero(self):
# https://github.com/google/jax/issues/14397
ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0)
self.assertAllClose(ans, 1., check_dtypes=False)

@jtu.sample_product(
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
for arg_shape in [(), (3,), (2, 3)]
Expand Down

0 comments on commit af42359

Please sign in to comment.