diff --git a/Project.toml b/Project.toml index 4240461703..7ed9ab945b 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,7 @@ ArrayInterface = "7.17.1" CEnum = "0.5" CUDA = "5.5" Downloads = "1.6" -Enzyme = "0.13.22" +Enzyme = "0.13.28" EnzymeCore = "0.8.8" GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index 900c360be7..a5051b3fa4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -19,7 +19,10 @@ import ..Reactant: ancestor, TracedType -@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) +@inline function traced_getfield(@nospecialize(obj), field) + return Base.getfield(obj, field) +end + @inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T} (isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field) return Base.getindex(obj, field) @@ -905,37 +908,75 @@ function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=fals end fname = gensym(Symbol(Symbol(f), :_reactant)) - expr = :(function $(fname)(args...) - $( - # if `f` is a closure, then prepend the closure into `args` - # the closure fields will be correctly extracted from it as the tracer has already passed through it - if !(closure_ty <: Nothing) - :(args = ($fnwrap, args...)) - end - ) + + body = quote $(flatten_code...) $(xla_call_code) $(sync_call) $(unflatten_code...) return result - end) + end - body = expr.args[2] - return register_thunk(fname, body) + return register_thunk(fname, Tuple{map(Core.Typeof, args)...}, body, f, isclosure) end # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() -struct Thunk{tag} end +struct Thunk{FTy,tag,IsClosure,ArgTypes} + f::FTy +end + +struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end -@generated function (thunk::Thunk{tag})(args...) where {tag} - return __thunk_body_cache[tag] +function Base.showerror( + io::IO, ece::MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes} +) where {FTy,tag,ArgTypes,FoundTypes,IsClosure} + print( + io, + "\nThe Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure})` exists, but no method is defined for this combination of argument types.", + ) + print( + io, + "\nYou passed in arguments with types\n\t(" * + join(FoundTypes.parameters, ", ") * + ")", + ) + return print( + io, + "\nHowever the method you are calling was compiled for arguments with types\n\t(" * + join(ArgTypes.parameters, ", ") * + ")", + ) end -function register_thunk(tag, body) +@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure})( + args... +) where {FTy,tag,ArgTypes,IsClosure} + FoundTypes = Tuple{args...} + if ArgTypes != FoundTypes + return quote + throw( + $(MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}()) + ) + end + end + body = __thunk_body_cache[tag] + if IsClosure + return quote + args = (thunk.f, args...) + $body + end + else + return body + end +end + +function register_thunk( + tag::Symbol, @nospecialize(argtys::Type), body::Expr, @nospecialize(f), isclosure::Bool +) __thunk_body_cache[tag] = body - return Thunk{tag}() + return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f) end end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index d68041d2ac..06d8883459 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -66,6 +66,7 @@ end world, false, #=forward_rules=# false, #=reverse_rules=# + false, #=inactive_rules=# false, #=broadcast_rewrite=# set_reactant_abi, ) @@ -81,7 +82,8 @@ else REACTANT_METHOD_TABLE, world, false, #=forward_rules=# - false, #=forward_rules=# + false, #=reverse_rules=# + false, #=inactive_rules=# false, #=broadcast_rewrite=# set_reactant_abi, ) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 7359bca97d..2f58fd633d 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -15,7 +15,7 @@ using NNlib, Reactant, Enzyme @testset "Activation: $act" for act in ( identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 ) - f_compile = Reactant.compile(sumabs2, (act, x_act)) + f_compile = Reactant.compile(sumabs2, (act, x_act_ca)) y_simple = sumabs2(act, x_act) y_compile = f_compile(act, x_act_ca)