Skip to content

Commit

Permalink
Merge pull request #20688 from pearu:pearu/tan
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623851102
  • Loading branch information
jax authors committed Apr 11, 2024
2 parents 2be7205 + fc04ba9 commit 36bedee
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
28 changes: 28 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,34 @@ def log1p(self, x):
return ctx.make_mpc(((-x.real)._mpf_, (3 * pi / 4)._mpf_))
return ctx.log1p(x)

def tan(self, x):
ctx = x.context
if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in tan(+-inf+-infj) evaluation (see mpmath/mpmath#781).
# TODO(pearu): remove this function when mpmath 1.4 or newer
# will be the required test dependency.
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
if x.imag > 0:
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
return ctx.tan(x)

def tanh(self, x):
ctx = x.context
if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in tanh(+-inf+-infj) evaluation (see mpmath/mpmath#781).
# TODO(pearu): remove this function when mpmath 1.4 or newer
# will be the required test dependency.
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
if x.imag > 0:
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
return ctx.tanh(x)

def log2(self, x):
return x.context.ln(x) / x.context.ln2

Expand Down
20 changes: 10 additions & 10 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3616,11 +3616,9 @@ def regions_with_inaccuracies_keep(*to_keep):
elif name == 'log10':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')

elif name == 'log1p':
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real',
'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real')
# TODO(pearu): after landing openxla/xla#10503, switch to
# regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
elif name == 'log1p' and xla_extension_version < 254:
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real',
'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real')

elif name == 'exp':
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')
Expand All @@ -3640,9 +3638,10 @@ def regions_with_inaccuracies_keep(*to_keep):
'ninf.imag', 'pinf.imag', 'ninfj.real', 'pinfj.real')

elif name == 'tan':
# TODO(pearu): eliminate this if-block when openxla/xla#10525 lands
regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', 'posj.imag',
'ninfj.imag', 'pinfj.imag', 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag',
'ninf.imag', 'pinf.imag')
'ninf.imag', 'pinf.imag', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real')

elif name == 'sinh':
if is_cuda:
Expand Down Expand Up @@ -3695,14 +3694,15 @@ def regions_with_inaccuracies_keep(*to_keep):
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')

elif name == 'arctanh':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
# TODO(pearu): after landing openxla/xla#10503, switch to
# regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
if xla_extension_version < 254:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
else:
regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')

elif name in {'cos', 'sin'}:
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')

elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1'}:
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p'}:
regions_with_inaccuracies.clear()
else:
assert 0 # unreachable
Expand Down

0 comments on commit 36bedee

Please sign in to comment.