Skip to content

Commit

Permalink
in custom_jvp/vjp stop_gradient on nondiff_argnums (#2804)
Browse files Browse the repository at this point in the history
fixes #2784
  • Loading branch information
mattjj authored Apr 23, 2020
1 parent 6b5e367 commit 8ccb907
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
20 changes: 18 additions & 2 deletions jax/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import itertools as it
import operator as op

import jax
from . import core
from . import linear_util as lu
from .tree_util import tree_flatten, tree_unflatten, tree_map, tree_multimap
Expand Down Expand Up @@ -82,6 +83,15 @@ def sum_tangents(x, *xs):
def zeros_like_pytree(x):
return tree_map(lambda _: zero, x)

def stop_gradient(x):
return tree_map(_stop_gradient, x)

def _stop_gradient(x):
if isinstance(x, core.Tracer) or core.valid_jaxtype(x):
return jax.lax.stop_gradient(x)
else:
return x


### JVPs

Expand Down Expand Up @@ -199,7 +209,10 @@ def __call__(self, *args, **kwargs):
raise AttributeError(msg.format(self.__name__))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
is_nondiff = [False] * len(args)
for i in self.nondiff_argnums: is_nondiff[i] = True
args = [stop_gradient(x) if b else x for b, x in zip(is_nondiff, args)]
dyn_argnums = [i for i, b in enumerate(is_nondiff) if not b]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
static_args = [args[i] for i in self.nondiff_argnums]
jvp = _add_args(lu.wrap_init(self.jvp), static_args, left=True)
Expand Down Expand Up @@ -436,7 +449,10 @@ def __call__(self, *args, **kwargs):
raise AttributeError(msg.format(self.__name__))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
is_nondiff = [False] * len(args)
for i in self.nondiff_argnums: is_nondiff[i] = True
args = [stop_gradient(x) if b else x for b, x in zip(is_nondiff, args)]
dyn_argnums = [i for i, b in enumerate(is_nondiff) if not b]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
static_args = [args[i] for i in self.nondiff_argnums]
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
Expand Down
22 changes: 22 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2853,6 +2853,28 @@ def g(f, x):

jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash

def test_nondiff_argnums_stop_gradient(self):
# https://github.com/google/jax/issues/2784
@partial(api.custom_vjp, nondiff_argnums=(0, 1))
def _clip_gradient(lo, hi, x):
return x # identity function

def clip_gradient_fwd(lo, hi, x):
# return x, None
return x, (hi, )

def clip_gradient_bwd(lo, hi, _, g):
return (np.clip(g, lo, hi),)

_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

def clip_gradient(x):
lo = -1
hi = x + 1 # causes things to break
return _clip_gradient(lo, hi, x)

jax.grad(clip_gradient)(1.) # doesn't crash


class DeprecatedCustomTransformsTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 8ccb907

Please sign in to comment.