Skip to content

Commit

Permalink
Remember compile-only clients for canonicalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jaro-sevcik committed Nov 21, 2024
1 parent f442d40 commit 9893da3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
6 changes: 4 additions & 2 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,8 @@ def lower_parallel_callable(
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
backend.platform)
module_name = f"pmap_{fun.__name__}"
platforms = lowering_platforms or (backend.platform,)
platforms = lowering_platforms or (
xb.canonical_platform_for_backend(backend),)
with core.extend_axis_env_nd([(axis_name, global_axis_size)]):
ordered_effects = list(
effects.ordered_effects.filter_in(closed_jaxpr.effects))
Expand Down Expand Up @@ -2213,7 +2214,8 @@ def lower_sharding_computation(
out_shardings = _concretize_abstract_shardings(
out_shardings, global_out_avals, device_assignment)

platforms = lowering_platforms or (backend.platform,)
platforms = lowering_platforms or (
xb.canonical_platform_for_backend(backend),)

committed = bool(
devices_from_context or
Expand Down
9 changes: 9 additions & 0 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import traceback
from typing import Any, Union
import warnings
import weakref

from jax._src import config
from jax._src import distributed
Expand Down Expand Up @@ -191,6 +192,8 @@ class BackendRegistration:
_backend_factories: dict[str, BackendRegistration] = {}
_default_backend: xla_client.Client | None = None
_backends : dict[str, xla_client.Client] = {}
_WeakClientToStrDict = weakref.WeakKeyDictionary[xla_client.Client, str]
_compile_only_backend_canonical_platform = _WeakClientToStrDict()
_backend_errors : dict[str, str] = {}
_backend_lock = threading.Lock()
_plugins_registered: bool = False
Expand Down Expand Up @@ -226,6 +229,8 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
if make_topology is not None:
_topology_factories[name] = make_topology

def register_compile_only_backend(platform: str, client: xla_client.Client):
_compile_only_backend_canonical_platform[client] = platform

def make_cpu_client(
collectives: xla_client._xla.CpuCollectives | None = None,
Expand Down Expand Up @@ -813,6 +818,10 @@ def is_known_platform(platform: str) -> bool:
# we've heard of it and it isn't, e.g., a typo.
return platform in known_platforms()

def canonical_platform_for_backend(backend: xla_client.Client) -> str:
if backend in _compile_only_backend_canonical_platform:
return _compile_only_backend_canonical_platform[backend]
return canonicalize_platform(backend.platform)

def canonicalize_platform(platform: str) -> str:
"""Replaces platform aliases with their concrete equivalent.
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def get_topology_desc(
)
try:
topology = xb.make_pjrt_topology(platform, topology_name, **kwargs)
return TopologyDescription(topology._make_compile_only_devices())
devices = topology._make_compile_only_devices()
if platform:
xb.register_compile_only_backend(platform, devices[0].client)
return TopologyDescription(devices)
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if msg.startswith("UNIMPLEMENTED"):
Expand Down

0 comments on commit 9893da3

Please sign in to comment.