@@ -19,7 +19,10 @@ import ..Reactant:
1919 ancestor,
2020 TracedType
2121
22- @inline traced_getfield (@nospecialize (obj), field) = Base. getfield (obj, field)
22+ @inline function traced_getfield (@nospecialize (obj), field)
23+ return Base. getfield (obj, field)
24+ end
25+
2326@inline function traced_getfield (@nospecialize (obj:: AbstractArray{T} ), field) where {T}
2427 (isbitstype (T) || ancestor (obj) isa RArray) && return Base. getfield (obj, field)
2528 return Base. getindex (obj, field)
@@ -905,37 +908,75 @@ function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=fals
905908 end
906909
907910 fname = gensym (Symbol (Symbol (f), :_reactant ))
908- expr = :(function $ (fname)(args... )
909- $ (
910- # if `f` is a closure, then prepend the closure into `args`
911- # the closure fields will be correctly extracted from it as the tracer has already passed through it
912- if ! (closure_ty <: Nothing )
913- :(args = ($ fnwrap, args... ))
914- end
915- )
911+
912+ body = quote
916913 $ (flatten_code... )
917914 $ (xla_call_code)
918915 $ (sync_call)
919916 $ (unflatten_code... )
920917 return result
921- end )
918+ end
922919
923- body = expr. args[2 ]
924- return register_thunk (fname, body)
920+ return register_thunk (fname, Tuple{map (Core. Typeof, args)... }, body, f, isclosure)
925921end
926922
927923# inspired by RuntimeGeneratedFunction.jl
928924const __thunk_body_cache = Dict {Symbol,Expr} ()
929925
930- struct Thunk{tag} end
926+ struct Thunk{FTy,tag,IsClosure,ArgTypes}
927+ f:: FTy
928+ end
929+
930+ struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
931931
932- @generated function (thunk:: Thunk{tag} )(args... ) where {tag}
933- return __thunk_body_cache[tag]
932+ function Base. showerror (
933+ io:: IO , ece:: MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}
934+ ) where {FTy,tag,ArgTypes,FoundTypes,IsClosure}
935+ print (
936+ io,
937+ " \n The Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure}) ` exists, but no method is defined for this combination of argument types." ,
938+ )
939+ print (
940+ io,
941+ " \n You passed in arguments with types\n\t (" *
942+ join (FoundTypes. parameters, " , " ) *
943+ " )" ,
944+ )
945+ return print (
946+ io,
947+ " \n However the method you are calling was compiled for arguments with types\n\t (" *
948+ join (ArgTypes. parameters, " , " ) *
949+ " )" ,
950+ )
934951end
935952
936- function register_thunk (tag, body)
953+ @generated function (thunk:: Thunk{FTy,tag,ArgTypes,IsClosure} )(
954+ args...
955+ ) where {FTy,tag,ArgTypes,IsClosure}
956+ FoundTypes = Tuple{args... }
957+ if ArgTypes != FoundTypes
958+ return quote
959+ throw (
960+ $ (MisMatchedThunkTypeError {Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes} ())
961+ )
962+ end
963+ end
964+ body = __thunk_body_cache[tag]
965+ if IsClosure
966+ return quote
967+ args = (thunk. f, args... )
968+ $ body
969+ end
970+ else
971+ return body
972+ end
973+ end
974+
975+ function register_thunk (
976+ tag:: Symbol , @nospecialize (argtys:: Type ), body:: Expr , @nospecialize (f), isclosure:: Bool
977+ )
937978 __thunk_body_cache[tag] = body
938- return Thunk {tag} ( )
979+ return Thunk {Core.Typeof(f), tag,argtys,isclosure} (f )
939980end
940981
941982end
0 commit comments