Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overlay of Core.throw_inexacterror results in bad codegen of kwarg call #48097

Closed
maleadt opened this issue Jan 3, 2023 · 10 comments · Fixed by #48116 or #48988
Closed

Overlay of Core.throw_inexacterror results in bad codegen of kwarg call #48097

maleadt opened this issue Jan 3, 2023 · 10 comments · Fixed by #48116 or #48988
Labels
compiler:codegen Generation of LLVM IR and native code gpu Affects running Julia on a GPU regression Regression in behavior compared to a previous version

Comments

@maleadt
Copy link
Member

maleadt commented Jan 3, 2023

Backstory: CUDA.jl and other GPU back-ends use method overlays to replace GPU-incompatible functionality, like several exception-throwing functions. While we generally do support exception, replacing them at the LLVM level with custom reporting function, some are inherently GPU-incompatible due to string interpolation or untyped fields (resulting in a GC frame we cannot support). One such example is InexactError (from boot.jl):

struct InexactError <: Exception
    func::Symbol
    T  # Type
    val
    InexactError(f::Symbol, @nospecialize(T), @nospecialize(val)) = (@noinline; new(f, T, val))
end
throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} = (@noinline; throw(InexactError(f, T, val)))

To support this, we rewrite throw_inexacterror:

@device_override @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
    @gpu_print "Inexact conversion"

This used to work fine, but with specific kernels (that contain a kwarg function call where the kwargs are heterogeneously typed) we now get LLVM IR that performs dynamic function calls:

child(; kwargs...) = return
function parent()
    child(; a=1f0, b=1.0)
    return
end

@overlay method_table @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
    return
define void @good() #0 {
top:
  %0 = call {}*** @julia.get_pgcstack()
  %1 = bitcast {}*** %0 to {}**
  %current_task = getelementptr inbounds {}*, {}** %1, i64 -12
  %2 = bitcast {}** %current_task to i64*
  %world_age = getelementptr inbounds i64, i64* %2, i64 13
  ret void
}

define void @bad() #0 {
top:
  ; blah blah blah
  %43 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @jl_f_apply_type, {} addrspace(10)* null, {} addrspace(10)* %36, {} addrspace(10)* %38, {} addrspace(10)* %28, {} addrspace(10)* %40, {} addrspace(10)* %42)
  ret void
}

I bisected this to #44224, which is a bit weird, since that PR was back-ported to 1.8 too. It's possible that there's some interplay with other functionality only on 1.9, but that's unlikely as the PR was merged very shortly after 1.9 branched.

Anyway, the full MWE (reduced from GPUCompiler.jl):

# pieces lifted from GPUCompiler.jl

using LLVM

# helper type to deal with the transition from Context to ThreadSafeContext
export JuliaContext
if VERSION >= v"1.9.0-DEV.516"
    const JuliaContextType = ThreadSafeContext
else
    const JuliaContextType = Context
end
function JuliaContext()
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType()
    else
        isboxed_ref = Ref{Bool}()
        typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
                       (Any, Ptr{Bool}), Any, isboxed_ref))
        context(typ)
    end
end
function JuliaContext(f)
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType(f)
    else
        f(JuliaContext())
    end
end
if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx

# CI cache
using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams
struct CodeCache
    dict::IdDict{MethodInstance,Vector{CodeInstance}}
    CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
    cis = get!(cache.dict, mi, CodeInstance[])
    push!(cis, ci)
end
const GLOBAL_CI_CACHE = CodeCache()

using Core.Compiler: WorldView
function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
    Core.Compiler.get(wvc, mi, nothing) !== nothing
end
function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
    for ci in get!(wvc.cache.dict, mi, CodeInstance[])
        if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
            src = if ci.inferred isa Vector{UInt8}
                ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                       mi.def, C_NULL, ci.inferred)
            else
                ci.inferred
            end
            return ci
        end
    end
    return default
end
function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
    r = Core.Compiler.get(wvc, mi, nothing)
    r === nothing && throw(KeyError(mi))
    return r::CodeInstance
end
function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
    src = if ci.inferred isa Vector{UInt8}
        ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                mi.def, C_NULL, ci.inferred)
    else
        ci.inferred
    end
    Core.Compiler.setindex!(wvc.cache, ci, mi)
