@@ -303,15 +303,16 @@ def body_fn(vals):
303
303
304
304
def igamma_impl (a , x , * , dtype ):
305
305
is_nan = bitwise_or (_isnan (a ), _isnan (x ))
306
- x_is_zero = eq (x , _const (x , 0 ))
307
306
x_is_infinity = eq (x , _const (x , float ('inf' )))
308
- domain_error = bitwise_or (lt (x , _const (x , 0 )), le (a , _const (a , 0 )))
309
- use_igammac = bitwise_and (gt (x , _const (x , 1 )), gt (x , a ))
307
+ a_is_zero = eq (a , _const (a , 0 ))
308
+ x_is_zero = eq (x , _const (x , 0 ))
309
+ domain_error = _reduce (bitwise_or , [lt (x , _const (x , 0 )), lt (a , _const (a , 0 )), bitwise_and (a_is_zero , x_is_zero )])
310
+
311
+ use_igammac = bitwise_and (ge (x , _const (x , 1 )), gt (x , a ))
310
312
ax = a * log (x ) - x - lgamma (a )
311
313
underflow = lt (ax , - log (dtypes .finfo (dtype ).max ))
312
314
ax = exp (ax )
313
- enabled = bitwise_not (
314
- _reduce (bitwise_or ,[x_is_zero , domain_error , underflow , is_nan ]))
315
+ enabled = bitwise_not (_reduce (bitwise_or , [x_is_zero , domain_error , underflow , is_nan , x_is_infinity ]))
315
316
316
317
output = select (
317
318
use_igammac ,
@@ -323,8 +324,7 @@ def igamma_impl(a, x, *, dtype):
323
324
)
324
325
output = select (x_is_zero , full_like (a , 0 ), output )
325
326
output = select (x_is_infinity , full_like (a , 1 ), output )
326
- output = select (bitwise_or (domain_error , is_nan ),
327
- full_like (a , float ('nan' )), output )
327
+ output = select (domain_error , full_like (a , float ('nan' )), output )
328
328
return output
329
329
330
330
def _igammac_continued_fraction (ax , x , a , enabled , dtype , mode ):
@@ -433,22 +433,26 @@ def body_fn(vals):
433
433
raise ValueError (f"Invalid mode: { mode } " )
434
434
435
435
def igammac_impl (a , x , * , dtype ):
436
- out_of_range = bitwise_or (le (x , _const (x , 0 )), le (a , _const (a , 0 )))
436
+ is_nan = bitwise_or (_isnan (a ), _isnan (x ))
437
+ a_is_zero = eq (a , _const (a , 0 ))
438
+ x_is_zero = eq (x , _const (x , 0 ))
439
+ x_is_infinity = eq (x , _const (x , float ('inf' )))
440
+ domain_error = _reduce (bitwise_or , [lt (x , _const (x , 0 )), lt (a , _const (a , 0 )), bitwise_and (a_is_zero , x_is_zero )])
437
441
use_igamma = bitwise_or (lt (x , _const (x , 1 )), lt (x , a ))
438
442
ax = a * log (x ) - x - lgamma (a )
439
443
underflow = lt (ax , - log (dtypes .finfo (dtype ).max ))
440
- enabled = bitwise_not (bitwise_or ( out_of_range , underflow ))
444
+ enabled = bitwise_not (_reduce ( bitwise_or , [ domain_error , underflow , is_nan , x_is_infinity , a_is_zero ] ))
441
445
ax = exp (ax )
442
446
443
447
igamma_call = _igamma_series (ax , x , a , bitwise_and (enabled , use_igamma ),
444
448
dtype , IgammaMode .VALUE )
445
449
igammac_cf_call = _igammac_continued_fraction (ax , x , a ,
446
450
bitwise_and (enabled , bitwise_not (use_igamma )), dtype , IgammaMode .VALUE )
447
451
448
- result = select (use_igamma , _const (a , 1 ) - igamma_call , igammac_cf_call )
449
- x_is_infinity = eq ( x , _const ( x , float ( 'inf' )) )
450
- result = select (x_is_infinity , full_like (result , 0 ), result )
451
- return select ( out_of_range , full_like ( a , 1 ), result )
452
+ output = select (use_igamma , _const (a , 1 ) - igamma_call , igammac_cf_call )
453
+ output = select ( bitwise_or ( x_is_infinity , a_is_zero ), full_like ( output , 0 ), output )
454
+ output = select (domain_error , full_like (a , float ( 'nan' )), output )
455
+ return output
452
456
453
457
def igamma_grad_a_impl (a , x , * , dtype ):
454
458
is_nan = bitwise_or (_isnan (a ), _isnan (x ))
0 commit comments