Fix a bug where full
and use_mesh
outside jit did not work because the shard
passed to make_array_from_callback
was sharded on all devices instead of just 1 device.
#5767
cloud-tpu-ci-presubmit.yml
on: pull_request
Matrix: build-jax-artifacts
TPU test (jaxlib=head, v5e-8)
/
Pytest TPU (v5e-8, Python 3.10, libtpu=nightly)
6m 59s