end
function ci_cache_populate(interp, mi, min_world, max_world)
    src = Core.Compiler.typeinf_ext_toplevel(interp, mi)
    wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
    @assert Core.Compiler.haskey(wvc, mi)
    ci = Core.Compiler.getindex(wvc, mi)
    if ci !== nothing && ci.inferred === nothing
        @static if VERSION >= v"1.9.0-DEV.1115"
            @atomic ci.inferred = src
        else
            ci.inferred = src
        end
    end
    return ci
end
function ci_cache_lookup(mi, min_world, max_world)
    wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
    ci = Core.Compiler.get(wvc, mi, nothing)
    if ci !== nothing && ci.inferred === nothing
        return nothing
    end
    return ci
end

# interpreter
using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams
struct GPUInterpreter <: AbstractInterpreter
    local_cache::Vector{InferenceResult}
    GPUInterpreter() = new(Vector{InferenceResult}())
end
Core.Compiler.InferenceParams(interp::GPUInterpreter) = InferenceParams()
Core.Compiler.OptimizationParams(interp::GPUInterpreter) = OptimizationParams()
Core.Compiler.get_world_counter(interp::GPUInterpreter) = Base.get_world_counter()
Core.Compiler.get_inference_cache(interp::GPUInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::GPUInterpreter) = WorldView(GLOBAL_CI_CACHE, Base.get_world_counter())

using Base.Experimental: @overlay, @MethodTable
@MethodTable(GLOBAL_METHOD_TABLE)

using Core.Compiler: OverlayMethodTable
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::GPUInterpreter) =
    OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
else
Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) =
    OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
end

# disable ir interpretation due to issues with overlay tables
@static if VERSION >= v"1.9.0-DEV.1248"
function Core.Compiler.concrete_eval_eligible(interp::GPUInterpreter,
    @nospecialize(f), result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret = @invoke Core.Compiler.concrete_eval_eligible(interp::AbstractInterpreter,
        f::Any, result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret === false && return nothing
    return ret
end
end

function codegen(sig)
    mi, _ = emit_julia(sig)

    JuliaContext() do ctx
        ir, _ = emit_llvm(mi; ctx)
        strip_debuginfo!(ir)
        string(ir)
    end
end

function emit_julia(sig)
    meth = which(sig)
    (ti, env) = ccall(:jl_type_intersection_with_env, Any,
                        (Any, Any), sig, meth.sig)::Core.SimpleVector
    meth = Base.func_for_method_checked(meth, ti, env)
    method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                            (Any, Any, Any, UInt), meth, ti, env, Base.get_world_counter())
    return method_instance, ()
end

function emit_llvm(@nospecialize(method_instance); ctx::JuliaContextType)
    InitializeAllTargets()
    InitializeAllTargetInfos()
    InitializeAllAsmPrinters()
    InitializeAllAsmParsers()
    InitializeAllTargetMCs()

    ir, compiled = irgen(method_instance; ctx)
    entry_fn = compiled[method_instance].specfunc
    entry = functions(ir)[entry_fn]
    return ir, (; entry, compiled)
end

function irgen(method_instance::Core.MethodInstance;
               ctx::JuliaContextType)
    mod, compiled = compile_method_instance(method_instance; ctx)
    entry_fn = compiled[method_instance].specfunc
    entry = functions(mod)[entry_fn]
    return mod, compiled
end

function compile_method_instance(method_instance::MethodInstance; ctx::JuliaContextType)
    world = Base.get_world_counter()
    interp =  GPUInterpreter()
    if ci_cache_lookup(method_instance, world, typemax(Cint)) === nothing
        ci_cache_populate(interp, method_instance, world, typemax(Cint))
    end

    method_instances = []
    function lookup_fun(mi, min_world, max_world)
        push!(method_instances, mi)
        ci_cache_lookup(mi, min_world, max_world)
    end
    lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
    params = Base.CodegenParams(; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
    GC.@preserve lookup_cb begin
        native_code = if VERSION >= v"1.9.0-DEV.516"
            mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
            ts_mod = ThreadSafeModule(mod; ctx)
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], ts_mod, Ref(params),  1)
        elseif VERSION >= v"1.9.0-DEV.115"
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], ctx, Ref(params),  1)
        elseif VERSION >= v"1.8.0-DEV.661"
            @assert ctx == JuliaContext()
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], Ref(params),  1)
        else
            @assert ctx == JuliaContext()
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, Base.CodegenParams, Cint),
                  [method_instance], params,  1)
        end
        @assert native_code != C_NULL
        llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
            ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
                  (Ptr{Cvoid},), native_code)
        else
            ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                  (Ptr{Cvoid},), native_code)
        end
        @assert llvm_mod_ref != C_NULL
        if VERSION >= v"1.9.0-DEV.516"
            llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
            llvm_mod = nothing
            llvm_ts_mod() do mod
                llvm_mod = mod
            end
        else
            llvm_mod = LLVM.Module(llvm_mod_ref)
        end
    end

    compiled = Dict()
    for mi in method_instances
        ci = ci_cache_lookup(mi, world, typemax(Cint))
        if ci !== nothing
            llvm_func_idx = Ref{Int32}(-1)
            llvm_specfunc_idx = Ref{Int32}(-1)
            ccall(:jl_get_function_id, Nothing,
                (Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}),
                native_code, ci, llvm_func_idx, llvm_specfunc_idx)
            llvm_func = if llvm_func_idx[] != -1
                llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
                                      (Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1)
                @assert llvm_func_ref != C_NULL
                LLVM.name(LLVM.Function(llvm_func_ref))
            else
                nothing
            end
            llvm_specfunc = if llvm_specfunc_idx[] != -1
                llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
                                        (Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1)
                @assert llvm_specfunc_ref != C_NULL
                LLVM.name(LLVM.Function(llvm_specfunc_ref))
            else
                nothing
            end
            compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc)
        end
    end

    return llvm_mod, compiled
