@@ -3723,10 +3723,11 @@ def _sin_complex(x):
3723
3723
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
3724
3724
a , b = real (x ), imag (x )
3725
3725
a_is_zero = eq (a , _const (a , 0 ))
3726
+ two = _const (a , 2 )
3726
3727
sn , cs = sin (a ), cos (a )
3727
- e1m , e2m = expm1 (b ), expm1 (- b )
3728
- snh , csh = ( e1m - e2m ) / 2 , ( e1m + e2m + 2 ) / 2
3729
- re , im = sn * csh , cs * snh
3728
+ e1m , e2m = expm1 (b ), expm1 (neg ( b ) )
3729
+ snh , csh = div ( sub ( e1m , e2m ), two ), div ( add ( add ( e1m , e2m ), two ), two )
3730
+ re , im = mul ( sn , csh ), mul ( cs , snh )
3730
3731
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
3731
3732
return select (a_is_zero , complex (_const (a , 0 ), im ), complex (re , im ))
3732
3733
@@ -3752,10 +3753,11 @@ def _cos_complex(x):
3752
3753
# see also _sin_complex
3753
3754
a , b = real (x ), imag (x )
3754
3755
a_is_zero = eq (a , _const (a , 0 ))
3756
+ two = _const (a , 2 )
3755
3757
sn , cs = sin (a ), cos (a )
3756
- e1m , e2m = expm1 (b ), expm1 (- b )
3757
- snh , csh = ( e1m - e2m ) / 2 , ( e1m + e2m + 2 ) / 2
3758
- re , im = cs * csh , - sn * snh
3758
+ e1m , e2m = expm1 (b ), expm1 (neg ( b ) )
3759
+ snh , csh = div ( sub ( e1m , e2m ), two ), div ( add ( add ( e1m , e2m ), two ), two )
3760
+ re , im = mul ( cs , csh ), mul ( neg ( sn ), snh )
3759
3761
return select (a_is_zero , complex (re , _const (a , 0 )), complex (re , im ))
3760
3762
3761
3763
def _cos_lowering (ctx , x ):
@@ -3769,28 +3771,28 @@ def _cos_lowering(ctx, x):
3769
3771
mlir .register_lowering (cos_p , _cos_lowering )
3770
3772
3771
3773
tan_p = standard_unop (_float | _complex , 'tan' )
3772
- ad .defjvp2 (tan_p , lambda g , ans , x : mul (g , _const (x , 1 ) + square (ans )))
3774
+ ad .defjvp2 (tan_p , lambda g , ans , x : mul (g , add ( _const (x , 1 ), square (ans ) )))
3773
3775
mlir .register_lowering (tan_p , partial (_nary_lower_hlo , hlo .tan ))
3774
3776
3775
3777
asin_p = standard_unop (_float | _complex , 'asin' )
3776
- ad .defjvp (asin_p , lambda g , x : mul (g , rsqrt (_const (x , 1 ) - square (x ))))
3778
+ ad .defjvp (asin_p , lambda g , x : mul (g , rsqrt (sub ( _const (x , 1 ), square (x ) ))))
3777
3779
mlir .register_lowering (asin_p , partial (_nary_lower_hlo , chlo .asin ))
3778
3780
3779
3781
acos_p = standard_unop (_float | _complex , 'acos' )
3780
- ad .defjvp (acos_p , lambda g , x : mul (g , - rsqrt (_const (x , 1 ) - square (x ))))
3782
+ ad .defjvp (acos_p , lambda g , x : mul (g , neg ( rsqrt (sub ( _const (x , 1 ), square (x )) ))))
3781
3783
mlir .register_lowering (acos_p , partial (_nary_lower_hlo , chlo .acos ))
3782
3784
3783
3785
def atan_impl (x ):
3784
3786
return atan2 (x , _const (x , 1 ))
3785
3787
3786
3788
atan_p = standard_unop (_float | _complex , 'atan' )
3787
- ad .defjvp (atan_p , lambda g , x : div (g , _const (x , 1 ) + square (x )))
3789
+ ad .defjvp (atan_p , lambda g , x : div (g , add ( _const (x , 1 ), square (x ) )))
3788
3790
mlir .register_lowering (atan_p , partial (_nary_lower_hlo , chlo .atan ))
3789
3791
3790
3792
atan2_p = standard_naryop ([_float | _complex , _float | _complex ], 'atan2' )
3791
3793
ad .defjvp (atan2_p ,
3792
- lambda g , x , y : g * ( y / (square (x ) + square (y ))),
3793
- lambda g , x , y : g * - x / (square (x ) + square (y )))
3794
+ lambda g , x , y : mul ( g , div ( y , add (square (x ), square (y ) ))),
3795
+ lambda g , x , y : mul ( g , div ( neg ( x ), add (square (x ), square (y )) )))
3794
3796
mlir .register_lowering (atan2_p , partial (_nary_lower_hlo , hlo .atan2 ))
3795
3797
3796
3798
sinh_p = standard_unop (_float | _complex , 'sinh' )
@@ -3802,17 +3804,17 @@ def atan_impl(x):
3802
3804
mlir .register_lowering (cosh_p , partial (_nary_lower_hlo , chlo .cosh ))
3803
3805
3804
3806
asinh_p = standard_unop (_float | _complex , 'asinh' )
3805
- ad .defjvp (asinh_p , lambda g , x : mul (g , rsqrt (square (x ) + _one (x ))))
3807
+ ad .defjvp (asinh_p , lambda g , x : mul (g , rsqrt (add ( square (x ), _one (x ) ))))
3806
3808
mlir .register_lowering (asinh_p , partial (_nary_lower_hlo , chlo .asinh ))
3807
3809
3808
3810
acosh_p = standard_unop (_float | _complex , 'acosh' )
3809
3811
ad .defjvp (acosh_p ,
3810
- lambda g , x : mul (g , rsqrt (( x - _one (x )) * ( x + _one (x )))))
3812
+ lambda g , x : mul (g , rsqrt (mul ( sub ( x , _one (x )), add ( x , _one (x ) )))))
3811
3813
mlir .register_lowering (acosh_p , partial (_nary_lower_hlo , chlo .acosh ))
3812
3814
3813
3815
atanh_p = standard_unop (_float | _complex , 'atanh' )
3814
3816
ad .defjvp (atanh_p ,
3815
- lambda g , x : mul (reciprocal (_one (x ) + x ) , div (g , (_one (x ) - x ))))
3817
+ lambda g , x : mul (reciprocal (add ( _one (x ), x )) , div (g , sub (_one (x ), x ))))
3816
3818
mlir .register_lowering (atanh_p , partial (_nary_lower_hlo , chlo .atanh ))
3817
3819
3818
3820
real_p = unop (_complex_basetype , _complex , 'real' )
@@ -3906,11 +3908,11 @@ def _square_complex(x):
3906
3908
a , b = real (x ), imag (x )
3907
3909
# zero square(x).real is handled explicitly for abs(a)==abs(b) cases
3908
3910
# where for finite a, 2 * a is non-finite:
3909
- zero_re = is_finite (a ) & (eq (a , b ) | eq (a , - b ))
3911
+ zero_re = is_finite (a ) & (eq (a , b ) | eq (a , neg ( b ) ))
3910
3912
# equivalent to a**2 - b**2 but avoids overflow errors for large a
3911
3913
# and large b cases:
3912
- re = ( a - b ) * ( a + b )
3913
- im = a * b * 2
3914
+ re = mul ( sub ( a , b ), add ( a , b ) )
3915
+ im = mul ( mul ( a , b ), _const ( a , 2 ))
3914
3916
return select (zero_re , complex (_const (a , 0 ), im ), complex (re , im ))
3915
3917
3916
3918
def _square_lower_hlo (ctx , x ):
@@ -5276,7 +5278,7 @@ def _ragged_dot_jvp_rule(
5276
5278
if type (dy ) is not ad_util .Zero
5277
5279
else _zeros (primal_out )
5278
5280
)
5279
- tangent_out = dx_out + dy_out
5281
+ tangent_out = add ( dx_out , dy_out )
5280
5282
5281
5283
return primal_out , tangent_out
5282
5284
0 commit comments