You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, the same program fails without a device because the XLA/JAX runtime performs some renaming between "gpu" and "cuda"/"rocm" platforms, presumably for compatibility with legacy scripts. Here is the error message we get when compiling deviceless:
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. ...
If we bypass some of the renaming, for example with the patch below, the deviceless compilation and serialization succeed.
Here is the patch that makes deviceless compilation succeed for NVIDIA GPUs:
diff--gita/jax/_src/xla_bridge.pyb/jax/_src/xla_bridge.pyindex1d3c50403..d1d56334d100644---a/jax/_src/xla_bridge.py+++b/jax/_src/xla_bridge.py@@ -798,6+798,7 @@ defcanonicalize_platform(platform: str) ->str:forpinplatforms:ifpinb.keys():returnp+return"cuda"raiseRuntimeError(f"Unknown backend: '{platform}' requested, but no "f"platforms that are instances of {platform} are present. ""Platforms present are: "+",".join(b.keys()))
Alternatively, if one removes the renaming in XLA, deviceless AOT compilation also passes:
diff --git a/xla/python/py_client.h b/xla/python/py_client.h
index 374b7f6d2e..73543e91a6 100644
--- a/xla/python/py_client.h
+++ b/xla/python/py_client.h
@@ -96,16 +96,7 @@ classPyClient {
}
std::string_view platform_name() const {
- // TODO(phawkins): this is a temporary backwards compatibility shim. We
- // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but
- // we haven't yet updated JAX clients that expect "gpu". Migrate users and
- // remove this code.
- if (ifrt_client_->platform_name() == "cuda" ||
- ifrt_client_->platform_name() == "rocm") {
- return"gpu";
- } else {
- return ifrt_client_->platform_name();
- }
+ return ifrt_client_->platform_name();
}
std::string_view platform_version() const {
return ifrt_client_->platform_version();
Ahead-of-time compilation and serialization code:
importjaximportjax.numpyasjpimportjax.experimental.topologiesastopologiesfromjax.shardingimportMesh, PartitionSpecasP, NamedShardingfromjax.experimental.serialize_executableimport (
deserialize_and_load,
serialize,
)
# Contents of https://github.com/openxla/xla/blob/main/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpbtarget_config_proto="""gpu_device_info { threads_per_block_limit: 1024 threads_per_warp: 32 shared_memory_per_block: 49152 shared_memory_per_core: 167936 threads_per_core_limit: 2048 core_count: 108 fpus_per_core: 64 block_dim_limit_x: 2147483647 block_dim_limit_y: 65535 block_dim_limit_z: 65535 memory_bandwidth: 2039000000000 l2_cache_size: 41943040 clock_rate_ghz: 1.1105 device_memory_size: 79050250240 shared_memory_per_block_optin: 166912 cuda_compute_capability { major: 8 } registers_per_core_limit: 65536 registers_per_block_limit: 65536}platform_name: "CUDA"dnn_version_info { major: 8 minor: 3 patch: 2}device_description_str: "A100 80GB""""# Requested topology:# 1 machine# 1 process per machine# 2 devices per processtopo=topologies.get_topology_desc(
"topo",
"cuda",
target_config=target_config_proto,
topology="1x1x2")
# Create the mesh and sharding.mesh=Mesh(topo.devices, ('x',))
s=NamedSharding(mesh, P('x', None))
deffn(x):
returnjp.sum(x*x)
# JIT with fully specified shardings.fn=jax.jit(fn, in_shardings=s, out_shardings=NamedSharding(mesh, P()))
# Provide input shape(s).x_shape=jax.ShapeDtypeStruct(
shape=(16, 16),
dtype=jp.dtype('float32'),
sharding=s)
# Lower and compile.lowered=fn.lower(x_shape)
compiled=lowered.compile()
# Serialize the compilation results.serialized, in_tree, out_tree=serialize(compiled)
print("Executable compiled and serialized")
# Write the serialized code to a file.fname="square.xla.bin"withopen(fname, "wb") asbinary_file:
binary_file.write(serialized)
print(f"Executable saved to '{fname}'")
System info (python version, jaxlib version, accelerator, etc.)
@hawkinsp The same problem also prevents us from using GPU client via IFRT proxy client because the canonicalize_platform function also does not find the 'cuda' backend for the 'gpu' platform alias (if we register the IFRT client under the name 'proxy').
Description
With a local GPU device, one can compile ahead-of-time even for a different GPU and topology, as illustrated by the code at the end of this report (this only works with recent XLA - openxla/xla#16913). For deserialization code, see https://gist.github.com/jaro-sevcik/3495718bb04c6096c0f998fc29220c2b.
However, the same program fails without a device because the XLA/JAX runtime performs some renaming between "gpu" and "cuda"/"rocm" platforms, presumably for compatibility with legacy scripts. Here is the error message we get when compiling deviceless:
If we bypass some of the renaming, for example with the patch below, the deviceless compilation and serialization succeed.
Here is the patch that makes deviceless compilation succeed for NVIDIA GPUs:
Alternatively, if one removes the renaming in XLA, deviceless AOT compilation also passes:
Ahead-of-time compilation and serialization code:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: