Skip to content

Commit e301a71

Browse files
committed
reviews
1 parent 3161fd9 commit e301a71

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

src/precompile.jl renamed to src/Precompile.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ using PrecompileTools: @setup_workload, @compile_workload
44
@static if haskey(ENV, "REACTANT_TEST_GROUP")
55
return
66
end
7-
@info "enable precompilation" gethostname() Base.active_project()
87
@compile_workload begin
9-
Reactant.__init__()
8+
initialize_dialect()
109
cpu = XLA.CPUClient()
1110
x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu)
1211
@code_hlo optimize = false sum(x)
12+
XLA.free_client(cpu)
13+
deinitialize_dialect()
1314
end
15+
XLA.cpuclientcount[] = 0
1416
end

src/Reactant.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,23 @@ include("Compiler.jl")
114114
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
115115
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace
116116

117-
const registry = Ref{MLIR.IR.DialectRegistry}()
118-
function __init__()
117+
const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()
118+
119+
function initialize_dialect()
119120
registry[] = MLIR.IR.DialectRegistry()
120121
@ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(
121122
registry[]::MLIR.API.MlirDialectRegistry
122123
)::Cvoid
123124
end
124125

126+
function deinitialize_dialect()
127+
return registry[] = nothing
128+
end
129+
130+
function __init__()
131+
return initialize_dialect()
132+
end
133+
125134
function set_default_backend(backend::XLA.Client)
126135
return XLA.default_backend[] = backend
127136
end
@@ -130,6 +139,6 @@ function set_default_backend(backend::String)
130139
return set_default_backend(XLA.backends[backend])
131140
end
132141

133-
include("precompile.jl")
142+
include("Precompile.jl")
134143

135144
end # module

src/XLA.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@ mutable struct Client
77

88
function Client(client::Ptr{Cvoid})
99
@assert client != C_NULL
10-
client = new(client)
11-
#TODO: Client are also constructed from MLIR.API.mlir_c.BufferToClient so the pointer cannot be free when Client is cleaned
12-
#finalizer(client) do client
13-
# @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
14-
#end
15-
return client
10+
return new(client)
1611
end
1712
end
1813

14+
@inline function free_client(client::Client)
15+
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
16+
end
17+
1918
function to_row_major(x::Array{T,N}) where {T,N}
2019
return permutedims(x, reverse(Base.OneTo(N)))
2120
end
@@ -42,8 +41,11 @@ end
4241

4342
SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid
4443

44+
global cpuclientcount = Ref(0)
4545
# TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing`
4646
function CPUClient(asynchronous=false, node_id=0, num_nodes=1)
47+
@assert cpuclientcount[] == 0
48+
cpuclientcount[] += 1
4749
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
4850
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
4951
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}

0 commit comments

Comments
 (0)