diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index 0a5e836e1aa6..ea079183b8a4 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -118,20 +118,47 @@ class ReflectionDefBase { info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; } } + template static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...)) { - auto fwrap = [func](const Class* target, Args... params) -> R { - return (const_cast(target)->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); + static_assert(std::is_base_of_v || std::is_base_of_v, + "Class must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [func](Class target, Args... params) -> R { + // call method pointer + return (target.*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class* target, Args... params) -> R { + // call method pointer + return (const_cast(target)->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } } template static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...) const) { - auto fwrap = [func](const Class* target, Args... params) -> R { - return (target->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); + static_assert(std::is_base_of_v || std::is_base_of_v, + "Class must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class target, Args... params) -> R { + // call method pointer + return (target.*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class* target, Args... params) -> R { + // call method pointer + return (target->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } } template @@ -140,6 +167,96 @@ class ReflectionDefBase { } }; +class GlobalDef : public ReflectionDefBase { + public: + /* + * \brief Define a global function. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the function. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + GlobalDef& def(const char* name, Func&& func, Extra&&... extra) { + RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), + std::forward(extra)...); + return *this; + } + + /* + * \brief Define a global function in ffi::PackedArgs format. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the function. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) { + RegisterFunc(name, ffi::Function::FromPacked(func), std::forward(extra)...); + return *this; + } + + /* + * \brief Expose a class method as a global function. + * + * An argument will be added to the first position if the function is not static. + * + * \tparam Class The class type. + * \tparam Func The function type. + * + * \param name The name of the method. + * \param func The function to be registered. + * + * \return The reflection definition. + */ + template + GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { + RegisterFunc(name, GetMethod_(std::string(name), std::forward(func)), + std::forward(extra)...); + return *this; + } + + private: + template + static TVM_FFI_INLINE Function GetMethod_(std::string name, Func&& func) { + return ffi::Function::FromTyped(std::forward(func), name); + } + + template + static TVM_FFI_INLINE Function GetMethod_(std::string name, R (Class::*func)(Args...) const) { + return GetMethod(std::string(name), func); + } + + template + static TVM_FFI_INLINE Function GetMethod_(std::string name, R (Class::*func)(Args...)) { + return GetMethod(std::string(name), func); + } + + template + void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) { + TVMFFIMethodInfo info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.doc = TVMFFIByteArray{nullptr, 0}; + info.type_schema = TVMFFIByteArray{nullptr, 0}; + info.flags = 0; + // obtain the method function + info.method = AnyView(func).CopyToTVMFFIAny(); + // apply method info traits + ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0)); + } +}; + template class ObjectDef : public ReflectionDefBase { public: diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc index 0ca2034aa219..a0dc660b4279 100644 --- a/ffi/src/ffi/container.cc +++ b/ffi/src/ffi/container.cc @@ -25,40 +25,11 @@ #include #include #include +#include namespace tvm { namespace ffi { -TVM_FFI_REGISTER_GLOBAL("ffi.Array").set_body_packed([](ffi::PackedArgs args, Any* ret) { - *ret = Array(args.data(), args.data() + args.size()); -}); - -TVM_FFI_REGISTER_GLOBAL("ffi.ArrayGetItem") - .set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }); - -TVM_FFI_REGISTER_GLOBAL("ffi.ArraySize").set_body_typed([](const ffi::ArrayObj* n) -> int64_t { - return static_cast(n->size()); -}); -// Map -TVM_FFI_REGISTER_GLOBAL("ffi.Map").set_body_packed([](ffi::PackedArgs args, Any* ret) { - TVM_FFI_ICHECK_EQ(args.size() % 2, 0); - Map data; - for (int i = 0; i < args.size(); i += 2) { - data.Set(args[i], args[i + 1]); - } - *ret = data; -}); - -TVM_FFI_REGISTER_GLOBAL("ffi.MapSize").set_body_typed([](const ffi::MapObj* n) -> int64_t { - return static_cast(n->size()); -}); - -TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem") - .set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }); - -TVM_FFI_REGISTER_GLOBAL("ffi.MapCount") - .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }); - // Favor struct outside function scope as MSVC may have bug for in fn scope struct. class MapForwardIterFunctor { public: @@ -86,10 +57,33 @@ class MapForwardIterFunctor { ffi::MapObj::iterator end_; }; -TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor") - .set_body_typed([](const ffi::MapObj* n) -> ffi::Function { - return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end())); - }); - +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("ffi.Array", + [](ffi::PackedArgs args, Any* ret) { + *ret = Array(args.data(), args.data() + args.size()); + }) + .def("ffi.ArrayGetItem", [](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }) + .def("ffi.ArraySize", + [](const ffi::ArrayObj* n) -> int64_t { return static_cast(n->size()); }) + .def_packed("ffi.Map", + [](ffi::PackedArgs args, Any* ret) { + TVM_FFI_ICHECK_EQ(args.size() % 2, 0); + Map data; + for (int i = 0; i < args.size(); i += 2) { + data.Set(args[i], args[i + 1]); + } + *ret = data; + }) + .def("ffi.MapSize", + [](const ffi::MapObj* n) -> int64_t { return static_cast(n->size()); }) + .def("ffi.MapGetItem", [](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }) + .def("ffi.MapCount", + [](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }) + .def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> ffi::Function { + return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end())); + }); +}); } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc index b6bfe73af373..4df86a213a06 100644 --- a/ffi/src/ffi/function.cc +++ b/ffi/src/ffi/function.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace tvm { @@ -307,31 +308,30 @@ int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol) { TVM_FFI_SAFE_CALL_END(); } -TVM_FFI_REGISTER_GLOBAL("ffi.FunctionRemoveGlobal") - .set_body_typed([](const tvm::ffi::String& name) -> bool { - return tvm::ffi::GlobalFunctionTable::Global()->Remove(name); - }); - -TVM_FFI_REGISTER_GLOBAL("ffi.FunctionListGlobalNamesFunctor").set_body_typed([]() { - // NOTE: we return functor instead of array - // so list global function names do not need to depend on array - // this is because list global function names usually is a core api that happens - // before array ffi functions are available. - tvm::ffi::Array names = tvm::ffi::GlobalFunctionTable::Global()->ListNames(); - auto return_functor = [names](int64_t i) -> tvm::ffi::Any { - if (i < 0) { - return names.size(); - } else { - return names[i]; - } - }; - return tvm::ffi::Function::FromTyped(return_functor); -}); - -TVM_FFI_REGISTER_GLOBAL("ffi.String").set_body_typed([](tvm::ffi::String val) -> tvm::ffi::String { - return val; -}); - -TVM_FFI_REGISTER_GLOBAL("ffi.Bytes").set_body_typed([](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { - return val; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ffi.FunctionRemoveGlobal", + [](const tvm::ffi::String& name) -> bool { + return tvm::ffi::GlobalFunctionTable::Global()->Remove(name); + }) + .def("ffi.FunctionListGlobalNamesFunctor", + []() { + // NOTE: we return functor instead of array + // so list global function names do not need to depend on array + // this is because list global function names usually is a core api that happens + // before array ffi functions are available. + tvm::ffi::Array names = + tvm::ffi::GlobalFunctionTable::Global()->ListNames(); + auto return_functor = [names](int64_t i) -> tvm::ffi::Any { + if (i < 0) { + return names.size(); + } else { + return names[i]; + } + }; + return tvm::ffi::Function::FromTyped(return_functor); + }) + .def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return val; }) + .def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return val; }); }); diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc index f3c48c8ad56f..dc3a5fb1ecdd 100644 --- a/ffi/src/ffi/ndarray.cc +++ b/ffi/src/ffi/ndarray.cc @@ -23,23 +23,27 @@ #include #include #include +#include namespace tvm { namespace ffi { -// Shape -TVM_FFI_REGISTER_GLOBAL("ffi.Shape").set_body_packed([](ffi::PackedArgs args, Any* ret) { - int64_t* mutable_data; - ObjectPtr shape = details::MakeEmptyShape(args.size(), &mutable_data); - for (int i = 0; i < args.size(); ++i) { - if (auto opt_int = args[i].try_cast()) { - mutable_data[i] = *opt_int; - } else { - TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("ffi.Shape", [](ffi::PackedArgs args, Any* ret) { + int64_t* mutable_data; + ObjectPtr shape = details::MakeEmptyShape(args.size(), &mutable_data); + for (int i = 0; i < args.size(); ++i) { + if (auto opt_int = args[i].try_cast()) { + mutable_data[i] = *opt_int; + } else { + TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; + } } - } - *ret = Shape(shape); + *ret = Shape(shape); + }); }); + } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 9b193b757fe2..84a83e4b73e6 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -404,7 +405,10 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { *ret = ObjectRef(ptr); } -TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", MakeObjectFromPackedArgs); +}); } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/testing.cc b/ffi/src/ffi/testing.cc index 6bc7968eab06..a1747a279df6 100644 --- a/ffi/src/ffi/testing.cc +++ b/ffi/src/ffi/testing.cc @@ -54,6 +54,12 @@ class TestObjectDerived : public TestObjectBase { TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjectDerived, TestObjectBase); }; +void TestRaiseError(String kind, String msg) { + throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE); +} + +void TestApply(Function f, PackedArgs args, Any* ret) { f.CallPacked(args, ret); } + TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; @@ -66,41 +72,27 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::ObjectDef() .def_ro("v_map", &TestObjectDerived::v_map) .def_ro("v_array", &TestObjectDerived::v_array); -}); - -void TestRaiseError(String kind, String msg) { - throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE); -} - -TVM_FFI_REGISTER_GLOBAL("testing.test_raise_error").set_body_typed(TestRaiseError); - -TVM_FFI_REGISTER_GLOBAL("testing.nop").set_body_packed([](PackedArgs args, Any* ret) { - *ret = args[0]; -}); - -TVM_FFI_REGISTER_GLOBAL("testing.echo").set_body_packed([](PackedArgs args, Any* ret) { - *ret = args[0]; -}); - -void TestApply(Function f, PackedArgs args, Any* ret) { f.CallPacked(args, ret); } - -TVM_FFI_REGISTER_GLOBAL("testing.apply").set_body_packed([](PackedArgs args, Any* ret) { - auto f = args[0].cast(); - TestApply(f, args.Slice(1), ret); -}); - -TVM_FFI_REGISTER_GLOBAL("testing.run_check_signal").set_body_typed([](int nsec) { - for (int i = 0; i < nsec; ++i) { - if (TVMFFIEnvCheckSignals() != 0) { - throw ffi::EnvErrorAlreadySet(); - } - std::this_thread::sleep_for(std::chrono::seconds(1)); - } - std::cout << "Function finished without catching signal" << std::endl; -}); -TVM_FFI_REGISTER_GLOBAL("testing.object_use_count").set_body_typed([](const Object* obj) { - return obj->use_count(); + refl::GlobalDef() + .def("testing.test_raise_error", TestRaiseError) + .def_packed("testing.nop", [](PackedArgs args, Any* ret) { *ret = args[0]; }) + .def_packed("testing.echo", [](PackedArgs args, Any* ret) { *ret = args[0]; }) + .def_packed("testing.apply", + [](PackedArgs args, Any* ret) { + auto f = args[0].cast(); + TestApply(f, args.Slice(1), ret); + }) + .def("testing.run_check_signal", + [](int nsec) { + for (int i = 0; i < nsec; ++i) { + if (TVMFFIEnvCheckSignals() != 0) { + throw ffi::EnvErrorAlreadySet(); + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + std::cout << "Function finished without catching signal" << std::endl; + }) + .def("testing.object_use_count", [](const Object* obj) { return obj->use_count(); }); }); } // namespace ffi diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc index 526e1ad03e96..c3c484f33317 100644 --- a/ffi/tests/cpp/test_function.cc +++ b/ffi/tests/cpp/test_function.cc @@ -131,7 +131,7 @@ TEST(Func, FromTyped) { EXPECT_EQ(error.kind(), "TypeError"); EXPECT_EQ(error.message(), "Mismatched number of arguments when calling: " - "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> object.Function`. " + "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> ffi.Function`. " "Expected 3 but got 0 arguments"); throw; } @@ -236,11 +236,4 @@ TEST(Func, ObjectRefWithFallbackTraits) { ::tvm::ffi::Error); } -TVM_FFI_REGISTER_GLOBAL("testing.Int_GetValue").set_body_method(&TIntObj::GetValue); - -TEST(Func, Register) { - Function fget_value = Function::GetGlobalRequired("testing.Int_GetValue"); - TInt a(12); - EXPECT_EQ(fget_value(a).cast(), 12); -} } // namespace diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc index 17494744ef65..ce15fc14c4be 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -153,4 +153,15 @@ TEST(Reflection, ForEachFieldInfo) { EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject)); } +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_method("testing.Int_GetValue", &TIntObj::GetValue); +}); + +TEST(Reflection, FuncRegister) { + Function fget_value = Function::GetGlobalRequired("testing.Int_GetValue"); + TInt a(12); + EXPECT_EQ(fget_value(a).cast(), 12); +} + } // namespace