end

############################################################################################

# compiler invocation

child(; kwargs...) = return
function parent()
    child(; a=1f0, b=1.0)
    return
end

# this override introduces a `jl_invoke`
@overlay GLOBAL_METHOD_TABLE @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
    return

println(codegen(Tuple{typeof(parent)}))

cc @aviatesk, you previously looked at GPU-related overlay issues

Putting this on the milestone as this breaks parts of CUDA.jl, but again, I'm also happy to adapt CUDA.jl (although dropping the overlay of throw_inexacterror will result in other code breaking, so that's not a great solution either).

@maleadt maleadt added regression Regression in behavior compared to a previous version compiler:codegen Generation of LLVM IR and native code gpu Affects running Julia on a GPU labels Jan 3, 2023
@maleadt maleadt added this to the 1.9 milestone Jan 3, 2023
@maleadt
Copy link
Member Author

maleadt commented Jan 3, 2023

I bisected this to #44224, which is a bit weird, since that PR was back-ported to 1.8 too. It's possible that there's some interplay with other functionality only on 1.9, but that's unlikely as the PR was merged very shortly after 1.9 branched.

Bisecting around on release-1.8, it looks like the backport of #44224 (5f0d551) did introduce this bug on there too, however the backport of #44515 (e1e02f6) fixed it again. The same applies to master, so it probably regressed again some time later. I'll bisect again tomorrow.

@aviatesk
Copy link
Member

aviatesk commented Jan 4, 2023

I guess another factor is #47157. The PR changes the internal representation of a call to kwfunc significantly, and indeed, it looks like this issue stems from eltype that is involved with the keyword argument handling:

on 1.8

julia> @code_typed interp=GPUInterpreter() parent()
CodeInfo(
1return nothing
) => Nothing

on 1.9 and master

julia> @code_typed interp=GPUInterpreter() parent()
CodeInfo(
1%1 = invoke Base._tuple_unique_fieldtypes(Tuple{Float32, Float64}::Any)::Core.SimpleVector%2 = Core._apply_iterate(Base.iterate, Base.afoldl, (Base.var"#51#52"(),), %1)::Any
│        Core.apply_type(Base.Pairs, Symbol, %2, Tuple{Symbol, Symbol}, NamedTuple{(:a, :b), Tuple{Float32, Float64}})::Type{Base.Pairs{Symbol, _A, Tuple{Symbol, Symbol}, NamedTuple{(:a, :b), Tuple{Float32, Float64}}}} where _A
└──      return nothing
) => Nothing

