diff --git a/jax/api.py b/jax/api.py index f50187e527bf..23f21196392d 100644 --- a/jax/api.py +++ b/jax/api.py @@ -909,6 +909,10 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, >>> out = pmap(lambda x: x ** 2)(np.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49] + + When the leading dimension is smaller than the number of available devices JAX + will simply run on a subset of devices: + >>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(np.dot)(x, y) @@ -920,6 +924,12 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0, [[ 1412. 1737.] [ 1740. 2141.]]] + If your leading dimension is larger than the number of available devices you + will get an error: + + >>> pmap(lambda x: x ** 2)(np.arange(9)) + ValueError: ... requires 9 replicas, but only 8 XLA devices are available + As with ``vmap``, using ``None`` in ``in_axes`` indicates that an argument doesn't have an extra axis and should be broadcasted, rather than mapped, across the replicas: