From 0296364d87c8b954994826204a6c4629f04c7ecd Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 2 Oct 2024 11:28:07 +0200 Subject: [PATCH] Mock Enzyme plugin --- test/plugin_testsetup.jl | 142 ++++++++++++++++++++++++++++++++++++++- test/ptx_tests.jl | 21 ++++++ 2 files changed, 161 insertions(+), 2 deletions(-) diff --git a/test/plugin_testsetup.jl b/test/plugin_testsetup.jl index 35a4abdc..9005479f 100644 --- a/test/plugin_testsetup.jl +++ b/test/plugin_testsetup.jl @@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end import GPUCompiler: abstract_call_known, GPUInterpreter import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo, StmtInfo, AbsIntState, EFFECTS_TOTAL, - MethodResultPure + MethodResultPure, CallInfo, IRCode function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) @@ -69,5 +69,143 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec return nothing end +struct MockEnzymeMeta end -end \ No newline at end of file +# Having to define this function is annoying +# introduce `abstract type InferenceMeta` +function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo) + return nothing +end + +function autodiff end + +import GPUCompiler: DeferredCallInfo +struct AutodiffCallInfo <: CallInfo + rt + info::DeferredCallInfo +end + +function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) + (; fargs, argtypes) = arginfo + + if f === autodiff + if length(argtypes) <= 1 + @static if VERSION < v"1.11.0-" + return CallMeta(Union{}, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + end + end + + other_fargs = fargs === nothing ? nothing : fargs[2:end] + other_arginfo = ArgInfo(other_fargs, argtypes[2:end]) + call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods) + callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info) + + # Real Enzyme must compute `rt` and `exct` according to enzyme semantics + # and likely perform a unwrapping of fargs... + rt = call.rt + + # TODO: Edges? Effects? + @static if VERSION < v"1.11.0-" + # Can't use call.effects since otherwise this call might be just replaced with rt + return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo)) + else + return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo)) + end + end + + return nothing +end + +import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature + +# We really need a Compiler stdlib +Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i) +Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i) + +const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8 +function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int, + stmt::Expr, info::AutodiffCallInfo, flag::FlagType, + sig::Signature, state::InliningState) + + # Goal: + # The IR we want to inline here is: + # unpack the args .. + # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...) + # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...) + + # 0. Obtain primal mi from DeferredCallInfo + # TODO: remove this code duplication + deferred_info = info.info + minfo = deferred_info.info + results = minfo.results + if length(results.matches) != 1 + return nothing + end + match = only(results.matches) + + # lookup the target mi with correct edge tracking + # TODO: Effects? + case = Core.Compiler.compileable_specialization( + match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info) + @assert case isa Core.Compiler.InvokeCase + @assert stmt.head === :call + + # Now create the IR we want to inline + ir = Core.Compiler.IRCode() # contains a placeholder + args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args... + idx = 0 + + # 0. Enzyme proper: Desugar args + primal_args = args + primal_argtypes = match.spec_types.parameters[2:end] + + adjoint_rt = info.rt + adjoint_args = args # TODO + adjoint_argtypes = primal_argtypes + + # 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call + expr = Expr(:foreigncall, + "extern gpuc.lookup", + Ptr{Cvoid}, + Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype + 0, + QuoteNode(:llvmcall), + deferred_info.meta, + case.invoke, + primal_args... + ) + ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid})) + + # 2. Call to magic `__autodiff` + expr = Expr(:foreigncall, + "extern __autodiff", + adjoint_rt, + Core.svec(Any, Ptr{Cvoid}, adjoint_argtypes...), + 0, + QuoteNode(:llvmcall), + ptr, + adjoint_args... + ) + ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt)) + + # Finally replace placeholder return + ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret) + ir[Core.SSAValue(1)][:type] = Ptr{Cvoid} + + ir = Core.Compiler.compact!(ir) + + # which mi to use here? + # push inlining todos + # TODO: Effects + # aviatesk mentioned using inlining_policy instead... + itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects()) + @assert itodo.linear_inline_eligible + push!(todo, (stmt_idx=>itodo)) + + return nothing +end + +end #module \ No newline at end of file diff --git a/test/ptx_tests.jl b/test/ptx_tests.jl index 6476afde..fdc99140 100644 --- a/test/ptx_tests.jl +++ b/test/ptx_tests.jl @@ -504,4 +504,25 @@ end ir = sprint(io->PTX.code_llvm(io, kernel_inline, Tuple{Ptr{Int64}, Int64}, meta=Plugin.NeverInlineMeta())) @test occursin("call fastcc i64 @julia_inline", ir) end + +@testset "Mock Enzyme" begin + function f(x) + x^2 + end + + function kernel(a, x) + y = Plugin.autodiff(f, x) + unsafe_store!(a, y) + nothing + end + + + @show PTX.code_typed(kernel, Tuple{Ptr{Float64}, Float64}, meta=Plugin.MockEnzymeMeta()) + + # FIXME: the fact that meta is necessary here almost invalidates that extension mechanism + # we somehow need to be able to add this kind of "autodiff" abs int handling. + ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, meta=Plugin.MockEnzymeMeta())) + @test occursin("call double @__autodiff", ir) +end + end #testitem