Skip to content

Commit

Permalink
add tanh rule (#2653)
Browse files Browse the repository at this point in the history
change expit taylor rule

add manual expit check, check stability of expit and tanh
  • Loading branch information
jacobjinkelly authored Apr 23, 2020
1 parent 8fe3c59 commit 59bdb1f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
26 changes: 26 additions & 0 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as onp

import jax
from jax import core
from jax.util import unzip2
from jax.tree_util import (register_pytree_node, tree_structure,
Expand Down Expand Up @@ -217,6 +218,9 @@ def fact(n):
def _scale(k, j):
return 1. / (fact(k - j) * fact(j - 1))

def _scale2(k, j):
return 1. / (fact(k - j) * fact(j))

def _exp_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
Expand Down Expand Up @@ -253,6 +257,28 @@ def _pow_taylor(primals_in, series_in):
return primal_out, series_out
jet_rules[lax.pow_p] = _pow_taylor

def _expit_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [jax.scipy.special.expit(x)] + [None] * len(series)
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum([_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)])
e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j)* v[j] * v[k-j] for j in range(1, k+1)])

primal_out, *series_out = v
return primal_out, series_out

def _tanh_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [2*x] + [2 * series_ for series_ in series]
primals_in, *series_in = u
primal_out, series_out = _expit_taylor((primals_in, ), (series_in, ))
series_out = [2 * series_ for series_ in series_out]
return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor

def _log_taylor(primals_in, series_in):
x, = primals_in
Expand Down
28 changes: 28 additions & 0 deletions tests/jet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from jax import test_util as jtu
import jax.numpy as np
import jax.scipy.special
from jax import random
from jax import jacfwd, jit
from jax.experimental import stax
Expand Down Expand Up @@ -153,6 +154,27 @@ def binary_check(self, fun, lims=[-2, 2], order=3, finite=True):
else:
self.check_jet_finite(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)

def expit_check(self, lims=[-2, 2], order=3):
dims = 2, 3
rng = onp.random.RandomState(0)
primal_in = transform(lims, rng.rand(*dims))
terms_in = [rng.randn(*dims) for _ in range(order)]

primals = (primal_in, )
series = (terms_in, )

y, terms = jax.experimental.jet._expit_taylor(primals, series)
expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series)

atol = 1e-4
rtol = 1e-4
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=True)

self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=True)


@jtu.skip_on_devices("tpu")
def test_exp(self): self.unary_check(np.exp)
@jtu.skip_on_devices("tpu")
Expand Down Expand Up @@ -187,6 +209,12 @@ def test_fft(self): self.unary_check(np.fft.fft)
def test_log1p(self): self.unary_check(np.log1p, lims=[0, 4.])
@jtu.skip_on_devices("tpu")
def test_expm1(self): self.unary_check(np.expm1)
@jtu.skip_on_devices("tpu")
def test_tanh(self): self.unary_check(np.tanh, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)

@jtu.skip_on_devices("tpu")
def test_div(self): self.binary_check(lambda x, y: x / y, lims=[0.8, 4.0])
Expand Down

0 comments on commit 59bdb1f

Please sign in to comment.