Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add custom_jvp / vjp, delete custom_transforms
Browse files Browse the repository at this point in the history
mattjj committed Feb 23, 2020
1 parent 8c3e3b2 commit b11c0b2
Showing 20 changed files with 1,294 additions and 1,493 deletions.
5 changes: 0 additions & 5 deletions docs/jax.rst
Original file line number Diff line number Diff line change
@@ -39,11 +39,6 @@ Automatic differentiation
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: vjp
.. autofunction:: custom_transforms
.. autofunction:: defjvp
.. autofunction:: defjvp_all
.. autofunction:: defvjp
.. autofunction:: defvjp_all
.. autofunction:: custom_gradient


2 changes: 1 addition & 1 deletion docs/notebooks/How_JAX_primitives_work.ipynb
Original file line number Diff line number Diff line change
@@ -1503,7 +1503,7 @@
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n",
" lambda: _flatten_axes(out_tree(), out_axes))\n",
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n",
" out_vals, out_dims = batch_fun(fun, in_vals, in_dims)\n",
" out_vals, out_dims = batch2(fun, in_vals, in_dims)\n",
"NotImplementedError: Batching rule for 'multiply_add' not implemented\n"
],
"name": "stderr"
576 changes: 13 additions & 563 deletions jax/api.py

Large diffs are not rendered by default.

29 changes: 28 additions & 1 deletion jax/api_util.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,8 @@
from .tree_util import (build_tree, tree_flatten, tree_unflatten,
treedef_is_leaf)
from . import linear_util as lu
from .util import safe_map, unzip2, partial, curry
from .util import safe_map, unzip2, partial, curry, WrapHashably, Hashable
from .core import unit

map = safe_map

@@ -70,3 +71,29 @@ def flatten_fun_nokwargs2(in_tree, *args_flat):
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)

def argnums_partial(f, dyn_argnums, args):
if isinstance(dyn_argnums, int):
dyn_argnums = (dyn_argnums,)
else:
dyn_argnums = tuple(dyn_argnums)
fixed_args = tuple([unit if i in dyn_argnums else wrap_hashably(arg)
for i, arg in enumerate(args)])
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args

def wrap_hashably(arg):
try:
hash(arg)
except TypeError:
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
else:
return Hashable(arg)

@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
args = [None if arg is unit else arg.val for arg in fixed_args]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
ans = yield args, kwargs
yield ans
41 changes: 22 additions & 19 deletions jax/core.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
import types

from . import linear_util as lu
from .util import safe_zip, safe_map, partial, curry
from .util import safe_zip, safe_map, partial, curry, split_list
from .pprint_util import pp, vcat, hcat, pp_kv_pairs

# TODO(dougalm): the trace cache breaks the leak detector. Consisder solving.
@@ -280,13 +280,6 @@ def __init__(self, master, sublevel):
self.level = master.level
self.sublevel = sublevel

def escaped_tracer_error(self, detail):
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
"through global state from a previously traced function.\n"
"The functions being transformed should not save traced values to "
"global state.\nDetails: {}.")
raise ValueError(msg.format(detail))

def full_raise(self, val):
if not isinstance(val, Tracer):
return self.pure(val)
@@ -298,44 +291,54 @@ def full_raise(self, val):
elif val._trace.sublevel < sublevel:
return self.sublift(val)
else:
self.escaped_tracer_error(
escaped_tracer_error(
"Can't lift sublevels {} to {}".format(val._trace.sublevel, sublevel))
elif val._trace.level < level:
if val._trace.sublevel > sublevel:
self.escaped_tracer_error(
escaped_tracer_error(
"Incompatible sublevel: {}, {}".format(val._trace, (level, sublevel)))
return self.lift(val)
elif val._trace.level > level:
self.escaped_tracer_error(
escaped_tracer_error(
"Can't lift level {} to {}".format(val, self))
else: # val._trace.level == self.level:
self.escaped_tracer_error("Different traces at same level: {}, {}".format(val, self))

escaped_tracer_error("Different traces at same level: {}, {}".format(val, self))

def pure(self, val):
assert False
raise NotImplementedError("must override")

def lift(self, tracer):
assert False
raise NotImplementedError("must override")

def sublift(self, tracer):
assert False
raise NotImplementedError("must override")

def process_primitive(self, primitive, tracers, params):
assert False, "Must override"
raise NotImplementedError("must override")

def __repr__(self):
return '{}(level={}/{})'.format(
self.__class__.__name__, self.level, self.sublevel)

def escaped_tracer_error(detail):
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
"through global state from a previously traced function.\n"
"The functions being transformed should not save traced values to "
"global state.\nDetails: {}.")
raise UnexpectedTracerError(msg.format(detail))

class UnexpectedTracerError(Exception): pass

class Tracer(object):
__array_priority__ = 1000
__slots__ = ['_trace', '__weakref__']

def __array__(self, *args, **kw):
raise Exception("Tracer can't be used with raw numpy functions. "
"You might have\n import numpy as np\ninstead of\n import jax.numpy as np")
"You might have\n"
" import numpy as np\n"
"instead of\n"
" import jax.numpy as np")

def __init__(self, trace):
self._trace = trace
@@ -348,7 +351,7 @@ def __len__(self):

@property
def aval(self):
assert False
raise NotImplementedError("must override")

def __neg__(self): return self.aval._neg(self)
def __pos__(self): return self.aval._pos(self)
Loading

0 comments on commit b11c0b2

Please sign in to comment.