Skip to content

Commit

Permalink
implement jet rules by lowering to other primitives (#2816)
Browse files Browse the repository at this point in the history
merge jet_test

add jet rules

use lax.square
  • Loading branch information
jacobjinkelly authored Apr 24, 2020
1 parent 2518343 commit fc4203c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
35 changes: 35 additions & 0 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,41 @@ def _log_taylor(primals_in, series_in):
return primal_out, series_out
jet_rules[lax.log_p] = _log_taylor

def _sqrt_taylor(primals_in, series_in):
return jet(lambda x: x ** 0.5, primals_in, series_in)
jet_rules[lax.sqrt_p] = _sqrt_taylor

def _rsqrt_taylor(primals_in, series_in):
return jet(lambda x: x ** -0.5, primals_in, series_in)
jet_rules[lax.rsqrt_p] = _rsqrt_taylor

def _asinh_taylor(primals_in, series_in):
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)), primals_in, series_in)
jet_rules[lax.asinh_p] = _asinh_taylor

def _acosh_taylor(primals_in, series_in):
return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)), primals_in, series_in)
jet_rules[lax.acosh_p] = _acosh_taylor

def _atanh_taylor(primals_in, series_in):
return jet(lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)), primals_in, series_in)
jet_rules[lax.atanh_p] = _atanh_taylor

def _atan2_taylor(primals_in, series_in):
x, y = primals_in
primal_out = lax.atan2(x, y)

x, series = jet(lax.div, primals_in, series_in)
c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
c = [c0] + cs
u = [x] + series
v = [primal_out] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.atan2_p] = _atan2_taylor

def _log1p_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
Expand Down
12 changes: 12 additions & 0 deletions tests/jet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ def test_tanh(self): self.unary_check(np.tanh, lims=[-500, 500], order=5)
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_sqrt(self): self.unary_check(np.sqrt, lims=[0, 5.])
@jtu.skip_on_devices("tpu")
def test_rsqrt(self): self.unary_check(lax.rsqrt, lims=[0, 5000.])
@jtu.skip_on_devices("tpu")
def test_asinh(self): self.unary_check(lax.asinh, lims=[-100, 100])
@jtu.skip_on_devices("tpu")
def test_acosh(self): self.unary_check(lax.acosh, lims=[-100, 100])
@jtu.skip_on_devices("tpu")
def test_atanh(self): self.unary_check(lax.atanh, lims=[-1, 1])

@jtu.skip_on_devices("tpu")
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
Expand Down Expand Up @@ -245,6 +255,8 @@ def test_xor(self): self.binary_check(lambda x, y: np.logical_xor(x, y))
@jtu.skip_on_devices("tpu")
@jtu.ignore_warning(message="overflow encountered in power")
def test_pow(self): self.binary_check(lambda x, y: x ** y, lims=([0.2, 500], [-500, 500]), finite=False)
@jtu.skip_on_devices("tpu")
def test_atan2(self): self.binary_check(lax.atan2, lims=[-40, 40])

def test_process_call(self):
def f(x):
Expand Down

0 comments on commit fc4203c

Please sign in to comment.