Skip to content

Commit 2e11fd5

Browse files
hawkinspChexDev
authored andcommitted
[JAX] Replace uses of jax.devices("cpu") with jax.local_devices(backend="cpu").
An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes. This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future. This change is always be safe (i.e., it should always preserve the previous behavior) but it may sometimes be unnecessary if code is never used in a multicontroller setting. PiperOrigin-RevId: 582427762
1 parent cd1cebb commit 2e11fd5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

chex/_src/asserts_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ def test_assert_tree_has_only_ndarrays(self):
11031103
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(101), 'b': [1, 2]})
11041104

11051105
def test_assert_tree_is_on_host(self):
1106-
cpu = jax.devices('cpu')[0]
1106+
cpu = jax.local_devices(backend='cpu')[0]
11071107

11081108
# Check Numpy arrays.
11091109
for flag in (False, True):
@@ -1173,7 +1173,7 @@ def test_assert_tree_is_on_host(self):
11731173

11741174
def test_assert_tree_is_on_device(self):
11751175
# Check CPU platform.
1176-
cpu = jax.devices('cpu')[0]
1176+
cpu = jax.local_devices(backend='cpu')[0]
11771177
to_cpu = lambda x: jax.device_put(x, cpu)
11781178

11791179
cpu_tree = {'a': to_cpu(np.zeros(1)), 'b': to_cpu(np.ones(3))}
@@ -1262,7 +1262,7 @@ def _format(*devs):
12621262
return re.escape(f'{devs}')
12631263

12641264
# Check single-device case.
1265-
cpu = jax.devices('cpu')[0]
1265+
cpu = jax.local_devices(backend='cpu')[0]
12661266
cpu_tree = jax.device_put_replicated(np_tree, (cpu,))
12671267

12681268
asserts.assert_tree_is_sharded(cpu_tree, devices=(cpu,))

0 commit comments

Comments
 (0)