@vtjnash
Copy link
Member

vtjnash commented Jan 4, 2023

That _tuple_unique_fieldtypes function is marked @_total_meta so why is it not evaluated for you anymore in v1.9

@aviatesk
Copy link
Member

aviatesk commented Jan 4, 2023

julia> (Base.infer_effects(; interp=GPUInterpreter()) do
           Base._tuple_unique_fieldtypes(Tuple{Float32,Float64})
       end).nonoverlayed
false

probably because of dynamic dispatches.

@aviatesk
Copy link
Member

aviatesk commented Jan 4, 2023

I tried to improve the effects with the following diff, but it looks like nonoverlayed is necessarily tainted after all, because Core.throw_inexacterror is deeply involved with integer conversion and thus is also involved with keyword arg handling.

diff --git a/base/essentials.jl b/base/essentials.jl
index 2093c792dd..61e9b1b258 100644
--- a/base/essentials.jl
+++ b/base/essentials.jl
@@ -386,8 +386,10 @@ function isvatuple(@nospecialize(t))
     return false
 end
 
-unwrapva(t::Core.TypeofVararg) = isdefined(t, :T) ? t.T : Any
-unwrapva(@nospecialize(t)) = t
+function unwrapva(@nospecialize(t))
+    isa(t, Core.TypeofVararg) || return t
+    return isdefined(t, :T) ? t.T : Any
+end
 
 function unconstrain_vararg_length(va::Core.TypeofVararg)
     # construct a new Vararg type where its length is unconstrained,
diff --git a/base/tuple.jl b/base/tuple.jl
index d59c923921..c1d96fd2ca 100644
--- a/base/tuple.jl
+++ b/base/tuple.jl
@@ -208,15 +208,15 @@ eltype(t::Type{<:Tuple}) = _compute_eltype(t)
 function _tuple_unique_fieldtypes(@nospecialize t)
     @_total_meta
     types = IdSet()
-    t´ = unwrap_unionall(t)
     # Given t = Tuple{Vararg{S}} where S<:Real, the various
     # unwrapping/wrapping/va-handling here will return Real
     if t isa Union
-        union!(types, _tuple_unique_fieldtypes(rewrap_unionall(t´.a, t)))
-        union!(types, _tuple_unique_fieldtypes(rewrap_unionall(t´.b, t)))
+        union!(types, _tuple_unique_fieldtypes(rewrap_unionall(t.a, t)))
+        union!(types, _tuple_unique_fieldtypes(rewrap_unionall(t.b, t)))
     else
         r = Union{}
-        for ti in (t´::DataType).parameters
+        t´ = unwrap_unionall(t)::DataType
+        for ti in t´.parameters
             r = push!(types, rewrap_unionall(unwrapva(ti), t))
         end
     end

@aviatesk
Copy link
Member

aviatesk commented Jan 4, 2023

@maleadt Is it okay to call the original non-overlayed version of Core.throw_inexacterror during GPU-code compilation? Currently if a piece of code involves Core.throw_inexacterror, then it's not eligible for concrete evaluation immediately, which can loose some inference accuracy that we usually have for the native compilation.

If we're fine with seeing the overlayed behavior of Core.throw_inexacterror during the actual runtime only, then you can do something like:

# Essentially, we want to see the overlayed behavior of these functions at actual runtime only,
# and we're fine with seeing their original non-overlayed behavior during compilation when they
# are concrete-evaluated.
const RUNTIME_ONLY_OVERLAYS = Any[
    Base._tuple_unique_fieldtypes # necessary for accurate keyword func handling
]

function Core.Compiler.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
    arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState,
    max_methods::Int)
    ret = @invoke Core.Compiler.abstract_call_known(interp::Core.Compiler.AbstractInterpreter, f::Any,
        arginfo::Core.Compiler.ArgInfo, si::Core.Compiler.StmtInfo, sv::Core.Compiler.InferenceState,
        max_methods::Int)
    if f in RUNTIME_ONLY_OVERLAYS
        new_effects = Core.Compiler.Effects(ret.effects; nonoverlayed=true)
        ret = Core.Compiler.CallMeta(ret.rt, new_effects, ret.info)
    end
    return ret
