Skip to content

Commit

Permalink
fix deferred (EnzymeAD#1426)
Browse files Browse the repository at this point in the history
* fix deferred

* fixup

* no escaping
  • Loading branch information
wsmoses authored May 11, 2024
1 parent 9668147 commit 9ca0d9d
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 8 deletions.
38 changes: 33 additions & 5 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ result, ∂v, ∂A
(7.26, 2.2, [3.3])
```
"""
@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs}
@inline function autodiff_deferred_thunk(mode::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, tt::Type{TapeType}, fa::Type{FA}, a2::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs}
@assert RABI == FFIABI
width = if Width == 0
w = same_or_one(1, args...)
Expand All @@ -819,10 +819,38 @@ result, ∂v, ∂A
primal_tt = Tuple{map(eltype, args)...}
world = codegen_world_age(eltype(FA), primal_tt)

primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr)
adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType}(adjoint_ptr)
primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(Compiler.remove_innerty(A2)), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)

RT = if A2 <: Duplicated && width != 1
if A2 isa UnionAll
BatchDuplicated{T, width} where T
else
BatchDuplicated{eltype(A2), width}
end
elseif A2 <: DuplicatedNoNeed && width != 1
if A2 isa UnionAll
BatchDuplicatedNoNeed{T, width} where T
else
BatchDuplicatedNoNeed{eltype(A2), width}
end
else
A2
end

rt = if RT isa UnionAll
@static if VERSION < v"1.8-"
throw(MethodError(autodiff_deferred_thunk, (mode, tt, fa, a2, args...)))
else
RT{Core.Compiler.return_type(Tuple{eltype(FA), map(eltype, args)...})}
end
else
@assert RT isa DataType
RT
end

aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, rt, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr)
adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, rt, TT, Val{width}, TapeType}(adjoint_ptr)
aug_thunk, adj_thunk
end

Expand Down
38 changes: 36 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2361,6 +2361,7 @@ function zero_allocation(B::LLVM.IRBuilder, jlType, LLVMType, obj, AlignedSize,
wrapper_f = LLVM.Function(mod, "zeroType", LLVM.FunctionType(LLVM.VoidType(), [value_type(obj), T_int8, value_type(Size)]))
push!(function_attributes(wrapper_f), StringAttribute("enzyme_math", "enzyme_zerotype"))
push!(function_attributes(wrapper_f), StringAttribute("enzyme_inactive"))
push!(function_attributes(wrapper_f), StringAttribute("enzyme_no_escaping_allocation"))
push!(function_attributes(wrapper_f), EnumAttribute("alwaysinline", 0))
push!(function_attributes(wrapper_f), EnumAttribute("nofree", 0))
push!(function_attributes(wrapper_f), EnumAttribute("argmemonly", 0))
Expand Down Expand Up @@ -2774,7 +2775,22 @@ function annotate!(mod, mode)
end
end

for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw")
for fname in ("julia.get_pgcstack", "julia.ptls_states", "jl_get_ptls_states", "julia.safepoint", "ijl_throw", "julia.pointer_from_objref",
"ijl_array_grow_end", "jl_array_grow_end", "ijl_array_del_end", "jl_array_del_end",
"ijl_array_grow_beg", "jl_array_grow_beg", "ijl_array_del_beg", "jl_array_del_beg",
"ijl_array_grow_at", "jl_array_grow_at",
"ijl_array_del_at", "jl_array_del_at",
"ijl_pop_handler", "jl_pop_handler",
"ijl_push_handler", "jl_push_handler",
"ijl_module_name", "jl_module_name",
"ijl_restore_excstack", "jl_restore_excstack",
"julia.except_enter",
"ijl_get_nth_field_checked", "jl_get_nth_field_checked",
"jl_egal__unboxed",
"ijl_reshape_array", "jl_reshape_array",
"ijl_eqtable_get", "jl_eqtable_get",
"jl_gc_run_pending_finalizers",
)
if haskey(fns, fname)
fn = fns[fname]
push!(function_attributes(fn), no_escaping_alloc)
Expand Down Expand Up @@ -2826,7 +2842,7 @@ function annotate!(mod, mode)
continue
end
LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0))
LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, no_escaping_alloc)
LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc)
if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash"))
LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0))
end
Expand Down Expand Up @@ -4324,6 +4340,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
if kind(prev) == kind(StringAttribute("enzyme_inactive"))
push!(attributes, prev)
end
if kind(prev) == kind(StringAttribute("enzyme_no_escaping_allocation"))
push!(attributes, prev)
end
end

if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0
Expand Down Expand Up @@ -4794,6 +4813,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
EnumAttribute("speculatable", 0),
StringAttribute("enzyme_shouldrecompute"),
StringAttribute("enzyme_inactive"),
StringAttribute("enzyme_no_escaping_allocation")
])
continue
end
Expand All @@ -4815,6 +4835,20 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
for bb in blocks(llvmfn)
for inst in instructions(bb)
if isa(inst, LLVM.CallInst)
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation"))
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive"))
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree"))
end
end
end
continue
end
if func === typeof(Base.match)
handleCustom(llvmfn, "base_match", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")], false, false)
for bb in blocks(llvmfn)
for inst in instructions(bb)
if isa(inst, LLVM.CallInst)
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("no_escaping_allocation"))
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_inactive"))
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("nofree"))
end
Expand Down
8 changes: 8 additions & 0 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,16 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls)
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive)
nofree = LLVM.EnumAttribute("nofree")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree)
no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc)
end
if funclib == Base.tuple && length(operands(inst)) == 4+1+1 && Base.isconcretetype(GT) && Enzyme.Compiler.guaranteed_const_nongen(GT, world)
inactive = LLVM.StringAttribute("enzyme_inactive", "")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive)
nofree = LLVM.EnumAttribute("nofree")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree)
no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc)
end
end
end
Expand Down Expand Up @@ -515,6 +519,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls)
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive)
nofree = LLVM.EnumAttribute("nofree")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree)
no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc)
end
end
end
Expand Down Expand Up @@ -596,6 +602,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls)
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), inactive)
nofree = LLVM.EnumAttribute("nofree")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), nofree)
no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation")
LLVM.API.LLVMAddCallSiteAttribute(inst, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc)
end
end
end
Expand Down
34 changes: 33 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,16 @@ end
Const{typeof(dot)}, Active, Duplicated{typeof(thunk_A)}
)
@test Tuple{Float64,Float64} === TapeType
Ret = if VERSION < v"1.8-"
Active{Float64}
else
Active
end
fwd, rev = Enzyme.autodiff_deferred_thunk(
ReverseSplitWithPrimal,
TapeType,
Const{typeof(dot)},
Active{Float64},
Ret,
Duplicated{typeof(thunk_A)}
)
tape, primal, _ = fwd(Const(dot), dup)
Expand All @@ -335,6 +340,33 @@ end
@test all(dA .== [6.0, 10.0])
@test all(dA .== def_dA)
@test all(dA .== thunk_dA)

@static if VERSION < v"1.8-"
else
function kernel(len, A)
for i in 1:len
A[i] *= A[i]
end
end

A = Array{Float64}(undef, 64)
dA = Array{Float64}(undef, 64)

A .= (1:1:64)
dA .= 1

function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT}
TapeType = Enzyme.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
forward, reverse = Enzyme.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
forward(Const(f), Const(ctx), args...)[1]
return nothing
end

ModifiedBetween = Val((false, false, true))

aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA))
end

end

@testset "Simple Complex tests" begin
Expand Down

0 comments on commit 9ca0d9d

Please sign in to comment.