Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ ArrayInterface = "7.17.1"
CEnum = "0.5"
CUDA = "5.5"
Downloads = "1.6"
Enzyme = "0.13.22"
Enzyme = "0.13.28"
EnzymeCore = "0.8.8"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
Expand Down
75 changes: 58 additions & 17 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ import ..Reactant:
ancestor,
TracedType

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline function traced_getfield(@nospecialize(obj), field)
return Base.getfield(obj, field)
end

@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
return Base.getindex(obj, field)
Expand Down Expand Up @@ -905,37 +908,75 @@ function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=fals
end

fname = gensym(Symbol(Symbol(f), :_reactant))
expr = :(function $(fname)(args...)
$(
# if `f` is a closure, then prepend the closure into `args`
# the closure fields will be correctly extracted from it as the tracer has already passed through it
if !(closure_ty <: Nothing)
:(args = ($fnwrap, args...))
end
)

body = quote
$(flatten_code...)
$(xla_call_code)
$(sync_call)
$(unflatten_code...)
return result
end)
end

body = expr.args[2]
return register_thunk(fname, body)
return register_thunk(fname, Tuple{map(Core.Typeof, args)...}, body, f, isclosure)
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

struct Thunk{tag} end
struct Thunk{FTy,tag,IsClosure,ArgTypes}
f::FTy
end

struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end

@generated function (thunk::Thunk{tag})(args...) where {tag}
return __thunk_body_cache[tag]
function Base.showerror(
io::IO, ece::MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}
) where {FTy,tag,ArgTypes,FoundTypes,IsClosure}
print(
io,
"\nThe Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure})` exists, but no method is defined for this combination of argument types.",
)
print(
io,
"\nYou passed in arguments with types\n\t(" *
join(FoundTypes.parameters, ", ") *
")",
)
return print(
io,
"\nHowever the method you are calling was compiled for arguments with types\n\t(" *
join(ArgTypes.parameters, ", ") *
")",
)
end

function register_thunk(tag, body)
@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure})(
args...
) where {FTy,tag,ArgTypes,IsClosure}
FoundTypes = Tuple{args...}
if ArgTypes != FoundTypes
return quote
throw(
$(MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}())
)
end
end
body = __thunk_body_cache[tag]
if IsClosure
return quote
args = (thunk.f, args...)
$body
end
else
return body
end
end

function register_thunk(
tag::Symbol, @nospecialize(argtys::Type), body::Expr, @nospecialize(f), isclosure::Bool
)
__thunk_body_cache[tag] = body
return Thunk{tag}()
return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f)
end

end
4 changes: 3 additions & 1 deletion src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ end
world,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
Expand All @@ -81,7 +82,8 @@ else
REACTANT_METHOD_TABLE,
world,
false, #=forward_rules=#
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
Expand Down
2 changes: 1 addition & 1 deletion test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using NNlib, Reactant, Enzyme
@testset "Activation: $act" for act in (
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6
)
f_compile = Reactant.compile(sumabs2, (act, x_act))
f_compile = Reactant.compile(sumabs2, (act, x_act_ca))

y_simple = sumabs2(act, x_act)
y_compile = f_compile(act, x_act_ca)
Expand Down
Loading