end

@aviatesk
Copy link
Member

aviatesk commented Jan 4, 2023

Well, upon a further thought, how does GPUCompiler call overlayed Core.throw_inexacterror with it not being inlined? Does it use some mechanism like CassetteOverlay?

@maleadt
Copy link
Member Author

maleadt commented Jan 4, 2023

Well, upon a further thought, how does GPUCompiler call overlayed Core.throw_inexacterror with it not being inlined? Does it use some mechanism like CassetteOverlay?

GPUCompiler.jl uses jl_create_native, which emits all functions of the call chain into a single module

@maleadt
Copy link
Member Author

maleadt commented Mar 13, 2023

This has surfaced again, ref JuliaGPU/GPUCompiler.jl#399

Updated MWE:

# pieces lifted from GPUCompiler.jl

using LLVM

# helper type to deal with the transition from Context to ThreadSafeContext
export JuliaContext
if VERSION >= v"1.9.0-DEV.516"
    const JuliaContextType = ThreadSafeContext
else
    const JuliaContextType = Context
end
function JuliaContext()
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType()
    else
        isboxed_ref = Ref{Bool}()
        typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
                       (Any, Ptr{Bool}), Any, isboxed_ref))
        context(typ)
    end
end
function JuliaContext(f)
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType(f)
    else
        f(JuliaContext())
    end
end
if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx

# CI cache
using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams
struct CodeCache
    dict::IdDict{MethodInstance,Vector{CodeInstance}}
    CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
    cis = get!(cache.dict, mi, CodeInstance[])
    push!(cis, ci)
end
const GLOBAL_CI_CACHE = CodeCache()

using Core.Compiler: WorldView
function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
    Core.Compiler.get(wvc, mi, nothing) !== nothing
end
function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
    for ci in get!(wvc.cache.dict, mi, CodeInstance[])
        if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
            src = if ci.inferred isa Vector{UInt8}
                ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                       mi.def, C_NULL, ci.inferred)
            else
                ci.inferred
            end
            return ci
        end
    end
    return default
end
function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
    r = Core.Compiler.get(wvc, mi, nothing)
    r === nothing && throw(KeyError(mi))
    return r::CodeInstance
end
function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
    src = if ci.inferred isa Vector{UInt8}
        ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                mi.def, C_NULL, ci.inferred)
    else
        ci.inferred
    end
    Core.Compiler.setindex!(wvc.cache, ci, mi)
end
function ci_cache_populate(interp, mi, min_world, max_world)
    src = Core.Compiler.typeinf_ext_toplevel(interp, mi)
    wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
    @assert Core.Compiler.haskey(wvc, mi)
    ci = Core.Compiler.getindex(wvc, mi)
    if ci !== nothing && ci.inferred === nothing
        @static if VERSION >= v"1.9.0-DEV.1115"
            @atomic ci.inferred = src
        else
            ci.inferred = src
        end
    end
    return ci
end
function ci_cache_lookup(mi, min_world, max_world)
    wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
    ci = Core.Compiler.get(wvc, mi, nothing)
    if ci !== nothing && ci.inferred === nothing
        return nothing
    end
    return ci
end

# interpreter
using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams
struct GPUInterpreter <: AbstractInterpreter
    local_cache::Vector{InferenceResult}
    GPUInterpreter() = new(Vector{InferenceResult}())
end
Core.Compiler.InferenceParams(interp::GPUInterpreter) = InferenceParams()
Core.Compiler.OptimizationParams(interp::GPUInterpreter) = OptimizationParams()
Core.Compiler.get_world_counter(interp::GPUInterpreter) = Base.get_world_counter()
Core.Compiler.get_inference_cache(interp::GPUInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::GPUInterpreter) = WorldView(GLOBAL_CI_CACHE, Base.get_world_counter())

using Base.Experimental: @overlay, @MethodTable
@MethodTable(GLOBAL_METHOD_TABLE)

using Core.Compiler: OverlayMethodTable
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::GPUInterpreter) =
    OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
