Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ahead-of-time compilation on GPU fails without a GPU device #23971

Open
jaro-sevcik opened this issue Sep 27, 2024 · 5 comments · May be fixed by #25033
Open

Ahead-of-time compilation on GPU fails without a GPU device #23971

jaro-sevcik opened this issue Sep 27, 2024 · 5 comments · May be fixed by #25033
Labels
bug Something isn't working

Comments

@jaro-sevcik
Copy link
Contributor

Description

With a local GPU device, one can compile ahead-of-time even for a different GPU and topology, as illustrated by the code at the end of this report (this only works with recent XLA - openxla/xla#16913). For deserialization code, see https://gist.github.com/jaro-sevcik/3495718bb04c6096c0f998fc29220c2b.

However, the same program fails without a device because the XLA/JAX runtime performs some renaming between "gpu" and "cuda"/"rocm" platforms, presumably for compatibility with legacy scripts. Here is the error message we get when compiling deviceless:

RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. ...

If we bypass some of the renaming, for example with the patch below, the deviceless compilation and serialization succeed.

Here is the patch that makes deviceless compilation succeed for NVIDIA GPUs:

diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 1d3c50403..d1d56334d 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -798,6 +798,7 @@ def canonicalize_platform(platform: str) -> str:
   for p in platforms:
     if p in b.keys():
       return p
+  return "cuda"
   raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
                      f"platforms that are instances of {platform} are present. "
                      "Platforms present are: " + ",".join(b.keys()))

Alternatively, if one removes the renaming in XLA, deviceless AOT compilation also passes:

diff --git a/xla/python/py_client.h b/xla/python/py_client.h
index 374b7f6d2e..73543e91a6 100644
--- a/xla/python/py_client.h
+++ b/xla/python/py_client.h
@@ -96,16 +96,7 @@ class PyClient {
   }
 
   std::string_view platform_name() const {
-    // TODO(phawkins): this is a temporary backwards compatibility shim. We
-    // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but
-    // we haven't yet updated JAX clients that expect "gpu". Migrate users and
-    // remove this code.
-    if (ifrt_client_->platform_name() == "cuda" ||
-        ifrt_client_->platform_name() == "rocm") {
-      return "gpu";
-    } else {
-      return ifrt_client_->platform_name();
-    }
+    return ifrt_client_->platform_name();
   }
   std::string_view platform_version() const {
     return ifrt_client_->platform_version();

Ahead-of-time compilation and serialization code:

import jax
import jax.numpy as jp
import jax.experimental.topologies as topologies
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental.serialize_executable import (
    deserialize_and_load,
    serialize,
)

# Contents of https://github.com/openxla/xla/blob/main/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb
target_config_proto = """gpu_device_info {
  threads_per_block_limit: 1024
  threads_per_warp: 32
  shared_memory_per_block: 49152
  shared_memory_per_core: 167936
  threads_per_core_limit: 2048
  core_count: 108
  fpus_per_core: 64
  block_dim_limit_x: 2147483647
  block_dim_limit_y: 65535
  block_dim_limit_z: 65535
  memory_bandwidth: 2039000000000
  l2_cache_size: 41943040
  clock_rate_ghz: 1.1105
  device_memory_size: 79050250240
  shared_memory_per_block_optin: 166912
  cuda_compute_capability {
    major: 8
  }
  registers_per_core_limit: 65536
  registers_per_block_limit: 65536
}
platform_name: "CUDA"
dnn_version_info {
  major: 8
  minor: 3
  patch: 2
}
device_description_str: "A100 80GB"
"""

# Requested topology:
# 1 machine
# 1 process per machine
# 2 devices per process
topo = topologies.get_topology_desc(
  "topo",
  "cuda",
  target_config=target_config_proto,
  topology="1x1x2")

# Create the mesh and sharding.
mesh = Mesh(topo.devices, ('x',))
s = NamedSharding(mesh, P('x', None))

def fn(x):
  return jp.sum(x * x)

# JIT with fully specified shardings.
fn = jax.jit(fn, in_shardings=s, out_shardings=NamedSharding(mesh, P()))

# Provide input shape(s).
x_shape = jax.ShapeDtypeStruct(
          shape=(16, 16),
          dtype=jp.dtype('float32'),
          sharding=s)

# Lower and compile.
lowered = fn.lower(x_shape)
compiled = lowered.compile()

# Serialize the compilation results.
serialized, in_tree, out_tree = serialize(compiled)
print("Executable compiled and serialized")

# Write the serialized code to a file.
fname = "square.xla.bin"
with open(fname, "wb") as binary_file:
    binary_file.write(serialized)

print(f"Executable saved to '{fname}'")

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.34.dev20240926+b6d668e0d
jaxlib: 0.4.34.dev20240927
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.4.0-92-generic', version='#103-Ubuntu SMP Fri Nov 26 16:13:00 UTC 2021', machine='x86_64')
@jaro-sevcik jaro-sevcik added the bug Something isn't working label Sep 27, 2024
@jaro-sevcik
Copy link
Contributor Author

@hawkinsp , could you take a look?

@jaro-sevcik
Copy link
Contributor Author

@hawkinsp The same problem also prevents us from using GPU client via IFRT proxy client because the canonicalize_platform function also does not find the 'cuda' backend for the 'gpu' platform alias (if we register the IFRT client under the name 'proxy').

@mjsML
Copy link
Collaborator

mjsML commented Nov 17, 2024

@jaro-sevcik should this be closed?

@jaro-sevcik
Copy link
Contributor Author

@mjsML Why do you think it should be closed? As far as I can see, there is still TODO(phawkins) in the offending code.

@mjsML
Copy link
Collaborator

mjsML commented Nov 18, 2024

@jaro-sevcik , daily because of this was last updated a month ago :) ... Let's sync about this offline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants