Skip to content

Commit 9bb572b

Browse files
Avik Palavik-pal
authored andcommitted
fix: correctly set kv_store
1 parent 8111623 commit 9bb572b

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/xla/Client.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,26 @@ function CPUClient(cfunc, node_id=0, num_nodes=1; asynchronous=true)
2323
return client
2424
end
2525

26-
function GPUClient(cfunc, node_id=0, num_nodes=1, platform="gpu"; allowed_devices=nothing)
26+
function GPUClient(
27+
cfunc,
28+
node_id=0,
29+
num_nodes=1,
30+
platform="gpu";
31+
allowed_devices=nothing,
32+
distributed_runtime_client=nothing,
33+
)
2734
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
2835
refstr = Ref{Cstring}()
2936

3037
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
3138
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
39+
distributed_runtime_client =
40+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
3241

3342
client = ccall(
3443
f,
3544
Ptr{Cvoid},
36-
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}),
45+
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}, Ptr{Cvoid}),
3746
node_id,
3847
num_nodes,
3948
allowed_devices,
@@ -42,6 +51,7 @@ function GPUClient(cfunc, node_id=0, num_nodes=1, platform="gpu"; allowed_device
4251
false,
4352
platform,
4453
refstr,
54+
distributed_runtime_client,
4555
)
4656
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
4757
LLVMclopts("-nvptx-fma-level=1")

src/xla/XLA.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ include("PJRT/PJRT.jl")
4141
end
4242

4343
function Base.getproperty(bs::BackendState, sym::Symbol)
44-
(sym === :initialized || bs.initialized) && return getfield(bs, sym)
44+
(sym === :initialized || bs.initialized) && return getfield(bs, sym)
4545
initialize_default_clients!(bs)
4646
return getfield(bs, sym)
4747
end
@@ -56,6 +56,7 @@ const global_backend_state = BackendState()
5656

5757
client(backend::String) = global_backend_state.clients[backend]
5858
default_backend() = global_backend_state.default_client
59+
process_index() = process_index(default_backend())
5960

6061
function set_default_backend(backend::AbstractClient)
6162
global_backend_state.default_client = backend
@@ -141,6 +142,13 @@ function initialize_default_clients!(state::BackendState)
141142
else
142143
if !Reactant.precompiling()
143144
try
145+
distributed_runtime_client = if PJRT.global_state.num_processes > 1
146+
@assert PJRT.global_state.client !== nothing
147+
PJRT.global_state.client
148+
else
149+
nothing
150+
end
151+
144152
if was_initialized && haskey(state.clients, "gpu")
145153
XLA.free_client(state.clients["gpu"])
146154
XLA.PJRT.gpu_client_count[] -= 1
@@ -149,6 +157,7 @@ function initialize_default_clients!(state::BackendState)
149157
PJRT.global_state.process_id,
150158
PJRT.global_state.num_processes;
151159
allowed_devices=PJRT.global_state.local_device_ids,
160+
distributed_runtime_client,
152161
)
153162
state.clients["gpu"] = gpu
154163
state.default_client = gpu

0 commit comments

Comments
 (0)