else
Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) =
    OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
end

# disable ir interpretation due to issues with overlay tables
@static if VERSION >= v"1.9.0-DEV.1248"
function Core.Compiler.concrete_eval_eligible(interp::GPUInterpreter,
    @nospecialize(f), result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret = @invoke Core.Compiler.concrete_eval_eligible(interp::AbstractInterpreter,
        f::Any, result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret === false && return nothing
    return ret
end
end

function codegen(sig)
    mi, _ = emit_julia(sig)

    JuliaContext() do ctx
        ir, _ = emit_llvm(mi; ctx)
        strip_debuginfo!(ir)
        string(ir)
    end
end

function emit_julia(sig)
    meth = which(sig)
    (ti, env) = ccall(:jl_type_intersection_with_env, Any,
                        (Any, Any), sig, meth.sig)::Core.SimpleVector
    meth = Base.func_for_method_checked(meth, ti, env)
    method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                            (Any, Any, Any, UInt), meth, ti, env, Base.get_world_counter())
    return method_instance, ()
end

function emit_llvm(@nospecialize(method_instance); ctx::JuliaContextType)
    InitializeAllTargets()
    InitializeAllTargetInfos()
    InitializeAllAsmPrinters()
    InitializeAllAsmParsers()
    InitializeAllTargetMCs()

    ir, compiled = irgen(method_instance; ctx)
    entry_fn = compiled[method_instance].specfunc
    entry = functions(ir)[entry_fn]
    return ir, (; entry, compiled)
end

function irgen(method_instance::Core.MethodInstance;
               ctx::JuliaContextType)
    mod, compiled = compile_method_instance(method_instance; ctx)
    entry_fn = compiled[method_instance].specfunc
    entry = functions(mod)[entry_fn]
    return mod, compiled
end

const _method_instances = Ref{Any}()
function _lookup_fun(mi, min_world, max_world)
    push!(_method_instances[], mi)
    ci_cache_lookup(mi, min_world, max_world)
end

function compile_method_instance(method_instance::MethodInstance; ctx::JuliaContextType)
    world = Base.get_world_counter()
    interp =  GPUInterpreter()
    if ci_cache_lookup(method_instance, world, typemax(Cint)) === nothing
        ci_cache_populate(interp, method_instance, world, typemax(Cint))
    end

    method_instances = []
    _method_instances[] = method_instances
    lookup_cb = @cfunction(_lookup_fun, Any, (Any, UInt, UInt))
    params = Base.CodegenParams(; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
    GC.@preserve lookup_cb begin
        native_code = if VERSION >= v"1.9.0-DEV.516"
            mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
            ts_mod = ThreadSafeModule(mod; ctx)
            if VERSION >= v"1.9.0-beta4.23"
                ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams},
                   Cint, Cint, Cint, Csize_t),
                  [method_instance], ts_mod, Ref(params),
                  #=extern policy=# 1, #=imaging mode=# 0,  #=external linkage=# 0,
                  Base.get_world_counter())
            else
                ccall(:jl_create_native, Ptr{Cvoid},
                    (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
                    [method_instance], ts_mod, Ref(params), #=extern policy=# 1)
            end
        elseif VERSION >= v"1.9.0-DEV.115"
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], ctx, Ref(params),  1)
        elseif VERSION >= v"1.8.0-DEV.661"
            @assert ctx == JuliaContext()
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], Ref(params),  1)
        else
            @assert ctx == JuliaContext()
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, Base.CodegenParams, Cint),
                  [method_instance], params,  1)
        end
        @assert native_code != C_NULL
        llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
            ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
                  (Ptr{Cvoid},), native_code)
        else
            ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                  (Ptr{Cvoid},), native_code)
        end
        @assert llvm_mod_ref != C_NULL
        if VERSION >= v"1.9.0-DEV.516"
            llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
            llvm_mod = nothing
            llvm_ts_mod() do mod
                llvm_mod = mod
            end
        else
            llvm_mod = LLVM.Module(llvm_mod_ref)
        end
    end

    compiled = Dict()
    for mi in method_instances
        ci = ci_cache_lookup(mi, world, typemax(Cint))
        if ci !== nothing
            llvm_func_idx = Ref{Int32}(-1)
            llvm_specfunc_idx = Ref{Int32}(-1)
            ccall(:jl_get_function_id, Nothing,
                (Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}),
                native_code, ci, llvm_func_idx, llvm_specfunc_idx)
            llvm_func = if llvm_func_idx[] != -1
                llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
                                      (Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1)
                @assert llvm_func_ref != C_NULL
                LLVM.name(LLVM.Function(llvm_func_ref))
            else
                nothing
            end
            llvm_specfunc = if llvm_specfunc_idx[] != -1
                llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
                                        (Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1)
                @assert llvm_specfunc_ref != C_NULL
                LLVM.name(LLVM.Function(llvm_specfunc_ref))
            else
                nothing
            end
            compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc)
        end
    end

    return llvm_mod, compiled
