diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index 74eb019..83c67c8 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -1103,7 +1103,7 @@ def test_assert_tree_has_only_ndarrays(self): asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(101), 'b': [1, 2]}) def test_assert_tree_is_on_host(self): - cpu = jax.devices('cpu')[0] + cpu = jax.local_devices(backend='cpu')[0] # Check Numpy arrays. for flag in (False, True): @@ -1173,7 +1173,7 @@ def test_assert_tree_is_on_host(self): def test_assert_tree_is_on_device(self): # Check CPU platform. - cpu = jax.devices('cpu')[0] + cpu = jax.local_devices(backend='cpu')[0] to_cpu = lambda x: jax.device_put(x, cpu) cpu_tree = {'a': to_cpu(np.zeros(1)), 'b': to_cpu(np.ones(3))} @@ -1262,7 +1262,7 @@ def _format(*devs): return re.escape(f'{devs}') # Check single-device case. - cpu = jax.devices('cpu')[0] + cpu = jax.local_devices(backend='cpu')[0] cpu_tree = jax.device_put_replicated(np_tree, (cpu,)) asserts.assert_tree_is_sharded(cpu_tree, devices=(cpu,))