Skip to content

Commit

Permalink
Add a note about jax.pmap when leading dim is smaller than num device…
Browse files Browse the repository at this point in the history
…s. (#2949)
  • Loading branch information
tomhennigan authored May 4, 2020
1 parent c9c653a commit 4c2c5ad
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,10 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
>>> out = pmap(lambda x: x ** 2)(np.arange(8))
>>> print(out)
[0, 1, 4, 9, 16, 25, 36, 49]
When the leading dimension is smaller than the number of available devices JAX
will simply run on a subset of devices:
>>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(np.dot)(x, y)
Expand All @@ -920,6 +924,12 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
[[ 1412. 1737.]
[ 1740. 2141.]]]
If your leading dimension is larger than the number of available devices you
will get an error:
>>> pmap(lambda x: x ** 2)(np.arange(9))
ValueError: ... requires 9 replicas, but only 8 XLA devices are available
As with ``vmap``, using ``None`` in ``in_axes`` indicates that an argument
doesn't have an extra axis and should be broadcasted, rather than mapped,
across the replicas:
Expand Down

0 comments on commit 4c2c5ad

Please sign in to comment.