@@ -41,7 +41,7 @@ include("PJRT/PJRT.jl")
4141end
4242
4343function Base. getproperty (bs:: BackendState , sym:: Symbol )
44- (sym === :initialized || bs. initialized) && return getfield (bs, sym)
44+ (sym === :initialized || bs. initialized) && return getfield (bs, sym)
4545 initialize_default_clients! (bs)
4646 return getfield (bs, sym)
4747end
@@ -56,6 +56,7 @@ const global_backend_state = BackendState()
5656
5757client (backend:: String ) = global_backend_state. clients[backend]
5858default_backend () = global_backend_state. default_client
59+ process_index () = process_index (default_backend ())
5960
6061function set_default_backend (backend:: AbstractClient )
6162 global_backend_state. default_client = backend
@@ -141,6 +142,13 @@ function initialize_default_clients!(state::BackendState)
141142 else
142143 if ! Reactant. precompiling ()
143144 try
145+ distributed_runtime_client = if PJRT. global_state. num_processes > 1
146+ @assert PJRT. global_state. client != = nothing
147+ PJRT. global_state. client
148+ else
149+ nothing
150+ end
151+
144152 if was_initialized && haskey (state. clients, " gpu" )
145153 XLA. free_client (state. clients[" gpu" ])
146154 XLA. PJRT. gpu_client_count[] -= 1
@@ -149,6 +157,7 @@ function initialize_default_clients!(state::BackendState)
149157 PJRT. global_state. process_id,
150158 PJRT. global_state. num_processes;
151159 allowed_devices= PJRT. global_state. local_device_ids,
160+ distributed_runtime_client,
152161 )
153162 state. clients[" gpu" ] = gpu
154163 state. default_client = gpu
0 commit comments