end

############################################################################################

# compiler invocation

child(; kwargs...) = return
function parent()
    child(; a=1f0, b=1.0)
    return
end

# this override introduces a `jl_invoke`
@overlay GLOBAL_METHOD_TABLE @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
    return

println(codegen(Tuple{typeof(parent)}))

Bisected to #48838; @aviatesk is taking a look. In the mean time, #48838 has been reverted on the backports branch.

@maleadt maleadt reopened this Mar 13, 2023
@maleadt maleadt removed this from the 1.9 milestone Mar 13, 2023
aviatesk added a commit that referenced this issue Mar 13, 2023
This test case is particularily crafted for #48097, which has occurred
a couple of times.
This should help GPUCompiler keep working with the nightly.

closes #48097
@maleadt
Copy link
Member Author

maleadt commented Mar 13, 2023

This doesn't reproduce on master. I bisected the 'fix' there to #48116, which was the original workaround to this issue, implying that #48838 only made release-1.9 regress and not master. After checking with @aviatesk, it should be fine to just keep #48838 reverted for 1.9.

@maleadt maleadt closed this as completed Mar 13, 2023
aviatesk added a commit that referenced this issue Mar 13, 2023
This test case is particularily crafted for #48097, which has occurred
a couple of times. Although the issue doesn't reproduce on master,
having it here should help GPUCompiler keep working with the nightly.

closes #48097
aviatesk added a commit that referenced this issue Mar 14, 2023
This test case is particularily crafted for #48097, which has occurred
a couple of times. Although the issue doesn't reproduce on master,
having it here should help GPUCompiler keep working with the nightly.

closes #48097
maleadt pushed a commit that referenced this issue Mar 14, 2023
This test case is particularily crafted for #48097, which has occurred
a couple of times. Although the issue doesn't reproduce on master,
having it here should help GPUCompiler keep working with the nightly.

closes #48097
aviatesk added a commit that referenced this issue Apr 26, 2023
When we skip inference of a call on `throw` block, previously we only
checked if the call is overlayed or not. But the call may call an
overlayed method internally, thus we need to conservatively taint the
`:nonoverlayed` bit when `interp` uses overlay method table.

Nevertheless this will not introduce any regressions on GPUCompiler
stack (like #48097), since it defines `InferenceParams(::GPUInterpreter)`
overload to turn off `unoptimize_throw_blocks`
<https://github.com/JuliaGPU/GPUCompiler.jl/blob/d5086fb3d93bbc4795a96f6f1457898af46a24cb/src/interface.jl#L272>
aviatesk added a commit that referenced this issue Apr 27, 2023
…ll (#49518)

When we skip inference of a call on `throw` block, previously we only
checked if the call is overlayed or not. But the call may call an
overlayed method internally, thus we need to conservatively taint the
`:nonoverlayed` bit when `interp` uses overlay method table.

Nevertheless this will not introduce any regressions on GPUCompiler
stack (like #48097), since it defines `InferenceParams(::GPUInterpreter)`
overload to turn off `unoptimize_throw_blocks`
<https://github.com/JuliaGPU/GPUCompiler.jl/blob/d5086fb3d93bbc4795a96f6f1457898af46a24cb/src/interface.jl#L272>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:codegen Generation of LLVM IR and native code gpu Affects running Julia on a GPU regression Regression in behavior compared to a previous version
Projects
None yet
3 participants