Skip to content

Commit

Permalink
temporarily revert parts of jax-ml#2026 pending bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj authored and NeilGirdhar committed Apr 13, 2020
1 parent 6a796c4 commit 458d71e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 41 deletions.
10 changes: 3 additions & 7 deletions jax/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

import numpy as onp

from jax import custom_jvp
from jax import custom_transforms, defjvp
from jax import dtypes
from jax import lax
from jax.scipy.special import expit
import jax.numpy as np

# activations

@custom_jvp
@custom_transforms
def relu(x):
r"""Rectified linear unit activation function.
Expand All @@ -35,11 +35,7 @@ def relu(x):
\mathrm{relu}(x) = \max(x, 0)
"""
return np.maximum(x, 0)
def _relu_jvp(primals, tangents):
x, = primals
t, = tangents
return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
relu.defjvp(_relu_jvp)
defjvp(relu, lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))

def softplus(x):
r"""Softplus activation function.
Expand Down
19 changes: 8 additions & 11 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import operator
from typing import Tuple, Union, cast

from jax import jit, vmap, custom_jvp
from jax import jit, vmap
from .. import lax
from .. import ops
from .. import lax_linalg
Expand All @@ -29,6 +29,7 @@
from .lax_numpy import _wraps
from .vectorize import vectorize
from . import lax_numpy as np
from ..api import custom_transforms, defjvp
from ..util import get_module_functions
from ..third_party.numpy.linalg import cond, tensorinv, tensorsolve

Expand Down Expand Up @@ -110,8 +111,8 @@ def matrix_rank(M, tol=None):
return np.sum(S > tol)


@custom_jvp
@_wraps(onp.linalg.slogdet)
@custom_transforms
@jit
def slogdet(a):
a = _promote_arg_dtypes(np.asarray(a))
Expand All @@ -136,15 +137,11 @@ def slogdet(a):
is_zero, np.array(-np.inf, dtype=dtype),
np.sum(np.log(np.abs(diag)), axis=-1))
return sign, np.real(logdet)
def _slogdet_jvp(primals, tangents):
x, = primals
g, = tangents
if np.issubdtype(np._dtype(x), np.complexfloating):
raise NotImplementedError # TODO(pfau): make this work for complex types
sign, ans = slogdet(x)
sign_dot, ans_dot = np.zeros_like(sign), np.trace(solve(x, g), axis1=-1, axis2=-2)
return (sign, ans), (sign_dot, ans_dot)
slogdet.defjvp(_slogdet_jvp)
def _jvp_slogdet(g, ans, x):
jvp_sign = np.zeros(x.shape[:-2])
jvp_logdet = np.trace(solve(x, g), axis1=-1, axis2=-2)
return jvp_sign, jvp_logdet
defjvp(slogdet, _jvp_slogdet)


@_wraps(onp.linalg.det)
Expand Down
36 changes: 13 additions & 23 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .. import util
from .. import lax
from .. import api
from ..api import custom_transforms, defjvp
from ..numpy import lax_numpy as jnp
from ..numpy.lax_numpy import (_wraps, asarray, _reduction_dims, _constant_like,
_promote_args_inexact)
Expand Down Expand Up @@ -79,26 +80,21 @@ def erfinv(x):
return lax.erf_inv(x)


@api.custom_jvp
@_wraps(osp_special.logit, update_doc=False)
@custom_transforms
def logit(x):
x = asarray(x)
return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))
def _logit_jvp(primals, tangents):
(x,), (t,) = primals, tangents
ans = logit(x)
t_out = lax.div(lax.mul(x, lax.sub(lax._const(x, 1), x)))
return ans, t_out
logit.defjvp(_logit_jvp)
defjvp(logit, lambda g, ans, x: g / (x * (1 - x)))


@api.custom_jvp
@_wraps(osp_special.expit, update_doc=False)
@custom_transforms
def expit(x):
return 1 / (1 + lax.exp(-x))
def _expit_jvp(primals, tangents):
(x,), (t,) = primals, tangents
ans = expit(x)
t_out = t * ans * (1 - ans)
return ans, t_out
expit.defjvp(_expit_jvp)
x = asarray(x)
one = lax._const(x, 1)
return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
defjvp(expit, lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans))


@_wraps(osp_special.logsumexp)
Expand Down Expand Up @@ -411,7 +407,7 @@ def _create_polynomial(var, coeffs):
return x_nan_replaced


@partial(api.custom_jvp, nondiff_argnums=(1,))
@custom_transforms
def log_ndtr(x, series_order=3):
r"""Log Normal distribution function.
Expand Down Expand Up @@ -512,13 +508,7 @@ def log_ndtr(x, series_order=3):
lax.log(_ndtr(lax.max(x, lower_segment))),
_log_ndtr_lower(lax.min(x, lower_segment),
series_order)))

def _log_ndtr_jvp(series_order, primals, tangents):
(x,), (t,) = primals, tangents
ans = log_ndtr(x, series_order=series_order)
t_out = lax.mul(t, lax.exp(lax.sub(_norm_logpdf(x), ans)))
return ans, t_out
log_ndtr.defjvp(_log_ndtr_jvp)
defjvp(log_ndtr, lambda g, ans, x: lax.mul(g, lax.exp(lax.sub(_norm_logpdf(x), ans))))

def _log_ndtr_lower(x, series_order):
"""Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
Expand Down

0 comments on commit 458d71e

Please sign in to comment.