@@ -21,6 +21,7 @@ function LLVMclopts(opts...)
2121 ):: Cvoid
2222end
2323
24+ include (" Distributed.jl" )
2425include (" Client.jl" )
2526include (" Device.jl" )
2627include (" Sharding.jl" )
@@ -53,6 +54,7 @@ function Base.setproperty!(bs::BackendState, sym::Symbol, val)
5354end
5455
5556const global_backend_state = BackendState ()
57+ const global_state = State ()
5658
5759client (backend:: String ) = global_backend_state. clients[backend]
5860default_backend () = global_backend_state. default_client
@@ -68,6 +70,13 @@ function set_default_backend(backend::String)
6870 return nothing
6971end
7072
73+ function update_global_state! (args... ; kwargs... )
74+ update! (global_state, args... ; kwargs... )
75+ # We need to update the clients based on the new state
76+ initialize_default_clients! (global_backend_state)
77+ return nothing
78+ end
79+
7180function __init__ ()
7281 # This must be the very first thing initialized (otherwise we can't throw errors)
7382 errptr = cglobal ((:ReactantThrowError , MLIR. API. mlir_c), Ptr{Ptr{Cvoid}})
@@ -90,16 +99,17 @@ function __init__()
9099 @debug " XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
91100 end
92101
102+ if haskey (ENV , " REACTANT_VISIBLE_GPU_DEVICES" )
103+ global_state. local_device_ids =
104+ parse .(Int, split (ENV [" REACTANT_VISIBLE_GPU_DEVICES" ], " ," ))
105+ @debug " REACTANT_VISIBLE_GPU_DEVICES: " global_state. local_device_ids
106+ end
107+
93108 @ccall MLIR. API. mlir_c. RegisterEnzymeXLACPUHandler ():: Cvoid
94109 @ccall MLIR. API. mlir_c. RegisterEnzymeXLAGPUHandler ():: Cvoid
95110 return nothing
96111end
97112
98- function initialize_default_clients ()
99- initialize_default_clients! (global_backend_state)
100- return nothing
101- end
102-
103113function initialize_default_clients! (state:: BackendState )
104114 was_initialized = state. initialized
105115 state. initialized = true
@@ -109,7 +119,7 @@ function initialize_default_clients!(state::BackendState)
109119 XLA. free_client (state. clients[" cpu" ])
110120 XLA. PJRT. cpu_client_count[] -= 1
111121 end
112- cpu = PJRT. CPUClient (PJRT . global_state. process_id, PJRT . global_state. num_processes)
122+ cpu = PJRT. CPUClient (global_state. process_id, global_state. num_processes)
113123 state. clients[" cpu" ] = cpu
114124 state. default_client = cpu
115125
@@ -142,9 +152,9 @@ function initialize_default_clients!(state::BackendState)
142152 else
143153 if ! Reactant. precompiling ()
144154 try
145- distributed_runtime_client = if PJRT . global_state. num_processes > 1
146- @assert PJRT . global_state. client != = nothing
147- PJRT . global_state. client
155+ distributed_runtime_client = if global_state. num_processes > 1
156+ @assert global_state. client != = nothing
157+ global_state. client
148158 else
149159 nothing
150160 end
@@ -154,9 +164,9 @@ function initialize_default_clients!(state::BackendState)
154164 XLA. PJRT. gpu_client_count[] -= 1
155165 end
156166 gpu = PJRT. GPUClient (
157- PJRT . global_state. process_id,
158- PJRT . global_state. num_processes;
159- allowed_devices= PJRT . global_state. local_device_ids,
167+ global_state. process_id,
168+ global_state. num_processes;
169+ allowed_devices= global_state. local_device_ids,
160170 distributed_runtime_client,
161171 )
162172 state. clients[" gpu" ] = gpu
0 commit comments