Skip to content

Commit

Permalink
add custom_jvp / vjp, delete custom_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Feb 12, 2020
1 parent 5e77789 commit 0864ece
Show file tree
Hide file tree
Showing 15 changed files with 749 additions and 1,010 deletions.
5 changes: 0 additions & 5 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ Automatic differentiation
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: vjp
.. autofunction:: custom_transforms
.. autofunction:: defjvp
.. autofunction:: defjvp_all
.. autofunction:: defvjp
.. autofunction:: defvjp_all
.. autofunction:: custom_gradient


Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/How_JAX_primitives_work.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@
" File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n",
" lambda: _flatten_axes(out_tree(), out_axes))\n",
" File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n",
" out_vals, out_dims = batch_fun(fun, in_vals, in_dims)\n",
" out_vals, out_dims = batch2(fun, in_vals, in_dims)\n",
"NotImplementedError: Batching rule for 'multiply_add' not implemented\n"
],
"name": "stderr"
Expand Down
Loading

0 comments on commit 0864ece

Please sign in to comment.