Skip to content

Commit

Permalink
Allow setting default_device with platform names.
Browse files Browse the repository at this point in the history
  • Loading branch information
Stella-S-Yan committed Nov 7, 2024
1 parent 37af100 commit a869033
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
7 changes: 1 addition & 6 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,13 +1555,11 @@ def _update_x64_thread_local(val):
def _update_default_device_global(val):
lib.jax_jit.global_state().default_device = val


def _update_default_device_thread_local(val):
lib.jax_jit.thread_local_state().default_device = val


def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device):
if val is not None and not isinstance(val, xla_client.Device) and val not in ['cpu', 'gpu', 'tpu']:
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface.
if 'Device' in str(type(val)):
Expand All @@ -1572,9 +1570,6 @@ def _validate_default_device(val):
raise ValueError('jax.default_device must be passed a Device object (e.g. '
f"`jax.devices('cpu')[0]`), got: {val!r}")


# TODO(skye): default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
default_device = string_or_object_state(
name='jax_default_device',
default=None,
Expand Down
7 changes: 6 additions & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,10 @@ class DeviceAssignmentMismatchError(Exception):


def _get_default_device() -> xc.Device:
return config.default_device.value or xb.local_devices()[0]
if isinstance(config.default_device.value, str):
return xb.get_backend(config.default_device.value).devices()[0]
else:
return config.default_device.value or xb.local_devices()[0]


def _get_and_check_device_assignment(
Expand Down Expand Up @@ -1715,6 +1718,7 @@ def _get_and_check_device_assignment(
raise DeviceAssignmentMismatchError([
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])

if first_sharding_info is None and devices:
final_device_assignment = devices
elif first_sharding_info is None:
Expand Down Expand Up @@ -2163,6 +2167,7 @@ def lower_sharding_computation(
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))


devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr
Expand Down
11 changes: 7 additions & 4 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,14 @@ def test_jit_default_device(self, module):

# TODO(skye): make this work!
def test_jit_default_platform(self):
with self.assertRaisesWithLiteralMatch(
ValueError, "jax.default_device must be passed a Device object "
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):

with jax.default_device("cpu"):
jax.jit(lambda x: x + 1)(1)
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, "cpu")

result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, jax.default_backend())


def test_complex_support(self):
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
Expand Down

0 comments on commit a869033

Please sign in to comment.