Skip to content

Commit

Permalink
Adapt to GPUCompiler 0.18
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Mar 15, 2023
1 parent 2ccf4b7 commit db587cf
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
CEnum = "0.4"
EnzymeCore = "0.2.1"
Enzyme_jll = "0.0.51"
GPUCompiler = "0.16.7, 0.17"
GPUCompiler = "0.18"
LLVM = "4.14"
ObjectFile = "0.3"
julia = "1.6"
42 changes: 22 additions & 20 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2441,7 +2441,7 @@ function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt)
else
ctx = ctxToThreadSafe[ctx]
end
funcspec = FunctionSpec(f, tt, #=kernel=# false, #=name=# nothing)
funcspec = FunctionSpec(typeof(f), tt; kernel=false)

# 3) Use the MI to create the correct augmented fwd/reverse
# TODO:
Expand Down Expand Up @@ -5673,15 +5673,19 @@ import .Interpreter: isKWCallSignature
"""
Create the `FunctionSpec` pair, and lookup the primal return type.
"""
@inline function fspec(@nospecialize(F), @nospecialize(TT))
# Entry for the cache look-up
adjoint = FunctionSpec(F, TT, #=kernel=# false, #=name=# nothing)

@inline function fspec(@nospecialize(F), @nospecialize(TT); world=nothing)
# primal function. Inferred here to get return type
_tt = (TT.parameters...,)

primal_tt = Tuple{map(eltype, _tt)...}
primal = FunctionSpec(F, primal_tt, #=kernel=# false, #=name=# nothing)
if world === nothing
world = GPUCompiler.get_world(F, primal_tt)
end

primal = FunctionSpec(F, primal_tt; world, kernel=false, name=nothing)

# Entry for the cache look-up
adjoint = FunctionSpec(F, TT; world, kernel=false, name=nothing)

return primal, adjoint
end
Expand Down Expand Up @@ -8118,18 +8122,25 @@ const cache_lock = ReentrantLock()
end
end

@inline function thunk(f::F,df::DF, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {F, DF, A<:Annotation, TT, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit}
_tt = (TT.parameters...,)
primal_tt = Tuple{map(eltype, _tt)...}

world = GPUCompiler.get_world(F, primal_tt)

genthunk(Val(world), F, f, df, A, TT, Val(Mode), Val(ModifiedBetween), Val(width), Val(ReturnPrimal), Val(ShadowInit))
end


@generated function genthunk(::Type{F}, f::Fn, df::DF, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{ModifiedBetween}, ::Val{width}, ::Val{specid}, ::Val{ReturnPrimal}, ::Val{ShadowInit}) where {F, Fn, DF, A<:Annotation, TT, Mode, ModifiedBetween, width, specid, ReturnPrimal, ShadowInit}
primal, adjoint = fspec(F, TT)
@generated function genthunk(::Val{World}, ::Type{F}, f::Fn, df::DF, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{ModifiedBetween}, ::Val{width}, ::Val{ReturnPrimal}, ::Val{ShadowInit}) where {World, F, Fn, DF, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit}
primal, adjoint = fspec(F, TT; world=World)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(adjoint, Mode, width, A, true, DF != Nothing, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit)
job = Compiler.CompilerJob(target, primal, params)

sig = Tuple{F, map(eltype, TT.parameters)...}

# world = ...

interp = Core.Compiler.NativeInterpreter(job.source.world)

# TODO check compile return here, early
Expand Down Expand Up @@ -8168,7 +8179,7 @@ end
# invalidations of the primal, which is managed by GPUCompiler.


thunk = cached_compilation(job, hash(hash(hash(hash(adjoint, hash(rt, UInt64(Mode))), UInt64(width)), hash(ModifiedBetween)), UInt64(ReturnPrimal)), specid)::Thunk
thunk = cached_compilation(job, hash(hash(hash(hash(adjoint, hash(rt, UInt64(Mode))), UInt64(width)), hash(ModifiedBetween)), UInt64(ReturnPrimal)), World)::Thunk
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
TapeType = thunk.TapeType
AugT = AugmentedForwardThunk{F, rt, adjoint.tt, Val{width} , DF, Val(ReturnPrimal), TapeType}
Expand All @@ -8193,16 +8204,7 @@ end
end
end

@inline function thunk(f::F,df::DF, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}=Val(false), ::Val{ShadowInit}=Val(false)) where {F, DF, A<:Annotation, TT, Mode, width, ModifiedBetween, ReturnPrimal, ShadowInit}
primal, adjoint = fspec(Core.Typeof(f), TT)
target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(adjoint, Mode, width, A, true, DF != Nothing, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit)
job = Compiler.CompilerJob(target, primal, params)

specid = GPUCompiler.specialization_id(job)

genthunk(Core.Typeof(f), f, df, A, TT, Val(Mode), Val(ModifiedBetween), Val(width), Val(specid), Val(ReturnPrimal), Val(ShadowInit))
end

import GPUCompiler: deferred_codegen_jobs

Expand Down

0 comments on commit db587cf

Please sign in to comment.