Skip to content

Commit

Permalink
Add a note about jax.pmap when leading dim is smaller than num devices.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhennigan committed May 4, 2020
1 parent 9802d73 commit c121792
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,10 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
>>> out = pmap(lambda x: x ** 2)(np.arange(8))
>>> print(out)
[0, 1, 4, 9, 16, 25, 36, 49]
When the leading dimension is smaller than the number of available devices JAX
will simply run on a subset of devices:
>>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(np.dot)(x, y)
Expand Down

0 comments on commit c121792

Please sign in to comment.