Skip to content

Commit

Permalink
Remember compile-only clients for canonicalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jaro-sevcik committed Nov 25, 2024
1 parent f442d40 commit ea681c3
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 4 deletions.
5 changes: 4 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,10 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
platforms = tuple(map(xb.canonicalize_platform, platforms))
backend = (backend_or_name if isinstance(backend_or_name, xb.XlaBackend) else
None)
platforms = tuple(map(lambda p: xb.canonicalize_platform(p, backend),
platforms))

in_avals = (jaxpr.in_avals if arg_shardings is None else
map(sharded_aval, jaxpr.in_avals, arg_shardings))
Expand Down
18 changes: 16 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import traceback
from typing import Any, Union
import warnings
import weakref

from jax._src import config
from jax._src import distributed
Expand Down Expand Up @@ -191,6 +192,8 @@ class BackendRegistration:
_backend_factories: dict[str, BackendRegistration] = {}
_default_backend: xla_client.Client | None = None
_backends : dict[str, xla_client.Client] = {}
_WeakClientToStrDict = weakref.WeakKeyDictionary[xla_client.Client, str]
_compile_only_backend_canonical_platform = _WeakClientToStrDict()
_backend_errors : dict[str, str] = {}
_backend_lock = threading.Lock()
_plugins_registered: bool = False
Expand Down Expand Up @@ -226,6 +229,11 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
if make_topology is not None:
_topology_factories[name] = make_topology

def register_compile_only_backend(platform: str, client: xla_client.Client):
# Register the compile-only backend-platform mapping as long as
# the platform is not an alias.
if _alias_to_platforms.get(platform, None) == None:
_compile_only_backend_canonical_platform[client] = platform

def make_cpu_client(
collectives: xla_client._xla.CpuCollectives | None = None,
Expand Down Expand Up @@ -813,8 +821,8 @@ def is_known_platform(platform: str) -> bool:
# we've heard of it and it isn't, e.g., a typo.
return platform in known_platforms()


def canonicalize_platform(platform: str) -> str:
def canonicalize_platform(platform: str,
backend: xla_client.Client | None = None) -> str:
"""Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
Expand All @@ -830,6 +838,12 @@ def canonicalize_platform(platform: str) -> str:
for p in platforms:
if p in b.keys():
return p

if (backend and
backend.platform == platform and
backend in _compile_only_backend_canonical_platform):
return _compile_only_backend_canonical_platform[backend]

raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
f"platforms that are instances of {platform} are present. "
"Platforms present are: " + ",".join(b.keys()))
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def get_topology_desc(
)
try:
topology = xb.make_pjrt_topology(platform, topology_name, **kwargs)
return TopologyDescription(topology._make_compile_only_devices())
devices = topology._make_compile_only_devices()
if platform:
xb.register_compile_only_backend(platform, devices[0].client)
return TopologyDescription(devices)
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if msg.startswith("UNIMPLEMENTED"):
Expand Down
13 changes: 13 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,19 @@ jax_multiplatform_test(
],
)

jax_multiplatform_test(
name = "deviceless_aot_test",
srcs = ["deviceless_aot_test.py"],
enable_backends = [ "gpu" ],
enable_configs = [
"gpu_a100",
],
deps = [
"//jax:experimental",
] + py_deps("numpy"),
)


jax_multiplatform_test(
name = "sparsify_test",
srcs = ["sparsify_test.py"],
Expand Down
88 changes: 88 additions & 0 deletions tests/deviceless_aot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for deviceless AOT compilation on GPU."""

from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.experimental import topologies
from jax.experimental.serialize_executable import (
deserialize_and_load,
serialize,
)
import jax.numpy as jnp
from jax.lib import xla_client as xc
import multiprocessing

class DevicelessGpuAotTest(jtu.JaxTestCase):

def compile_worker(target_config):
# Run without GPU.
jax.config.update("jax_platforms", "cpu")

topo = topologies.get_topology_desc(
"topo",
"cuda",
target_config=target_config,
topology="1x1x1")

sharding = jax.sharding.SingleDeviceSharding(topo.devices[0])

# Function to compile.
@jax.jit
def fn(x):
return jnp.sum(x * x)

# Provide input shape(s).
x_shape = jax.ShapeDtypeStruct(
shape=(2, 2),
dtype=jnp.dtype('float32'),
sharding=sharding)

# Lower and compile.
compiled = fn.lower(x_shape).compile()

# Serialize the compilation results.
serialized, in_tree, out_tree = serialize(compiled)

return serialized

@jtu.skip_under_pytest("Test must run in an isolated process")
def test_serialize_deserialize_execute(self):
target_config = xc.get_topology_for_devices(jax.devices()).target_config

# Call the compilation in a different process so that we
# can start JAX without the GPU platform there.
multiprocessing.set_start_method("spawn")
pool = multiprocessing.Pool(processes = 1)
[serialized] = pool.map(DevicelessGpuAotTest.compile_worker,
[target_config])
pool.close()

# Provide the input pytree structure (0 stands for leaf).
_, in_tree = jax.tree_util.tree_flatten(((0,), {}))
# Provide the output pytree structure (here just one JAX array).
_, out_tree = jax.tree_util.tree_flatten(0)
# Deserialize the function.
compiled = deserialize_and_load(serialized, in_tree, out_tree)


# Call the deserialized function.
result = compiled(jnp.array([[0., 1.], [2., 3.]]))

self.assertEqual(result, 14)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ea681c3

Please sign in to comment.