Skip to content

Commit 14b215f

Browse files
Merge pull request #27032 from dfm:lax-dtype
PiperOrigin-RevId: 735424674
2 parents ab0ce8a + 4eada56 commit 14b215f

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

jax/_src/lax/lax.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -3723,10 +3723,11 @@ def _sin_complex(x):
37233723
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
37243724
a, b = real(x), imag(x)
37253725
a_is_zero = eq(a, _const(a, 0))
3726+
two = _const(a, 2)
37263727
sn, cs = sin(a), cos(a)
3727-
e1m, e2m = expm1(b), expm1(-b)
3728-
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
3729-
re, im = sn * csh, cs * snh
3728+
e1m, e2m = expm1(b), expm1(neg(b))
3729+
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
3730+
re, im = mul(sn, csh), mul(cs, snh)
37303731
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
37313732
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
37323733

@@ -3752,10 +3753,11 @@ def _cos_complex(x):
37523753
# see also _sin_complex
37533754
a, b = real(x), imag(x)
37543755
a_is_zero = eq(a, _const(a, 0))
3756+
two = _const(a, 2)
37553757
sn, cs = sin(a), cos(a)
3756-
e1m, e2m = expm1(b), expm1(-b)
3757-
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
3758-
re, im = cs * csh, -sn * snh
3758+
e1m, e2m = expm1(b), expm1(neg(b))
3759+
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
3760+
re, im = mul(cs, csh), mul(neg(sn), snh)
37593761
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
37603762

37613763
def _cos_lowering(ctx, x):
@@ -3769,28 +3771,28 @@ def _cos_lowering(ctx, x):
37693771
mlir.register_lowering(cos_p, _cos_lowering)
37703772

37713773
tan_p = standard_unop(_float | _complex, 'tan')
3772-
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
3774+
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans))))
37733775
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
37743776

37753777
asin_p = standard_unop(_float | _complex, 'asin')
3776-
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
3778+
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x)))))
37773779
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
37783780

37793781
acos_p = standard_unop(_float | _complex, 'acos')
3780-
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
3782+
ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x))))))
37813783
mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos))
37823784

37833785
def atan_impl(x):
37843786
return atan2(x, _const(x, 1))
37853787

37863788
atan_p = standard_unop(_float | _complex, 'atan')
3787-
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
3789+
ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x))))
37883790
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
37893791

37903792
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
37913793
ad.defjvp(atan2_p,
3792-
lambda g, x, y: g * (y / (square(x) + square(y))),
3793-
lambda g, x, y: g * -x / (square(x) + square(y)))
3794+
lambda g, x, y: mul(g, div(y, add(square(x), square(y)))),
3795+
lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y)))))
37943796
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
37953797

37963798
sinh_p = standard_unop(_float | _complex, 'sinh')
@@ -3802,17 +3804,17 @@ def atan_impl(x):
38023804
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
38033805

38043806
asinh_p = standard_unop(_float | _complex, 'asinh')
3805-
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
3807+
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x)))))
38063808
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
38073809

38083810
acosh_p = standard_unop(_float | _complex, 'acosh')
38093811
ad.defjvp(acosh_p,
3810-
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
3812+
lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x))))))
38113813
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
38123814

38133815
atanh_p = standard_unop(_float | _complex, 'atanh')
38143816
ad.defjvp(atanh_p,
3815-
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
3817+
lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x))))
38163818
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
38173819

38183820
real_p = unop(_complex_basetype, _complex, 'real')
@@ -3906,11 +3908,11 @@ def _square_complex(x):
39063908
a, b = real(x), imag(x)
39073909
# zero square(x).real is handled explicitly for abs(a)==abs(b) cases
39083910
# where for finite a, 2 * a is non-finite:
3909-
zero_re = is_finite(a) & (eq(a, b) | eq(a, -b))
3911+
zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b)))
39103912
# equivalent to a**2 - b**2 but avoids overflow errors for large a
39113913
# and large b cases:
3912-
re = (a - b) * (a + b)
3913-
im = a * b * 2
3914+
re = mul(sub(a, b), add(a, b))
3915+
im = mul(mul(a, b), _const(a, 2))
39143916
return select(zero_re, complex(_const(a, 0), im), complex(re, im))
39153917

39163918
def _square_lower_hlo(ctx, x):
@@ -5276,7 +5278,7 @@ def _ragged_dot_jvp_rule(
52765278
if type(dy) is not ad_util.Zero
52775279
else _zeros(primal_out)
52785280
)
5279-
tangent_out = dx_out + dy_out
5281+
tangent_out = add(dx_out, dy_out)
52805282

52815283
return primal_out, tangent_out
52825284

0 commit comments

Comments
 (0)