Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove LLVM IR introspection from init #881

Merged
merged 2 commits into from
Jun 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 27 additions & 105 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6225,111 +6225,38 @@ function from_tape_type(::Type{B}, ctx) where {B<:Tuple}
end
const Tracked = 10


const task_offset=Ref{Int}(0)
function current_task_offset()
if task_offset[] == 0
f = Core.Typeof(Base.current_task)
world = Base.get_world_counter()
ctx = JuliaContext()
target = Enzyme.Compiler.DefaultCompilerTarget()
params = Enzyme.Compiler.PrimalCompilerParams(Enzyme.API.CDerivativeMode(0))
funcspec = GPUCompiler.methodinstance(f, Tuple{}, world)
job = CompilerJob(funcspec, CompilerConfig(target, params; kernel=false), world)

otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, cleanup=false, validate=false, ctx)

fn = only((f for f in functions(otherMod) if !isempty(LLVM.blocks(f))))
bb = only(blocks(fn))

todel = LLVM.Instruction[]
for i in instructions(bb)
if isa(i, LLVM.CallInst) && isa(value_type(i), LLVM.VoidType)
push!(todel, i)
elseif isa(i, LLVM.FenceInst)
push!(todel, i)
end
end
for i in todel
unsafe_delete!(bb, i)
end

LLVM.ModulePassManager() do pm
dce!(pm)
run!(pm, otherMod)
end

gep = only((i for i in instructions(bb) if isa(i, LLVM.GetElementPtrInst)))
op = only(operands(gep)[2:end])
off = convert(Int, op)
task_offset[] = off
# See get_current_task_from_pgcstack (used from 1.7+)
if VERSION >= v"1.9.1"
current_task_offset() = -(unsafe_load(cglobal(:jl_task_gcstack_offset, Cint)) ÷ sizeof(Ptr{Cvoid}))
elseif VERSION >= v"1.9.0"
if Sys.WORD_SIZE == 64
current_task_offset() = -13
else
current_task_offset() = -18
end
else
if Sys.WORD_SIZE == 64
current_task_offset() = -12 #1.8/1.7
else
current_task_offset() = -17 #1.8/1.7
end
return task_offset[]
end

@static if VERSION < v"1.7.0"
# See get_current_ptls_from_task (used from 1.7+)
if VERSION >= v"1.9.1"
current_ptls_offset() = unsafe_load(cglobal(:jl_task_ptls_offset, Cint)) ÷ sizeof(Ptr{Cvoid})
elseif VERSION >= v"1.9.0"
if Sys.WORD_SIZE == 64
current_ptls_offset() = 15
else
current_ptls_offset() = 20
end
else
const ptls_offset=Ref{Int}(0)
function current_ptls_offset()
if ptls_offset[] == 0
@static if VERSION < v"1.8.0"
f = Core.Typeof(Base.Ref)
world = Base.get_world_counter()
ctx = JuliaContext()
target = Enzyme.Compiler.DefaultCompilerTarget()
params = Enzyme.Compiler.PrimalCompilerParams(Enzyme.API.CDerivativeMode(0))
funcspec = GPUCompiler.methodinstance(f, Tuple{Float64}, world)
job = CompilerJob(funcspec, CompilerConfig(target, params; kernel=false), world)
otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, cleanup=false, validate=false, ctx)

fn = only((f for f in functions(otherMod) if !isempty(LLVM.blocks(f))))
bb = only(blocks(fn))
gep = only((i for i in instructions(bb) if LLVM.name(i) == "ptls_field"))
op = only(operands(gep)[2:end])
off = convert(Int, op)
ptls_offset[] = off
else
task_off = current_task_offset()
mod = """
declare {}*** @julia.get_pgcstack()
declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*)

define void @gc_alloc_lowering() {
top:
%pgc = call {}*** @julia.get_pgcstack()
%t = bitcast {}*** %pgc to {}**
%current_task = getelementptr inbounds {}*, {}** %t, i64 $task_off
%v = call noalias {} addrspace(10)* @julia.gc_alloc_obj({}** %current_task, i64 8, {} addrspace(10)* undef)
ret void
}
"""

ctx = LLVM.Context()
otherMod = parse(LLVM.Module, mod; ctx)

LLVM.ModulePassManager() do pm
LLVM.Interop.late_lower_gc_frame!(pm)
run!(pm, otherMod)
end
fn = only((f for f in functions(otherMod) if !isempty(LLVM.blocks(f))))
bb = only(blocks(fn))
off = 0
for i in instructions(bb)
if LLVM.name(i) == "ptls_field"
op = only(operands(i)[2:end])
off_n = convert(Int, op)
if off != 0
@assert off == off_n
end
off = off_n
end
end
@assert off != 0
ptls_offset[] = off
end
if Sys.WORD_SIZE == 64
current_ptls_offset() = 14 # 1.8/1.7
else
current_ptls_offset() = 19
end
return ptls_offset[]
end
end

function get_julia_inner_types(B, p, startvals...; added=[])
Expand Down Expand Up @@ -6763,11 +6690,6 @@ function emit_inacterror(B, V, orig)
end

function __init__()
current_task_offset()
@static if VERSION < v"1.7.0"
else
current_ptls_offset()
end
API.EnzymeSetHandler(@cfunction(julia_error, Cvoid, (Cstring, LLVM.API.LLVMValueRef, API.ErrorType, Ptr{Cvoid}, LLVM.API.LLVMValueRef)))
API.EnzymeSetSanitizeDerivatives(@cfunction(julia_sanitize, LLVM.API.LLVMValueRef, (LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef)));
if API.EnzymeHasCustomInactiveSupport()
Expand Down