Skip to content

Commit

Permalink
Fixes in the FAQ for RST (#2761)
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula authored Apr 19, 2020
1 parent a1b4fe4 commit 3ca7f6e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
10 changes: 8 additions & 2 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ These are the release notes for JAX.
jax 0.1.64 (unreleased)
---------------------------

* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.63...master>`_.
* Improves error message for reverse-mode differentiation of :func:`lax.while_loop`
`#2129 <https://github.com/google/jax/issues/2129>`_.

jaxlib 0.1.45 (unreleased)
------------------------------

Expand All @@ -29,6 +33,7 @@ jaxlib 0.1.44 (April 16, 2020)
jax 0.1.63
---------------------------

* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63>`_.
* Added ``jax.custom_jvp`` and ``jax.custom_vjp`` from `#2026 <https://github.com/google/jax/pull/2026>`_, see the `tutorial notebook <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_. Deprecated ``jax.custom_transforms`` and removed it from the docs (though it still works).
* Add ``scipy.sparse.linalg.cg`` `#2566 <https://github.com/google/jax/pull/2566>`_.
* Changed how Tracers are printed to show more useful information for debugging `#2591 <https://github.com/google/jax/pull/2591>`_.
Expand All @@ -51,6 +56,7 @@ jaxlib 0.1.43 (March 31, 2020)
jax 0.1.62 (March 21, 2020)
---------------------------

* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62>`_.
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function ``lax._safe_mul``, which implemented the
convention ``0. * nan == 0.``. This change means some programs when
Expand All @@ -69,14 +75,14 @@ jaxlib 0.1.42 (March 19, 2020)

jax 0.1.61 (March 17, 2020)
---------------------------

* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61>`_.
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5.

jax 0.1.60 (March 17, 2020)
---------------------------

* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.59...master>`_.
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60>`_.
* New features:

* :py:func:`jax.pmap` has ``static_broadcast_argnums`` argument which allows
Expand Down
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@
'notebooks/score_matching.ipynb',
'notebooks/maml.ipynb',
# Fails with shape error in XL
'notebooks/XLA_in_Python.ipynb'
'notebooks/XLA_in_Python.ipynb',
# Sometimes sphinx reads its own outputs as inputs!
'build/html',
]

# The name of the Pygments (syntax highlighting) style to use.
Expand Down
32 changes: 16 additions & 16 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ JAX's NumPy::
import numpy as np
np.array([0] * int(1e6))

The reason is that in NumPy the `numpy.array` function is implemented in C, while
the `jax.numpy.array` is implemented in Python, and it needs to iterate over a long
The reason is that in NumPy the ``numpy.array`` function is implemented in C, while
the :func:`jax.numpy.array` is implemented in Python, and it needs to iterate over a long
list to convert each list element to an array element.

An alternative would be to create the array with original NumPy and then convert
Expand All @@ -23,13 +23,12 @@ it to a JAX array::
from jax import numpy as jnp
jnp.array(np.array([0] * int(1e6)))


`jit` changes the behavior of my function
-----------------------------------------

If you have a Python function that changes behavior after using `jit`, perhaps
If you have a Python function that changes behavior after using :func:`jax.jit`, perhaps
your function uses global state, or has side-effects. In the following code, the
`impure_func` uses the global `y` and has a side-effect due to `print`::
``impure_func`` uses the global ``y`` and has a side-effect due to ``print``::

y = 0

Expand All @@ -41,7 +40,7 @@ your function uses global state, or has side-effects. In the following code, the
for y in range(3):
print("Result:", impure_func(y))

Without `jit` the output is::
Without ``jit`` the output is::

Inside: 0
Result: 0
Expand All @@ -50,27 +49,28 @@ Without `jit` the output is::
Inside: 2
Result: 4

and with `jit` it is:
and with ``jit`` it is::

Inside: 0
Result: 0
Result: 1
Result: 2

For `jit` the function is executed once using the Python interpreter, at which time the
`Inside` printing happens, and the first value of `y` is observed. Then the function
is compiled and cached, and executed multiple times with different values of `x`, but
with the same first value of `y`.
For :func:`jax.jit`, the function is executed once using the Python interpreter, at which time the
``Inside`` printing happens, and the first value of ``y`` is observed. Then the function
is compiled and cached, and executed multiple times with different values of ``x``, but
with the same first value of ``y``.

Additional reading:

* [JAX - The Sharp Bits: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Pure-functions)
* `JAX - The Sharp Bits: Pure Functions <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Pure-functions>`_.


Gradients contain `NaN` where using ``where``
------------------------------------------------

If you define a function using ``where`` to avoid an undefined value, if you
are not careful you may obtain a `NaN` for reverse differentiation::
are not careful you may obtain a ``NaN`` for reverse differentiation::

def my_log(x):
return np.where(x > 0., np.log(x), 0.)
Expand All @@ -90,7 +90,7 @@ that the adjoint is always finite::
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok

The inner ``np.where`` may be needed in addition to the original one, e.g.:
The inner ``np.where`` may be needed in addition to the original one, e.g.::

def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
Expand All @@ -99,5 +99,5 @@ The inner ``np.where`` may be needed in addition to the original one, e.g.:

Additional reading:

* [Issue: gradients through np.where when one of branches is nan](https://github.com/google/jax/issues/1052#issuecomment-514083352)
* [How to avoid NaN gradients when using ``where``](https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf)
* `Issue: gradients through np.where when one of branches is nan <https://github.com/google/jax/issues/1052#issuecomment-514083352>`_.
* `How to avoid NaN gradients when using where <https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf>`_.

0 comments on commit 3ca7f6e

Please sign in to comment.