Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting default_device with platform names. #24751

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,13 +1555,11 @@ def _update_x64_thread_local(val):
def _update_default_device_global(val):
lib.jax_jit.global_state().default_device = val


def _update_default_device_thread_local(val):
lib.jax_jit.thread_local_state().default_device = val


def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device):
if val is not None and not isinstance(val, xla_client.Device) and val not in ['cpu', 'gpu', 'tpu']:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is a good idea. Parameterizing default device on platform name doesn't sound like the right API.

What problem is this solving?

# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface.
if 'Device' in str(type(val)):
Expand All @@ -1572,9 +1570,6 @@ def _validate_default_device(val):
raise ValueError('jax.default_device must be passed a Device object (e.g. '
f"`jax.devices('cpu')[0]`), got: {val!r}")


# TODO(skye): default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
default_device = string_or_object_state(
name='jax_default_device',
default=None,
Expand Down
7 changes: 6 additions & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,10 @@ class DeviceAssignmentMismatchError(Exception):


def _get_default_device() -> xc.Device:
return config.default_device.value or xb.local_devices()[0]
if isinstance(config.default_device.value, str):
return xb.get_backend(config.default_device.value).devices()[0]
else:
return config.default_device.value or xb.local_devices()[0]


def _get_and_check_device_assignment(
Expand Down Expand Up @@ -1715,6 +1718,7 @@ def _get_and_check_device_assignment(
raise DeviceAssignmentMismatchError([
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])

if first_sharding_info is None and devices:
final_device_assignment = devices
elif first_sharding_info is None:
Expand Down Expand Up @@ -2163,6 +2167,7 @@ def lower_sharding_computation(
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))


devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr
Expand Down
11 changes: 7 additions & 4 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,14 @@ def test_jit_default_device(self, module):

# TODO(skye): make this work!
def test_jit_default_platform(self):
with self.assertRaisesWithLiteralMatch(
ValueError, "jax.default_device must be passed a Device object "
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):

with jax.default_device("cpu"):
jax.jit(lambda x: x + 1)(1)
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, "cpu")

result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, jax.default_backend())


def test_complex_support(self):
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
Expand Down