diff --git a/jax/nn/functions.py b/jax/nn/functions.py index 0d34409a8b09..026e70703f82 100644 --- a/jax/nn/functions.py +++ b/jax/nn/functions.py @@ -17,7 +17,7 @@ import numpy as onp -from jax import custom_jvp +from jax import custom_transforms, defjvp from jax import dtypes from jax import lax from jax.scipy.special import expit @@ -25,7 +25,7 @@ # activations -@custom_jvp +@custom_transforms def relu(x): r"""Rectified linear unit activation function. @@ -35,11 +35,7 @@ def relu(x): \mathrm{relu}(x) = \max(x, 0) """ return np.maximum(x, 0) -def _relu_jvp(primals, tangents): - x, = primals - t, = tangents - return relu(x), lax.select(x > 0, t, lax.full_like(t, 0)) -relu.defjvp(_relu_jvp) +defjvp(relu, lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0))) def softplus(x): r"""Softplus activation function. diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 728ee83ea1e6..ff0668017c50 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -20,7 +20,7 @@ import operator from typing import Tuple, Union, cast -from jax import jit, vmap, custom_jvp +from jax import jit, vmap from .. import lax from .. import ops from .. import lax_linalg @@ -29,6 +29,7 @@ from .lax_numpy import _wraps from .vectorize import vectorize from . import lax_numpy as np +from ..api import custom_transforms, defjvp from ..util import get_module_functions from ..third_party.numpy.linalg import cond, tensorinv, tensorsolve @@ -110,8 +111,8 @@ def matrix_rank(M, tol=None): return np.sum(S > tol) -@custom_jvp @_wraps(onp.linalg.slogdet) +@custom_transforms @jit def slogdet(a): a = _promote_arg_dtypes(np.asarray(a)) @@ -136,15 +137,11 @@ def slogdet(a): is_zero, np.array(-np.inf, dtype=dtype), np.sum(np.log(np.abs(diag)), axis=-1)) return sign, np.real(logdet) -def _slogdet_jvp(primals, tangents): - x, = primals - g, = tangents - if np.issubdtype(np._dtype(x), np.complexfloating): - raise NotImplementedError # TODO(pfau): make this work for complex types - sign, ans = slogdet(x) - sign_dot, ans_dot = np.zeros_like(sign), np.trace(solve(x, g), axis1=-1, axis2=-2) - return (sign, ans), (sign_dot, ans_dot) -slogdet.defjvp(_slogdet_jvp) +def _jvp_slogdet(g, ans, x): + jvp_sign = np.zeros(x.shape[:-2]) + jvp_logdet = np.trace(solve(x, g), axis1=-1, axis2=-2) + return jvp_sign, jvp_logdet +defjvp(slogdet, _jvp_slogdet) @_wraps(onp.linalg.det) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index cd29c118f5f0..aa11c5df9233 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -20,6 +20,7 @@ from .. import util from .. import lax from .. import api +from ..api import custom_transforms, defjvp from ..numpy import lax_numpy as jnp from ..numpy.lax_numpy import (_wraps, asarray, _reduction_dims, _constant_like, _promote_args_inexact) @@ -79,26 +80,21 @@ def erfinv(x): return lax.erf_inv(x) -@api.custom_jvp +@_wraps(osp_special.logit, update_doc=False) +@custom_transforms def logit(x): + x = asarray(x) return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x))) -def _logit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = logit(x) - t_out = lax.div(lax.mul(x, lax.sub(lax._const(x, 1), x))) - return ans, t_out -logit.defjvp(_logit_jvp) +defjvp(logit, lambda g, ans, x: g / (x * (1 - x))) -@api.custom_jvp +@_wraps(osp_special.expit, update_doc=False) +@custom_transforms def expit(x): - return 1 / (1 + lax.exp(-x)) -def _expit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = expit(x) - t_out = t * ans * (1 - ans) - return ans, t_out -expit.defjvp(_expit_jvp) + x = asarray(x) + one = lax._const(x, 1) + return lax.div(one, lax.add(one, lax.exp(lax.neg(x)))) +defjvp(expit, lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans)) @_wraps(osp_special.logsumexp) @@ -411,7 +407,7 @@ def _create_polynomial(var, coeffs): return x_nan_replaced -@partial(api.custom_jvp, nondiff_argnums=(1,)) +@custom_transforms def log_ndtr(x, series_order=3): r"""Log Normal distribution function. @@ -512,13 +508,7 @@ def log_ndtr(x, series_order=3): lax.log(_ndtr(lax.max(x, lower_segment))), _log_ndtr_lower(lax.min(x, lower_segment), series_order))) - -def _log_ndtr_jvp(series_order, primals, tangents): - (x,), (t,) = primals, tangents - ans = log_ndtr(x, series_order=series_order) - t_out = lax.mul(t, lax.exp(lax.sub(_norm_logpdf(x), ans))) - return ans, t_out -log_ndtr.defjvp(_log_ndtr_jvp) +defjvp(log_ndtr, lambda g, ans, x: lax.mul(g, lax.exp(lax.sub(_norm_logpdf(x), ans)))) def _log_ndtr_lower(x, series_order): """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""