Skip to content

Commit c72dacb

Browse files
Provide better error message if calling thunk with wrong types (#474)
* Provide better error message if calling thunk with wrong types * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add new line * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix test * bump enzyme --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 765e992 commit c72dacb

File tree

4 files changed

+63
-20
lines changed

4 files changed

+63
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ArrayInterface = "7.17.1"
5454
CEnum = "0.5"
5555
CUDA = "5.5"
5656
Downloads = "1.6"
57-
Enzyme = "0.13.22"
57+
Enzyme = "0.13.28"
5858
EnzymeCore = "0.8.8"
5959
GPUArraysCore = "0.1.6, 0.2"
6060
LinearAlgebra = "1.10"

src/Compiler.jl

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ import ..Reactant:
1919
ancestor,
2020
TracedType
2121

22-
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
22+
@inline function traced_getfield(@nospecialize(obj), field)
23+
return Base.getfield(obj, field)
24+
end
25+
2326
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
2427
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
2528
return Base.getindex(obj, field)
@@ -905,37 +908,75 @@ function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=fals
905908
end
906909

907910
fname = gensym(Symbol(Symbol(f), :_reactant))
908-
expr = :(function $(fname)(args...)
909-
$(
910-
# if `f` is a closure, then prepend the closure into `args`
911-
# the closure fields will be correctly extracted from it as the tracer has already passed through it
912-
if !(closure_ty <: Nothing)
913-
:(args = ($fnwrap, args...))
914-
end
915-
)
911+
912+
body = quote
916913
$(flatten_code...)
917914
$(xla_call_code)
918915
$(sync_call)
919916
$(unflatten_code...)
920917
return result
921-
end)
918+
end
922919

923-
body = expr.args[2]
924-
return register_thunk(fname, body)
920+
return register_thunk(fname, Tuple{map(Core.Typeof, args)...}, body, f, isclosure)
925921
end
926922

927923
# inspired by RuntimeGeneratedFunction.jl
928924
const __thunk_body_cache = Dict{Symbol,Expr}()
929925

930-
struct Thunk{tag} end
926+
struct Thunk{FTy,tag,IsClosure,ArgTypes}
927+
f::FTy
928+
end
929+
930+
struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
931931

932-
@generated function (thunk::Thunk{tag})(args...) where {tag}
933-
return __thunk_body_cache[tag]
932+
function Base.showerror(
933+
io::IO, ece::MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}
934+
) where {FTy,tag,ArgTypes,FoundTypes,IsClosure}
935+
print(
936+
io,
937+
"\nThe Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure})` exists, but no method is defined for this combination of argument types.",
938+
)
939+
print(
940+
io,
941+
"\nYou passed in arguments with types\n\t(" *
942+
join(FoundTypes.parameters, ", ") *
943+
")",
944+
)
945+
return print(
946+
io,
947+
"\nHowever the method you are calling was compiled for arguments with types\n\t(" *
948+
join(ArgTypes.parameters, ", ") *
949+
")",
950+
)
934951
end
935952

936-
function register_thunk(tag, body)
953+
@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure})(
954+
args...
955+
) where {FTy,tag,ArgTypes,IsClosure}
956+
FoundTypes = Tuple{args...}
957+
if ArgTypes != FoundTypes
958+
return quote
959+
throw(
960+
$(MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}())
961+
)
962+
end
963+
end
964+
body = __thunk_body_cache[tag]
965+
if IsClosure
966+
return quote
967+
args = (thunk.f, args...)
968+
$body
969+
end
970+
else
971+
return body
972+
end
973+
end
974+
975+
function register_thunk(
976+
tag::Symbol, @nospecialize(argtys::Type), body::Expr, @nospecialize(f), isclosure::Bool
977+
)
937978
__thunk_body_cache[tag] = body
938-
return Thunk{tag}()
979+
return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f)
939980
end
940981

941982
end

src/Interpreter.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ end
6666
world,
6767
false, #=forward_rules=#
6868
false, #=reverse_rules=#
69+
false, #=inactive_rules=#
6970
false, #=broadcast_rewrite=#
7071
set_reactant_abi,
7172
)
@@ -81,7 +82,8 @@ else
8182
REACTANT_METHOD_TABLE,
8283
world,
8384
false, #=forward_rules=#
84-
false, #=forward_rules=#
85+
false, #=reverse_rules=#
86+
false, #=inactive_rules=#
8587
false, #=broadcast_rewrite=#
8688
set_reactant_abi,
8789
)

test/nn/nnlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using NNlib, Reactant, Enzyme
1515
@testset "Activation: $act" for act in (
1616
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6
1717
)
18-
f_compile = Reactant.compile(sumabs2, (act, x_act))
18+
f_compile = Reactant.compile(sumabs2, (act, x_act_ca))
1919

2020
y_simple = sumabs2(act, x_act)
2121
y_compile = f_compile(act, x_act_ca)

0 commit comments

Comments
 (0)