diff --git a/.travis.yml b/.travis.yml
index 5bdc183678f2..8df31c498b19 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -44,10 +44,8 @@ script:
- if [ "$JAX_ONLY_DOCUMENTATION" = true ]; then
sphinx-build -b html -D nbsphinx_execute=always docs docs/build/html ;
elif [ "$JAX_ONLY_CHECK_TYPES" = true ]; then
- echo "===== Checking with mypy ====" &&
- time mypy --config-file=mypy.ini jax &&
- echo "===== Checking with pytype ====" &&
- time pytype jax ;
+ echo "===== Checking with mypy ===="
+ time mypy --config-file=mypy.ini jax ;
else
pytest -n 1 tests examples -W ignore ;
fi
diff --git a/README.md b/README.md
index 59d21f4ebb33..0f639cc4237f 100644
--- a/README.md
+++ b/README.md
@@ -136,7 +136,7 @@ print(grad(grad(grad(tanh)))(1.0))
For more advanced autodiff, you can use
[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
reverse-mode vector-Jacobian products and
-[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp) for
+[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
one another, and with other JAX transformations. Here's one way to compose those
to make a function that efficiently computes [full Hessian
diff --git a/design_notes/custom_derivatives.md b/design_notes/custom_derivatives.md
new file mode 100644
index 000000000000..41fb6df644d3
--- /dev/null
+++ b/design_notes/custom_derivatives.md
@@ -0,0 +1,472 @@
+# Custom JVP/VJP rules for JAX-transformable functions
+
+This is a design document, explaining some of the thinking behind the design and
+implementation of `jax.custom_jvp` and `jax.custom_vjp`. For user-oriented
+documentation, see [the tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html).
+
+There are two ways to define differentiation rules in JAX:
+1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation
+ rules for Python functions that are already JAX-transformable; and
+2. defining new `core.Primitive` instances along with all their transformation
+ rules, for example to call into functions from other systems like solvers,
+ simulators, or general numerical computing systems.
+
+This document is about #1 only.
+
+### Contents
+
+* [Goals](#goals)
+* [Non-goals](#non-goals)
+* [Main problem descriptions](#main-problem-descriptions)
+ * [The vmap-removes-custom-jvp semantics problem](#the-vmap-removes-custom-jvp-semantics-problem)
+ * [The Python flexibility problem](#the-python-flexibility-problem)
+* [Solution idea](#solution-idea)
+* [Implementation notes](#implementation-notes)
+
+## Goals
+
+We want **users** to customize the forward- and/or reverse-mode differentiation
+behavior of their code. This customization
+1. should have a _clear and consistent semantics_ in how it works and how it
+ composes with other JAX transformations; and
+2. should be _flexible_ in supporting use cases and workflows like in
+ [Autograd](https://github.com/hips/autograd) and
+ [PyTorch](https://pytorch.org), including cases involving differentiation of
+ Python control flow and workflows for NaN debugging.
+
+As **JAX developers** we want to write library functions, like
+[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
+and
+[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
+that are defined in terms of other primitives, but for the purposes of
+differentiation have primitive-like behavior in the sense that we want to define
+custom differentiation rules for them, which may be more numerically stable or
+performant. In particular, we don't want to have to specify `vmap` or `jit`
+rules for functions like `logit` and `expit`.
+
+As a stretch goal, we’d like to make JAX a great environment for power users
+looking to add custom differentiation rules for higher-order functions like
+`fixed_point`, `odeint`, etc.; this design doc won’t solve that problem, but we
+want to be confident we’re not going to preclude good solutions to that problem.
+
+That is, our primary goals are
+1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and
+2. allow Python in custom VJPs, e.g. to debug NaNs
+ ([#1275](https://github.com/google/jax/issues/1275)).
+
+Secondary goals are
+3. clean up and simplify user experience (symbolic zeros, kwargs, etc)
+4. make progress towards a world where users can easily add `fixed_point`,
+ `odeint`, `root`, etc.
+
+Overall, we want to close
+[#116](https://github.com/google/jax/issues/116),
+[#1097](https://github.com/google/jax/issues/1097),
+[#1249](https://github.com/google/jax/issues/1249),
+[#1275](https://github.com/google/jax/issues/1275),
+[#1366](https://github.com/google/jax/issues/1366),
+[#1723](https://github.com/google/jax/issues/1723),
+[#1670](https://github.com/google/jax/issues/1670),
+[#1875](https://github.com/google/jax/issues/1875),
+[#1938](https://github.com/google/jax/issues/1938),
+and replace the custom_transforms machinery (from
+[#636](https://github.com/google/jax/issues/636),
+[#818](https://github.com/google/jax/issues/818),
+and others).
+
+## Non-goals
+
+Here are objectives we're **not** aiming to achieve:
+1. The `custom_transforms` machinery aimed to provide a transformation-generic
+ mechanism for customizing behavior, in principle (though never really used in
+ practice) allowing users to customize rules for any transformation while
+ somehow inheriting the “transparent” behavior for others. **We are instead
+ only going to solve the customization problem for differentiation (JVP and
+ VJP, separately).** Differentiation is the only case actually requested, and
+ by specializing to differentiation we can reduce complexity and improve
+ flexibility. To control all rules one can just write a primitive.
+2. **We’re not going to prioritize mathematical aesthetics** over flexibility
+ and clarity on the user side, and simplicity on the implementation side. In
+ particular, while the custom VJP signature `a -> (b, CT b --o CT a)` is
+ mathematically pleasing, if it’s hard to implement in a Python mechanism
+ because of the closure in the return type, we’re fine doing something that
+ handles residuals more explicitly.
+3. **Serialization support**, of the form where the staged-out serialized
+ program representation can be loaded and further JAX-transformed as opposed
+ to just evaluated, is currently out of scope for these custom JVP/VJP
+ transformation rules. Serialization may be useful not only for researchers
+ who want to save some representation of their computation (and transform it
+ after loading it), but also for future considerations like having jaxpr
+ transformations implemented outside Python, or having jaxprs as an MLIR
+ dialect. By defining this as a non-goal for the purpose of this design, we
+ have fewer constraints on where we can stash Python callables.
+
+## Main problem descriptions
+
+### The vmap-removes-custom-jvp semantics problem
+
+The vmap-removes-custom-jvp semantics problem is that vmap does not compose
+properly with differentiation of functions with `custom_transforms` rules:
+
+```python
+# old custom_transforms api to be replaced
+@jax.custom_transforms
+def f(x):
+ return 2. * x
+
+# f_vjp :: a -> (b, CT b --o CT a)
+def f_vjp(x):
+ return f(x), lambda g: 3. * x # 3 instead of 2
+
+jax.defvjp_all(f, f_vjp)
+
+grad(f)(1.) # 3.
+vmap(grad(f))(np.ones(4)) # [3., 3., 3., 3.]
+grad(lambda x: vmap(f)(x).sum())(np.ones(4)) # [2., 2., 2., 2.]
+```
+
+The last grad-of-vmap line has an unexpected result! In general, applying
+`vmap`, or really any non-differentiation transformation, has the effect of
+removing the custom differentiation rule. (Applying `jvp` causes a failure when
+a custom VJP rule is defined.)
+
+The problem exists because transformations are like rewrites, and the `vmap`
+transformation effectively rewrites the function to no longer call the
+newly-introduced primitive for which there is a custom rule (and hence `grad`
+then doesn’t produce the custom rule’s result). In more detail, the
+`custom_transforms` machinery sets things up so that evaluating `f(x)` applies
+the function
+
+```
+{ lambda ; ; a.
+ let b = f_primitive a
+ in [b] }
+```
+
+where `f_primitive` is a new primitive (introduced for every `custom_transforms`
+function and in fact for every call of the function) to which the custom VJP
+rule is associated. When we evaluate `grad(f)(x)`, the differentiation machinery
+encounters `f_primitive` and processes it with the custom rule.
+
+However, because `f_primitive` is _transparent_ to `vmap`, in the sense that
+`vmap` operates on (effectively by inlining) the definition of `f_primitive`,
+the function `vmap(f)` is effectively
+
+```
+{ lambda ; ; a.
+ let b = mul 2. a
+ in [b] }
+```
+
+In words, `vmap` rewrites the function in terms of its underlying primitives and
+their transformation rules, removing `f_primitive` entirely.
+
+
+More generally, **because `vmap(f)` has semantics defined in terms of calls to
+f, it is semantically inconsistent to remove the custom derivative rule**. That
+is, since we define
+
+```python
+vmap(f)(xs) == np.stack([f(x) for x in xs])
+```
+
+we must have
+
+```python
+jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))
+```
+
+yet this property is not observed when `f` has a custom derivative rule defined,
+as the custom derivative rule is used in the right-hand version but not the
+left-hand one.
+
+This issue isn’t specific to `vmap`; it applies to all transformations for which
+the semantics of transforming a function `f` are defined in terms of calls to
+the function `f`, rather than rewriting it into another function. The `mask`
+transformation also falls into this class. Differentiation transforms and the
+hypothetical all-unary-functions-become-cosine transform are not in this class.
+
+(The interaction between additional custom rules, like custom `vmap` rules, is
+likely to get even more complex, suggesting the problem framing of
+`custom_transforms` is too broad.)
+
+### The Python flexibility problem
+
+In JAX, as in [Autograd](https://github.com/hips/autograd) and
+[PyTorch](https://pytorch.org) but not TF1, differentiation of a Python function
+is performed while the function is being executed and traced. This behavior
+delights users for a few reasons.
+
+**First and most importantly, it enables pdb-based workflows, e.g. for
+inspecting numerics or catching NaNs.** That is, users can employ the standard
+Python debugger and other Python-native tools to debug their code, even being
+able to inspect runtime values to understand numerical behavior on examples and
+to catch fundamentally runtime errors like NaNs. In fact, just while working on
+the PR corresponding to this design, especially on the `odeint` primitive, I
+used runtime value inspection to debug issues many times, increasing my
+confidence that this is a key user workflow in Python. One especially handy
+trick, which I’ve used in both JAX and Autograd many times, is the ability to
+insert a debugger breakpoint in a custom VJP rule to enter a debugger at a
+specific point in the backward pass.
+
+**Second, it allows differentiation of Python native control flow.** We’re not
+sure how often this is used in practice in finalized software artifacts, but
+when users first poke around JAX or Autograd they’re often impressed by this
+freedom. There’s a reason we include it at the top of our JAX and Autograd
+READMEs, slide decks, and demos. Ceding this capability would be a step backward
+from Autograd. We want JAX to have the best automatic differentiation.
+
+However, the `custom_transforms` machinery does not provide this Python-support
+flexibility. That is, because it’s implemented in terms of up-front jaxpr
+formation from the Python code for both the user function and custom
+differentiation rules, code like this leads to an abstract value tracing error:
+
+```python
+# old custom_transforms api to be replaced
+@jax.custom_transforms
+def f(x):
+ if x > 0:
+ return x
+ else:
+ return 0.
+
+def f_vjp(x):
+ return ...
+
+jax.defvjp_all(f, f_vjp)
+
+grad(f)(1.) # Error!
+```
+
+## Solution idea
+
+The main idea is that **[dougalm@](https://github.com/dougalm) already solved
+these problems with `core.call`**. That is, we can frame the task of specifying
+a custom JVP rule for a user function in terms of a new Python-level call
+primitive (not to be added to the jaxpr language; see below). This new call
+primitive has a user Python function associated with it just like `core.call`,
+but additionally has a second Python callable representing the JVP rule. Let’s
+refer to this new call primitive as `custom_jvp_call`.
+
+Transformations like `vmap` interact with `custom_jvp_call` as with `core.call`:
+they effectively pass right through it and are applied to the underlying Python
+callables. Schematically, writing in terms of curried versions of the primitives
+for convenience, analogously to how `vmap` interacts with `core.call` by
+applying to the function to be called:
+
+```python
+vmap(call(f)) == call(vmap(f))
+```
+
+for the new primitive `custom_jvp_call` we simply apply `vmap` to the two
+functions it entails:
+
+```python
+vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))
+```
+
+This behavior means we’ve solved the [vmap-removes-custom-jvp semantics
+problem](the-vmap-removes-custom-jvp-semantics-problem).
+
+The `jvp` transformation interacts as one might expect: it just calls `f_jvp`,
+
+
+```python
+jvp(call(f)) == call(jvp(f))
+
+jvp(custom_jvp_call(f, f_jvp)) == f_jvp
+```
+
+Because `custom_jvp_call` acts like `core.call` (and not like `xla.xla_call`) in
+that it doesn’t raise the abstraction level of its inputs (because it’s not
+delaying anything or staging anything out), it means we’ve solved [the Python
+flexibility problem](the-python-flexibility-problem): there are no constraints
+on the user Python function (above the usual functional programming constraints
+required by `jvp` or `vjp`).
+
+What about evaluation and compilation? These are two ways to “exit” the JAX
+system, in the sense that no additional transformations can be applied after
+these steps. As a result, their rules are trivial:
+
+```python
+eval(call(f)) == eval(f)
+jit(call(f)) == hlo_call(jit(f))
+
+eval(custom_jvp_call(f, f_jvp)) == eval(f)
+jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))
+```
+
+In words, if a JVP rule hasn’t already rewritten `custom_jvp_call(f, f_jvp)`
+into `f_jvp`, when we get to the point of evaluation with `eval` or staging out
+to XLA with `jit`, differentiation is never going to be applied, so we just
+ignore `f_jvp` and behave just like `core.call`. However, due to the wrinkle
+discussed next, the partial eval rule for `custom_jvp_call` must be a bit more
+complex, since partial evaluation isn’t just used to stage out to XLA with
+`jit`.
+
+The only remaining wrinkle has to do with “initial-style” jaxpr-forming
+primitives, like `lax.scan`, and their transformation rules. These represent a
+different kind of “staging out to a jaxpr” than that for compilation because we
+can perform additional transformations on the staged-out jaxpr. That is, when
+`lax.scan` forms a jaxpr, it does not exit the transformation system, since when
+we apply a jvp or vmap to a `lax.scan` we need to apply it to the function
+represented by the jaxpr.
+
+Another way to state the wrinkle is that initial-style primitives like `lax.scan`
+rely on the ability to round-trip to a jaxpr and back to a Python callable while
+preserving semantics. That must mean preserving custom differentiation rule
+semantics too.
+
+The solution is for the partial evaluation rule for `custom_jvp_call` to stage
+out an initial-style call-like primitive that can be still be processed
+correctly by `eval`, `jit`, `jvp` and/or `vmap` transformations. That means a
+staged-out call-like primitive that carries with it enough information about `f`
+and `f_jvp` to support all these transformations. We refer to this additional
+primitive as `custom_jvp_call_jaxpr`. It is similar to `custom_jvp_call` except
+it’s parameterized by a jaxpr for the primal function f rather than a Python
+callable. The jaxpr for `f` is formed up-front before binding the primitive,
+similar to other initial-style primitives.
+
+(Three footnotes. First, we could refer to both the Python trace-time primitive
+`custom_jvp_call`, which takes a wrapped Python callable as an argument, and the
+jaxpr language primitive `custom_jvp_call_jaxpr`, which has a jaxpr as a
+parameter, as simply "`custom_jvp_call`", analogously to how we refer to both
+versions of `xla_call` as just "`xla_call`", but here we chose to use different
+names to make the distinction more explicit. Second, for implementation
+simplicity, both `custom_jvp_call` and `custom_jvp_call_jaxpr` have partial eval
+rules that don’t do any nontrivial partial evaluation and instead stage
+everything out. That doesn’t constrain automatic differentiation because
+`custom_jvp_call_jaxpr`'s JVP rule doesn’t itself bind a call primitive but
+instead just invokes the custom JVP rule callable. Third, we don’t form a jaxpr
+for the JVP rule callable up-front, and instead keep it as a Python callable, to
+avoid a recursion problem: in the common case that the JVP rule itself calls the
+underlying custom-JVP function, we can’t trace the JVP rule up-front without
+getting an infinite recursion. By not forming a jaxpr, we’re solving this in the
+same way we always do: rules are Python callbacks invoked when a transformation
+is applied, not part of the primitive, and though the rule here is associated
+directly with the primitive, rather than being in a global dict, that’s just an
+implementation detail.)
+
+If we gave up on [the Python flexibility
+problem](the-python-flexibility-problem), we could get away with only having
+`custom_jvp_call_jaxpr` and not having the separate Python-level primitive
+`custom_jvp_call`. One way to view the relationship between the two primitives
+is in this schematic:
+
+
+
![]()
+
+
+## API
+
+The custom JVP for an `a -> b` function is specified with an `(a, Ta) -> (b, T
+b)` function:
+
+```python
+# f :: a -> b
+@jax.custom_jvp
+def f(x):
+ return np.sin(x)
+
+# f_jvp :: (a, T a) -> (b, T b)
+def f_jvp(primals, tangents):
+ x, = primals
+ t, = tangents
+ return f(x), np.cos(x) * t
+
+f.defjvp(f_jvp)
+```
+
+(Interesting autodiff aside: for the rule to apply to higher-order
+differentiation, one must call `f` in the body of `f_jvp`; that precludes some
+kinds of work sharing between the internals of `f` and the tangent calculation.)
+
+The custom VJP for an `a -> b` function is specified with an `a -> (b, c)` forward
+pass function paired with a `(c, CT b) -> CT` a backward pass function:
+
+```python
+# f :: a -> b
+@jax.custom_vjp
+def f(x):
+ return np.sin(x)
+
+# f_fwd :: a -> (b, c)
+def f_fwd(x):
+ return f(x), np.cos(x)
+
+# f_bwd :: (c, CT b) -> CT a
+def f_bwd(cos_x, g):
+ return (cos_x * g,)
+
+f.defvjp(f_fwd, f_bwd)
+```
+
+The signature `a -> (b, CT b --o CT a)` is more aesthetically pleasing, but
+supporting it would make the implementation more complex and might require
+compromising expressibility desiderata. The basic reason that Python callables
+are opaque (unless we trace them to a jaxpr eagerly, which places expressiveness
+constraints), and in this case we may be returning a callable with `vmap` tracers
+inside its closure that we need to know about during the forward pass.
+
+We could add convenience wrappers, for example to define the JVP rule for a
+single argument at a time (like we do internally for primitives). But because
+this proposal is complicated enough as it is, I decided against convenience
+layers; let’s keep things minimal for now.
+
+There are some other bells and whistles to the API:
+* Inputs and output types `a`, `b`, and `c` can be arbitrary pytrees of
+ jaxtypes.
+* Passing arguments by name (keyword arguments) is supported when they can be
+ resolved to positions using the `inspect` module. This is a bit of an experiment
+ with Python 3’s improved ability to programmatically inspect argument
+ signatures. I believe it is sound but not complete, which is a fine place to be.
+ (See also [#2069](https://github.com/google/jax/issues/2069).)
+* Arguments can be marked non-differentiable using `nondiff_argnums`, and as with
+ `jit`’s `static_argnums` these arguments don’t have to be JAX types. We need to
+ set a convention for how these arguments are passed to the rules. For a primal
+ function with type signature `(d, a) -> b` where `d` represents the
+ non-differentiable type, the JVP rule’s signature is `(a, T a, d) -> T b` and
+ the VJP rule’s reverse component signature is `(d, c, CT b) -> CT a`. That is,
+ the non-differentiable arguments are passed in order after `primals` and
+ `tangents` for a custom JVP rule, and passed in order preceding the residuals in
+ a custom VJP rule’s reverse function.
+
+## Implementation notes
+
+* Updated `jax.experimental.odeint`
+ * Since `odeint` is a pretty complex user of a custom VJP rule, in addition to
+ just updating it to work at all, I wanted to revise it to be a canonical
+ user of the new custom VJP API as a way to test that the API was a good one.
+ * Along the way I made other improvements to the `odeint` implementation:
+ * remove raveling/unraveling boilerplate
+ * make use of `lax.scan` to remove the index-update logic
+ * speed up by 20+% on the simple pendulum benchmark
+* Added a custom bind method on each transform for the custom derivative call
+ primitives, `custom_jvp_call` and `custom_vjp_call`. It’s like
+ `core.call_bind`, except we don’t process env traces: those are just errors.
+* Added `custom_lin` primitive, which gets staged out into linear jaxprs to be
+ transposed when using a custom VJP rule.
+ * Because our reverse-mode autodiff is decomposed into linearization, partial
+ evaluation, and transposition, our custom VJP rules are processed in two
+ separate steps: one during linearization and one during transposition.
+ * The linearization step, i.e. the JVP rule for `custom_vjp_call`, applies
+ `custom_lin` to the tangent values; `custom_lin` carries with it the user’s
+ custom backward-pass function, and as a primitive it only has a transpose
+ rule.
+ * This mechanism is described more in [#636](https://github.com/google/jax/issues/636).
+* Added a variant of `transformation_with_aux` called
+ `transformation_with_equal_aux` to allow repeated stores of equal values due
+ to running the same function multiple times.
+ * The custom rules functions, like `f_jvp` and `f_fwd`/`f_bwd` in the examples
+ above, are not “linear” in the sense of linear_util.py when used in
+ `custom_jvp_call_jaxpr` and `custom_vjp_call_jaxpr`, respectively. They may be
+ invoked multiple times as a jaxpr is processed in initial style. It’s
+ usually fine for rules to be invoked multiple times, but these rules must
+ plumb aux data out to the api.py-level caller, namely output pytree aux
+ data.
+ * (Recall from a footnote above that we can’t solve this by forming jaxprs for
+ the rules up-front because that can lead to infinite recursion.)
+
+
diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst
index 5086fedf2e99..a029eea4931f 100644
--- a/docs/CHANGELOG.rst
+++ b/docs/CHANGELOG.rst
@@ -77,7 +77,7 @@ jax 0.1.59 (February 11, 2020)
* Simplified :py:class:`Jaxpr` by removing the ``Jaxpr.freevars`` and
``Jaxpr.bound_subjaxprs``. The call primitives (``xla_call``, ``xla_pmap``,
``sharded_call``, and ``remat_call``) get a new parameter ``call_jaxpr`` with a
- fully-closed (no ``constvars``) JAXPR. Also, added a new field ``call_primitive``
+ fully-closed (no ``constvars``) jaxpr. Also, added a new field ``call_primitive``
to primitives.
* New features:
diff --git a/docs/index.rst b/docs/index.rst
index d148479c0d23..0966b98fca26 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -13,7 +13,8 @@ For an introduction to JAX, start at the
notebooks/quickstart
notebooks/autodiff_cookbook
- Training a Simple Neural Network, with PyTorch Data Loading
+ notebooks/vmapped_log_probs
+ Training a Simple Neural Network, with Tensorflow Datasets Data Loading
.. toctree::
@@ -21,14 +22,11 @@ For an introduction to JAX, start at the
:caption: Advanced JAX Tutorials
notebooks/Common_Gotchas_in_JAX
- notebooks/XLA_in_Python
+ notebooks/Custom_derivative_rules_for_Python_code
notebooks/JAX_pytrees
+ notebooks/XLA_in_Python
notebooks/How_JAX_primitives_work
notebooks/Writing_custom_interpreters_in_Jax.ipynb
- Training a Simple Neural Network, with Tensorflow Datasets Data Loading
- notebooks/maml
- notebooks/score_matching
- notebooks/vmapped_log_probs
.. toctree::
:maxdepth: 1
diff --git a/docs/jax.rst b/docs/jax.rst
index 0d3f53c6f6cc..c7a876059035 100644
--- a/docs/jax.rst
+++ b/docs/jax.rst
@@ -41,11 +41,6 @@ Automatic differentiation
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: vjp
-.. autofunction:: custom_transforms
-.. autofunction:: defjvp
-.. autofunction:: defjvp_all
-.. autofunction:: defvjp
-.. autofunction:: defvjp_all
.. autofunction:: custom_gradient
diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst
index b189ed507c60..034e776761a8 100644
--- a/docs/jaxpr.rst
+++ b/docs/jaxpr.rst
@@ -1,4 +1,4 @@
-Understanding JAXPR
+Understanding jaxprs
====================
Updated: February 14, 2020 (for commit 9e6fe64).
@@ -8,29 +8,29 @@ Updated: February 14, 2020 (for commit 9e6fe64).
Conceptually, one can think of JAX transformations as first tracing the Python
function to be transformed into a small and well-behaved intermediate form,
-the JAXPR, that is then transformed accordingly, and ultimately compiled and executed.
+the jaxpr, that is then transformed accordingly, and ultimately compiled and executed.
One of the reasons JAX can pack so much power into such a small software package
is that it starts with a familiar and flexible programming interface (Python with NumPy)
and it uses the actual Python interpreter to do most of the heavy lifting to distill the
essence of the computation into a simple statically-typed expression language
-with limited higher-order features: the JAXPR language.
+with limited higher-order features: the jaxpr language.
Not all Python programs can be processed this way, but it turns out that many
scientific computing and machine learning programs do have this property.
Before we proceed, it is important to point out that not all JAX transformations
-materialize a JAXPR as described above; some, e.g., differentiation,
+materialize a jaxpr as described above; some, e.g., differentiation,
will apply transformations incrementally during tracing.
Nevertheless, if one wants to understand how JAX works internally, or to
-make use of the result of JAX tracing, it is useful to understand JAXPR.
+make use of the result of JAX tracing, it is useful to understand jaxpr.
-A JAXPR instance represents a function with one of more typed parameters (input variables)
+A jaxpr instance represents a function with one of more typed parameters (input variables)
and one or more typed results. The results depend only on the input
variables; there are no free variables captured from enclosing scopes.
The inputs and outputs have types, which in JAX are represented as abstract
-values. There are two related representations in the code for JAXPRs. The main
+values. There are two related representations in the code for jaxprs. The main
one is :py:class:`jax.core.TypedJaxpr` and is what you obtain when you
-use :py:func:`jax.make_jaxpr` to inspect JAXPRs. It has the following
+use :py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following
fields:
* ``jaxpr``: is the actual computation content of the actual function (described below).
@@ -49,12 +49,12 @@ The most interesting part of the TypedJaxpr is the actual execution content,
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
grammar::
- JAXPR ::= { lambda Var* ; Var+.
+ jaxpr ::= { lambda Var* ; Var+.
let Eqn*
in [Expr+] }
where:
- * The parameter of the JAXPR are shown as two lists of variables separated by
+ * The parameter of the jaxpr are shown as two lists of variables separated by
``;``. The first set of variables are the ones that have been introduced
to stand for constants that have been hoisted out. These are called the
`constvars`. The second list of variables are the real input variables.
@@ -62,7 +62,7 @@ where:
intermediate expressions. Each equation defines one or more variables as the
result of applying a primitive on some atomic expressions. Each equation uses only
input variables and intermediate variables defined by previous equations.
- * ``Expr+``: is a list of output atomic expressions for the JAXPR.
+ * ``Expr+``: is a list of output atomic expressions for the jaxpr.
Equations are printed as follows::
@@ -79,14 +79,14 @@ where:
square brackets. Each parameter is shown as ``Name = Value``.
-Most JAXPR primitives are first-order (they take just one or more Expr as arguments)::
+Most jaxpr primitives are first-order (they take just one or more Expr as arguments)::
Primitive := add | sub | sin | mul | ...
-The JAXPR primitives are documented in the :py:mod:`jax.lax` module.
+The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
-For example, here is the JAXPR produced for the function ``func1`` below::
+For example, here is the jaxpr produced for the function ``func1`` below::
from jax import numpy as jnp
def func1(first, second):
@@ -110,12 +110,12 @@ The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``,
addition to the operand ``e``.
Note that JAX traces through Python-level control-flow and higher-order functions
-when it extracts the JAXPR. This means that just because a Python program contains
-functions and control-flow, the resulting JAXPR does not have
+when it extracts the jaxpr. This means that just because a Python program contains
+functions and control-flow, the resulting jaxpr does not have
to contain control-flow or higher-order features.
For example, when tracing the function ``func3`` JAX will inline the call to
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
-JAXPR as before::
+jaxpr as before::
def func2(inner, first, second):
temp = first + inner(second) * 3.
@@ -142,13 +142,13 @@ JAXPR as before::
Handling PyTrees
----------------
-In JAXPR there are no tuple types; instead primitives take multiple inputs
+In jaxpr there are no tuple types; instead primitives take multiple inputs
and produce multiple outputs. When processing a function that has structured
-inputs or outputs, JAX will flatten those and in JAXPR they will appear as lists
+inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists
of inputs and outputs. For more details, please see the documentation for
PyTrees (:doc:`notebooks/JAX_pytrees`).
-For example, the following code produces an identical JAXPR to what we saw
+For example, the following code produces an identical jaxpr to what we saw
before (with two input vars, one for each element of the input tuple)::
@@ -184,7 +184,7 @@ from the Python program, or from constant-folding. For example, the function
print(api.make_jaxpr(func6)(jnp.ones(8)))
-JAX produces the following JAXPR::
+JAX produces the following jaxpr::
{ lambda b d a.
let c = add a b
@@ -196,13 +196,13 @@ When tracing ``func6``, the function ``func5`` is invoked with a constant value
``jnp.sin(second) * 3.`` is constant-folded.
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
-JAXPR notation what constants the constant variables stand for.
+jaxpr notation what constants the constant variables stand for.
Higher-order primitives
-----------------------
-JAXPR includes several higher-order primitives. They are more complicated because
-they include sub-JAXPRs.
+jaxpr includes several higher-order primitives. They are more complicated because
+they include sub-jaxprs.
Cond
^^^^
@@ -238,7 +238,7 @@ For example::
The cond primitive has a number of parameters:
- * `true_jaxpr` and `false_jaxpr` are JAXPRs that correspond to the true
+ * `true_jaxpr` and `false_jaxpr` are jaxprs that correspond to the true
and false branch functionals. In this example, those functionals take each
one input variable, corresponding to ``xtrue`` and ``xfalse`` respectively.
* `linear` is a tuple of booleans that is used internally by the auto-differentiation
@@ -273,7 +273,7 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`::
in a } ] d b c e b c
in f }
-The top-level JAXPR has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
+The top-level jaxpr has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
body of the ``false_jaxpr``) and three input variables ``a b c`` (corresponding to ``arg1``
and the two elements of ``arg2``; note that ``arg2`` has been flattened).
The ``true_jaxpr`` has two input variables (corresponding to the two elements of ``arg2``
@@ -286,10 +286,10 @@ The actual operands to the cond primitive are: ``d b c e b c``, which correspond
* 1 operand for the predicate,
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are input vars,
- corresponding to ``arg2`` for the top-level JAXPR,
- * 1 constant for ``false_jaxpr``, i.e., ``e``, which is a consvar for the top-level JAXPR,
+ corresponding to ``arg2`` for the top-level jaxpr,
+ * 1 constant for ``false_jaxpr``, i.e., ``e``, which is a consvar for the top-level jaxpr,
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are the input vars
- corresponding to ``arg2`` for the top-level JAXPR.
+ corresponding to ``arg2`` for the top-level jaxpr.
While
^^^^^
@@ -328,7 +328,7 @@ For example, here is an example fori loop::
cond_nconsts=0 ] c a 0 b e
in h }
-The top-level JAXPR has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
+The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
There are also two input variables (``a`` corresponding to ``arg`` and ``b`` corresponding
to ``n``).
@@ -386,7 +386,7 @@ For the example consider the function ``func11`` below::
num_consts=1 ] b 0.0 a * c
in (d, e) }
-The top-level JAXPR has one constvar ``c`` corresponding to the ``ones`` constant,
+The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
and two input variables corresponding to the arguments ``arr`` and ``extra``.
The body of the scan has 5 input variables, of which:
@@ -413,7 +413,7 @@ XLA_call
^^^^^^^^
The call primitive arises from JIT compilation, and it encapsulates
-a sub-JAXPR along with parameters the specify the backend and the device the
+a sub-jaxpr along with parameters the specify the backend and the device the
computation should run. For example::
def func12(arg):
@@ -438,7 +438,7 @@ computation should run. For example::
The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
The ``xla_call`` primitive stands for a call to the jitted ``inner`` function.
-The primitive has the function body in the ``call_jaxpr`` parameter, a JAXPR
+The primitive has the function body in the ``call_jaxpr`` parameter, a jaxpr
with 3 input parameters:
* ``c`` is a constvar and stands for the ``ones`` constant,
diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
new file mode 100644
index 000000000000..9de88eda0b2b
--- /dev/null
+++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
@@ -0,0 +1,2344 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "Custom derivative rules for Python code.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LqiaKasFjH82",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Custom derivative rules for JAX-transformable Python functions\n",
+ "\n",
+ "*mattjj@ Mar 19 2020*\n",
+ "\n",
+ "There are two ways to define differentiation rules in JAX:\n",
+ "\n",
+ "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n",
+ "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n",
+ "\n",
+ "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n",
+ "\n",
+ "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp), and the mathematical meaning of JVPs and VJPs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9Fg3NFNY-2RY",
+ "colab_type": "text"
+ },
+ "source": [
+ "## TL;DR"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "zXic8tr--1PK",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import jax.numpy as np\n",
+ "from jax import custom_jvp\n",
+ "\n",
+ "@custom_jvp\n",
+ "def f(x, y):\n",
+ " return np.sin(x) * y\n",
+ "\n",
+ "@f.defjvp\n",
+ "def f_jvp(primals, tangents):\n",
+ " x, y = primals\n",
+ " x_dot, y_dot = tangents\n",
+ " primal_out = f(x, y)\n",
+ " tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot\n",
+ " return primal_out, tangent_out"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "RrNf588X_kJF",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 85
+ },
+ "outputId": "246fe70d-2348-4e3e-f58d-766ef16304bc"
+ },
+ "source": [
+ "from jax import jvp, grad\n",
+ "\n",
+ "print(f(2., 3.))\n",
+ "y, y_dot = jvp(f, (2., 3.), (1., 0.))\n",
+ "print(y)\n",
+ "print(y_dot)\n",
+ "print(grad(f)(2., 3.))"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "2.7278922\n",
+ "2.7278922\n",
+ "-1.2484405\n",
+ "-1.2484405\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "35ScHqhrBwPh",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from jax import custom_vjp\n",
+ "\n",
+ "@custom_vjp\n",
+ "def f(x, y):\n",
+ " return np.sin(x) * y\n",
+ "\n",
+ "def f_fwd(x, y):\n",
+ " return f(x, y), (np.cos(x), np.sin(x), y)\n",
+ "\n",
+ "def f_bwd(res, g):\n",
+ " cos_x, sin_x, y = res\n",
+ " return (cos_x * g * y, -sin_x * g)\n",
+ "\n",
+ "f.defvjp(f_fwd, f_bwd)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "HpSozxKUCXgp",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "889eb046-19bf-49c7-d8a9-c037b8325fe9"
+ },
+ "source": [
+ "print(grad(f)(2., 3.))"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "-1.2484405\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "p5ypWA7XlZpu",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Example problems\n",
+ "\n",
+ "To get an idea of what problems `jax.custom_jvp` and `jax.custom_vjp` are meant to solve, let's go over a few examples. A more thorough introduction to the `jax.custom_jvp` and `jax.custom_vjp` APIs is in [the next section](#scrollTo=Dr0aNkBslfQf).\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AR02eyd1GQhC",
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "### Numerical stability\n",
+ "\n",
+ "One application of `jax.custom_jvp` is to improve the numerical stability of differentiation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GksPXslaGPaW",
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "Say we want to write a function called `log1pexp`, which computes $x \\mapsto \\log ( 1 + e^x )$. We can write that using `jax.numpy`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "6lWbTvs40ET-",
+ "colab_type": "code",
+ "outputId": "b7b9d021-5a34-42ee-cba6-90d3a3e1ee55",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "import jax.numpy as np\n",
+ "\n",
+ "def log1pexp(x):\n",
+ " return np.log(1. + np.exp(x))\n",
+ "\n",
+ "log1pexp(3.)"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "DeviceArray(3.0485873, dtype=float32)"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 6
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PL36r_cD0oE8",
+ "colab_type": "text"
+ },
+ "source": [
+ "Since it's written in terms of `jax.numpy`, it's JAX-transformable:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "XgtGKFld02UD",
+ "colab_type": "code",
+ "outputId": "06691c5d-a3c6-4632-a5c7-96eb7379c978",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ }
+ },
+ "source": [
+ "from jax import jit, grad, vmap\n",
+ "\n",
+ "print(jit(log1pexp)(3.))\n",
+ "print(jit(grad(log1pexp))(3.))\n",
+ "print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "3.0485873\n",
+ "0.95257413\n",
+ "[0.5 0.7310586 0.88079715]\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "o56Nr3V61PKS",
+ "colab_type": "text"
+ },
+ "source": [
+ "But there's a numerical stability problem lurking here:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "sVM6iwIO22sB",
+ "colab_type": "code",
+ "outputId": "39338c73-dd3a-4915-de26-8c634ba96375",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "print(grad(log1pexp)(100.))"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "nan\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Zu9sR2I73wuO",
+ "colab_type": "text"
+ },
+ "source": [
+ "That doesn't seem right! After all, the derivative of $x \\mapsto \\log (1 + e^x)$ is $x \\mapsto \\frac{e^x}{1 + e^x}$, and so for large values of $x$ we'd expect the value to be about 1.\n",
+ "\n",
+ "We can get a bit more insight into what's going on by looking at the jaxpr for the gradient computation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "dO6uZlYR4TVp",
+ "colab_type": "code",
+ "outputId": "3f573ed5-a7a6-49f8-bb03-a26abc4fd4ef",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 119
+ }
+ },
+ "source": [
+ "from jax import make_jaxpr\n",
+ "\n",
+ "make_jaxpr(grad(log1pexp))(100.)"
+ ],
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "{ lambda ; a.\n",
+ " let b = exp a\n",
+ " c = add b 1.0\n",
+ " d = div 1.0 c\n",
+ " e = mul d b\n",
+ " in (e,) }"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 9
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "52HR5EW26PEt",
+ "colab_type": "text"
+ },
+ "source": [
+ "Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and $\\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + np.exp(x))) * np.exp(x)` for large `x`, which effectively turns into `0. * np.inf`.\n",
+ "\n",
+ "Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n",
+ "\n",
+ "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n",
+ "\n",
+ "This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...).\n",
+ "\n",
+ "Here's a solution using `jax.custom_jvp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "XQt6MAuTJewG",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from jax import custom_jvp\n",
+ "\n",
+ "@custom_jvp\n",
+ "def log1pexp(x):\n",
+ " return np.log(1. + np.exp(x))\n",
+ "\n",
+ "@log1pexp.defjvp\n",
+ "def log1pexp_jvp(primals, tangents):\n",
+ " x, = primals\n",
+ " x_dot, = tangents\n",
+ " ans = log1pexp(x)\n",
+ " ans_dot = (1 - 1/(1 + np.exp(x))) * x_dot\n",
+ " return ans, ans_dot"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "rhiMHulfKBIF",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "36c4065b-5e77-4a4c-eedb-857f00a24cf7"
+ },
+ "source": [
+ "print(grad(log1pexp)(100.))"
+ ],
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "1.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "9cLDuAo6KGUu",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "c0f25e98-e4a0-47fe-ee74-a843f3697d55"
+ },
+ "source": [
+ "print(jit(log1pexp)(3.))\n",
+ "print(jit(grad(log1pexp))(3.))\n",
+ "print(vmap(jit(grad(log1pexp)))(np.arange(3.)))"
+ ],
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "3.0485873\n",
+ "0.95257413\n",
+ "[0.5 0.7310586 0.8807971]\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "V9tHAfrSF1N-",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Enforcing a differentiation convention\n",
+ "\n",
+ "A related application is to enforce a differentiation convention, perhaps at a boundary."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "l_6tdb-QGK-H",
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "Consider the function $f : \\mathbb{R}_+ \\mapsto \\mathbb{R}_+$ with $f(x) = \\frac{x}{1 + \\sqrt{x}}$, where we take $\\mathbb{R}_+ = [0, \\infty)$. We might implement $f$ as a program like this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "AfF5P7x_GaSe",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def f(x):\n",
+ " return x / (1 + np.sqrt(x))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BVcEkF3ZGgv1",
+ "colab_type": "text"
+ },
+ "source": [
+ "As a mathematical function on $\\mathbb{R}$ (the full real line), $f$ is not differentiable at zero (because the limit defining the derivative doesn't exist from the left). Correspondingly, autodiff produces a `nan` value:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "piI0u5MiHhQh",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 0
+ },
+ "outputId": "5e1ed5f0-271a-4c35-860a-d69cd069b439"
+ },
+ "source": [
+ "print(grad(f)(0.))"
+ ],
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "nan\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IP0H2b7ZHkzD",
+ "colab_type": "text"
+ },
+ "source": [
+ "But mathematically if we think of $f$ as a function on $\\mathbb{R}_+$ then it is differentiable at 0 [Rudin Definition 5.1, or Tao Definition 10.1.1]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function `grad(f)` to return at `0.0`, namely `1.0`, even though JAX's machinery for differentiation over reals doesn't produce it.\n",
+ "\n",
+ "We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function $x \\mapsto \\frac{\\sqrt{x} + 2}{2(\\sqrt{x} + 1)^2}$ on $\\mathbb{R}_+$,"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ksHmCkcSKQJr",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@custom_jvp\n",
+ "def f(x):\n",
+ " return x / (1 + np.sqrt(x))\n",
+ "\n",
+ "@f.defjvp\n",
+ "def f_jvp(primals, tangents):\n",
+ " x, = primals\n",
+ " x_dot, = tangents\n",
+ " ans = f(x)\n",
+ " ans_dot = ((np.sqrt(x) + 2) / (2 * (np.sqrt(x) + 1)**2)) * x_dot\n",
+ " return ans, ans_dot"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Gsh9ZvMTKi1O",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 0
+ },
+ "outputId": "4ce69b63-166b-4e18-cf79-6550046b510a"
+ },
+ "source": [
+ "print(grad(f)(0.))"
+ ],
+ "execution_count": 16,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "1.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "CICQuI86WK4_"
+ },
+ "source": [
+ "### Python debugging\n",
+ "\n",
+ "Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cgxMjNTrGjJn",
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "When trying to track down the source of a `nan` runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation.\n",
+ "\n",
+ "We'll defer an example until the next section."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IC7tEcr1-Fc5",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Implicit function differentiation of iterative implementations\n",
+ "\n",
+ "This example gets pretty deep in the mathematical weeds!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "szAt97t80hew",
+ "colab_type": "text"
+ },
+ "source": [
+ "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n",
+ "\n",
+ "For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "2uA8X2izXH2b",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from jax.lax import while_loop\n",
+ "\n",
+ "def fixed_point(f, a, x_guess):\n",
+ " def cond_fun(carry):\n",
+ " x_prev, x = carry\n",
+ " return np.abs(x_prev - x) > 1e-6\n",
+ "\n",
+ " def body_fun(carry):\n",
+ " _, x = carry\n",
+ " return x, f(a, x)\n",
+ "\n",
+ " _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))\n",
+ " return x_star"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "p2xFQAte19sF",
+ "colab_type": "text"
+ },
+ "source": [
+ "This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \\mapsto x^*(a)$ that is implicity defined by equation $x = f(a, x)$.\n",
+ "\n",
+ "We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "rDDwM8bYYzRT",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def newton_sqrt(a):\n",
+ " update = lambda a, x: 0.5 * (x + a / x)\n",
+ " return fixed_point(update, a, a)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "42Ydd7_6aLXU",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "29833920-bea5-4853-d55b-10c66a27fd32"
+ },
+ "source": [
+ "print(newton_sqrt(2.))"
+ ],
+ "execution_count": 19,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "1.4142135\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-yFtYWH13QWm",
+ "colab_type": "text"
+ },
+ "source": [
+ "We can `vmap` or `jit` the function as well:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "t_YSXieT3Yyk",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "a6b00ffe-cba1-4177-c50d-0f1b55ceed31"
+ },
+ "source": [
+ "print(jit(vmap(newton_sqrt))(np.array([1., 2., 3., 4.])))"
+ ],
+ "execution_count": 20,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "[1. 1.4142135 1.7320508 2. ]\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "emwWIt3d3h1T",
+ "colab_type": "text"
+ },
+ "source": [
+ "We can't apply reverse-mode automatic differentiation because of the `while_loop`, but it turns out we wouldn't want to anyway: instead of differentiating through the implementation of `fixed_point` and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas's Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we're about to use. In essence, we linearize at the solution and solve those linear equations iteratively to compute the derivatives we want.\n",
+ "\n",
+ "Consider again the equation $x = f(a, x)$ and the function $x^*$. We want to evaluate vector-Jacobian products like $v^\\mathsf{T} \\mapsto v^\\mathsf{T} \\partial x^*(a_0)$.\n",
+ "\n",
+ "At least in an open neighborhood around the point $a_0$ at which we want to differentiate, let's assume that the equation $x^*(a) = f(a, x^*(a))$ holds for all $a$. Since the two sides are equal as functions of $a$, their derivatives must be equal as well, so let's differentiate both sides:\n",
+ "\n",
+ "$\\qquad \\partial x^*(a) = \\partial_0 f(a, x^*(a)) + \\partial_1 f(a, x^*(a)) \\partial x^*(a)$.\n",
+ "\n",
+ "Setting $A = \\partial_1 f(a_0, x^*(a_0))$ and $B = \\partial_0 f(a_0, x^*(a_0))$, we can write the quantity we're after more simply as\n",
+ "\n",
+ "$\\qquad \\partial x^*(a_0) = B + A \\partial x^*(a_0)$,\n",
+ "\n",
+ "or, by rearranging,\n",
+ "\n",
+ "$\\qquad \\partial x^*(a_0) = (I - A)^{-1} B$.\n",
+ "\n",
+ "That means we can evaluate vector-Jacobian products like\n",
+ "\n",
+ "$\\qquad v^\\mathsf{T} \\partial x^*(a_0) = v^\\mathsf{T} (I - A)^{-1} B = w^\\mathsf{T} B$,\n",
+ "\n",
+ "where $w^\\mathsf{T} = v^\\mathsf{T} (I - A)^{-1}$, or equivalently $w^\\mathsf{T} = v^\\mathsf{T} + w^\\mathsf{T} A$, or equivalently $w^\\mathsf{T}$ is the fixed point of the map $u^\\mathsf{T} \\mapsto v^\\mathsf{T} + u^\\mathsf{T} A$. That last characterization gives us a way to write the VJP for `fixed_point` in terms of a call to `fixed_point`! Moreover, after expanding $A$ and $B$ back out, we can see we need only to evaluate VJPs of $f$ at $(a_0, x^*(a_0))$.\n",
+ "\n",
+ "Here's the upshot:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "g4jo-xlvdiym",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "from jax import custom_vjp\n",
+ "from jax import vjp\n",
+ "\n",
+ "@partial(custom_vjp, nondiff_argnums=(0,))\n",
+ "def fixed_point(f, a, x_guess):\n",
+ " def cond_fun(carry):\n",
+ " x_prev, x = carry\n",
+ " return np.abs(x_prev - x) > 1e-6\n",
+ "\n",
+ " def body_fun(carry):\n",
+ " _, x = carry\n",
+ " return x, f(a, x)\n",
+ "\n",
+ " _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))\n",
+ " return x_star\n",
+ "\n",
+ "def fixed_point_fwd(f, a, x_init):\n",
+ " x_star = fixed_point(f, a, x_init)\n",
+ " return x_star, (a, x_star)\n",
+ "\n",
+ "def fixed_point_rev(f, res, x_star_bar):\n",
+ " a, x_star = res\n",
+ " _, vjp_a = vjp(lambda a: f(a, x_star), a)\n",
+ " a_bar, = vjp_a(fixed_point(partial(rev_iter, f),\n",
+ " (a, x_star, x_star_bar),\n",
+ " x_star_bar))\n",
+ " return a_bar, np.zeros_like(x_star)\n",
+ " \n",
+ "def rev_iter(f, packed, u):\n",
+ " a, x_star, x_star_bar = packed\n",
+ " _, vjp_x = vjp(lambda x: f(a, x), x_star)\n",
+ " return x_star_bar + vjp_x(u)[0]\n",
+ "\n",
+ "fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "iKzfT6d_mEoB",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "99482a53-f7b5-4715-ffcb-66fb0346a7b3"
+ },
+ "source": [
+ "print(newton_sqrt(2.))"
+ ],
+ "execution_count": 22,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "1.4142135\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Hmcpjr6gmtkO",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "f76cc1d0-de93-4e1c-c1f3-509d3f04df32"
+ },
+ "source": [
+ "print(grad(newton_sqrt)(2.))\n",
+ "print(grad(grad(newton_sqrt))(2.))"
+ ],
+ "execution_count": 23,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "0.35355335\n",
+ "-0.088388346\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DvVmlaPD7W-4",
+ "colab_type": "text"
+ },
+ "source": [
+ "We can check our answers by differentiating `np.sqrt`, which uses a totally different implementation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "jj_JnI9Pm4jg",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "bae61d39-5f13-452b-b700-897d1f9a38d3"
+ },
+ "source": [
+ "print(grad(np.sqrt)(2.))\n",
+ "print(grad(grad(np.sqrt))(2.))"
+ ],
+ "execution_count": 24,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "0.35355338\n",
+ "-0.08838835\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HowvqayEuy-H",
+ "colab_type": "text"
+ },
+ "source": [
+ "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. While other JAX mechanisms can handle closed-over transformation-traced values in the arguments to higher-order functions (as is done for the control flow primitives like `lax.cond`, `lax.scan`, and `lax.while_loop` itself), `jax.custom_vjp` used as above cannot. A `fixed_point` routine that used a bit more of JAX's internals could have a more convenient and robust API."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Dr0aNkBslfQf",
+ "colab_type": "text"
+ },
+ "source": [
+ "## Basic usage of `jax.custom_jvp` and `jax.custom_vjp` APIs\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MojTOg4tmQNT",
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules\n",
+ "\n",
+ "Here's a canonical basic example of using `jax.custom_jvp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "nVkhbIFAOGZk",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from jax import custom_jvp\n",
+ "import jax.numpy as np\n",
+ "\n",
+ "# f :: a -> b\n",
+ "@custom_jvp\n",
+ "def f(x):\n",
+ " return np.sin(x)\n",
+ "\n",
+ "# f_jvp :: (a, T a) -> (b, T b)\n",
+ "def f_jvp(primals, tangents):\n",
+ " x, = primals\n",
+ " t, = tangents\n",
+ " return f(x), np.cos(x) * t\n",
+ "\n",
+ "f.defjvp(f_jvp)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "fxhlECvW7Krj",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "3dc9640f-ae25-458d-854b-e60330a1812d"
+ },
+ "source": [
+ "from jax import jvp\n",
+ "\n",
+ "print(f(3.))\n",
+ "\n",
+ "y, y_dot = jvp(f, (3.,), (1.,))\n",
+ "print(y)\n",
+ "print(y_dot)"
+ ],
+ "execution_count": 26,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "0.14112\n",
+ "0.14112\n",
+ "-0.9899925\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JaoQVRzSQ9Qd",
+ "colab_type": "text"
+ },
+ "source": [
+ "In words, we start with a a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it a JVP rule function `f_jvp` that takes a pair of inputs representing the primal inputs of type `a` and the corresponding tangent inputs of type `T a`, and produces a pair of outputs representing the primal outputs of type `b` and tangent outputs of type `T b`. The tangent outputs should be a linear function of the tangent inputs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1xGky7yMOavq",
+ "colab_type": "text"
+ },
+ "source": [
+ "You can also use `f.defjvp` as a decorator, as in\n",
+ "\n",
+ "```python\n",
+ "@custom_jvp\n",
+ "def f(x):\n",
+ " ...\n",
+ "\n",
+ "@f.defjvp\n",
+ "def f_jvp(primals, tangents):\n",
+ " ...\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e9R-ppvdQIOC",
+ "colab_type": "text"
+ },
+ "source": [
+ "Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on `f`. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "hl9Io86pQD6s",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "3d0ce20f-823c-4cca-8f7b-b40cf3f0c3bd"
+ },
+ "source": [
+ "from jax import grad\n",
+ "\n",
+ "print(grad(f)(3.))\n",
+ "print(grad(grad(f))(3.))"
+ ],
+ "execution_count": 27,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "-0.9899925\n",
+ "-0.14112\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MRlKe5D90svj",
+ "colab_type": "text"
+ },
+ "source": [
+ "For automatic transposition to work, the JVP rule's output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GRu-0yg96lXE",
+ "colab_type": "text"
+ },
+ "source": [
+ "Multiple arguments work like this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "JFLXlXuq6pRf",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@custom_jvp\n",
+ "def f(x, y):\n",
+ " return x ** 2 * y\n",
+ "\n",
+ "@f.defjvp\n",
+ "def f_jvp(primals, tangents):\n",
+ " x, y = primals\n",
+ " x_dot, y_dot = tangents\n",
+ " primal_out = f(x, y)\n",
+ " tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot\n",
+ " return primal_out, tangent_out"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "QpKwA0oA8DfE",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "53871c47-509d-4f4f-cdae-75ffc414467f"
+ },
+ "source": [
+ "print(grad(f)(2., 3.))"
+ ],
+ "execution_count": 29,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "12.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kZ0yc-Ihoezk",
+ "colab_type": "text"
+ },
+ "source": [
+ "Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3FGwfT67PDs9",
+ "colab_type": "text"
+ },
+ "source": [
+ "When you're not performing differentiation, the function `f` is called just as if it weren't decorated by `jax.custom_jvp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "b-tB3xCHPRFt",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@custom_jvp\n",
+ "def f(x):\n",
+ " print('called f!') # a harmless side-effect\n",
+ " return np.sin(x)\n",
+ "\n",
+ "@f.defjvp\n",
+ "def f_jvp(primals, tangents):\n",
+ " print('called f_jvp!')\n",
+ " x, = primals\n",
+ " t, = tangents\n",
+ " return f(x), np.cos(x) * t"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "xAlRea95PjA5",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "26fcf10c-22bb-4bae-94a6-be6f22d5ef34"
+ },
+ "source": [
+ "from jax import vmap, jit\n",
+ "\n",
+ "print(f(3.))"
+ ],
+ "execution_count": 31,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f!\n",
+ "0.14112\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "dyD2ow4NmpI-",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 85
+ },
+ "outputId": "b88e980a-6fe0-4c17-d76b-e03e4692f16f"
+ },
+ "source": [
+ "print(vmap(f)(np.arange(3.)))\n",
+ "print(jit(f)(3.))"
+ ],
+ "execution_count": 32,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f!\n",
+ "[0. 0.841471 0.9092974]\n",
+ "called f!\n",
+ "0.14112\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EzB75KZ5Pz7m",
+ "colab_type": "text"
+ },
+ "source": [
+ "The custom JVP rule is invoked during differentiation, whether forward or reverse:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "hKF0xyAxPyLZ",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "2ec1142f-0273-4c21-9b28-5d4490bcba12"
+ },
+ "source": [
+ "y, y_dot = jvp(f, (3.,), (1.,))\n",
+ "print(y_dot)"
+ ],
+ "execution_count": 33,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f_jvp!\n",
+ "called f!\n",
+ "-0.9899925\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Z1KaEgA58MEG",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "f231d379-b296-4a36-8fc3-924476c57edb"
+ },
+ "source": [
+ "print(grad(f)(3.))"
+ ],
+ "execution_count": 34,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f_jvp!\n",
+ "called f!\n",
+ "-0.9899925\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "o8JFxk3lQhOs",
+ "colab_type": "text"
+ },
+ "source": [
+ "Notice that `f_jvp` calls `f` to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original `f` to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of `f` in our rule _and also_ have the rule apply in all orders of higher-order differentiation.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "B6PLJooTQgVp",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 85
+ },
+ "outputId": "10b1dd9f-f6fa-4774-bb77-7d222a6bc2c4"
+ },
+ "source": [
+ "grad(grad(f))(3.)"
+ ],
+ "execution_count": 35,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f_jvp!\n",
+ "called f_jvp!\n",
+ "called f!\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "DeviceArray(-0.14112, dtype=float32)"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 35
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XNxAmFSsaaro",
+ "colab_type": "text"
+ },
+ "source": [
+ "You can use Python control flow with `jax.custom_jvp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "kkXlSJL6adU2",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@custom_jvp\n",
+ "def f(x):\n",
+ " if x > 0:\n",
+ " return np.sin(x)\n",
+ " else:\n",
+ " return np.cos(x)\n",
+ "\n",
+ "@f.defjvp\n",
+ "def f_jvp(primals, tangents):\n",
+ " x, = primals\n",
+ " x_dot, = tangents\n",
+ " ans = f(x)\n",
+ " if x > 0:\n",
+ " return ans, 2 * x_dot\n",
+ " else:\n",
+ " return ans, 3 * x_dot"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "QCHmJ56Na2G3",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "13c86bfa-03f7-45db-aefd-2958b2b60578"
+ },
+ "source": [
+ "print(grad(f)(1.))\n",
+ "print(grad(f)(-1.))"
+ ],
+ "execution_count": 37,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "2.0\n",
+ "3.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9cVdgR7ilt8l",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Use `jax.custom_vjp` to define custom reverse-mode-only rules\n",
+ "\n",
+ "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "zAZk1n3dUw76",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from jax import custom_vjp\n",
+ "import jax.numpy as np\n",
+ "\n",
+ "# f :: a -> b\n",
+ "@custom_vjp\n",
+ "def f(x):\n",
+ " return np.sin(x)\n",
+ "\n",
+ "# f_fwd :: a -> (b, c)\n",
+ "def f_fwd(x):\n",
+ " return f(x), np.cos(x)\n",
+ "\n",
+ "# f_bwd :: (c, CT b) -> CT a\n",
+ "def f_bwd(cos_x, y_bar):\n",
+ " return (cos_x * y_bar,)\n",
+ "\n",
+ "f.defvjp(f_fwd, f_bwd)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "E8W-H2S0Ngdr",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "d0f25cc0-8454-42a4-b13b-10863beeb6ea"
+ },
+ "source": [
+ "from jax import grad\n",
+ "\n",
+ "print(f(3.))\n",
+ "print(grad(f)(3.))"
+ ],
+ "execution_count": 39,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "0.14112\n",
+ "-0.9899925\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yLING7qEVGGN",
+ "colab_type": "text"
+ },
+ "source": [
+ "In words, we again start with a a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it two functions, `f_fwd` and `f_bwd`, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively.\n",
+ "\n",
+ "The function `f_fwd` describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function `f`, in that it takes a primal input of type `a`. But as output it produces a pair, where the first element is the primal output `b` and the second element is any \"residual\" data of type `c` to be stored for use by the backward pass. (This second output is analogous to [PyTorch's `save_for_backward` mechanism](https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html).)\n",
+ "\n",
+ "The function `f_bwd` describes the backward pass. It takes two inputs, where the first is the residual data of type `c` produced by `f_fwd` and the second is the output cotangents of type `CT b` corresponding to the output of the primal function. It produces an output of type `CT a` representing the cotangents corresponding to the input of the primal function. In particular, the output of `f_bwd` must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "d1b5v67Oncfz",
+ "colab_type": "text"
+ },
+ "source": [
+ "So multiple arguments work like this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "IhMb64gkngAt",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from jax import custom_vjp\n",
+ "\n",
+ "@custom_vjp\n",
+ "def f(x, y):\n",
+ " return np.sin(x) * y\n",
+ "\n",
+ "def f_fwd(x, y):\n",
+ " return f(x, y), (np.cos(x), np.sin(x), y)\n",
+ "\n",
+ "def f_bwd(res, g):\n",
+ " cos_x, sin_x, y = res\n",
+ " return (cos_x * g * y, -sin_x * g)\n",
+ "\n",
+ "f.defvjp(f_fwd, f_bwd)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "EnRtIhhLnkry",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "b949c30c-62b8-4ed5-ea76-6910aca4da69"
+ },
+ "source": [
+ "print(grad(f)(2., 3.))"
+ ],
+ "execution_count": 41,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "-1.2484405\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GwC26P9kn8qw",
+ "colab_type": "text"
+ },
+ "source": [
+ "Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XfH-ae8bYt6-",
+ "colab_type": "text"
+ },
+ "source": [
+ "As with `jax.custom_jvp`, the custom VJP rule comprised by `f_fwd` and `f_bwd` is not invoked if differentiation is not applied. If function is evaluated, or transformed with `jit`, `vmap`, or other non-differentiation transformations, then only `f` is called."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "s-_Dbqi-N5Ij",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@custom_vjp\n",
+ "def f(x):\n",
+ " print(\"called f!\")\n",
+ " return np.sin(x)\n",
+ "\n",
+ "def f_fwd(x):\n",
+ " print(\"called f_fwd!\")\n",
+ " return f(x), np.cos(x)\n",
+ "\n",
+ "def f_bwd(cos_x, y_bar):\n",
+ " print(\"called f_bwd!\")\n",
+ " return (cos_x * y_bar,)\n",
+ "\n",
+ "f.defvjp(f_fwd, f_bwd)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "r0aZ79OmOAR5",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "0b06b579-ba48-405a-98eb-c54069ac2f59"
+ },
+ "source": [
+ "print(f(3.))"
+ ],
+ "execution_count": 43,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f!\n",
+ "0.14112\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "7ToB9BYlm6uN",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 85
+ },
+ "outputId": "8413a918-95b3-40ed-9062-202002550af4"
+ },
+ "source": [
+ "print(grad(f)(3.))"
+ ],
+ "execution_count": 44,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f_fwd!\n",
+ "called f!\n",
+ "called f_bwd!\n",
+ "-0.9899925\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "s1Pn_qCIODcF",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "e9921a2e-4488-4cf9-bb0e-e2f2cae98806"
+ },
+ "source": [
+ "from jax import vjp\n",
+ "\n",
+ "y, f_vjp = vjp(f, 3.)\n",
+ "print(y)"
+ ],
+ "execution_count": 45,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f_fwd!\n",
+ "called f!\n",
+ "0.14112\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "dvgQtDHaOHuo",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "cbc26c6f-871a-46fc-a0b5-2c855b236f4b"
+ },
+ "source": [
+ "print(f_vjp(1.))"
+ ],
+ "execution_count": 46,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f_bwd!\n",
+ "(DeviceArray(-0.9899925, dtype=float32),)\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qFIIpkFcZCNP",
+ "colab_type": "text"
+ },
+ "source": [
+ "**Forward-mode autodiff cannot be used on the `jax.custom_vjp` function** and will raise an error:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "3RGQRbI_OSEX",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "2c4a7602-566b-4bb0-a591-5f2568a5427e"
+ },
+ "source": [
+ "from jax import jvp\n",
+ "\n",
+ "try:\n",
+ " jvp(f, (3.,), (1.,))\n",
+ "except TypeError as e:\n",
+ " print('ERROR! {}'.format(e))"
+ ],
+ "execution_count": 47,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "called f_fwd!\n",
+ "called f!\n",
+ "ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "u04I9j2dntAU",
+ "colab_type": "text"
+ },
+ "source": [
+ "If you want to use both forward- and reverse-mode, use `jax.custom_jvp` instead."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YN97y7LEZbWV",
+ "colab_type": "text"
+ },
+ "source": [
+ "We can use `jax.custom_vjp` together with `pdb` to insert a debugger trace in the backward pass:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "-DvRKsHPZk_g",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import pdb\n",
+ "\n",
+ "@custom_vjp\n",
+ "def debug(x):\n",
+ " return x # acts like identity\n",
+ "\n",
+ "def debug_fwd(x):\n",
+ " return x, x\n",
+ "\n",
+ "def debug_bwd(x, g):\n",
+ " import pdb; pdb.set_trace()\n",
+ " return g\n",
+ "\n",
+ "debug.defvjp(debug_fwd, debug_bwd)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "49GdkP4pZ2IV",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def foo(x):\n",
+ " y = x ** 2\n",
+ " y = debug(y) # insert pdb in corresponding backward pass step\n",
+ " return np.sin(y)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sGLnRcPwaKoX",
+ "colab_type": "text"
+ },
+ "source": [
+ "```python\n",
+ "jax.grad(foo)(3.)\n",
+ "\n",
+ "> (12)debug_bwd()\n",
+ "-> return g\n",
+ "(Pdb) p x\n",
+ "DeviceArray(9., dtype=float32)\n",
+ "(Pdb) p g\n",
+ "DeviceArray(-0.91113025, dtype=float32)\n",
+ "(Pdb) q\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DaTfAJLAl1Lb",
+ "colab_type": "text"
+ },
+ "source": [
+ "## More features and details\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LQF_UDApl_UV",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n",
+ "\n",
+ "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://github.com/google/jax/blob/master/docs/notebooks/JAX_pytrees.ipynb) are permissible, so long as their structures are consistent according to the type constraints. \n",
+ "\n",
+ "Here's a contrived example with `jax.custom_jvp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "6sDLZ3dAn3P2",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from collections import namedtuple\n",
+ "Point = namedtuple(\"Point\", [\"x\", \"y\"])\n",
+ "\n",
+ "@custom_jvp\n",
+ "def f(pt):\n",
+ " x, y = pt.x, pt.y\n",
+ " return {'a': x ** 2,\n",
+ " 'b': (np.sin(x), np.cos(y))}\n",
+ "\n",
+ "@f.defjvp\n",
+ "def f_jvp(primals, tangents):\n",
+ " pt, = primals\n",
+ " pt_dot, = tangents\n",
+ " ans = f(pt)\n",
+ " ans_dot = {'a': 2 * pt.x * pt_dot.x,\n",
+ " 'b': (np.cos(pt.x) * pt_dot.x, -np.sin(pt.y) * pt_dot.y)}\n",
+ " return ans, ans_dot\n",
+ "\n",
+ "def fun(pt):\n",
+ " dct = f(pt)\n",
+ " return dct['a'] + dct['b'][0]"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "My8pbOlPppJj",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "2ef4c080-7ecb-425f-dbf9-8e2807529365"
+ },
+ "source": [
+ "pt = Point(1., 2.)\n",
+ "\n",
+ "print(f(pt))"
+ ],
+ "execution_count": 51,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "a9qyiCAhqLd3",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "36dfdb67-f5ca-4833-e37e-4e2488b96c0d"
+ },
+ "source": [
+ "print(grad(fun)(pt))"
+ ],
+ "execution_count": 52,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Point(x=DeviceArray(2.5403023, dtype=float32), y=array(0.))\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BWLN9tu4qWQd",
+ "colab_type": "text"
+ },
+ "source": [
+ "And an analogous contrived example with `jax.custom_vjp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "QkdbwGkJqS3J",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@custom_vjp\n",
+ "def f(pt):\n",
+ " x, y = pt.x, pt.y\n",
+ " return {'a': x ** 2,\n",
+ " 'b': (np.sin(x), np.cos(y))}\n",
+ "\n",
+ "def f_fwd(pt):\n",
+ " return f(pt), pt\n",
+ "\n",
+ "def f_bwd(pt, g):\n",
+ " a_bar, (b0_bar, b1_bar) = g['a'], g['b']\n",
+ " x_bar = 2 * pt.x * a_bar + np.cos(pt.x) * b0_bar\n",
+ " y_bar = -np.sin(pt.y) * b1_bar\n",
+ " return (Point(x_bar, y_bar),)\n",
+ "\n",
+ "f.defvjp(f_fwd, f_bwd)\n",
+ "\n",
+ "def fun(pt):\n",
+ " dct = f(pt)\n",
+ " return dct['a'] + dct['b'][0]"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "3onW7t6nrJ4E",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "ad13ae57-72e4-4b90-ac25-d4d6aa0f8161"
+ },
+ "source": [
+ "pt = Point(1., 2.)\n",
+ "\n",
+ "print(f(pt))"
+ ],
+ "execution_count": 54,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ryyeKIXtrNpd",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "4ce36200-6e27-4ba5-912a-8e86e7500bd0"
+ },
+ "source": [
+ "print(grad(fun)(pt))"
+ ],
+ "execution_count": 55,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Point(x=DeviceArray(2.5403023, dtype=float32), y=DeviceArray(-0., dtype=float32))\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JKTNivxbmKWO",
+ "colab_type": "text"
+ },
+ "source": [
+ "### Handling non-differentiable arguments"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7g9sXSp_uc36",
+ "colab_type": "text"
+ },
+ "source": [
+ "Some use cases, like the final example problem, call for non-differentiable arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of `fixed_point`, the function argument `f` was such a non-differentiable argument. A similar situation arises with `jax.experimental.odeint`.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9yNIOzyBCvE5",
+ "colab_type": "text"
+ },
+ "source": [
+ "#### `jax.custom_jvp` with `nondiff_argnums`\n",
+ "\n",
+ "Use the optional `nondiff_argnums` parameter to `jax.custom_jvp` to indicate arguments like these. Here's an example with `jax.custom_jvp`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "b3YMxxTBvy0I",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "@partial(custom_jvp, nondiff_argnums=(0,))\n",
+ "def app(f, x):\n",
+ " return f(x)\n",
+ "\n",
+ "@app.defjvp\n",
+ "def app_jvp(f, primals, tangents):\n",
+ " x, = primals\n",
+ " x_dot, = tangents\n",
+ " return f(x), 2. * x_dot"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "5W-yEw9IB34S",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "d6bea705-49c9-4821-e1d5-f0a6e431130f"
+ },
+ "source": [
+ "print(app(lambda x: x ** 3, 3.))"
+ ],
+ "execution_count": 57,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "27.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "zbVIlOmqB7_O",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "d690ebab-e3b9-4836-adf8-2be0a743a5d6"
+ },
+ "source": [
+ "print(grad(app, 1)(lambda x: x ** 3, 3.))"
+ ],
+ "execution_count": 58,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "2.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-b_B_4WaBI2D",
+ "colab_type": "text"
+ },
+ "source": [
+ "Notice the gotcha here: no matter where in the argument list these parameters appear, they're placed at the *start* of the signature of the corresponding JVP rule. Here's another example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "9hokWmyHBgKK",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@partial(custom_jvp, nondiff_argnums=(0, 2))\n",
+ "def app2(f, x, g):\n",
+ " return f(g((x)))\n",
+ "\n",
+ "@app2.defjvp\n",
+ "def app2_jvp(f, g, primals, tangents):\n",
+ " x, = primals\n",
+ " x_dot, = tangents\n",
+ " return f(g(x)), 3. * x_dot"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "J7GsvJTgCfS0",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "76b6754c-9465-493a-b191-492915c0614c"
+ },
+ "source": [
+ "print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))"
+ ],
+ "execution_count": 60,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "3375.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "kPP8Jt1CCb1X",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "152e181f-6ede-4fc6-9dbe-10bdc6483d64"
+ },
+ "source": [
+ "print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))"
+ ],
+ "execution_count": 61,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "3.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ECbalHIkC4ts",
+ "colab_type": "text"
+ },
+ "source": [
+ "#### `jax.custom_vjp` with `nondiff_argnums`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0u0jn4aWC8k1",
+ "colab_type": "text"
+ },
+ "source": [
+ "A similar option exists for `jax.custom_vjp`, and similarly the convention is that the non-differentiable arguments are passed as the first arguments to the rules, no matter where they appear in the original function's signature. Here's an example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "yCdu-_9GClWs",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "@partial(custom_vjp, nondiff_argnums=(0,))\n",
+ "def app(f, x):\n",
+ " return f(x)\n",
+ "\n",
+ "def app_fwd(f, x):\n",
+ " return f(x), x\n",
+ "\n",
+ "def app_bwd(f, x, g):\n",
+ " return (5 * g,)\n",
+ "\n",
+ "app.defvjp(app_fwd, app_bwd)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qSgcWa1eDj4r",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "28289e33-a96e-4f2d-e8c3-479232dfb36d"
+ },
+ "source": [
+ "print(app(lambda x: x ** 2, 4.))"
+ ],
+ "execution_count": 63,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "16.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "tccagflcDmaz",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "12eae540-65a9-4c17-ea82-e0965fae924e"
+ },
+ "source": [
+ "print(grad(app, 1)(lambda x: x ** 2, 4.))"
+ ],
+ "execution_count": 64,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "5.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BTEnNTk5D0sM",
+ "colab_type": "text"
+ },
+ "source": [
+ "See `fixed_point` above for another usage example."
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb
index 924dffe8425c..59b46bb48273 100644
--- a/docs/notebooks/How_JAX_primitives_work.ipynb
+++ b/docs/notebooks/How_JAX_primitives_work.ipynb
@@ -981,7 +981,7 @@
" * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n",
" * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n",
" we do not call the multiply_add_abstract_eval.\n",
- " * I think it would be useful to show the JAXPR here\n",
+ " * I think it would be useful to show the jaxpr here\n",
" "
]
},
@@ -1503,7 +1503,7 @@
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n",
" lambda: _flatten_axes(out_tree(), out_axes))\n",
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n",
- " out_vals, out_dims = batch_fun(fun, in_vals, in_dims)\n",
+ " out_vals, out_dims = batch2(fun, in_vals, in_dims)\n",
"NotImplementedError: Batching rule for 'multiply_add' not implemented\n"
],
"name": "stderr"
diff --git a/examples/control_test.py b/examples/control_test.py
index 3c1272876503..b172b2752be5 100644
--- a/examples/control_test.py
+++ b/examples/control_test.py
@@ -215,8 +215,8 @@ def testMpcWithLqrProblem(self):
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
+ @jtu.skip_on_devices("cpu") # TODO(mattjj,froystig): only fails on travis?
def testMpcWithLqrProblemSpecifiedGenerally(self):
- raise SkipTest # TODO(froystig)
randn = onp.random.RandomState(0).randn
dim, T, num_iters = 2, 10, 3
p = one_step_control(dim, T)
@@ -229,6 +229,7 @@ def testMpcWithLqrProblemSpecifiedGenerally(self):
self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)
+ @jtu.skip_on_devices("cpu") # TODO(mattjj,froystig): only fails on travis?
def testMpcWithNonlinearProblem(self):
def cost(t, x, u):
return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1.)
diff --git a/images/custom_jvp_schematic.png b/images/custom_jvp_schematic.png
new file mode 100644
index 000000000000..a06f0800ed88
Binary files /dev/null and b/images/custom_jvp_schematic.png differ
diff --git a/jax/api.py b/jax/api.py
index 4989ba281a9e..429f09f23e89 100644
--- a/jax/api.py
+++ b/jax/api.py
@@ -40,12 +40,12 @@
from . import dtypes
from .core import eval_jaxpr
from .api_util import (wraps, flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
- flatten_fun_nokwargs2)
+ flatten_fun_nokwargs2, argnums_partial)
from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
tree_transpose, tree_leaves, tree_multimap,
treedef_is_leaf, _replace_nones)
-from .util import (unzip2, curry, partial, safe_map, safe_zip,
- WrapHashably, Hashable, prod, split_list, extend_name_stack, wrap_name)
+from .util import (unzip2, curry, partial, safe_map, safe_zip, prod,
+ split_list, extend_name_stack, wrap_name)
from .lib import xla_bridge as xb
# Unused imports to be exported
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
@@ -59,6 +59,7 @@
from .interpreters import parallel
from .interpreters import masking
from .interpreters.masking import ensure_poly
+from .custom_derivatives import custom_jvp, custom_vjp
from .config import flags, config, bool_env
AxisName = Any
@@ -131,7 +132,7 @@ def jit(fun: Callable, static_argnums: Union[int, Iterable[int]] = (),
@wraps(fun)
def f_jitted(*args, **kwargs):
- if _thread_local_state.jit_is_disabled or config.read('jax_disable_jit'):
+ if _jit_is_disabled():
return fun(*args, **kwargs)
if static_argnums and max(static_argnums) >= len(args):
msg = ("Jitted function has static_argnums={} but was called with only {}"
@@ -140,7 +141,7 @@ def f_jitted(*args, **kwargs):
f = lu.wrap_init(fun)
if static_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
- f, dyn_args = _argnums_partial(f, dyn_argnums, args)
+ f, dyn_args = argnums_partial(f, dyn_argnums, args)
else:
dyn_args = args
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
@@ -196,6 +197,9 @@ def disable_jit():
finally:
_thread_local_state.jit_is_disabled = prev_val
+def _jit_is_disabled():
+ return _thread_local_state.jit_is_disabled or config.read('jax_disable_jit')
+
def xla_computation(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
@@ -413,7 +417,7 @@ def value_and_grad_f(*args, **kwargs):
raise TypeError(msg.format(argnums, max_argnum + 1, len(args)))
f = lu.wrap_init(fun, kwargs)
- f_partial, dyn_args = _argnums_partial(f, argnums, args)
+ f_partial, dyn_args = argnums_partial(f, argnums, args)
if not has_aux:
ans, vjp_py = _vjp(f_partial, *dyn_args)
else:
@@ -476,7 +480,7 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
- f_partial, dyn_args = _argnums_partial(f, argnums, args)
+ f_partial, dyn_args = argnums_partial(f, argnums, args)
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
pushfwd = partial(_jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, batching.last))(_std_basis(dyn_args))
@@ -521,7 +525,7 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
"""
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
- f_partial, dyn_args = _argnums_partial(f, argnums, args)
+ f_partial, dyn_args = argnums_partial(f, argnums, args)
y, pullback = _vjp(f_partial, *dyn_args)
holomorphic or tree_map(_check_real_output_jacrev, y)
jac = vmap(pullback)(_std_basis(y))
@@ -914,7 +918,7 @@ def f_pmapped(*args, **kwargs):
f = lu.wrap_init(fun)
if static_broadcasted_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_broadcasted_argnums]
- f, dyn_args = _argnums_partial(f, dyn_argnums, args)
+ f, dyn_args = argnums_partial(f, dyn_argnums, args)
else:
dyn_args = args
args, in_tree = tree_flatten((dyn_args, kwargs))
@@ -1441,32 +1445,6 @@ def device_get(x):
return tree_map(_device_get, x)
-def _argnums_partial(f: lu.WrappedFun, dyn_argnums, args):
- if isinstance(dyn_argnums, int):
- dyn_argnums = (dyn_argnums,)
- else:
- dyn_argnums = tuple(dyn_argnums)
- fixed_args = tuple([core.unit if i in dyn_argnums else _wrap_hashably(arg)
- for i, arg in enumerate(args)])
- dyn_args = tuple(args[i] for i in dyn_argnums)
- return _argnums_partial_(f, dyn_argnums, fixed_args), dyn_args
-
-def _wrap_hashably(arg):
- try:
- hash(arg)
- except TypeError:
- return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
- else:
- return Hashable(arg)
-
-@lu.transformation
-def _argnums_partial_(dyn_argnums, fixed_args, *dyn_args, **kwargs):
- args = [None if arg is core.unit else arg.val for arg in fixed_args]
- for i, arg in zip(dyn_argnums, dyn_args):
- args[i] = arg
- ans = yield args, kwargs
- yield ans
-
def _check_args(args):
for arg in args:
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
@@ -1482,534 +1460,6 @@ def _valid_jaxtype(arg):
return True
-class CustomTransformsFunction(object):
- def __init__(self, fun, prim):
- self.fun = fun
- self.prim = prim
- wraps(fun)(self)
-
- def __repr__(self):
- return ''.format(fun=self.__name__)
-
- def __call__(self, *args):
- # TODO(mattjj): instead of tracing to a jaxpr, use process_call
- args_flat, in_tree = tree_flatten(args)
- flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
- in_pvals = [pe.PartialVal((raise_to_shaped(core.get_aval(x)), core.unit))
- for x in args_flat]
- jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
- outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
- in_tree=in_tree, out_tree=out_tree(),
- num_consts=len(consts))
- return tree_unflatten(out_tree(), outs)
-
-def custom_transforms(fun):
- """Wraps a function so that its transformation behavior can be controlled.
-
- A primary use case of ``custom_transforms`` is defining custom VJP rules (aka
- custom gradients) for a Python function, while still supporting other
- transformations like ``jax.jit`` and ``jax.vmap``. Custom differentiation
- rules can be supplied using the ``jax.defjvp`` and ``jax.defvjp`` functions.
-
- The ``custom_transforms`` decorator wraps ``fun`` so that its transformation
- behavior can be overridden, but not all transformation rules need to be
- specified manually. The default behavior is retained for any non-overridden
- rules.
-
- The function ``fun`` must satisfy the same constraints required for jit
- compilation. In particular the shapes of arrays in the computation of ``fun``
- may depend on the shapes of ``fun``'s arguments, but not their values.
- Value dependent Python control flow is also not yet supported.
-
- Args:
- fun: a Python callable. Must be functionally pure. Its arguments and return
- value should be arrays, scalars, or (nested) standard Python containers
- (tuple/list/dict) thereof.
-
- Returns:
- A Python callable with the same input/output and transformation behavior as
- ``fun``, but for which custom transformation rules can be supplied, e.g.
- using ``jax.defvjp``.
-
- For example:
-
- >>> @jax.custom_transforms
- ... def f(x):
- ... return np.sin(x ** 2)
- ...
- >>> print(f(3.))
- 0.4121185
- >>> print(jax.grad(f)(3.))
- -5.4667816
- >>> jax.defvjp(f, lambda g, x: g * x)
- >>> print(jax.grad(f)(3.))
- 3.0
- """
- name = getattr(fun, '__name__', '')
- fun_p = core.Primitive(name)
- fun_p.multiple_results = True
-
- def fun_impl(*args, **params):
- consts, args = split_list(args, [params['num_consts']])
- return core.eval_jaxpr(params['jaxpr'], consts, *args)
- fun_p.def_impl(fun_impl)
-
- def fun_jvp(primals, tangents, **params):
- return ad.jvp(lu.wrap_init(fun_impl, params)).call_wrapped(primals, tangents)
- ad.primitive_jvps[fun_p] = fun_jvp
-
- def fun_batch(args, dims, **params):
- return batching.batch_fun(lu.wrap_init(fun_impl, params), args, dims)
- batching.primitive_batchers[fun_p] = fun_batch
-
- def fun_abstract_eval(*avals, **params):
- return pe.abstract_eval_fun(fun_impl, *avals, **params)
- fun_p.def_abstract_eval(fun_abstract_eval)
-
- def fun_translation(c, *xla_args, **params):
- return xla.lower_fun(fun_impl)(c, *xla_args, **params)
- xla.translations[fun_p] = fun_translation
-
- return CustomTransformsFunction(fun, fun_p)
-
-def _check_custom_transforms_type(name, fun):
- if type(fun) is not CustomTransformsFunction:
- msg = ("{} requires a custom_transforms function as its first argument, "
- "but got type {}.")
- raise TypeError(msg.format(name, type(fun)))
-
-def defjvp_all(fun, custom_jvp):
- """Define a custom JVP rule for a ``custom_transforms`` function.
-
- If ``fun`` represents a function with signature ``a -> b``, then
- ``custom_jvp`` represents a function with signature ``(a, T a) -> (b, T b)``,
- where we use ``T x`` to represent a tangent type for the type ``x``.
-
- In more detail, ``custom_jvp`` must take two arguments, both tuples of length
- equal to the number of positional arguments to ``fun``. The first argument to
- ``custom_jvp`` represents the input primal values, and the second represents
- the input tangent values. ``custom_jvp`` must return a pair where the first
- element represents the output primal value and the second element represents
- the output tangent value.
-
- Defining a custom JVP rule also affects the default VJP rule, which is derived
- from the JVP rule automatically via transposition.
-
- Args:
- fun: a custom_transforms function.
- custom_jvp: a Python callable specifying the JVP rule, taking two tuples as
- arguments specifying the input primal values and tangent values,
- respectively. The tuple elements can be arrays, scalars, or (nested)
- standard Python containers (tuple/list/dict) thereof. The output must be a
- pair representing the primal output and tangent output, which can be
- arrays, scalars, or (nested) standard Python containers. Must be
- functionally pure.
-
- Returns:
- None. A side-effect is that ``fun`` is associated with the JVP rule
- specified by ``custom_jvp``.
-
- For example:
-
- >>> @jax.custom_transforms
- ... def f(x):
- ... return np.sin(x ** 2)
- ...
- >>> print(f(3.))
- 0.4121185
- >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
- >>> print(out_primal)
- 0.4121185
- >>> print(out_tangent)
- -10.933563
- >>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0]))
- >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
- >>> print(out_primal)
- 0.4121185
- >>> print(out_tangent)
- 16.0
- """
- _check_custom_transforms_type("defjvp_all", fun)
- def custom_transforms_jvp(primals, tangents, **params):
- num_consts, in_tree = params['num_consts'], params['in_tree']
- _, args_flat = split_list(primals, [num_consts])
- consts_dot, args_dot_flat = split_list(tangents, [num_consts])
- if not all(t is ad_util.zero for t in consts_dot):
- msg = ("Detected differentiation with respect to closed-over values with "
- "custom JVP rule, which isn't supported.")
- raise ValueError(msg)
- args = tree_unflatten(in_tree, args_flat)
- args_dot = tree_unflatten(in_tree, args_dot_flat)
- out, out_dot = custom_jvp(args, args_dot)
- out_flat, out_tree = tree_flatten(out)
- out_dot_flat, out_tree2 = tree_flatten(out_dot)
- if out_tree != out_tree2:
- msg = ("Custom JVP rule returned different tree structures for primals "
- "and tangents, but they must be equal: {} and {}.")
- raise TypeError(msg.format(out_tree, out_tree2))
- return out_flat, out_dot_flat
- ad.primitive_jvps[fun.prim] = custom_transforms_jvp
-
-def defjvp(fun, *jvprules):
- """Definine JVP rules for each argument separately.
-
- This function is a convenience wrapper around ``jax.defjvp_all`` for
- separately defining JVP rules for each of the function's arguments. This
- convenience wrapper does not provide a mechanism for depending on anything
- other than the function arguments and its primal output value, though
- depending on intermediate results is possible using ``jax.defjvp_all``.
-
- The signature of each component JVP rule is ``lambda g, ans, *primals: ...``
- where ``g`` represents the tangent of the corresponding positional argument,
- ``ans`` represents the output primal, and ``*primals`` represents all the
- primal positional arguments.
-
- Defining a custom JVP rule also affects the default VJP rule, which is derived
- from the JVP rule automatically via transposition.
-
- Args:
- fun: a custom_transforms function.
- *jvprules: a sequence of functions or Nones specifying the JVP rule for each
- corresponding positional argument. When an element is None, it indicates
- that the Jacobian from the corresponding input to the output is zero.
-
- Returns:
- None. A side-effect is that ``fun`` is associated with the JVP rule
- specified by ``*jvprules``.
-
- For example:
-
- >>> @jax.custom_transforms
- ... def f(x):
- ... return np.sin(x ** 2)
- ...
- >>> print(f(3.))
- 0.4121185
- >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
- >>> print(out_primal)
- 0.4121185
- >>> print(out_tangent)
- -10.933563
- >>> jax.defjvp(f, lambda g, ans, x: 8. * g + ans)
- >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
- >>> print(out_primal)
- 0.4121185
- >>> print(out_tangent)
- 16.412119
- """
- _check_custom_transforms_type("defjvp", fun)
- def custom_jvp(primals, tangents):
- ans = fun(*primals)
- tangents_out = [rule(t, ans, *primals) for rule, t in zip(jvprules, tangents)
- if rule is not None and t is not ad_util.zero]
- return ans, functools.reduce(ad.add_tangents, tangents_out, ad_util.zero)
- defjvp_all(fun, custom_jvp)
-
-def defvjp_all(fun, custom_vjp):
- """Define a custom VJP rule for a ``custom_transforms`` function.
-
- If ``fun`` represents a function with signature ``a -> b``, then
- ``custom_vjp`` represents a function with signature ``a -> (b, CT b -> CT a)``
- where we use ``CT x`` to represent a cotangent type for the type ``x``. That
- is, ``custom_vjp`` should take the same arguments as ``fun`` and return a pair
- where the first element represents the primal value of ``fun`` applied to the
- arguments, and the second element is a VJP function that maps from output
- cotangents to input cotangents, returning a tuple with length equal to the
- number of positional arguments supplied to ``fun``.
-
- The VJP function returned as the second element of the output of
- ``custom_vjp`` can close over intermediate values computed when evaluating the
- primal value of ``fun``. That is, use lexical closure to share work between
- the forward pass and the backward pass of reverse-mode automatic
- differentiation.
-
- See also ``jax.custom_gradient``.
-
- Args:
- fun: a custom_transforms function.
- custom_vjp: a Python callable specifying the VJP rule, taking the same
- arguments as ``fun`` and returning a pair where the first element is the
- value of ``fun`` applied to the arguments and the second element is a
- Python callable representing the VJP map from output cotangents to input
- cotangents. The returned VJP function must accept a value with the same
- shape as the value of ``fun`` applied to the arguments and must return a
- tuple with length equal to the number of positional arguments to ``fun``.
- Arguments can be arrays, scalars, or (nested) standard Python containers
- (tuple/list/dict) thereof. Must be functionally pure.
-
- Returns:
- None. A side-effect is that ``fun`` is associated with the VJP rule
- specified by ``custom_vjp``.
-
- For example:
-
- >>> @jax.custom_transforms
- ... def f(x):
- ... return np.sin(x ** 2)
- ...
- >>> print(f(3.))
- 0.4121185
- >>> print(jax.grad(f)(3.))
- -5.4667816
- >>> jax.defvjp_all(f, lambda x: (np.sin(x ** 2), lambda g: (g * x,)))
- >>> print(f(3.))
- 0.4121185
- >>> print(jax.grad(f)(3.))
- 3.0
-
- An example with a function on two arguments, so that the VJP function must
- return a tuple of length two:
-
- >>> @jax.custom_transforms
- ... def f(x, y):
- ... return x * y
- ...
- >>> jax.defvjp_all(f, lambda x, y: (x * y, lambda g: (y, x)))
- >>> print(f(3., 4.))
- 12.0
- >>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
- (4.0, 3.0)
- """
- _check_custom_transforms_type("defvjp_all", fun)
- def custom_transforms_vjp(*consts_and_args, **params):
- num_consts, in_tree = params['num_consts'], params['in_tree']
- consts, args_flat = split_list(consts_and_args, [num_consts])
- args = tree_unflatten(params['in_tree'], args_flat)
- out, vjp = custom_vjp(*args)
- out_flat, out_tree = tree_flatten(out)
- if out_tree != params['out_tree']:
- msg = (
- "First output of `custom_vjp`: {} doesn't match the structure of "
- "the output of `fun`: {}\n"
- "{}\n"
- "vs\n"
- "{}\n".format(custom_vjp, fun, out_tree, params['out_tree'])
- )
- raise TypeError(msg)
- def vjp_flat(*cts_flat):
- cts = tree_unflatten(out_tree, cts_flat)
- args_cts_flat, in_tree2 = tree_flatten(vjp(cts))
- if in_tree != in_tree2:
- msg = (
- "Output of the `vjp`: {} doesn't match the structure of args of "
- "`fun`: {}\n"
- "{}\n"
- "vs\n"
- "{}\n".format(vjp, fun, in_tree2, in_tree)
- )
- raise TypeError(msg)
- return [core.unit] * num_consts + list(args_cts_flat)
- return out_flat, vjp_flat
- ad.defvjp_all(fun.prim, custom_transforms_vjp)
-
-def defvjp(fun, *vjprules):
- """Define VJP rules for each argument separately.
-
- This function is a convenience wrapper around ``jax.defvjp_all`` for
- separately defining VJP rules for each of the function's arguments. This
- convenience wrapper does not provide a mechanism for depending on anything
- other than the function arguments and its primal output value, though
- depending on intermediate results is possible using ``jax.defvjp_all``.
-
- The signature of each component VJP rule is ``lambda g, ans, *primals: ...``
- where ``g`` represents the output cotangent, ``ans`` represents the output
- primal, and ``*primals`` represents all the primal positional arguments.
-
- Args:
- fun: a custom_transforms function.
- *vjprules: a sequence of functions or Nones specifying the VJP rule for each
- corresponding positional argument. When an element is None, it indicates
- that the Jacobian from the corresponding input to the output is zero.
-
- Returns:
- None. A side-effect is that ``fun`` is associated with the VJP rule
- specified by ``*vjprules``.
-
- For example:
-
- >>> @jax.custom_transforms
- ... def f(x, y):
- ... return np.sin(x ** 2 + y)
- ...
- >>> print(f(3., 4.))
- 0.42016703
- >>> print(jax.grad(f)(3., 4.))
- 5.4446807
- >>> print(jax.grad(f, 1)(3., 4.))
- 0.9074468
- >>> jax.defvjp(f, None, lambda g, ans, x, y: g + x + y + ans)
- >>> print(jax.grad(f)(3., 4.))
- 0.0
- >>> print(jax.grad(f, 1)(3., 4.))
- 8.420167
- """
- _check_custom_transforms_type("defvjp", fun)
- def custom_vjp(*primals):
- ans = fun(*primals)
- # TODO(mattjj): avoid instantiating zeros?
- def vjpfun(ct):
- return tuple(vjp(ct, ans, *primals) if vjp else ad_util.zeros_like_jaxval(x)
- for x, vjp in zip(primals, vjprules))
- return ans, vjpfun
- defvjp_all(fun, custom_vjp)
-
-def custom_gradient(fun):
- """Convenience function for defining custom VJP rules (aka custom gradients).
-
- While the canonical way to define custom VJP rules is via ``jax.defvjp_all``
- and its convenience wrappers, the ``custom_gradient`` convenience wrapper
- follows TensorFlow's ``tf.custom_gradient`` API. The difference here is that
- ``custom_gradient`` can be used as a decorator on one function that returns
- both the primal value (representing the output of the mathematical function to
- be differentiated) and the VJP (gradient) function.
-
- See https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
-
- If the mathematical function to be differentiated has type signature
- ``a -> b``, then the Python callable ``fun`` should have signature
- ``a -> (b, CT b -> CT a)`` where we use ``CT x`` to denote a cotangent type
- for ``x``. See the example below. That is, ``fun`` should return a pair where
- the first element represents the value of the mathematical function to be
- differentiated and the second element is a function that represents the custom
- VJP rule.
-
- The custom VJP function returned as the second element of the output of ``fun``
- can close over intermediate values computed when evaluating the function to be
- differentiated. That is, use lexical closure to share work between the forward
- pass and the backward pass of reverse-mode automatic differentiation.
-
- Args:
- fun: a Python callable specifying both the mathematical function to be
- differentiated and its reverse-mode differentiation rule. It should return
- a pair consisting of an output value and a Python callable that represents
- the custom gradient function.
-
- Returns:
- A Python callable with signature ``a -> b``, i.e. that returns the output
- value specified by the first element of ``fun``'s output pair. A side effect
- is that under-the-hood ``jax.defvjp_all`` is called to set up the returned
- Python callable with the custom VJP rule specified by the second element
- of ``fun``'s output pair.
-
- For example:
-
- >>> @jax.custom_gradient
- ... def f(x):
- ... return x ** 2, lambda g: (g * x,)
- ...
- >>> print(f(3.))
- 9.0
- >>> print(jax.grad(f)(3.))
- 3.0
-
- An example with a function on two arguments, so that the VJP function must
- return a tuple of length two:
-
- >>> @jax.custom_gradient
- ... def f(x, y):
- ... return x * y, lambda g: (y, x)
- ...
- >>> print(f(3., 4.))
- 12.0
- >>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
- (4.0, 3.0)
- """
- def primal_fun(*args, **kwargs):
- ans, _ = fun(*args, **kwargs)
- return ans
- primal_fun = custom_transforms(primal_fun)
- defvjp_all(primal_fun, fun)
- return primal_fun
-
-
-def jarrett(fun):
- new_fun = custom_transforms(fun)
-
- def elementwise_jvp(primals, tangents):
- pushfwd = partial(jvp, fun, primals)
- y, jacs = vmap(pushfwd, out_axes=(None, 0))(_elementwise_std_basis(tangents))
- flat_tangents, _ = tree_flatten(tangents)
- out_tangent = sum([t * jac for t, jac in zip(flat_tangents, jacs)])
- return y, out_tangent
- defjvp_all(new_fun, elementwise_jvp)
-
- return new_fun
-
-def _elementwise_std_basis(pytree):
- leaves, _ = tree_flatten(pytree)
- arity = len(leaves)
- dims = map(onp.size, leaves)
- # TODO(mattjj): use symbolic constants
- dtype = dtypes.result_type(*leaves)
- if not dtypes.issubdtype(dtype, onp.floating):
- msg = ("Jacobian only defined for functions with floating input and output "
- "dtypes (i.e. dtypes that model real numbers), got {}.")
- raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
- basis_array = onp.stack([onp.concatenate(
- [onp.ones(dims[j], dtype) if i == j else onp.zeros(dims[j], dtype)
- for j in range(arity)]) for i in range(arity)])
- return _unravel_array_into_pytree(pytree, 1, basis_array)
-
-
-# This function mostly exists for making slides about JAX.
-def _make_graphviz(fun):
- """Adapts `fun` to return a graphviz dot string of its program representation.
-
- Args:
- fun: The function whose `jaxpr` is to be rendered into graphviz dot. Its
- positional arguments and return value should be arrays, scalars, or
- standard Python containers (tuple/list/dict) thereof.
-
- Returns:
- A wrapped version of `fun`, set up to return a graphviz dot string.
-
- See make_jaxpr for a related function.
- """
- # TODO(mattjj): handle eqn.restructure
- # TODO(mattjj): handle subjaxprs
-
- def pv_like(x):
- aval = xla.abstractify(x)
- return pe.PartialVal((aval, core.unit))
-
- id_names = ("id{}".format(i) for i in it.count())
-
- def jaxpr_to_graphviz(jaxpr, consts):
- fragment = []
-
- fragment.extend(map(invar_node, jaxpr.invars, jaxpr.invars))
- fragment.extend(map(constant_node, jaxpr.constvars, consts))
-
- for eqn in jaxpr.eqns:
- id_name = next(id_names)
- fragment.append(function_node(id_name, eqn.primitive.name))
- fragment.extend(edge(invar, id_name) for invar in eqn.invars)
- fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
- for ov in jaxpr.outvars:
- fragment.append(outvar_node(ov, "out"))
- return graph(''.join(fragment))
-
- edge = '{} -> {} [color=gray30];\n'.format
- function_node = '{} [label="{}", shape=box, color=lightskyblue, style=filled];\n'.format
- invar_node = '{} [rank=2, label="{}", color=mediumspringgreen, style=filled];\n'.format
- outvar_node = '{} [label="{}", fillcolor=indianred1, style="filled,dashed", color=black];\n'.format
- constant_node = '{} [rank=2, label="{}", color=goldenrod1, style=filled];\n'.format
- freevar_node = '{} [rank=2, label="{}", color=palegreen, style=filled];\n'.format
- graph = 'digraph G {{{}}}'.format
-
- @wraps(fun)
- def graphviz_maker(*args, **kwargs):
- wrapped = lu.wrap_init(fun, kwargs)
- jax_args, in_tree = tree_flatten((args, kwargs))
- jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
- pvals = map(pv_like, jax_args)
- jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
- return jaxpr_to_graphviz(jaxpr, consts)
-
- graphviz_maker.__name__ = "make_graphviz({})".format(graphviz_maker.__name__)
- return graphviz_maker
-
-
class ShapeDtypeStruct(object):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
diff --git a/jax/api_util.py b/jax/api_util.py
index 8e1b46623ac0..f9fc6cd6663d 100644
--- a/jax/api_util.py
+++ b/jax/api_util.py
@@ -16,7 +16,8 @@
from .tree_util import (build_tree, tree_flatten, tree_unflatten,
treedef_is_leaf)
from . import linear_util as lu
-from .util import safe_map, unzip2, partial, curry
+from .util import safe_map, unzip2, partial, curry, WrapHashably, Hashable
+from .core import unit
map = safe_map
@@ -70,3 +71,29 @@ def flatten_fun_nokwargs2(in_tree, *args_flat):
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
+
+def argnums_partial(f, dyn_argnums, args):
+ if isinstance(dyn_argnums, int):
+ dyn_argnums = (dyn_argnums,)
+ else:
+ dyn_argnums = tuple(dyn_argnums)
+ fixed_args = tuple([unit if i in dyn_argnums else wrap_hashably(arg)
+ for i, arg in enumerate(args)])
+ dyn_args = tuple(args[i] for i in dyn_argnums)
+ return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
+
+def wrap_hashably(arg):
+ try:
+ hash(arg)
+ except TypeError:
+ return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
+ else:
+ return Hashable(arg)
+
+@lu.transformation
+def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
+ args = [None if arg is unit else arg.val for arg in fixed_args]
+ for i, arg in zip(dyn_argnums, dyn_args):
+ args[i] = arg
+ ans = yield args, kwargs
+ yield ans
diff --git a/jax/core.py b/jax/core.py
index 799a09d13425..265093465c8c 100644
--- a/jax/core.py
+++ b/jax/core.py
@@ -300,13 +300,6 @@ def __init__(self, master, sublevel):
self.level = master.level
self.sublevel = sublevel
- def escaped_tracer_error(self, detail):
- msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
- "through global state from a previously traced function.\n"
- "The functions being transformed should not save traced values to "
- "global state.\nDetails: {}.")
- raise ValueError(msg.format(detail))
-
def full_raise(self, val):
if not isinstance(val, Tracer):
return self.pure(val)
@@ -318,36 +311,43 @@ def full_raise(self, val):
elif val._trace.sublevel < sublevel:
return self.sublift(val)
else:
- self.escaped_tracer_error(
+ escaped_tracer_error(
"Can't lift sublevels {} to {}".format(val._trace.sublevel, sublevel))
elif val._trace.level < level:
if val._trace.sublevel > sublevel:
- self.escaped_tracer_error(
+ escaped_tracer_error(
"Incompatible sublevel: {}, {}".format(val._trace, (level, sublevel)))
return self.lift(val)
elif val._trace.level > level:
- self.escaped_tracer_error(
+ escaped_tracer_error(
"Can't lift level {} to {}".format(val, self))
else: # val._trace.level == self.level:
- self.escaped_tracer_error("Different traces at same level: {}, {}".format(val, self))
-
+ escaped_tracer_error("Different traces at same level: {}, {}".format(val, self))
def pure(self, val):
- assert False
+ raise NotImplementedError("must override")
def lift(self, tracer):
- assert False
+ raise NotImplementedError("must override")
def sublift(self, tracer):
- assert False
+ raise NotImplementedError("must override")
def process_primitive(self, primitive, tracers, params):
- assert False, "Must override"
+ raise NotImplementedError("must override")
def __repr__(self):
return '{}(level={}/{})'.format(
self.__class__.__name__, self.level, self.sublevel)
+def escaped_tracer_error(detail):
+ msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
+ "through global state from a previously traced function.\n"
+ "The functions being transformed should not save traced values to "
+ "global state.\nDetails: {}.")
+ raise UnexpectedTracerError(msg.format(detail))
+
+class UnexpectedTracerError(Exception): pass
class Tracer(object):
__array_priority__ = 1000
@@ -355,7 +355,10 @@ class Tracer(object):
def __array__(self, *args, **kw):
raise Exception("Tracer can't be used with raw numpy functions. "
- "You might have\n import numpy as np\ninstead of\n import jax.numpy as np")
+ "You might have\n"
+ " import numpy as np\n"
+ "instead of\n"
+ " import jax.numpy as np")
def __init__(self, trace):
self._trace = trace
@@ -368,7 +371,7 @@ def __len__(self):
@property
def aval(self):
- assert False
+ raise NotImplementedError("must override")
def __neg__(self): return self.aval._neg(self)
def __pos__(self): return self.aval._pos(self)
diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py
new file mode 100644
index 000000000000..a8709a40919a
--- /dev/null
+++ b/jax/custom_derivatives.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial, update_wrapper
+import inspect
+import itertools as it
+
+from . import core
+from . import linear_util as lu
+from .tree_util import tree_flatten, tree_unflatten
+from .util import safe_zip, safe_map, unzip2, split_list, curry
+from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
+from .abstract_arrays import raise_to_shaped
+from .ad_util import zero
+from .interpreters import partial_eval as pe
+from .interpreters import ad
+from .interpreters import batching
+from .interpreters import xla
+
+map = safe_map
+zip = safe_zip
+
+
+### util
+
+def _resolve_kwargs(fun, args, kwargs):
+ ba = inspect.signature(fun).bind(*args, **kwargs)
+ ba.apply_defaults()
+ if ba.kwargs:
+ raise TypeError("keyword arguments could not be resolved to positions")
+ else:
+ return ba.args
+
+def _initial_style_jaxpr(fun, in_avals):
+ in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
+ jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
+ stage_out_calls=True)
+ out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
+ const_avals = [raise_to_shaped(core.get_aval(c)) for c in consts]
+ typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
+ (), const_avals + in_avals, out_avals)
+ return typed_jaxpr, consts
+
+def _add_args(f, extra_args, left):
+ return _add_args_(f, tuple(map(wrap_hashably, extra_args)), left)
+
+@lu.transformation
+def _add_args_(extra_args, left, *args, **kwargs):
+ extra_args = tuple([arg.val for arg in extra_args])
+ args = (extra_args + args) if left else (args + extra_args)
+ yield (yield args, kwargs)
+
+@curry
+def transformation_with_equal_aux(gen, fun: lu.WrappedFun, *gen_static_args):
+ out_store = StoreEqualValues()
+ out_thunk = lambda: out_store.val
+ return fun.wrap(gen, gen_static_args, out_store), out_thunk
+
+class StoreEqualValues(lu.Store):
+ """A Store that allows storing equal values multiple times."""
+ def store(self, val):
+ if self._val is not lu._EMPTY_STORE_VALUE:
+ try:
+ same = self._val == val
+ except:
+ same = False
+ if not same:
+ raise lu.StoreException("Store occupied")
+ self._val = val
+
+
+### JVPs
+
+class custom_jvp:
+ """Set up a JAX-transformable function for a custom JVP rule definition.
+
+ This class is meant to be used as a function decorator. Instances are
+ callables that behave similarly to the underlying function to which the
+ decorator was applied, except when a differentiation transformation (like
+ ``jax.jvp`` or ``jax.grad``) is applied, in which case a custom user-supplied
+ JVP rule function is used instead of tracing into and performing automatic
+ differentiation of the underlying function's implementation. There is a single
+ instance method, ``defjvp``, which defines the custom JVP rule.
+
+ For example:
+
+ import jax.numpy as np
+
+ @jax.custom_jvp
+ def f(x, y):
+ return np.sin(x) * y
+
+ @f.defjvp
+ def f_jvp(primals, tangents):
+ x, y = primals
+ x_dot, y_dot = tangents
+ primal_out = f(x, y)
+ tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot
+ return primal_out, tangent_out
+
+ For a more detailed introduction, see the tutorial_.
+
+ .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
+ """
+
+ def __init__(self, fun, nondiff_argnums=()):
+ self.fun = fun
+ self.nondiff_argnums = nondiff_argnums
+ self.jvp = None
+ update_wrapper(self, fun)
+
+ def defjvp(self, jvp):
+ """Define a custom JVP rule for the function represented by this instance.
+
+ Args:
+ jvp: a Python callable representing the custom JVP rule. When there are no
+ ``nondiff_argnums``, the ``jvp`` function should accept two arguments,
+ where the first is a tuple of primal inputs and the second is a tuple of
+ tangent inputs. The lengths of both tuples is equal to the number of
+ parameters of the ``custom_jvp`` function. The ``jvp`` function should
+ produce as output a pair where the first element is the primal output
+ and the second element is the tangent output. Elements of the input and
+ output tuples may be arrays or any nested tuples/lists/dicts thereof.
+
+ Returns:
+ None.
+
+ Example:
+
+ import jax.numpy as np
+
+ @jax.custom_jvp
+ def f(x, y):
+ return np.sin(x) * y
+
+ @f.defjvp
+ def f_jvp(primals, tangents):
+ x, y = primals
+ x_dot, y_dot = tangents
+ primal_out = f(x, y)
+ tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot
+ return primal_out, tangent_out
+ """
+ self.jvp = jvp
+
+ def __call__(self, *args, **kwargs):
+ if not self.jvp:
+ msg = "No JVP defined for custom_jvp function {} using defjvp."
+ raise AttributeError(msg.format(self.__name__)) from None
+ args = _resolve_kwargs(self.fun, args, kwargs)
+ if self.nondiff_argnums:
+ dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
+ f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
+ static_args = [args[i] for i in self.nondiff_argnums]
+ jvp = _add_args(lu.wrap_init(self.jvp), static_args, left=True)
+ else:
+ f_, dyn_args = lu.wrap_init(self.fun), args
+ jvp = lu.wrap_init(self.jvp)
+ args_flat, in_tree = tree_flatten(dyn_args)
+ flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
+ flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
+ out_flat = custom_jvp_call(flat_fun, *args_flat, jvp=flat_jvp)
+ try: out_tree = out_tree1()
+ except lu.StoreException: out_tree = out_tree2()
+ return tree_unflatten(out_tree, out_flat)
+
+@transformation_with_equal_aux
+def _flatten_jvp(in_tree, *args):
+ primals_in, tangents_in = split_list(args, [len(args) // 2])
+ py_primals = tree_unflatten(in_tree, primals_in)
+ py_tangents = tree_unflatten(in_tree, tangents_in)
+ py_primals_out, py_tangents_out = yield (py_primals, py_tangents), {}
+ primals_out, out_tree = tree_flatten(py_primals_out)
+ tangents_out, out_tree2 = tree_flatten(py_tangents_out)
+ if out_tree != out_tree2:
+ msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
+ "container (pytree) structures, but got {} and {}.")
+ raise TypeError(msg.format(out_tree, out_tree2)) from None
+ yield primals_out + tangents_out, out_tree
+
+def _custom_deriv_call_bind(primitive, f, *args, **params):
+ top_trace = core.find_top_trace(args)
+ level = (core.trace_state.trace_stack.next_level(True)
+ if top_trace is None else top_trace.level)
+ if top_trace is None:
+ with core.new_sublevel():
+ return primitive.impl(f, *args, **params)
+ else:
+ tracers = map(top_trace.full_raise, args)
+ outs = top_trace.process_call(primitive, f, tracers, params)
+ outs = map(core.full_lower, outs)
+ return outs
+
+def _custom_call_impl(f, *args, **params):
+ return f.call_wrapped(*args)
+
+custom_jvp_call_p = core.Primitive('custom_jvp_call')
+custom_jvp_call_p.multiple_results = True
+custom_jvp_call = partial(_custom_deriv_call_bind, custom_jvp_call_p)
+custom_jvp_call_p.def_custom_bind(custom_jvp_call)
+custom_jvp_call_p.def_impl(_custom_call_impl)
+
+def _custom_jvp_call_jvp(trace, call_primitive, fun, tracers, params):
+ primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
+ primals_in = map(core.full_lower, primals_in)
+ tangents_in = map(ad.instantiate_zeros, primals_in, tangents_in)
+ outs = params['jvp'].call_wrapped(*it.chain(primals_in, tangents_in))
+ primals_out, tangents_out = split_list(outs, [len(outs) // 2])
+ return map(partial(ad.JVPTracer, trace), primals_out, tangents_out)
+ad.call_jvp_rules[custom_jvp_call_p] = _custom_jvp_call_jvp
+
+def _custom_jvp_call_vmap(trace, call_primitive, fun, tracers, params):
+ in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
+ jvp = params['jvp']
+ fun, out_dims = batching.batch_subtrace(fun, trace.master, in_dims)
+ jvp, out_dims2 = batching.batch_subtrace(jvp, trace.master, in_dims * 2)
+ out_vals = custom_jvp_call(fun, *in_vals, jvp=jvp)
+ try: out_dims = out_dims()
+ except lu.StoreException: out_dims = out_dims2()[:len(out_vals)]
+ return [batching.BatchTracer(trace, v, d) for v, d in zip(out_vals, out_dims)]
+batching.call_batching_rules[custom_jvp_call_p] = _custom_jvp_call_vmap
+
+def _custom_jvp_call_partial_eval(trace, call_primitive, fun, tracers, params):
+ return custom_jvp_call_jaxpr(fun, params['jvp'], *tracers)
+pe.call_partial_eval_rules[custom_jvp_call_p] = _custom_jvp_call_partial_eval
+
+
+def custom_jvp_call_jaxpr(fun, jvp, *args):
+ in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
+ jaxpr, consts = _initial_style_jaxpr(fun, in_avals)
+ return custom_jvp_call_jaxpr_p.bind(*it.chain(consts, args), jaxpr=jaxpr,
+ jvp=jvp, num_consts=len(consts))
+
+def _custom_call_jaxpr_impl(*args, jaxpr, **kwargs):
+ del kwargs
+ return core.jaxpr_as_fun(jaxpr)(*args)
+
+def _custom_call_jaxpr_abstract_eval(*args, jaxpr, **kwargs):
+ del kwargs
+ return jaxpr.out_avals
+
+def _custom_jvp_call_jaxpr_jvp(primals, tangents, jaxpr, jvp, num_consts):
+ _, primals = split_list(primals, [num_consts])
+ zero_tangents, tangents = split_list(tangents, [num_consts])
+ assert all(t is zero for t in zero_tangents)
+ outs = jvp.call_wrapped(*(primals + tangents))
+ primals_out, tangents_out = split_list(outs, [len(outs) // 2])
+ return primals_out, tangents_out
+
+def _custom_jvp_call_jaxpr_vmap(args, in_dims, jaxpr, jvp, num_consts):
+ size, = {x.shape[d] for x, d in zip(args, in_dims)
+ if d is not batching.not_mapped}
+ args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
+ else x for x, d in zip(args, in_dims)]
+ in_batched = [d is not batching.not_mapped for d in in_dims]
+ del in_dims
+ batched_jaxpr, out_batched = batching.batch_jaxpr(jaxpr, size, in_batched, False)
+ out_dims = [0 if b else batching.not_mapped for b in out_batched]
+
+ jvp_in_dims = [0 if b else batching.not_mapped for b in in_batched] * 2
+ batched_jvp = batching.batch_fun(jvp, jvp_in_dims, lambda: out_dims * 2)
+
+ batched_outs = custom_jvp_call_jaxpr_p.bind(
+ *args, jaxpr=batched_jaxpr, jvp=batched_jvp, num_consts=num_consts)
+ return batched_outs, out_dims
+
+# If a (multi)linear function is defined with a custom jvp, then
+# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. We transpose it
+# like a core.call.
+def _custom_jvp_call_jaxpr_transpose(cts, *args, jaxpr, **kwargs):
+ name = 'custom_jvp_call_jaxpr_linear'
+ return ad.call_transpose(core.call_p, dict(name=name), jaxpr.jaxpr,
+ tuple(jaxpr.literals) + args, cts)
+
+custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr')
+custom_jvp_call_jaxpr_p.multiple_results = True
+custom_jvp_call_jaxpr_p.def_impl(_custom_call_jaxpr_impl)
+custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_call_jaxpr_abstract_eval)
+ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp
+ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
+batching.primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
+xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
+ xla.lower_fun_initial_style(_custom_call_jaxpr_impl)
+
+
+### VJPs
+
+class custom_vjp:
+ """Set up a JAX-transformable function for a custom VJP rule definition.
+
+ This class is meant to be used as a function decorator. Instances are
+ callables that behave similarly to the underlying function to which the
+ decorator was applied, except when a reverse-mode differentiation
+ transformation (like ``jax.grad``) is applied, in which case a custom
+ user-supplied VJP rule function is used instead of tracing into and performing
+ automatic differentiation of the underlying function's implementation. There
+ is a single instance method, ``defvjp``, which defines the custom VJP rule.
+
+ This decorator precludes the use of forward-mode automatic differentiation.
+
+ For example:
+
+ import jax.numpy as np
+
+ @jax.custom_vjp
+ def f(x, y):
+ return np.sin(x) * y
+
+ def f_fwd(x, y):
+ return f(x, y), (np.cos(x), np.sin(x), y)
+
+ def f_bwd(res, g):
+ cos_x, sin_x, y = res
+ return (cos_x * g * y, -sin_x * g)
+
+ f.defvjp(f_fwd, f_bwd)
+
+ For a more detailed introduction, see the tutorial_.
+
+ .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
+ """
+
+ def __init__(self, fun, nondiff_argnums=()):
+ self.fun = fun
+ self.nondiff_argnums = nondiff_argnums
+ self.fwd = None
+ self.bwd = None
+ update_wrapper(self, fun)
+
+ def defvjp(self, fwd, bwd):
+ """Define a custom VJP rule for the function represented by this instance.
+
+ Args:
+ fwd: a Python callable representing the forward pass of the custom VJP
+ rule. When there are no ``nondiff_argnums``, the ``fwd`` function has
+ the same input signature as the underlying primal function. It should
+ return as output a pair, where the first element represents the primal
+ output and the second element represents any "residual" values to store
+ from the forward pass for use on the backward pass by the function
+ ``bwd``. Input arguments and elements of the output pair may be arrays
+ or nested tuples/lists/dicts thereof.
+ bwd: a Python callable representing the backward pass of the custom VJP
+ rule. When there are no ``nondiff_argnums``, the ``bwd`` function takes
+ two arguments, where the first is the "residual" values produced on the
+ forward pass by ``fwd``, and the second is the output cotangent with the
+ same structure as the primal function output. The output of ``bwd`` must
+ be a tuple of length equal to the number of arguments of the primal
+ function, and the tuple elements may be arrays or nested
+ tuples/lists/dicts thereof so as to match the structure of the primal
+ input arguments.
+
+ Returns:
+ None.
+
+ Example:
+
+ import jax.numpy as np
+
+ @jax.custom_vjp
+ def f(x, y):
+ return np.sin(x) * y
+
+ def f_fwd(x, y):
+ return f(x, y), (np.cos(x), np.sin(x), y)
+
+ def f_bwd(res, g):
+ cos_x, sin_x, y = res
+ return (cos_x * g * y, -sin_x * g)
+
+ f.defvjp(f_fwd, f_bwd)
+ """
+ self.fwd = fwd
+ self.bwd = bwd
+
+ def __call__(self, *args, **kwargs):
+ if not self.fwd or not self.bwd:
+ msg = "No VJP defined for custom_vjp function {} using defvjp."
+ raise AttributeError(msg.format(self.__name__))
+ args = _resolve_kwargs(self.fun, args, kwargs)
+ if self.nondiff_argnums:
+ dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
+ f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
+ static_args = [args[i] for i in self.nondiff_argnums]
+ fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
+ bwd = _add_args(lu.wrap_init(self.bwd), static_args, left=True)
+ else:
+ f_, dyn_args = lu.wrap_init(self.fun), args
+ fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
+ args_flat, in_tree = tree_flatten(dyn_args)
+ flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
+ flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
+ flat_bwd = _flatten_bwd(bwd, in_tree, out_trees)
+ out_flat = custom_vjp_call(flat_fun, *args_flat, fwd=flat_fwd, bwd=flat_bwd,
+ out_trees=out_trees)
+ try: out_tree = out_tree()
+ except lu.StoreException: out_tree, _ = out_trees()
+ return tree_unflatten(out_tree, out_flat)
+
+custom_vjp_call_p = core.Primitive('custom_vjp_call')
+custom_vjp_call_p.multiple_results = True
+custom_vjp_call = partial(_custom_deriv_call_bind, custom_vjp_call_p)
+custom_vjp_call_p.def_custom_bind(custom_vjp_call)
+custom_vjp_call_p.def_impl(_custom_call_impl)
+
+@transformation_with_equal_aux
+def _flatten_fwd(in_tree, *args):
+ py_args = tree_unflatten(in_tree, args)
+ py_outs, res = yield py_args, {}
+ out, out_tree = tree_flatten(py_outs)
+ res, res_tree = tree_flatten(res)
+ yield res + out, (out_tree, res_tree)
+
+@lu.transformation
+def _flatten_bwd(in_tree, out_trees, *args):
+ out_tree, res_tree = out_trees()
+ res, cts_out = split_list(args, [res_tree.num_leaves])
+ py_res = tree_unflatten(res_tree, res)
+ py_cts_out = tree_unflatten(out_tree, cts_out)
+ py_cts_in = yield (py_res, py_cts_out), {}
+ cts_in, in_tree2 = tree_flatten(py_cts_in)
+ if in_tree != in_tree2:
+ msg = ("Custom VJP rule must produce an output with the same container "
+ "(pytree) structure as the args tuple of the primal function, "
+ "and in particular must produce a tuple of length equal to the "
+ "number of arguments to the primal function, but got VJP output "
+ "structure {} for primal input structure {}.")
+ raise TypeError(msg.format(in_tree2, in_tree)) from None
+ yield cts_in
+
+def _custom_vjp_call_jvp(trace, call_primitive, fun, tracers, params):
+ primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
+ tangents_in = map(ad.instantiate_zeros, primals_in, tangents_in)
+ fwd, bwd, out_trees = params['fwd'], params['bwd'], params['out_trees']
+ res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
+ out_tree, res_tree = out_trees()
+ res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
+ avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
+ tangents_out = custom_lin_p.bind(
+ *it.chain(res, tangents_in), num_res=res_tree.num_leaves, bwd=bwd,
+ avals_out=avals_out)
+ return map(partial(ad.JVPTracer, trace), primals_out, tangents_out)
+ad.call_jvp_rules[custom_vjp_call_p] = _custom_vjp_call_jvp
+
+def _custom_vjp_call_vmap(trace, call_primitive, fun, tracers, params):
+ in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
+ fwd, bwd, out_trees = params['fwd'], params['bwd'], params['out_trees']
+ fun, out_dims = batching.batch_subtrace(fun, trace.master, in_dims)
+ fwd, out_dims2 = batching.batch_subtrace(fwd, trace.master, in_dims)
+ bwd = batching.batch_fun(bwd, out_dims2, in_dims)
+ out_vals = custom_vjp_call(fun, *in_vals, fwd=fwd, bwd=bwd,
+ out_trees=out_trees)
+ try: out_dims = out_dims()
+ except lu.StoreException: out_dims = out_dims2()
+ out_dims = out_dims[-len(out_vals) % len(out_dims):]
+ return [batching.BatchTracer(trace, v, d) for v, d in zip(out_vals, out_dims)]
+batching.call_batching_rules[custom_vjp_call_p] = _custom_vjp_call_vmap
+
+def _custom_vjp_call_partial_eval(trace, call_primitive, fun, tracers, params):
+ return custom_vjp_call_jaxpr(fun, params['fwd'], params['bwd'],
+ params['out_trees'], *tracers)
+pe.call_partial_eval_rules[custom_vjp_call_p] = _custom_vjp_call_partial_eval
+
+
+custom_lin_p = core.Primitive('custom_lin')
+custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
+custom_lin_p.multiple_results = True
+
+def _raise_custom_vjp_error_on_jvp(*args, **kwargs):
+ raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
+ "function.")
+custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
+
+def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
+ res, _ = split_list(invals, [num_res])
+ cts_out = map(ad.instantiate_zeros_aval, avals_out, cts_out)
+ cts_in = bwd.call_wrapped(*(res + cts_out))
+ cts_in_flat, in_tree = tree_flatten(cts_in)
+ return [None] * num_res + cts_in_flat
+ad.primitive_transposes[custom_lin_p] = _custom_lin_transpose
+
+
+def custom_vjp_call_jaxpr(fun, fwd, bwd, out_trees, *args):
+ in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
+ jaxpr, consts = _initial_style_jaxpr(fun, in_avals)
+ return custom_vjp_call_jaxpr_p.bind(
+ *it.chain(consts, args), jaxpr=jaxpr, fwd=fwd, bwd=bwd,
+ out_trees=out_trees, num_consts=len(consts))
+
+def _custom_vjp_call_jaxpr_jvp(primals, tangents, jaxpr, fwd, bwd, out_trees,
+ num_consts):
+ _, primals = split_list(primals, [num_consts])
+ zero_tangents, tangents = split_list(tangents, [num_consts])
+ assert all(t is zero for t in zero_tangents)
+ res_and_primals_out = fwd.call_wrapped(*primals)
+ out_tree, res_tree = out_trees()
+ res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
+ avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
+ tangents_out = custom_lin_p.bind(
+ *it.chain(res, tangents), num_res=res_tree.num_leaves, bwd=bwd,
+ avals_out=avals_out)
+ return primals_out, tangents_out
+
+def _custom_vjp_call_jaxpr_vmap(args, in_dims, jaxpr, fwd, bwd, out_trees,
+ num_consts):
+ size, = {x.shape[d] for x, d in zip(args, in_dims)
+ if d is not batching.not_mapped}
+ args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
+ else x for x, d in zip(args, in_dims)]
+ in_batched = [d is not batching.not_mapped for d in in_dims]
+ del in_dims
+ batched_jaxpr, out_batched = batching.batch_jaxpr(jaxpr, size, in_batched, False)
+ out_dims = [0 if b else batching.not_mapped for b in out_batched]
+
+ fwd_in_dims = [0 if b else batching.not_mapped for b in in_batched]
+ batched_fwd, fwd_out_dims = batching.batch_fun2(fwd, fwd_in_dims)
+ batched_bwd = batching.batch_fun(bwd, fwd_out_dims, fwd_in_dims)
+
+ batched_outs = custom_vjp_call_jaxpr_p.bind(
+ *args, jaxpr=batched_jaxpr, fwd=batched_fwd, bwd=batched_bwd,
+ out_trees=out_trees, num_consts=num_consts)
+ return batched_outs, out_dims
+
+custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr')
+custom_vjp_call_jaxpr_p.multiple_results = True
+custom_vjp_call_jaxpr_p.def_impl(_custom_call_jaxpr_impl)
+custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_call_jaxpr_abstract_eval)
+ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
+batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
+xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
+ xla.lower_fun_initial_style(_custom_call_jaxpr_impl)
diff --git a/jax/experimental/loops.py b/jax/experimental/loops.py
index f21394abc335..a12f453add2d 100644
--- a/jax/experimental/loops.py
+++ b/jax/experimental/loops.py
@@ -370,7 +370,7 @@ def end_tracing_body(self):
in_tracers=in_tracers,
out_tracers=body_out_tracers,
trace=self.trace)
- except ValueError as e:
+ except core.UnexpectedTracerError as e:
if "Tracer not among input tracers" in str(e):
raise ValueError("Body of cond_range or while_range should not use the "
"index variable returned by iterator.") from e
diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py
index 719c0dad5397..a52d7a81b3c3 100644
--- a/jax/experimental/ode.py
+++ b/jax/experimental/ode.py
@@ -23,20 +23,36 @@
"""
-import functools
+from functools import partial
+import operator as op
import time
import jax
-from jax.flatten_util import ravel_pytree
-import jax.lax
import jax.numpy as np
-import jax.ops
-from jax.test_util import check_vjp
+from jax import lax
+from jax import ops
+from jax.util import safe_map, safe_zip
+from jax.flatten_util import ravel_pytree
+from jax.test_util import check_grads
+from jax.tree_util import tree_map
+from jax import linear_util as lu
import numpy as onp
import scipy.integrate as osp_integrate
+map = safe_map
+zip = safe_zip
+
+
+def ravel_first_arg(f, unravel):
+ return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
+
+@lu.transformation
+def ravel_first_arg_(unravel, y_flat, *args):
+ y = unravel(y_flat)
+ ans = yield (y,) + args, {}
+ ans_flat, _ = ravel_pytree(ans)
+ yield ans_flat
-@jax.jit
def interp_fit_dopri(y0, y1, k, dt):
# Fit a polynomial to the results of a Runge-Kutta step.
dps_c_mid = np.array([
@@ -46,562 +62,234 @@ def interp_fit_dopri(y0, y1, k, dt):
y_mid = y0 + dt * np.dot(dps_c_mid, k)
return np.array(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
-
-@jax.jit
def fit_4th_order_polynomial(y0, y1, y_mid, dy0, dy1, dt):
- """Fit fourth order polynomial over function interval.
-
- Args:
- y0: function value at the start of the interval.
- y1: function value at the end of the interval.
- y_mid: function value at the mid-point of the interval.
- dy0: derivative value at the start of the interval.
- dy1: derivative value at the end of the interval.
- dt: width of the interval.
- Returns:
- Coefficients `[a, b, c, d, e]` for the polynomial
- p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
- """
- v = np.stack([dy0, dy1, y0, y1, y_mid])
- a = np.dot(np.hstack([-2. * dt, 2. * dt, np.array([-8., -8., 16.])]), v)
- b = np.dot(np.hstack([5. * dt, -3. * dt, np.array([18., 14., -32.])]), v)
- c = np.dot(np.hstack([-4. * dt, dt, np.array([-11., -5., 16.])]), v)
+ a = -2.*dt*dy0 + 2.*dt*dy1 - 8.*y0 - 8.*y1 + 16.*y_mid
+ b = 5.*dt*dy0 - 3.*dt*dy1 + 18.*y0 + 14.*y1 - 32.*y_mid
+ c = -4.*dt*dy0 + dt*dy1 - 11.*y0 - 5.*y1 + 16.*y_mid
d = dt * dy0
e = y0
return a, b, c, d, e
-
-@functools.partial(jax.jit, static_argnums=(0,))
def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
- """Empirically choose initial step size.
-
- Args:
- fun: Function to evaluate like `func(y, t)` to compute the time
- derivative of `y`.
- t0: initial time.
- y0: initial value for the state.
- order: order of interpolation
- rtol: relative local error tolerance for solver.
- atol: absolute local error tolerance for solver.
- f0: initial value for the derivative, computed from `func(t0, y0)`.
- Returns:
- Initial step size for odeint algorithm.
-
- Algorithm from:
- E. Hairer, S. P. Norsett G. Wanner,
- Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
- """
+ # Algorithm from:
+ # E. Hairer, S. P. Norsett G. Wanner,
+ # Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
scale = atol + np.abs(y0) * rtol
d0 = np.linalg.norm(y0 / scale)
d1 = np.linalg.norm(f0 / scale)
- order_pow = (1. / (order + 1.))
- h0 = np.where(np.any(np.asarray([d0 < 1e-5, d1 < 1e-5])),
- 1e-6,
- 0.01 * d0 / d1)
+ h0 = np.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
y1 = y0 + h0 * f0
f1 = fun(y1, t0 + h0)
d2 = np.linalg.norm((f1 - f0) / scale) / h0
- h1 = np.where(np.all(np.asarray([d1 <= 1e-15, d2 <= 1e-15])),
+ h1 = np.where((d1 <= 1e-15) & (d2 <= 1e-15),
np.maximum(1e-6, h0 * 1e-3),
- (0.01 / np.max(d1 + d2))**order_pow)
+ (0.01 / np.max(d1 + d2)) ** (1. / (order + 1.)))
return np.minimum(100. * h0, h1)
-
-@functools.partial(jax.jit, static_argnums=(0,))
def runge_kutta_step(func, y0, f0, t0, dt):
- """Take an arbitrary Runge-Kutta step and estimate error.
-
- Args:
- func: Function to evaluate like `func(y, t)` to compute the time
- derivative of `y`.
- y0: initial value for the state.
- f0: initial value for the derivative, computed from `func(t0, y0)`.
- t0: initial time.
- dt: time step.
- alpha, beta, c: Butcher tableau describing how to take the Runge-Kutta
- step.
-
- Returns:
- y1: estimated function at t1 = t0 + dt
- f1: derivative of the state at t1
- y1_error: estimated error at t1
- k: list of Runge-Kutta coefficients `k` used for calculating these terms.
- """
# Dopri5 Butcher tableaux
alpha = np.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
- beta = np.array(
- [[1 / 5, 0, 0, 0, 0, 0, 0],
- [3 / 40, 9 / 40, 0, 0, 0, 0, 0],
- [44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
- [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
- [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
- [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]])
- c_sol = np.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84,
- 0])
+ beta = np.array([
+ [1 / 5, 0, 0, 0, 0, 0, 0],
+ [3 / 40, 9 / 40, 0, 0, 0, 0, 0],
+ [44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
+ [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
+ [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
+ [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]
+ ])
+ c_sol = np.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0])
c_error = np.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
11 / 84 - 649 / 6300, -1. / 60.])
- def _fori_body_fun(i, val):
+ def body_fun(i, k):
ti = t0 + dt * alpha[i-1]
- yi = y0 + dt * np.dot(beta[i-1, :], val)
+ yi = y0 + dt * np.dot(beta[i-1, :], k)
ft = func(yi, ti)
- return jax.ops.index_update(val, jax.ops.index[i, :], ft)
+ return ops.index_update(k, jax.ops.index[i, :], ft)
- k = jax.lax.fori_loop(
- 1,
- 7,
- _fori_body_fun,
- jax.ops.index_update(np.zeros((7, f0.shape[0])), jax.ops.index[0, :], f0))
+ k = ops.index_update(np.zeros((7, f0.shape[0])), ops.index[0, :], f0)
+ k = lax.fori_loop(1, 7, body_fun, k)
y1 = dt * np.dot(c_sol, k) + y0
y1_error = dt * np.dot(c_error, k)
f1 = k[-1]
return y1, f1, y1_error, k
-
-@jax.jit
def error_ratio(error_estimate, rtol, atol, y0, y1):
err_tol = atol + rtol * np.maximum(np.abs(y0), np.abs(y1))
err_ratio = error_estimate / err_tol
- return np.mean(err_ratio**2)
+ return np.mean(err_ratio ** 2)
-
-@jax.jit
-def optimal_step_size(last_step,
- mean_error_ratio,
- safety=0.9,
- ifactor=10.0,
- dfactor=0.2,
- order=5.0):
+def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
+ dfactor=0.2, order=5.0):
"""Compute optimal Runge-Kutta stepsize."""
mean_error_ratio = np.max(mean_error_ratio)
- dfactor = np.where(mean_error_ratio < 1,
- 1.0,
- dfactor)
+ dfactor = np.where(mean_error_ratio < 1, 1.0, dfactor)
err_ratio = np.sqrt(mean_error_ratio)
factor = np.maximum(1.0 / ifactor,
- np.minimum(err_ratio**(1.0 / order) / safety,
- 1.0 / dfactor))
- return np.where(mean_error_ratio == 0,
- last_step * ifactor,
- last_step / factor,)
-
+ np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
+ return np.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
-@functools.partial(jax.jit, static_argnums=(0,))
-def odeint(ofunc, y0, t, *args, **kwargs):
+def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
"""Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
Args:
- ofunc: Function to evaluate `yt = ofunc(y, t, *args)` that
- returns the time derivative of `y`.
- y0: initial value for the state.
- t: Timespan for `ofunc` evaluation like `np.linspace(0., 10., 101)`.
- *args: Additional arguments to `ofunc` beyond y0 and t.
- **kwargs: Two relevant keyword arguments:
- 'rtol': Relative local error tolerance for solver.
- 'atol': Absolute local error tolerance for solver.
- 'mxstep': Maximum number of steps to take for each timepoint.
+ func: function to evaluate the time derivative of the solution `y` at time
+ `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
+ y0: array or pytree of arrays representing the initial value for the state.
+ t: array of float times for evaluation, like `np.linspace(0., 10., 101)`,
+ in which the values must be strictly increasing.
+ *args: tuple of additional arguments for `func`.
+ rtol: float, relative local error tolerance for solver (optional).
+ atol: float, absolute local error tolerance for solver (optional).
+ mxstep: int, maximum number of steps to take for each timepoint (optional).
Returns:
- Integrated system values at each timepoint.
+ Values of the solution `y` (i.e. integrated system values) at each time
+ point in `t`, represented as an array (or pytree of arrays) with the same
+ shape/structure as `y0` except with a new leading axis of length `len(t)`.
"""
- rtol = kwargs.get('rtol', 1.4e-8)
- atol = kwargs.get('atol', 1.4e-8)
- mxstep = kwargs.get('mxstep', np.inf)
-
- @functools.partial(jax.jit, static_argnums=(0,))
- def _fori_body_fun(func, i, val):
- """Internal fori_loop body to interpolate an integral at each timestep."""
- t, cur_y, cur_f, cur_t, dt, last_t, interp_coeff, solution = val
- cur_y, cur_f, cur_t, dt, last_t, interp_coeff, _ = jax.lax.while_loop(
- lambda x: (x[2] < t[i]) & (x[-1] < mxstep),
- functools.partial(_while_body_fun, func),
- (cur_y, cur_f, cur_t, dt, last_t, interp_coeff, 0.))
-
- relative_output_time = (t[i] - last_t) / (cur_t - last_t)
- out_x = np.polyval(interp_coeff, relative_output_time)
-
- return (t, cur_y, cur_f, cur_t, dt, last_t, interp_coeff,
- jax.ops.index_update(solution,
- jax.ops.index[i, :],
- out_x))
-
- @functools.partial(jax.jit, static_argnums=(0,))
- def _while_body_fun(func, x):
- """Internal while_loop body to determine interpolation coefficients."""
- cur_y, cur_f, cur_t, dt, last_t, interp_coeff, j = x
- next_t = cur_t + dt
- next_y, next_f, next_y_error, k = runge_kutta_step(
- func, cur_y, cur_f, cur_t, dt)
- error_ratios = error_ratio(next_y_error, rtol, atol, cur_y, next_y)
- new_interp_coeff = interp_fit_dopri(cur_y, next_y, k, dt)
- dt = optimal_step_size(dt, error_ratios)
-
- next_j = j + 1
- new_rav, unravel = ravel_pytree(
- (next_y, next_f, next_t, dt, cur_t, new_interp_coeff, next_j))
- old_rav, _ = ravel_pytree(
- (cur_y, cur_f, cur_t, dt, last_t, interp_coeff, next_j))
-
- return unravel(np.where(np.all(error_ratios <= 1.),
- new_rav,
- old_rav))
-
- func = lambda y, t: ofunc(y, t, *args)
- f0 = func(y0, t[0])
- dt = initial_step_size(func, t[0], y0, 4, rtol, atol, f0)
+ return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
+
+@partial(jax.jit, static_argnums=(0, 1, 2, 3))
+def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
+ y0, unravel = ravel_pytree(y0)
+ func = ravel_first_arg(func, unravel)
+ out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
+ return jax.vmap(unravel)(out)
+
+@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
+def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
+ func_ = lambda y, t: func(y, t, *args)
+
+ def scan_fun(carry, target_t):
+
+ def cond_fun(state):
+ i, _, _, t, _, _, _ = state
+ return (t < target_t) & (i < mxstep)
+
+ def body_fun(state):
+ i, y, f, t, dt, last_t, interp_coeff = state
+ next_y, next_f, next_y_error, k = runge_kutta_step(func_, y, f, t, dt)
+ next_t = t + dt
+ error_ratios = error_ratio(next_y_error, rtol, atol, y, next_y)
+ new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
+ dt = optimal_step_size(dt, error_ratios)
+
+ new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
+ old = [i + 1, y, f, t, dt, last_t, interp_coeff]
+ return map(partial(np.where, np.all(error_ratios <= 1.)), new, old)
+
+ _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
+ _, _, t, _, last_t, interp_coeff = carry
+ relative_output_time = (target_t - last_t) / (t - last_t)
+ y_target = np.polyval(interp_coeff, relative_output_time)
+ return carry, y_target
+
+ f0 = func_(y0, ts[0])
+ dt = initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0)
interp_coeff = np.array([y0] * 5)
+ init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
+ _, ys = lax.scan(scan_fun, init_carry, ts[1:])
+ return np.concatenate((y0[None], ys))
- return jax.lax.fori_loop(1,
- t.shape[0],
- functools.partial(_fori_body_fun, func),
- (t, y0, f0, t[0], dt, t[0], interp_coeff,
- jax.ops.index_update(
- np.zeros((t.shape[0], y0.shape[0])),
- jax.ops.index[0, :],
- y0)))[-1]
-
+def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
+ ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
+ return ys, (ys, ts, args)
-def vjp_odeint(ofunc, y0, t, *args, **kwargs):
- """Return a function that calculates `vjp(odeint(func(y, t, *args))`.
+def _odeint_rev(func, rtol, atol, mxstep, res, g):
+ ys, ts, args = res
- Args:
- ofunc: Function `ydot = ofunc(y, t, *args)` to compute the time
- derivative of `y`.
- y0: initial value for the state.
- t: Timespan for `ofunc` evaluation like `np.linspace(0., 10., 101)`.
- *args: Additional arguments to `ofunc` beyond y0 and t.
- **kwargs: Two relevant keyword arguments:
- 'rtol': Relative local error tolerance for solver.
- 'atol': Absolute local error tolerance for solver.
- 'mxstep': Maximum number of steps to take for each timepoint.
-
- Returns:
- VJP function `vjp = vjp_all(g)` where `yt = ofunc(y, t, *args)`
- and g is used for VJP calculation. To evaluate the gradient w/ the VJP,
- supply `g = np.ones_like(yt)`. To evaluate the reverse Jacobian do a vmap
- over the standard basis of yt.
- """
- rtol = kwargs.get('rtol', 1.4e-8)
- atol = kwargs.get('atol', 1.4e-8)
- mxstep = kwargs.get('mxstep', np.inf)
-
- flat_args, unravel_args = ravel_pytree(args)
- flat_func = lambda y, t, flat_args: ofunc(y, t, *unravel_args(flat_args))
-
- @jax.jit
- def aug_dynamics(augmented_state, t, flat_args):
+ def aug_dynamics(augmented_state, t, *args):
"""Original system augmented with vjp_y, vjp_t and vjp_args."""
- state_len = int(np.floor_divide(
- augmented_state.shape[0] - flat_args.shape[0] - 1, 2))
- y = augmented_state[:state_len]
- adjoint = augmented_state[state_len:2*state_len]
- dy_dt, vjpfun = jax.vjp(flat_func, y, t, flat_args)
- return np.hstack([np.ravel(dy_dt), np.hstack(vjpfun(-adjoint))])
-
- rev_aug_dynamics = lambda y, t, flat_args: -aug_dynamics(y, -t, flat_args)
-
- @jax.jit
- def _fori_body_fun(i, val):
- """fori_loop function for VJP calculation."""
- rev_yt, rev_t, rev_tarray, rev_gi, vjp_y, vjp_t0, vjp_args, time_vjp_list = val
- this_yt = rev_yt[i, :]
- this_t = rev_t[i]
- this_tarray = rev_tarray[i, :]
- this_gi = rev_gi[i, :]
- # this is g[i-1, :] when g has been reversed
- this_gim1 = rev_gi[i+1, :]
- state_len = this_yt.shape[0]
- vjp_cur_t = np.dot(flat_func(this_yt, this_t, flat_args), this_gi)
- vjp_t0 = vjp_t0 - vjp_cur_t
- # Run augmented system backwards to the previous observation.
- aug_y0 = np.hstack((this_yt, vjp_y, vjp_t0, vjp_args))
- aug_ans = odeint(rev_aug_dynamics,
- aug_y0,
- this_tarray,
- flat_args,
- rtol=rtol,
- atol=atol,
- mxstep=mxstep)
- vjp_y = aug_ans[1][state_len:2*state_len] + this_gim1
- vjp_t0 = aug_ans[1][2*state_len]
- vjp_args = aug_ans[1][2*state_len+1:]
- time_vjp_list = jax.ops.index_update(time_vjp_list, i, vjp_cur_t)
- return rev_yt, rev_t, rev_tarray, rev_gi, vjp_y, vjp_t0, vjp_args, time_vjp_list
-
- @jax.jit
- def vjp_all(g, yt, t):
- """Calculate the VJP g * Jac(odeint(ofunc, y0, t, *args))."""
- rev_yt = yt[-1::-1, :]
- rev_t = t[-1::-1]
- rev_tarray = -np.array([t[-1:0:-1], t[-2::-1]]).T
- rev_gi = g[-1::-1, :]
-
- vjp_y = g[-1, :]
- vjp_t0 = 0.
- vjp_args = np.zeros_like(flat_args)
- time_vjp_list = np.zeros_like(t)
-
- result = jax.lax.fori_loop(0,
- rev_t.shape[0]-1,
- _fori_body_fun,
- (rev_yt,
- rev_t,
- rev_tarray,
- rev_gi,
- vjp_y,
- vjp_t0,
- vjp_args,
- time_vjp_list))
-
- time_vjp_list = jax.ops.index_update(result[-1], -1, result[-3])
- vjp_times = np.hstack(time_vjp_list)[::-1]
- vjp_args = unravel_args(result[-2])
- return (result[-4], vjp_times, *vjp_args)
-
- primals_out = odeint(flat_func, y0, t, flat_args, rtol=rtol, atol=atol, mxstep=mxstep)
- vjp_fun = lambda g: vjp_all(g, primals_out, t)
-
- return primals_out, vjp_fun
-
-
-def build_odeint(ofunc, rtol=1.4e-8, atol=1.4e-8, mxstep=onp.inf):
- """Return `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)`.
-
- Given the function ofunc(y, t, *args), return the jitted function
- `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)` with
- the VJP of `f` defined using `vjp_odeint`, where:
-
- `y0` is the initial condition of the ODE integration,
- `t` is the time course of the integration, and
- `*args` are all other arguments to `ofunc`.
-
- Args:
- ofunc: The function to be wrapped into an ODE integration.
- rtol: relative local error tolerance for solver.
- atol: absolute local error tolerance for solver.
- mxstep: Maximum number of steps to take for each timepoint.
-
- Returns:
- `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)`
- """
- ct_odeint = jax.custom_transforms(
- lambda y0, t, *args: odeint(ofunc, y0, t, *args, rtol=rtol, atol=atol, mxstep=mxstep))
-
- v = lambda y0, t, *args: vjp_odeint(ofunc, y0, t, *args, rtol=rtol, atol=atol, mxstep=mxstep)
- jax.defvjp_all(ct_odeint, v)
-
- return jax.jit(ct_odeint)
-
-
-def my_odeint_grad(fun):
- """Calculate the Jacobian of an odeint."""
- @jax.jit
- def _gradfun(*args, **kwargs):
- ys, pullback = vjp_odeint(fun, *args, **kwargs)
- my_grad = pullback(np.ones_like(ys))
- return my_grad
- return _gradfun
-
-
-def my_odeint_jacrev(fun):
- """Calculate the Jacobian of an odeint."""
- @jax.jit
- def _jacfun(*args, **kwargs):
- ys, pullback = vjp_odeint(fun, *args, **kwargs)
- my_jac = jax.vmap(pullback)(jax.api._std_basis(ys))
- my_jac = jax.api.tree_map(
- functools.partial(jax.api._unravel_array_into_pytree, ys, 0), my_jac)
- my_jac = jax.api.tree_transpose(
- jax.api.tree_structure(args), jax.api.tree_structure(ys), my_jac)
- return my_jac
- return _jacfun
-
-
-def nd(f, x, eps=0.0001):
- flat_x, unravel = ravel_pytree(x)
- dim = len(flat_x)
- g = onp.zeros_like(flat_x)
- for i in range(dim):
- d = onp.zeros_like(flat_x)
- d[i] = eps
- g[i] = (f(unravel(flat_x + d)) - f(unravel(flat_x - d))) / (2.0 * eps)
- return g
-
-
-def test_grad_vjp_odeint():
- """Compare numerical and exact differentiation of a simple odeint."""
-
- def f(y, t, arg1, arg2):
- return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
-
- def onearg_odeint(args):
- return np.sum(
- odeint(f, *args, atol=1e-8, rtol=1e-8))
-
- dim = 10
- t0 = 0.1
- t1 = 0.2
- y0 = np.linspace(0.1, 0.9, dim)
- arg1 = 0.1
- arg2 = 0.2
- wrap_args = (y0, np.array([t0, t1]), arg1, arg2)
-
- numerical_grad = nd(onearg_odeint, wrap_args)
- exact_grad, _ = ravel_pytree(my_odeint_grad(f)(*wrap_args))
-
- assert np.allclose(numerical_grad, exact_grad)
-
-
-def plot_gradient_field(ax, func, xlimits, ylimits, numticks=30):
- """Plot the gradient field of `func` on `ax`."""
- x = np.linspace(*xlimits, num=numticks)
- y = np.linspace(*ylimits, num=numticks)
- x_mesh, y_mesh = np.meshgrid(x, y)
- zs = jax.vmap(func)(y_mesh.ravel(), x_mesh.ravel())
- z_mesh = zs.reshape(x_mesh.shape)
- ax.quiver(x_mesh, y_mesh, np.ones(z_mesh.shape), z_mesh)
- ax.set_xlim(xlimits)
- ax.set_ylim(ylimits)
-
-
-@jax.jit
-def pend(y, t, arg1, arg2):
- """Simple pendulum system for odeint testing."""
- del t
+ y, y_bar, *_ = augmented_state
+ y_dot, vjpfun = jax.vjp(func, y, -t, *args)
+ return (-y_dot, *vjpfun(y_bar))
+
+ y_bar = g[-1]
+ ts_bar = []
+ t0_bar = 0.
+
+ def scan_fun(carry, i):
+ y_bar, t0_bar, args_bar = carry
+ # Compute effect of moving measurement time
+ t_bar = np.dot(func(ys[i], ts[i], *args), g[i])
+ t0_bar = t0_bar - t_bar
+ # Run augmented system backwards to previous observation
+ _, y_bar, t0_bar, args_bar = odeint(
+ aug_dynamics, (ys[i], y_bar, t0_bar, args_bar), np.array([ts[i - 1], ts[i]]),
+ *args, rtol=rtol, atol=atol, mxstep=mxstep)
+ y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
+ # Add gradient from current output
+ y_bar = y_bar + g[i - 1]
+ return (y_bar, t0_bar, args_bar), t_bar
+
+ init_carry = (g[-1], 0., tree_map(np.zeros_like, args))
+ (y_bar, t0_bar, args_bar), rev_ts_bar = lax.scan(
+ scan_fun, init_carry, np.arange(len(ts) - 1, 0, -1))
+ ts_bar = np.concatenate([np.array([t0_bar]), rev_ts_bar[::-1]])
+ return (y_bar, ts_bar, *args_bar)
+
+_odeint.defvjp(_odeint_fwd, _odeint_rev)
+
+
+def pend(np, y, _, m, g):
theta, omega = y
- dydt = np.array([omega, -arg1*omega - arg2*np.sin(theta)])
- return dydt
-
-
-@jax.jit
-def swoop(y, t, arg1, arg2):
- return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2)
-
-
-@jax.jit
-def decay(y, t, arg1, arg2):
- return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
-
+ return [omega, -m * omega - g * np.sin(theta)]
def benchmark_odeint(fun, y0, tspace, *args):
"""Time performance of JAX odeint method against scipy.integrate.odeint."""
- n_trials = 5
+ n_trials = 10
+ n_repeat = 100
+ y0, tspace = onp.array(y0), onp.array(tspace)
+ onp_fun = partial(fun, onp)
+ scipy_times = []
for k in range(n_trials):
start = time.time()
- scipy_result = osp_integrate.odeint(fun, y0, tspace, args)
+ for _ in range(n_repeat):
+ scipy_result = osp_integrate.odeint(onp_fun, y0, tspace, args)
end = time.time()
- print('scipy odeint elapsed time ({} of {}): {}'.format(
- k+1, n_trials, end-start))
+ print('scipy odeint elapsed time ({} of {}): {}'.format(k+1, n_trials, end-start))
+ scipy_times.append(end - start)
+ y0, tspace = np.array(y0), np.array(tspace)
+ jax_fun = partial(fun, np)
+ jax_times = []
for k in range(n_trials):
start = time.time()
- jax_result = odeint(fun, np.array(y0), np.array(tspace), *args)
+ for _ in range(n_repeat):
+ jax_result = odeint(jax_fun, y0, tspace, *args)
jax_result.block_until_ready()
end = time.time()
- print('JAX odeint elapsed time ({} of {}): {}'.format(
- k+1, n_trials, end-start))
+ print('JAX odeint elapsed time ({} of {}): {}'.format(k+1, n_trials, end-start))
+ jax_times.append(end - start)
+ print('(avg scipy time) / (avg jax time) = {}'.format(
+ onp.mean(scipy_times[1:]) / onp.mean(jax_times[1:])))
print('norm(scipy result-jax result): {}'.format(
np.linalg.norm(np.asarray(scipy_result) - jax_result)))
-
return scipy_result, jax_result
-
def pend_benchmark_odeint():
- _, _ = benchmark_odeint(pend,
- (onp.pi - 0.1, 0.0),
- onp.linspace(0., 10., 101),
- 0.25,
- 9.8)
-
-
-def test_odeint_grad():
- """Test the gradient behavior of various ODE integrations."""
- def _test_odeint_grad(func, *args):
- def onearg_odeint(fargs):
- return np.sum(odeint(func, *fargs))
-
- numerical_grad = nd(onearg_odeint, args)
- exact_grad, _ = ravel_pytree(my_odeint_grad(func)(*args))
- assert np.allclose(numerical_grad, exact_grad)
-
- ts = np.array((0.1, 0.2))
- y0 = np.linspace(0.1, 0.9, 10)
- big_y0 = np.linspace(1.1, 10.9, 10)
-
- # check pend()
- for cond in (
- (np.array((onp.pi - 0.1, 0.0)), ts, 0.25, 0.98),
- (np.array((onp.pi * 0.1, 0.0)), ts, 0.1, 0.4),
- ):
- _test_odeint_grad(pend, *cond)
-
- # check swoop
- for cond in (
- (y0, ts, 0.1, 0.2),
- (big_y0, ts, 0.1, 0.3),
- ):
- _test_odeint_grad(swoop, *cond)
-
- # check decay
- for cond in (
- (y0, ts, 0.1, 0.2),
- (big_y0, ts, 0.1, 0.3),
- ):
- _test_odeint_grad(decay, *cond)
-
-
-def test_odeint_vjp():
- """Use check_vjp to check odeint VJP calculations."""
-
- # check pend()
- y = np.array([np.pi - 0.1, 0.0])
- t = np.linspace(0., 10., 11)
- b = 0.25
- c = 9.8
- wrap_args = (y, t, b, c)
- pend_odeint_wrap = lambda y, t, *args: odeint(pend, y, t, *args)
- pend_vjp_wrap = lambda y, t, *args: vjp_odeint(pend, y, t, *args)
- check_vjp(pend_odeint_wrap, pend_vjp_wrap, wrap_args)
-
- # check swoop()
- y = np.array([0.1])
- t = np.linspace(0., 10., 11)
- arg1 = 0.1
- arg2 = 0.2
- wrap_args = (y, t, arg1, arg2)
- swoop_odeint_wrap = lambda y, t, *args: odeint(swoop, y, t, *args)
- swoop_vjp_wrap = lambda y, t, *args: vjp_odeint(swoop, y, t, *args)
- check_vjp(swoop_odeint_wrap, swoop_vjp_wrap, wrap_args)
-
- # decay() check_vjp hangs!
-
-
-def test_defvjp_all():
- """Use build_odeint to check odeint VJP calculations."""
- n_trials = 5
- swoop_build = build_odeint(swoop)
- jacswoop = jax.jit(jax.jacrev(swoop_build))
- y = np.array([0.1])
- t = np.linspace(0., 2., 11)
- arg1 = 0.1
- arg2 = 0.2
- wrap_args = (y, t, arg1, arg2)
- for k in range(n_trials):
- start = time.time()
- rslt = jacswoop(*wrap_args)
- rslt.block_until_ready()
- end = time.time()
- print('JAX jacrev elapsed time ({} of {}): {}'.format(
- k+1, n_trials, end-start))
+ _, _ = benchmark_odeint(pend, [np.pi - 0.1, 0.0], np.linspace(0., 10., 101),
+ 0.25, 9.8)
+def pend_check_grads():
+ def f(y0, ts, *args):
+ return odeint(partial(pend, np), y0, ts, *args)
-if __name__ == '__main__':
+ y0 = [np.pi - 0.1, 0.0]
+ ts = np.linspace(0., 1., 11)
+ args = (0.25, 9.8)
- test_odeint_grad()
- test_odeint_vjp()
+ check_grads(f, (y0, ts, *args), modes=["rev"], order=2,
+ atol=1e-1, rtol=1e-1)
+
+
+if __name__ == '__main__':
+ pend_benchmark_odeint()
+ pend_check_grads()
diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py
index ef718f228fc2..687423bbd92e 100644
--- a/jax/interpreters/ad.py
+++ b/jax/interpreters/ad.py
@@ -277,8 +277,8 @@ def get_primitive_transpose(p):
return primitive_transposes[p]
except KeyError as err:
raise NotImplementedError(
- "Reverse-mode differentiation rule for '{}' not implemented".format(p)
- ) from err
+ "Transpose rule (for reverse-mode differentiation) for '{}' "
+ "not implemented".format(p)) from err
class JVPTrace(Trace):
@@ -307,15 +307,19 @@ def process_primitive(self, primitive, tracers, params):
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
- primals = [t.primal for t in tracers]
- tangents = [t.tangent for t in tracers]
- nonzero_tangents, in_tree_def = tree_flatten(tangents)
- f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), len(primals), in_tree_def)
- name = params.get('name', f.__name__)
- params = dict(params, name=wrap_name(name, 'jvp'))
- result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
- primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
- return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
+ if call_primitive in call_jvp_rules:
+ return call_jvp_rules[call_primitive](self, call_primitive, f, tracers, params)
+ else:
+ primals = [t.primal for t in tracers]
+ tangents = [t.tangent for t in tracers]
+ nonzero_tangents, in_tree_def = tree_flatten(tangents)
+ f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
+ len(primals), in_tree_def)
+ name = params.get('name', f.__name__)
+ params = dict(params, name=wrap_name(name, 'jvp'))
+ result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
+ primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
+ return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def post_process_call(self, call_primitive, out_tracers, params):
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
@@ -371,7 +375,8 @@ def _primal_tangent_shapes_match(primal, tangent):
# -------------------- Primitives --------------------
-primitive_jvps: Dict[core.Primitive, Callable] = {}
+primitive_jvps : Dict[core.Primitive, Callable] = {}
+call_jvp_rules : Dict[core.Primitive, Callable] = {}
primitive_transposes: Dict[core.Primitive, Callable] = {}
@@ -430,65 +435,6 @@ def add_tangents(x, y):
return add_jaxvals(x, y)
-def defvjp_all(prim, custom_vjp):
- # see https://github.com/google/jax/pull/636
- name = prim.name
-
- def fun_jvp(xs, ts, **params):
- ts = map(instantiate_zeros, xs, ts)
- primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
- primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
- if prim.multiple_results:
- return primals, tangents
- else:
- primal, = primals
- tangent, = tangents
- return primal, tangent
- primitive_jvps[prim] = fun_jvp
-
- fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
- fun_jvp_p.multiple_results = True
- def fun_jvp_partial_eval(trace, *tracers, **params):
- primals, tangents = split_list(tracers, [len(tracers) // 2])
- primals_out, vjp_py = custom_vjp(*primals, **params)
- if not prim.multiple_results:
- primals_out = [primals_out]
- out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
- ct_pvals = [pe.PartialVal((aval, core.unit)) for aval in out_avals]
- jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
- tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
- num_res=len(res), out_avals=out_avals)
- return primals_out + tangents_out
- pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
-
- fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
- fun_lin_p.multiple_results = True
- fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
- def fun_lin_transpose(cts, *args, **kwargs):
- num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
- res, _ = split_list(args, [num_res])
- cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
- outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
- return [None] * num_res + outs
- primitive_transposes[fun_lin_p] = fun_lin_transpose
-
-def defvjp(prim, *vjps):
- def vjpmaker(*primals):
- ans = prim.bind(*primals)
- vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
- for x, vjp in zip(primals, vjps)]
- return ans, vjpfun
- defvjp_all(prim, vjpmaker)
-
-def defvjp2(prim, *vjps):
- def vjpmaker(*primals):
- ans = prim.bind(*primals)
- vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
- for x, vjp in zip(primals, vjps)]
- return ans, vjpfun
- defvjp_all(prim, vjpmaker)
-
-
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
assert isinstance(prim, Primitive)
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py
index befd76f07fab..83b61ae2b8b9 100644
--- a/jax/interpreters/batching.py
+++ b/jax/interpreters/batching.py
@@ -28,28 +28,49 @@
map = safe_map
-def batch(fun: lu.WrappedFun, in_vals, in_dims, out_dim_dests):
- size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
- out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
- return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
-
-def batch_fun(fun: lu.WrappedFun, in_vals, in_dims):
- with new_master(BatchTrace) as master:
- fun, out_dims = batch_subtrace(fun, master, in_dims)
- out_vals = fun.call_wrapped(*in_vals)
- del master
- return out_vals, out_dims()
+def batch(fun : lu.WrappedFun, in_vals, in_dims, out_dim_dests):
+ # executes a batched version of `fun` following out_dim_dests
+ batched_fun = batch_fun(fun, in_dims, out_dim_dests)
+ return batched_fun.call_wrapped(*in_vals)
@lu.transformation_with_aux
-def batch_subtrace(master, in_dims, *in_vals):
+def batch_subtrace(master, in_dims, *in_vals, **params):
trace = BatchTrace(master, core.cur_sublevel())
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_dims)]
- outs = yield in_tracers, {}
+ outs = yield in_tracers, params
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
yield out_vals, out_dims
+def batch_fun(fun : lu.WrappedFun, in_dims, out_dim_dests):
+ # transformation version of batch, which doesn't call the function
+ fun, out_dims = batch_subtrace(fun)
+ return _batch_fun(fun, in_dims, out_dims, out_dim_dests)
+
+@lu.transformation
+def _batch_fun(in_dims, out_dims, out_dim_dests, *in_vals, **params):
+ in_dims = in_dims() if callable(in_dims) else in_dims
+ size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
+ with new_master(BatchTrace) as master:
+ out_vals = yield (master, in_dims,) + in_vals, params
+ del master
+ out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
+ out_vals = map(partial(matchaxis, size), out_dims(), out_dim_dests, out_vals)
+ yield out_vals
+
+def batch_fun2(fun : lu.WrappedFun, in_dims):
+ # like `batch_fun` but returns output batch dims (so no out_dim_dests)
+ fun, out_dims = batch_subtrace(fun)
+ return _batch_fun2(fun, in_dims), out_dims
+
+@lu.transformation
+def _batch_fun2(in_dims, *in_vals, **params):
+ with new_master(BatchTrace) as master:
+ out_vals = yield (master, in_dims,) + in_vals, params
+ del master
+ yield out_vals
+
### tracer
@@ -112,17 +133,19 @@ def process_primitive(self, primitive, tracers, params):
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
- name = params.get('name', f.__name__)
- params = dict(params, name=wrap_name(name, 'vmap'))
- if call_primitive in pe.map_primitives:
+ params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
+ if call_primitive in call_batching_rules:
+ return call_batching_rules[call_primitive](self, call_primitive, f, tracers, params)
+ elif call_primitive in pe.map_primitives:
return self.process_map(call_primitive, f, tracers, params)
- vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
- if all(bdim is not_mapped for bdim in dims):
- return call_primitive.bind(f, *vals, **params)
else:
- f, dims_out = batch_subtrace(f, self.master, dims)
- vals_out = call_primitive.bind(f, *vals, **params)
- return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
+ vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
+ if all(bdim is not_mapped for bdim in dims):
+ return call_primitive.bind(f, *vals, **params)
+ else:
+ f, dims_out = batch_subtrace(f, self.master, dims)
+ vals_out = call_primitive.bind(f, *vals, **params)
+ return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
@@ -151,7 +174,8 @@ def todo(x):
### primitives
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
-primitive_batchers: Dict[core.Primitive, BatchingRule] = {}
+primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
+call_batching_rules : Dict[core.Primitive, BatchingRule] = {}
def get_primitive_batcher(p):
try:
diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py
index ec5a247f11c6..b5ac96507248 100644
--- a/jax/interpreters/partial_eval.py
+++ b/jax/interpreters/partial_eval.py
@@ -124,7 +124,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
name = wrap_name(name, 'pe')
params = dict(params, name=name)
if call_primitive in call_partial_eval_rules:
- return call_partial_eval_rules[call_primitive](self, f, tracers, params)
+ return call_partial_eval_rules[call_primitive](self, call_primitive, f, tracers, params)
if call_primitive in map_primitives:
return self.process_map(call_primitive, f, tracers, params)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
@@ -285,7 +285,7 @@ def __init__(self, trace, pval, recipe):
assert isinstance(pval, PartialVal)
pv, const = pval
if isinstance(const, Tracer) and const._trace.level >= trace.level:
- trace.escaped_tracer_error(
+ core.escaped_tracer_error(
"Tracer from a higher level: {} in trace {}".format(const, trace))
self._trace = trace
self.pval = pval
@@ -348,10 +348,8 @@ def partial_val_aval(pv, const):
else:
raise TypeError(pv)
-
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate=False, stage_out_calls=False, bottom=False):
- """Traces a function, given abstract inputs, to a jaxpr."""
trace_type = StagingJaxprTrace if stage_out_calls else JaxprTrace
with new_master(trace_type, bottom=bottom) as master:
fun = trace_to_subjaxpr(fun, master, instantiate)
@@ -454,7 +452,7 @@ def getconstvar(c):
processed_eqn_ids.add(recipe.eqn_id)
elif isinstance(recipe, LambdaBinding):
if not any(t is in_tracer for in_tracer in in_tracers):
- t._trace.escaped_tracer_error("Tracer not among input tracers {}".format(t))
+ core.escaped_tracer_error("Tracer not among input tracers {}".format(t))
assert in_tracers, "Lambda binding with no args"
elif isinstance(recipe, FreeVar):
env[getvar(t)] = recipe.val
@@ -539,7 +537,7 @@ def _split_aval(unknown, aval):
remat_call_p.def_impl(core.call_impl)
remat_call_p.multiple_results = True
-def _remat_partial_eval(trace, f, tracers, params):
+def _remat_partial_eval(trace, _, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
diff --git a/jax/lax/lax.py b/jax/lax/lax.py
index d03dea529df5..2e05b0b2b1a3 100644
--- a/jax/lax/lax.py
+++ b/jax/lax/lax.py
@@ -1853,7 +1853,9 @@ def _pow_jvp_rhs(g, ans, x, y):
ad.defjvp_zero(xor_p)
def _add_transpose(t, x, y):
- # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y) # not affine
+ # The following linearity assertion is morally true, but because in some cases we
+ # instantiate zeros for convenience, it doesn't always hold.
+ # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
return [t, t]
add_p = standard_naryop([_num, _num], 'add')
@@ -1862,7 +1864,9 @@ def _add_transpose(t, x, y):
def _sub_transpose(t, x, y):
- assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
+ # The following linearity assertion is morally true, but because in some cases
+ # we instantiate zeros for convenience, it doesn't always hold.
+ # assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
return [t, neg(t) if t is not ad_util.zero else ad_util.zero]
sub_p = standard_naryop([_num, _num], 'sub')
diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py
index 3a391c00feb7..2dd657e9dfd3 100644
--- a/jax/lax/lax_control_flow.py
+++ b/jax/lax/lax_control_flow.py
@@ -26,7 +26,7 @@
import numpy as onp
-from jax import api
+import jax
from jax import core
from jax import dtypes
from jax.lax import lax
@@ -44,7 +44,8 @@
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
split_dict, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
- treedef_children, treedef_tuple)
+ treedef_children, treedef_tuple, tree_leaves,
+ tree_multimap)
from jax import ad_util
_map = safe_map
@@ -78,7 +79,11 @@ def typematch(aval1, aval2):
return (raise_to_shaped(aval1).strip_weak_type() ==
raise_to_shaped(aval2).strip_weak_type())
-class FixedPointError(Exception): pass
+def _disable_jit_impl(prim, interp, *args, **kwargs):
+ if jax.api._jit_is_disabled():
+ return interp(*args, **kwargs)
+ else:
+ return xla.apply_primitive(prim, *args, **kwargs)
### fori_loop and while_loop
@@ -210,6 +215,12 @@ def while_loop(cond_fun, body_fun, init_val):
Returns:
The output from the final iteration of body_fun, of type ``a``.
"""
+ if jax.api._jit_is_disabled():
+ val = init_val
+ while cond_fun(val):
+ val = body_fun(val)
+ return val
+
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
@@ -430,6 +441,13 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
else:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))
+
+ if jax.api._jit_is_disabled():
+ if pred:
+ return true_fun(true_operand)
+ else:
+ return false_fun(false_operand)
+
true_ops, true_tree = tree_flatten((true_operand,))
true_avals = tuple(_map(_abstractify, true_ops))
true_jaxpr, true_consts, true_out_tree = _initial_style_jaxpr(true_fun, true_tree, true_avals)
@@ -783,7 +801,7 @@ def scan(f, init, xs, length=None):
the second output of ``f`` when scanned over the leading axis of the inputs.
"""
init_flat, init_tree = tree_flatten(init)
- xs_flat, _ = tree_flatten(xs)
+ xs_flat, xs_tree = tree_flatten(xs)
in_flat, in_tree = tree_flatten((init, xs))
try:
@@ -811,6 +829,17 @@ def scan(f, init, xs, length=None):
else:
length, = unique_lengths
+ if jax.api._jit_is_disabled():
+ carry = init
+ ys = []
+ for i in range(length):
+ xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
+ carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
+ ys.append(y)
+ stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
+ else jax.numpy.stack((y, *ys)))
+ return carry, tree_multimap(stack, *ys)
+
carry_avals = tuple(_map(_abstractify, init_flat))
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
x_dtypes = [x.dtype for x in xs_flat]
@@ -1363,7 +1392,7 @@ def custom_root(f, initial_guess, solve, tangent_solve):
_check_tree("solve", "initial_guess", solution_tree, in_tree)
def linearize_and_solve(x, b):
- unchecked_zeros, f_jvp = api.linearize(f, x)
+ unchecked_zeros, f_jvp = jax.linearize(f, x)
return tangent_solve(f_jvp, b)
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
@@ -1445,7 +1474,7 @@ def _transpose_function(linear_fun, primals):
# TODO(shoyer): can we use something more direct than the vjp machinery?
# It's particularly awkward that we need the second argument to give
# particular values of the primals, which are entirely arbitrary.
- _, vjp_fun = api.vjp(linear_fun, primals)
+ _, vjp_fun = jax.vjp(linear_fun, primals)
def transposed_fun(x):
(y,) = vjp_fun(x)
diff --git a/jax/linear_util.py b/jax/linear_util.py
index 534eebafbbcc..6373201a9d6d 100644
--- a/jax/linear_util.py
+++ b/jax/linear_util.py
@@ -81,7 +81,8 @@ def __init__(self):
self._val = _EMPTY_STORE_VALUE
def store(self, val):
- assert self._val is _EMPTY_STORE_VALUE, "Store occupied"
+ if self._val is not _EMPTY_STORE_VALUE:
+ raise StoreException("Store occupied")
self._val = val
@property
diff --git a/jax/nn/functions.py b/jax/nn/functions.py
index 051d7f6d1b52..0d34409a8b09 100644
--- a/jax/nn/functions.py
+++ b/jax/nn/functions.py
@@ -17,15 +17,15 @@
import numpy as onp
+from jax import custom_jvp
from jax import dtypes
-from jax import custom_transforms, defjvp
from jax import lax
from jax.scipy.special import expit
import jax.numpy as np
# activations
-@custom_transforms
+@custom_jvp
def relu(x):
r"""Rectified linear unit activation function.
@@ -35,7 +35,11 @@ def relu(x):
\mathrm{relu}(x) = \max(x, 0)
"""
return np.maximum(x, 0)
-defjvp(relu, lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
+def _relu_jvp(primals, tangents):
+ x, = primals
+ t, = tangents
+ return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
+relu.defjvp(_relu_jvp)
def softplus(x):
r"""Softplus activation function.
diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py
index 9856b3fb3397..0443b2d27647 100644
--- a/jax/numpy/lax_numpy.py
+++ b/jax/numpy/lax_numpy.py
@@ -38,7 +38,7 @@
import numpy as onp
import opt_einsum
-from jax import jit, device_put, custom_transforms, defjvp
+from jax import jit, device_put
from .. import core
from .. import dtypes
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
@@ -423,6 +423,7 @@ def fn(x1, x2):
arctan = _one_to_one_unop(onp.arctan, lax.atan, True)
sinh = _one_to_one_unop(onp.sinh, lax.sinh, True)
cosh = _one_to_one_unop(onp.cosh, lax.cosh, True)
+arcsinh = _one_to_one_unop(onp.arcsinh, lax.asinh, True)
tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
arcsinh = _one_to_one_unop(onp.arcsinh, lax.asinh, True)
arccosh = _one_to_one_unop(onp.arccosh, lax.acosh, True)
diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py
index 512adb35af6a..728ee83ea1e6 100644
--- a/jax/numpy/linalg.py
+++ b/jax/numpy/linalg.py
@@ -20,15 +20,15 @@
import operator
from typing import Tuple, Union, cast
-from jax import jit, ops, vmap
+from jax import jit, vmap, custom_jvp
from .. import lax
+from .. import ops
from .. import lax_linalg
from .. import dtypes
from .lax_numpy import _not_implemented
from .lax_numpy import _wraps
from .vectorize import vectorize
from . import lax_numpy as np
-from ..api import custom_transforms, defjvp
from ..util import get_module_functions
from ..third_party.numpy.linalg import cond, tensorinv, tensorsolve
@@ -110,15 +110,8 @@ def matrix_rank(M, tol=None):
return np.sum(S > tol)
-# TODO(pfau): make this work for complex types
-def _jvp_slogdet(g, ans, x):
- jvp_sign = np.zeros(x.shape[:-2])
- jvp_logdet = np.trace(solve(x, g), axis1=-1, axis2=-2)
- return jvp_sign, jvp_logdet
-
-
+@custom_jvp
@_wraps(onp.linalg.slogdet)
-@custom_transforms
@jit
def slogdet(a):
a = _promote_arg_dtypes(np.asarray(a))
@@ -143,7 +136,15 @@ def slogdet(a):
is_zero, np.array(-np.inf, dtype=dtype),
np.sum(np.log(np.abs(diag)), axis=-1))
return sign, np.real(logdet)
-defjvp(slogdet, _jvp_slogdet)
+def _slogdet_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ if np.issubdtype(np._dtype(x), np.complexfloating):
+ raise NotImplementedError # TODO(pfau): make this work for complex types
+ sign, ans = slogdet(x)
+ sign_dot, ans_dot = np.zeros_like(sign), np.trace(solve(x, g), axis1=-1, axis2=-2)
+ return (sign, ans), (sign_dot, ans_dot)
+slogdet.defjvp(_slogdet_jvp)
@_wraps(onp.linalg.det)
diff --git a/jax/random.py b/jax/random.py
index 7c61801d0282..3af4b54c72ef 100644
--- a/jax/random.py
+++ b/jax/random.py
@@ -30,7 +30,7 @@
from . import numpy as np
from . import tree_util
from . import dtypes
-from .api import custom_transforms, defjvp, jit, vmap
+from .api import jit, vmap
from .numpy.lax_numpy import _constant_like, asarray, stack
from jax.lib import xla_bridge
from jax.lib import cuda_prng
diff --git a/jax/scipy/special.py b/jax/scipy/special.py
index a443e153c6bb..cd29c118f5f0 100644
--- a/jax/scipy/special.py
+++ b/jax/scipy/special.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from functools import partial
import numpy as np
import scipy.special as osp_special
-from .. import lax
from .. import util
-from ..api import custom_transforms, defjvp
+from .. import lax
+from .. import api
from ..numpy import lax_numpy as jnp
from ..numpy.lax_numpy import (_wraps, asarray, _reduction_dims, _constant_like,
_promote_args_inexact)
@@ -78,21 +79,26 @@ def erfinv(x):
return lax.erf_inv(x)
-@_wraps(osp_special.logit, update_doc=False)
-@custom_transforms
+@api.custom_jvp
def logit(x):
- x = asarray(x)
return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))
-defjvp(logit, lambda g, ans, x: g / (x * (1 - x)))
+def _logit_jvp(primals, tangents):
+ (x,), (t,) = primals, tangents
+ ans = logit(x)
+ t_out = lax.div(lax.mul(x, lax.sub(lax._const(x, 1), x)))
+ return ans, t_out
+logit.defjvp(_logit_jvp)
-@_wraps(osp_special.expit, update_doc=False)
-@custom_transforms
+@api.custom_jvp
def expit(x):
- x = asarray(x)
- one = lax._const(x, 1)
- return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
-defjvp(expit, lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans))
+ return 1 / (1 + lax.exp(-x))
+def _expit_jvp(primals, tangents):
+ (x,), (t,) = primals, tangents
+ ans = expit(x)
+ t_out = t * ans * (1 - ans)
+ return ans, t_out
+expit.defjvp(_expit_jvp)
@_wraps(osp_special.logsumexp)
@@ -405,7 +411,7 @@ def _create_polynomial(var, coeffs):
return x_nan_replaced
-@custom_transforms
+@partial(api.custom_jvp, nondiff_argnums=(1,))
def log_ndtr(x, series_order=3):
r"""Log Normal distribution function.
@@ -507,6 +513,12 @@ def log_ndtr(x, series_order=3):
_log_ndtr_lower(lax.min(x, lower_segment),
series_order)))
+def _log_ndtr_jvp(series_order, primals, tangents):
+ (x,), (t,) = primals, tangents
+ ans = log_ndtr(x, series_order=series_order)
+ t_out = lax.mul(t, lax.exp(lax.sub(_norm_logpdf(x), ans)))
+ return ans, t_out
+log_ndtr.defjvp(_log_ndtr_jvp)
def _log_ndtr_lower(x, series_order):
"""Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
@@ -548,9 +560,6 @@ def _norm_logpdf(x):
log_normalizer = _constant_like(x, _norm_logpdf_constant)
return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
-defjvp(log_ndtr,
- lambda g, ans, x: lax.mul(g, lax.exp(lax.sub(_norm_logpdf(x), ans))))
-
@_wraps(osp_special.i0e)
def i0e(x):
return lax.bessel_i0e(x)
diff --git a/tests/api_test.py b/tests/api_test.py
index 22cc102f4b74..64492a2b70ce 100644
--- a/tests/api_test.py
+++ b/tests/api_test.py
@@ -271,7 +271,7 @@ def foo(x):
ad.defjvp(foo_p, lambda g, x: foo(g))
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
- "Reverse-mode differentiation rule for 'foo' not implemented")
+ "Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
def test_device_put_and_get(self):
x = onp.arange(12.).reshape((3, 4)).astype("float32")
@@ -540,39 +540,6 @@ def test_vjp_mismatched_arguments(self):
"Type of cotangent input to vjp pullback.*does not match type",
lambda: pullback((onp.float16(42))))
- def test_jarrett_jvps(self):
- def f1(x):
- return np.sin(np.sin(np.sin(x)))
- f2 = api.jarrett(f1)
-
- for x in [3., onp.array([2., 3., 4.])]:
- self.assertAllClose(f1(x), f2(x), check_dtypes=True)
-
- _, f1_vjp = api.vjp(f1, x)
- _, f2_vjp = api.vjp(f2, x)
- self.assertAllClose(f1_vjp(x), f2_vjp(x), check_dtypes=True)
-
- # TODO(mattjj): test that constants/literals are set up properly
- # jaxpr2 = api.make_jaxpr(f2_vjp)(x)
- # assert len(jaxpr2.constvars) == 1
-
- def test_jarrett_jvps2(self):
- def f1(x, y):
- return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y)
- f2 = api.jarrett(f1)
-
- # TODO(mattjj): doesn't work for (3., onp.array([4., 5.]))
- for x, y in [(3., 4.), (onp.array([5., 6.]), onp.array([7., 8.]))]:
- self.assertAllClose(f1(x, y), f2(x, y), check_dtypes=True)
-
- _, f1_vjp = api.vjp(f1, x, y)
- _, f2_vjp = api.vjp(f2, x, y)
- self.assertAllClose(f1_vjp(y), f2_vjp(y), check_dtypes=True)
-
- # TODO(mattjj): test that constants/literals are set up properly
- # jaxpr2 = api.make_jaxpr(f2_vjp)(y)
- # assert len(jaxpr2.constvars) == 2
-
def test_jvp_jit_cached(self):
"""Bug in caching in presence of JVP and JIT."""
@@ -629,199 +596,6 @@ def f(z):
def test_complex_input_jacfwd_raises_error(self):
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
- def test_defvjp_all(self):
- foo_p = Primitive('foo')
- def foo(x): return 2. * foo_p.bind(x)
-
- ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
- val_ans, grad_ans = api.value_and_grad(foo)(3.)
- self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
- self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
-
- def test_defvjp_all_const(self):
- foo_p = Primitive('foo')
- def foo(x): return foo_p.bind(x)
-
- ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
- val_ans, grad_ans = api.value_and_grad(foo)(3.)
- self.assertAllClose(val_ans, 9., check_dtypes=False)
- self.assertAllClose(grad_ans, 12., check_dtypes=True)
-
- def test_defvjp_all_higher_order_revmode(self):
- foo_p = Primitive('foo')
- def foo(x): return 2. * foo_p.bind(x)
-
- ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
- ans = api.grad(api.grad(foo))(3.)
- self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
-
- def test_defvjp_all_multiple_arguments(self):
- # also tests passing in symbolic zero tangents b/c we differentiate wrt only
- # the first argument in one case
-
- foo_p = Primitive('foo')
- def foo(x, y): return foo_p.bind(x, y)
-
- def vjpfun(x, y):
- out = x**2 + y**3
- vjp = lambda g: (g + x + y, g * x * 9.)
- return out, vjp
-
- ad.defvjp_all(foo_p, vjpfun)
- val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
- self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
- self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
-
- ans = api.grad(foo, (0, 1))(3., 4.)
- self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
-
- def test_defvjp_all_custom_transforms(self):
- @api.custom_transforms
- def foo(x):
- return np.sin(x)
-
- api.defvjp_all(foo, lambda x: (np.sin(x), lambda g: (g * x,)))
- val_ans, grad_ans = api.value_and_grad(foo)(3.)
- self.assertAllClose(val_ans, onp.sin(3.), check_dtypes=False)
- self.assertAllClose(grad_ans, 3., check_dtypes=False)
-
- # TODO(mattjj): add defvjp_all test with pytree arguments
-
- def test_defvjp(self):
- @api.custom_transforms
- def foo(x, y):
- return np.sin(x * y)
-
- api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
- val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
- self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
- self.assertAllClose(grad_ans, 0., check_dtypes=False)
-
- ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
- self.assertAllClose(ans_0, 0., check_dtypes=False)
- self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
-
- def test_defvjp_higher_order(self):
- @api.custom_transforms
- def foo(x):
- return np.sin(2. * x)
-
- api.defvjp(foo, lambda g, _, x: g * np.cos(x))
- ans = api.grad(api.grad(foo))(2.)
- expected = api.grad(api.grad(np.sin))(2.)
- self.assertAllClose(ans, expected, check_dtypes=False)
-
- def test_defvjp_use_ans(self):
- @api.custom_transforms
- def foo(x, y):
- return np.sin(x * y)
-
- api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
- val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
- self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
- self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
- check_dtypes=False)
-
- # TODO
- # def test_defjvp_closure_error(self):
- # def foo(x):
- # @api.custom_transforms
- # def bar(y):
- # return x * y
-
- # api.defjvp(bar, lambda y_dot, ans, y: x * y)
- # return bar(x)
- # jtu.check_raises(
- # lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
- # "Detected differentiation with respect to closed-over values with "
- # "custom JVP rule, which isn't supported.")
-
- # TODO
- # def test_defvjp_closure_error(self):
- # def foo(x):
- # @api.custom_transforms
- # def bar(y):
- # return x * y
-
- # api.defvjp(bar, lambda g, ans, y: x * y)
- # return bar(x)
- # jtu.check_raises(
- # lambda: grad(foo)(1.,), ValueError,
- # "Detected differentiation w.r.t. variables from outside "
- # "the scope of , but defvjp and "
- # "defvjp_all only support differentiation w.r.t. positional arguments.")
-
- def test_custom_transforms_eval_with_pytrees(self):
- @api.custom_transforms
- def f(x):
- a, b = x[0], x[1]
- return {'hi': 2 * a, 'bye': 2 * b}
-
- ans = f((1, 2))
- self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
-
- def test_custom_transforms_jit_with_pytrees(self):
- @api.custom_transforms
- def f(x):
- a, b = x[0], x[1]
- return {'hi': 2 * a, 'bye': 2 * b}
-
- ans = jit(f)((1, 2))
- self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
-
- def test_custom_transforms_jit_with_pytrees_consts(self):
- # The purpose of this test is to exercise the custom_transforms default
- # translation rule in how it deals with constants that are too large to be
- # treated as literals (at the time of writing).
- z = onp.arange(10.)
-
- @api.custom_transforms
- def f(x):
- a, b = x[0], x[1]
- return {'hi': 2 * a, 'bye': z * b}
-
- ans = jit(f)((1, 2))
- self.assertAllClose(ans, {'hi': 2 * 1, 'bye': z * 2}, check_dtypes=False)
-
- def test_custom_transforms_jvp_with_pytrees(self):
- @api.custom_transforms
- def f(x):
- a, b = x[0], x[1]
- return {'hi': 2 * a, 'bye': 2 * b}
-
- ans, out_tangent = api.jvp(f, ((1, 2),), ((3, 4),))
- self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
- self.assertEqual(out_tangent, {'hi': 2 * 3, 'bye': 2 * 4})
-
- def test_custom_transforms_vmap_with_pytrees(self):
- @api.custom_transforms
- def f(x):
- a, b = x[0], x[1]
- return {'hi': 2 * a, 'bye': 2 * b}
-
- ans = api.vmap(f)((onp.arange(3), onp.ones((3, 2))))
- expected = {'hi': 2 * onp.arange(3), 'bye': 2 * onp.ones((3, 2))}
- self.assertAllClose(ans, expected, check_dtypes=False)
-
- def test_custom_transforms_jvp_with_closure(self):
- def f(x):
- @api.custom_transforms
- def g(y):
- return x * y
- return g(x)
-
- ans = api.grad(f)(1.)
- expected = 2.
- self.assertAllClose(ans, expected, check_dtypes=False)
-
- def test_custom_gradient(self):
- @api.custom_gradient
- def f(x):
- return x ** 2, lambda g: (g * x,)
-
- self.assertAllClose(f(3.), 9., check_dtypes=False)
- self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
-
def test_legacy_devicearray_repr(self):
dx = device_put(3.)
str(dx.item()) # doesn't crash
@@ -1182,55 +956,6 @@ def check_warning(warn, nowarn):
check_warning(lambda: np.tri(2, dtype="float64"),
lambda: np.tri(2, dtype="float32"))
- def test_custom_vjp_zeros(self):
- @api.custom_transforms
- def f(x, y):
- return 2 * x, 3 * y
-
- def f_vjp(x, y):
- return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1])
-
- api.defvjp_all(f, f_vjp, )
- api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
-
- def test_custom_transforms_vjp_nones(self):
- # issue rasied by jsnoek@ and jumper@
- @jax.custom_transforms
- def solve(a, b):
- return np.dot(np.linalg.inv(a), b)
- # print(solve(a, b))
-
- def solve_vjp(a, b):
- x = solve(a, b)
- def vjp(x_tangent):
- dx = np.dot(solve(a, x_tangent), x.T)
- out = (dx, b * 0.)
- return out
- return x, vjp
- jax.defvjp_all(solve, solve_vjp)
- gf = grad(lambda a,b: np.sum(solve(a, b)))
-
- n = 3
- a_in = np.linspace(0, 1, n)[:, None]
- a = np.dot(a_in, a_in.T) + np.eye(n) * 0.1
- real_x = onp.random.RandomState(0).randn(n)
- b = np.dot(a + np.eye(a.shape[0]), real_x)
- print(gf(a, b)) # doesn't crash
-
- def test_vmap_in_axes_list(self):
- # https://github.com/google/jax/issues/2367
- dictionary = {'a': 5., 'b': np.ones(2)}
- x = np.zeros(3)
- y = np.arange(3.)
-
-
- def f(dct, x, y):
- return dct['a'] + dct['b'] + x + y
-
- out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
- out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
- self.assertAllClose(out1, out2, check_dtypes=True)
-
def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
self.assertRaisesRegex(
@@ -1772,7 +1497,7 @@ def helper_save_tracer(self, x):
def test_escaped_tracers_diffent_top_level_traces(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
- ValueError,
+ core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Different traces at same level",
re.DOTALL)):
@@ -1781,7 +1506,7 @@ def test_escaped_tracers_diffent_top_level_traces(self):
def test_escaped_tracers_cant_lift_sublevels(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
- ValueError,
+ core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
re.DOTALL)):
@@ -1790,7 +1515,7 @@ def test_escaped_tracers_cant_lift_sublevels(self):
def test_escaped_tracers_tracer_from_higher_level(self):
api.grad(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
- ValueError,
+ core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer from a higher level",
re.DOTALL)):
@@ -1802,7 +1527,7 @@ def func1(x):
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
- ValueError,
+ core.UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
re.DOTALL)):
api.jit(func1)(2.)
@@ -1812,8 +1537,9 @@ def func1(x):
api.grad(self.helper_save_tracer)(0.)
return x + self._saved_tracer
with self.assertRaisesRegex(
- ValueError, re.compile("Encountered an unexpected tracer.*Can't lift",
- re.DOTALL)):
+ core.UnexpectedTracerError,
+ re.compile("Encountered an unexpected tracer.*Can't lift",
+ re.DOTALL)):
api.grad(func1)(2.)
def test_escaped_tracers_not_among_input_tracers(self):
@@ -1823,7 +1549,8 @@ def func1(x):
return x + self._saved_tracer
with self.assertRaisesRegex(
- ValueError, re.compile(
+ core.UnexpectedTracerError,
+ re.compile(
"Encountered an unexpected tracer.*Tracer not among input tracers",
re.DOTALL)):
api.jit(func1)(2.)
@@ -1870,7 +1597,7 @@ def f(x):
""", str(jaxpr))
def testExamplesJaxprDoc(self):
- """Tests examples included in the Understanding JAXPRs doc (docs/jaxpr.rst)."""
+ """Tests examples included in the Understanding jaxprs doc (docs/jaxpr.rst)."""
from jax import numpy as jnp
def func1(first, second):
temp = first + jnp.sin(second) * 3.
@@ -2221,6 +1948,607 @@ def test_zeros_ones_compilation(self):
self.assertAllClose(x, onp.ones(3), check_dtypes=False)
self.assertAllClose(y, onp.ones(3) + onp.ones(3), check_dtypes=False)
+class CustomJVPTest(jtu.JaxTestCase):
+
+ def test_basic(self):
+ @api.custom_jvp
+ def f(x):
+ return np.sin(x)
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ return f(x), 2 * np.cos(x) * g
+ f.defjvp(f_jvp)
+
+ x = 3.
+ self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(api.jvp(f, (x,), (1.,)),
+ (np.sin(x), 2 * np.cos(x)),
+ check_dtypes=True)
+ self.assertAllClose(api.grad(f)(x), 2 * np.cos(x), check_dtypes=True)
+
+ def test_invariance(self):
+ @api.custom_jvp
+ def f(x):
+ return np.cos(2 * x) / 2.
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ return (f(x), 3 * g)
+ f.defjvp(f_jvp)
+ def f2(x):
+ y, _ = api.jvp(f, (x,), (x,))
+ return y
+ def f3(x):
+ y, _ = api.jvp(f2, (x,), (x,))
+ return y
+ x = 1.
+ self.assertAllClose(api.jvp(f, (x,), (x,)),
+ api.jvp(f2, (x,), (x,)),
+ check_dtypes=False)
+ self.assertAllClose(api.jvp(f, (x,), (x,)),
+ api.jvp(f3, (x,), (x,)),
+ check_dtypes=False)
+
+ def test_python_control_flow(self):
+ @api.custom_jvp
+ def f(x):
+ if x > 0:
+ return np.sin(x)
+ else:
+ return np.cos(x)
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ if x > 0:
+ return f(x), 2 * g
+ else:
+ return f(x), 3 * g
+ f.defjvp(f_jvp)
+ x = 2.
+ self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(f(-x), np.cos(-x), check_dtypes=True)
+ self.assertAllClose(api.jvp(f, (x,), (1.,)),
+ (np.sin(x), 2.),
+ check_dtypes=False)
+ self.assertAllClose(api.jvp(f, (-x,), (1.,)),
+ (np.cos(-x), 3.),
+ check_dtypes=False)
+ self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False)
+ self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False)
+
+ def test_vmap(self):
+ @api.custom_jvp
+ def f(x):
+ assert np.ndim(x) == 0
+ return np.sin(x)
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ assert np.ndim(x) == np.ndim(g) == 0
+ return f(x), 2 * np.cos(x) * g
+ f.defjvp(f_jvp)
+
+ x = np.arange(3.)
+ xx = np.arange(6.).reshape(2, 3)
+
+ # vmap of f
+ self.assertAllClose(api.vmap(f)(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(api.vmap(api.vmap(f))(xx), np.sin(xx), check_dtypes=True)
+
+ # vmap of jvp of f
+ self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x),
+ (np.sin(x), 2 * np.cos(x) * x),
+ check_dtypes=True)
+ self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx),
+ (np.sin(xx), 2 * np.cos(xx) * xx),
+ check_dtypes=True)
+
+ # jvp of vmap of f
+ self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)),
+ (np.sin(x), 2 * np.cos(x) * x),
+ check_dtypes=True)
+ self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)),
+ (np.sin(xx), 2 * np.cos(xx) * xx),
+ check_dtypes=True)
+
+ # vmap of jvp of vmap of f
+ self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx),
+ (np.sin(xx), 2 * np.cos(xx) * xx),
+ check_dtypes=True)
+
+ def test_jit(self):
+ @api.custom_jvp
+ def f(x):
+ return np.sin(x)
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ return f(x), 2 * np.cos(x) * g
+ f.defjvp(f_jvp)
+
+ x = 3.
+
+ # jit
+ self.assertAllClose(api.jit(f)(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(api.jit(api.jit(f))(x), np.sin(x), check_dtypes=True)
+
+ # jit of jvp
+ self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x),
+ (np.sin(x), 2 * np.cos(x) * x),
+ check_dtypes=False)
+
+ # jvp of jit
+ self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)),
+ (np.sin(x), 2 * np.cos(x) * x),
+ check_dtypes=False)
+
+ def test_pytrees(self):
+ @api.custom_jvp
+ def f(x):
+ return {'b': np.sin(x['a'])}
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ return f(x), {'b': 2 * np.cos(x['a']) * g['a']}
+ f.defjvp(f_jvp)
+ x = {'a': 3.}
+ self.assertAllClose(f(x)['b'], np.sin(x['a']), check_dtypes=True)
+ self.assertAllClose(api.jvp(f, (x,), (x,)),
+ ({'b': np.sin(x['a'])},
+ {'b': 2 * np.cos(x['a']) * x['a']}),
+ check_dtypes=False)
+
+ def test_kwargs(self):
+ # from https://github.com/google/jax/issues/1938
+ @api.custom_jvp
+ def my_fun(x, y, c=1.):
+ return c * (x + y)
+ def my_jvp(primals, tangents):
+ x, y, c = primals
+ t_x, t_y, t_c = tangents
+ return my_fun(x, y, c), t_c
+ my_fun.defjvp(my_jvp)
+ f = lambda x, y: np.square(my_fun(x, y, c=2.)).sum()
+ f(10., 5.) # doesn't crash
+ api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash
+
+ def test_initial_style(self):
+ @api.custom_jvp
+ def f(x):
+ return 3 * x
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ return f(x), 2 * g
+ f.defjvp(f_jvp)
+
+ def foo(x):
+ out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
+ return out
+
+ ans = api.grad(foo)(3.)
+ expected = 2.
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ ans = api.grad(api.grad(foo))(3.)
+ expected = 0.
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ def test_initial_style_vmap(self):
+ @api.custom_jvp
+ def f(x):
+ assert np.ndim(x) == 0
+ return 3 * x
+ def f_jvp(primals, tangents):
+ x, = primals
+ g, = tangents
+ return f(x), 2 * g
+ f.defjvp(f_jvp)
+
+ def foo(x):
+ out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
+ return out
+
+ ans = api.vmap(foo)(np.ones(3))
+ expected = 3. * np.ones(3)
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ ans = api.grad(lambda x: api.vmap(foo)(x).sum())(np.ones(3))
+ expected = 2. * np.ones(3)
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ def test_closed_over_tracers_error_message(self):
+ def f(x):
+ @api.custom_jvp
+ def g(y):
+ return x + y
+ def g_jvp(primals, tangents):
+ (y,), (t,) = primals, tangents
+ return g(x), 2 * y
+ g.defjvp(g_jvp)
+ return g(1.)
+
+ self.assertRaises(
+ core.UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,)))
+ self.assertRaises(
+ core.UnexpectedTracerError, lambda: api.grad(f)(3.))
+
+ def test_nondiff_arg(self):
+ @partial(api.custom_jvp, nondiff_argnums=(0,))
+ def app(f, x):
+ return f(x)
+ def app_jvp(f, primals, tangents):
+ (x,), (t,) = primals, tangents
+ return app(f, x), 3 * t
+ app.defjvp(app_jvp)
+
+ ans = app(lambda x: 2 * x, 1)
+ expected = 2
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,))
+ expected = (2., 3.)
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ def test_nondiff_arg_tracer(self):
+ @partial(api.custom_jvp, nondiff_argnums=(0,))
+ def f(x, y):
+ return x * y
+ def f_jvp(x, primals, tangents):
+ (y,), (t_y,) = primals, tangents
+ return f(x, y), 5 * t_y
+ f.defjvp(f_jvp)
+
+ @jit
+ def g(x, y):
+ return f(x, y)
+
+ ans = api.jvp(lambda y: g(2., y), (3.,), (1.,))
+ expected = (6., 5.)
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ def test_vmap_axes(self):
+ raise unittest.SkipTest("TODO") # TODO(mattjj): write test
+
+ def test_pmap(self):
+ raise unittest.SkipTest("TODO") # TODO(mattjj): write test
+
+ def test_missing_jvp_rule_error(self):
+ @api.custom_jvp
+ def foo(x):
+ return x ** 2
+
+ self.assertRaisesRegex(
+ AttributeError,
+ r"No JVP defined for custom_jvp function foo using defjvp.",
+ lambda: foo(2))
+ self.assertRaisesRegex(
+ AttributeError,
+ r"No JVP defined for custom_jvp function foo using defjvp.",
+ lambda: api.jvp(foo, (2.,), (1.,)))
+ self.assertRaisesRegex(
+ AttributeError,
+ r"No JVP defined for custom_jvp function foo using defjvp.",
+ lambda: api.grad(foo)(2.))
+
+ def test_jvp_rule_inconsistent_pytree_structures_error(self):
+ @api.custom_jvp
+ def f(x):
+ return (x**2,)
+
+ @f.defjvp
+ def foo_jvp(primals, tangents):
+ x, = primals
+ t, = tangents
+ return f(x), [2 * x * t, x]
+
+ f(2.) # doesn't crash
+ self.assertRaisesRegex(
+ TypeError,
+ re.escape(
+ "Custom JVP rule must produce primal and tangent outputs "
+ "with equal container (pytree) structures, but got "
+ "{} and {}.".format(
+ tree_util.tree_structure((1,)),
+ tree_util.tree_structure([1, 2]))
+ ),
+ lambda: api.jvp(f, (2.,), (1.,)))
+
+
+class CustomVJPTest(jtu.JaxTestCase):
+
+ def test_basic(self):
+ @api.custom_vjp
+ def f(x):
+ return np.sin(x)
+ def f_fwd(x):
+ return f(x), np.cos(x)
+ def f_rev(cos_x, g):
+ return (2 * cos_x * g,)
+ f.defvjp(f_fwd, f_rev)
+
+ x = 3.
+ self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(api.grad(f)(x), 2 * np.cos(x), check_dtypes=True)
+ self.assertAllClose(api.value_and_grad(f)(x),
+ (np.sin(x), 2 * np.cos(x)),
+ check_dtypes=True)
+
+ def test_invariance(self):
+ @api.custom_vjp
+ def f(x):
+ return np.cos(2 * x) / 2.
+ def f_fwd(x):
+ return (f(x), x)
+ def f_rev(x, g):
+ return (g * 3,)
+ f.defvjp(f_fwd, f_rev)
+ def f2(x):
+ y, _ = api.value_and_grad(f)(x)
+ return y
+ def f3(x):
+ y, _ = api.value_and_grad(f2)(x)
+ return y
+ x = 1.
+ self.assertAllClose(f(x), f2(x), check_dtypes=False)
+ self.assertAllClose(f(x), f3(x), check_dtypes=False)
+ self.assertAllClose(api.grad(f)(x), api.grad(f2)(x),
+ check_dtypes=False)
+ self.assertAllClose(api.grad(f)(x), api.grad(f3)(x),
+ check_dtypes=False)
+
+ def test_python_control_flow(self):
+ @api.custom_vjp
+ def f(x):
+ if x > 0:
+ return np.sin(x)
+ else:
+ return np.cos(x)
+ def f_fwd(x):
+ if x > 0:
+ return f(x), x
+ else:
+ return f(x), x
+ def f_rev(x, g):
+ if x > 0:
+ return (2 * g,)
+ else:
+ return (3 * g,)
+ f.defvjp(f_fwd, f_rev)
+
+ x = 2.
+ self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(f(-x), np.cos(-x), check_dtypes=True)
+ self.assertAllClose(api.value_and_grad(f)(x), (np.sin(x), 2.),
+ check_dtypes=False)
+ self.assertAllClose(api.value_and_grad(f)(-x), (np.cos(-x), 3.),
+ check_dtypes=False)
+
+ def test_vmap(self):
+ @api.custom_vjp
+ def f(x):
+ assert np.ndim(x) == 0
+ return np.sin(x)
+ def f_fwd(x):
+ assert np.ndim(x) == 0
+ return f(x), np.cos(x)
+ def f_rev(cos_x, g):
+ return (2 * cos_x * g,)
+ f.defvjp(f_fwd, f_rev)
+
+ x = np.arange(3.)
+ xx = np.arange(6.).reshape(2, 3)
+
+ # vmap of f
+ self.assertAllClose(api.vmap(f)(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(api.vmap(api.vmap(f))(xx), np.sin(xx), check_dtypes=True)
+
+ # vmap of grad of f
+ self.assertAllClose(api.vmap(api.grad(f))(x), 2 * np.cos(x),
+ check_dtypes=True)
+ self.assertAllClose(api.vmap(api.value_and_grad(f))(x),
+ (np.sin(x), 2 * np.cos(x)),
+ check_dtypes=True)
+ self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * np.cos(xx),
+ check_dtypes=True)
+ self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx),
+ (np.sin(xx), 2 * np.cos(xx)),
+ check_dtypes=True)
+
+ # grad of vmap of f
+ self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x),
+ 2 * np.cos(x),
+ check_dtypes=True)
+ self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx),
+ 2 * np.cos(xx),
+ check_dtypes=True)
+
+ # vmap of grad of vmap of f
+ self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx),
+ 2 * np.cos(xx),
+ check_dtypes=True)
+
+ def test_jit(self):
+ @api.custom_vjp
+ def f(x):
+ return np.sin(x)
+ def f_fwd(x):
+ return f(x), np.cos(x)
+ def f_rev(cos_x, g):
+ return (2 * cos_x * g,)
+ f.defvjp(f_fwd, f_rev)
+
+ x = 3.
+
+ # jit
+ self.assertAllClose(api.jit(f)(x), np.sin(x), check_dtypes=True)
+ self.assertAllClose(api.jit(api.jit(f))(x), np.sin(x), check_dtypes=True)
+
+ # jit of grad
+ self.assertAllClose(api.jit(api.grad(f))(x), 2 * np.cos(x),
+ check_dtypes=False)
+
+ # grad of jit
+ self.assertAllClose(api.grad(api.jit(f))(x), 2 * np.cos(x),
+ check_dtypes=False)
+
+ def test_pytrees(self):
+ @api.custom_vjp
+ def f(x):
+ return {'b': np.sin(x['a'])}
+ def f_fwd(x):
+ return f(x), {'r': np.cos(x['a'])}
+ def f_bwd(res, g):
+ cos_x = res['r']
+ return ({'a': 2 * cos_x * g['b']},)
+ f.defvjp(f_fwd, f_bwd)
+ x = {'a': 3.}
+ self.assertAllClose(f(x)['b'], np.sin(x['a']), check_dtypes=True)
+ self.assertAllClose(api.grad(lambda x: f(x)['b'])(x),
+ {'a': 2 * np.cos(x['a'])},
+ check_dtypes=True)
+
+ def test_jvp_error(self):
+ @api.custom_vjp
+ def f(x):
+ return np.sin(x)
+ def f_fwd(x):
+ return f(x), np.cos(x)
+ def f_rev(cos_x, g):
+ return (2 * cos_x * g,)
+ f.defvjp(f_fwd, f_rev)
+
+ self.assertRaisesRegex(
+ TypeError,
+ r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
+ lambda: api.jvp(f, (3.,), (1.,)))
+ self.assertRaisesRegex(
+ TypeError,
+ r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
+ lambda: api.jvp(api.vmap(f), (np.arange(3.),), (np.ones(3),)))
+
+ def test_kwargs(self):
+ # from https://github.com/google/jax/issues/1938
+ @api.custom_vjp
+ def my_fun(x, y, c=1.):
+ return c * (x + y)
+ my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None),
+ lambda _, g: (g, g, g))
+ f = lambda x, y: np.square(my_fun(x, y, c=2.)).sum()
+ f(10., 5.) # doesn't crash
+ api.grad(f)(10., 5.) # doesn't crash
+
+ def test_initial_style(self):
+ @api.custom_vjp
+ def f(x):
+ return np.sin(x)
+ def f_fwd(x):
+ return f(x), np.cos(x)
+ def f_rev(cos_x, g):
+ return (2 * cos_x * g,)
+ f.defvjp(f_fwd, f_rev)
+
+ def foo(x):
+ out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
+ return out
+
+ ans = api.grad(foo)(3.)
+ expected = 2. * np.cos(3.)
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ ans = api.grad(api.grad(foo))(3.)
+ expected = -2. * np.sin(3.)
+ self.assertAllClose(ans, expected, check_dtypes=True)
+
+ def test_initial_style_vmap(self):
+ @api.custom_vjp
+ def f(x):
+ assert np.ndim(x) == 0
+ return 3 * x
+ def f_fwd(x):
+ return f(x), np.cos(x)
+ def f_rev(cos_x, g):
+ return (2 * cos_x * g,)
+ f.defvjp(f_fwd, f_rev)
+
+ def foo(x):
+ out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
+ return out
+
+ ans = api.vmap(foo)(np.arange(3.))
+ expected = 3. * np.arange(3.)
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ ans = api.grad(lambda x: api.vmap(foo)(x).sum())(np.arange(3.))
+ expected = 2. * np.cos(np.arange(3.))
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ def test_nondiff_arg(self):
+ @partial(api.custom_vjp, nondiff_argnums=(0,))
+ def app(f, x):
+ return f(x)
+ def app_fwd(f, x):
+ return app(f, x), np.cos(x)
+ def app_rev(f, cos_x, g):
+ return (cos_x * g,)
+ app.defvjp(app_fwd, app_rev)
+
+ ans = app(lambda x: 2 * x, 1)
+ expected = 2
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.)
+ expected = (2., np.cos(1.))
+ self.assertAllClose(ans, expected, check_dtypes=False)
+
+ def test_vmap_axes(self):
+ raise unittest.SkipTest("TODO") # TODO(mattjj): write test
+
+ def test_pmap(self):
+ raise unittest.SkipTest("TODO") # TODO(mattjj): write test
+
+ def test_missing_vjp_rule_error(self):
+ @api.custom_vjp
+ def foo(x):
+ return x ** 2
+
+ self.assertRaisesRegex(
+ AttributeError,
+ r"No VJP defined for custom_vjp function foo using defvjp.",
+ lambda: foo(2))
+ self.assertRaisesRegex(
+ AttributeError,
+ r"No VJP defined for custom_vjp function foo using defvjp.",
+ lambda: api.grad(foo)(2.))
+
+ def test_vjp_rule_inconsistent_pytree_structures_error(self):
+ @api.custom_vjp
+ def f(x):
+ return x
+
+ def foo_fwd(x):
+ return x, None
+
+ def foo_bwd(_, g):
+ return g
+
+ f.defvjp(foo_fwd, foo_bwd)
+
+ f(2) # doesn't crash
+ self.assertRaisesRegex(
+ TypeError,
+ re.escape(
+ "Custom VJP rule must produce an output with the same container "
+ "(pytree) structure as the args tuple of the primal function, "
+ "and in particular must produce a tuple of length equal to the "
+ "number of arguments to the primal function, but got VJP output "
+ "structure {} for primal input structure {}.".format(
+ tree_util.tree_structure(1),
+ tree_util.tree_structure((1,)))
+ ),
+ lambda: api.grad(f)(2.))
+
if __name__ == '__main__':
absltest.main()