Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonicalize platform name for compile-only backends #25033

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,10 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
platforms = tuple(map(xb.canonicalize_platform, platforms))
backend = (backend_or_name if isinstance(backend_or_name, xb.XlaBackend) else
None)
platforms = tuple(map(lambda p: xb.canonicalize_platform(p, backend),
platforms))

in_avals = (jaxpr.in_avals if arg_shardings is None else
map(sharded_aval, jaxpr.in_avals, arg_shardings))
Expand Down
18 changes: 16 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import traceback
from typing import Any, Union
import warnings
import weakref

from jax._src import config
from jax._src import distributed
Expand Down Expand Up @@ -191,6 +192,8 @@ class BackendRegistration:
_backend_factories: dict[str, BackendRegistration] = {}
_default_backend: xla_client.Client | None = None
_backends : dict[str, xla_client.Client] = {}
_WeakClientToStrDict = weakref.WeakKeyDictionary[xla_client.Client, str]
_compile_only_backend_canonical_platform = _WeakClientToStrDict()
_backend_errors : dict[str, str] = {}
_backend_lock = threading.Lock()
_plugins_registered: bool = False
Expand Down Expand Up @@ -226,6 +229,11 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
if make_topology is not None:
_topology_factories[name] = make_topology

def register_compile_only_backend(platform: str, client: xla_client.Client):
# Register the compile-only backend-platform mapping as long as
# the platform is not an alias.
if _alias_to_platforms.get(platform, None) == None:
_compile_only_backend_canonical_platform[client] = platform

def make_cpu_client(
collectives: xla_client._xla.CpuCollectives | None = None,
Expand Down Expand Up @@ -813,8 +821,8 @@ def is_known_platform(platform: str) -> bool:
# we've heard of it and it isn't, e.g., a typo.
return platform in known_platforms()


def canonicalize_platform(platform: str) -> str:
def canonicalize_platform(platform: str,
backend: xla_client.Client | None = None) -> str:
"""Replaces platform aliases with their concrete equivalent.

In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
Expand All @@ -830,6 +838,12 @@ def canonicalize_platform(platform: str) -> str:
for p in platforms:
if p in b.keys():
return p

if (backend and
backend.platform == platform and
backend in _compile_only_backend_canonical_platform):
return _compile_only_backend_canonical_platform[backend]

raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
f"platforms that are instances of {platform} are present. "
"Platforms present are: " + ",".join(b.keys()))
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def get_topology_desc(
)
try:
topology = xb.make_pjrt_topology(platform, topology_name, **kwargs)
return TopologyDescription(topology._make_compile_only_devices())
devices = topology._make_compile_only_devices()
if platform:
xb.register_compile_only_backend(platform, devices[0].client)
return TopologyDescription(devices)
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if msg.startswith("UNIMPLEMENTED"):
Expand Down
13 changes: 13 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,19 @@ jax_multiplatform_test(
],
)

jax_multiplatform_test(
name = "deviceless_aot_test",
srcs = ["deviceless_aot_test.py"],
enable_backends = [ "gpu" ],
enable_configs = [
"gpu_a100",
],
deps = [
"//jax:experimental",
] + py_deps("numpy"),
)


jax_multiplatform_test(
name = "sparsify_test",
srcs = ["sparsify_test.py"],
Expand Down
88 changes: 88 additions & 0 deletions tests/deviceless_aot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for deviceless AOT compilation on GPU."""

from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.experimental import topologies
from jax.experimental.serialize_executable import (
deserialize_and_load,
serialize,
)
import jax.numpy as jnp
from jax.lib import xla_client as xc
import multiprocessing

class DevicelessGpuAotTest(jtu.JaxTestCase):

def compile_worker(target_config):
# Run without GPU.
jax.config.update("jax_platforms", "cpu")

topo = topologies.get_topology_desc(
"topo",
"cuda",
target_config=target_config,
topology="1x1x1")

sharding = jax.sharding.SingleDeviceSharding(topo.devices[0])

# Function to compile.
@jax.jit
def fn(x):
return jnp.sum(x * x)

# Provide input shape(s).
x_shape = jax.ShapeDtypeStruct(
shape=(2, 2),
dtype=jnp.dtype('float32'),
sharding=sharding)

# Lower and compile.
compiled = fn.lower(x_shape).compile()

# Serialize the compilation results.
serialized, in_tree, out_tree = serialize(compiled)

return serialized

@jtu.skip_under_pytest("Test must run in an isolated process")
def test_serialize_deserialize_execute(self):
target_config = xc.get_topology_for_devices(jax.devices()).target_config

# Call the compilation in a different process so that we
# can start JAX without the GPU platform there.
multiprocessing.set_start_method("spawn")
pool = multiprocessing.Pool(processes = 1)
[serialized] = pool.map(DevicelessGpuAotTest.compile_worker,
[target_config])
pool.close()

# Provide the input pytree structure (0 stands for leaf).
_, in_tree = jax.tree_util.tree_flatten(((0,), {}))
# Provide the output pytree structure (here just one JAX array).
_, out_tree = jax.tree_util.tree_flatten(0)
# Deserialize the function.
compiled = deserialize_and_load(serialized, in_tree, out_tree)


# Call the deserialized function.
result = compiled(jnp.array([[0., 1.], [2., 3.]]))

self.assertEqual(result, 14)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())