diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 146b9cf05071..88ad99e47af2 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -128,10 +128,10 @@ RPCEnv::RPCEnv(const std::string& wd) { ffi::Function::SetGlobal( "tvm.rpc.server.workpath", - ffi::Function::FromUnpacked([this](const std::string& path) { return this->GetPath(path); })); + ffi::Function::FromTyped([this](const std::string& path) { return this->GetPath(path); })); ffi::Function::SetGlobal("tvm.rpc.server.listdir", - ffi::Function::FromUnpacked([this](const std::string& path) { + ffi::Function::FromTyped([this](const std::string& path) { std::string dir = this->GetPath(path); std::ostringstream os; for (auto d : ListDir(dir)) { @@ -141,7 +141,7 @@ RPCEnv::RPCEnv(const std::string& wd) { })); ffi::Function::SetGlobal("tvm.rpc.server.load_module", - ffi::Function::FromUnpacked([this](const std::string& path) { + ffi::Function::FromTyped([this](const std::string& path) { std::string file_name = this->GetPath(path); file_name = BuildSharedLibrary(file_name); LOG(INFO) << "Load module from " << file_name << " ..."; @@ -149,7 +149,7 @@ RPCEnv::RPCEnv(const std::string& wd) { })); ffi::Function::SetGlobal("tvm.rpc.server.download_linked_module", - ffi::Function::FromUnpacked([this](const std::string& path) { + ffi::Function::FromTyped([this](const std::string& path) { std::string file_name = this->GetPath(path); file_name = BuildSharedLibrary(file_name); std::string bin; diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index 5b39ace9ba6c..f4fc9fb365a8 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -137,27 +137,25 @@ Model::Model(tvm::runtime::Module executor, tvm::runtime::Module module, std::st run = get_module_func(model_executor, "run"); } -const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) { +const tvm::ffi::Function get_runtime_func(const std::string& name) { if (auto pf = tvm::ffi::Function::GetGlobal(name)) { return *pf; } - return tvm::runtime::PackedFunc(); + return tvm::ffi::Function(); } -const tvm::runtime::PackedFunc get_module_func(tvm::runtime::Module module, - const std::string& name) { +const tvm::ffi::Function get_module_func(tvm::runtime::Module module, const std::string& name) { return module.GetFunction(name, false); } void reset_device_api() { - const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon"); + const tvm::ffi::Function api = get_runtime_func("device_api.hexagon"); tvm::ffi::Function::SetGlobal("device_api.cpu", api, true); } tvm::runtime::Module load_module(const std::string& file_name) { - static const tvm::runtime::PackedFunc loader = - get_runtime_func("runtime.module.loadfile_hexagon"); - tvm::runtime::TVMRetValue rv = loader(file_name); + static const tvm::ffi::Function loader = get_runtime_func("runtime.module.loadfile_hexagon"); + tvm::ffi::Any rv = loader(file_name); if (rv.type_code() == kTVMModuleHandle) { ICHECK_EQ(rv.type_code(), kTVMModuleHandle) << __func__ << ": loaded " << file_name << ", but did not get module handle"; @@ -180,7 +178,7 @@ tvm::runtime::Module create_graph_executor(const std::string& graph_json, tvm::runtime::Module graph_module, tvm::Device device) { std::string launcher_name = "tvm.graph_executor.create"; - const tvm::runtime::PackedFunc create_executor = get_runtime_func(launcher_name); + const tvm::ffi::Function create_executor = get_runtime_func(launcher_name); uint64_t device_type = device.device_type; uint64_t device_id = device.device_id; @@ -188,18 +186,18 @@ tvm::runtime::Module create_graph_executor(const std::string& graph_json, LOG(ERROR) << __func__ << ": graph executor requires graph JSON"; return tvm::runtime::Module(); } - tvm::runtime::TVMRetValue rv = create_executor(graph_json, graph_module, device_type, device_id); + tvm::ffi::Any rv = create_executor(graph_json, graph_module, device_type, device_id); return rv.operator tvm::runtime::Module(); } tvm::runtime::Module create_aot_executor(tvm::runtime::Module factory_module, tvm::Device device) { - tvm::runtime::PackedFunc list_modules = get_module_func(factory_module, "list_module_names"); + tvm::ffi::Function list_modules = get_module_func(factory_module, "list_module_names"); tvm::Array module_names = list_modules(); if (module_names.size() != 1) { LOG(WARNING) << __func__ << ": expecting single module, got: " << module_names << ", using " << module_names[0]; } - tvm::runtime::PackedFunc f = get_module_func(factory_module, module_names[0]); + tvm::ffi::Function f = get_module_func(factory_module, module_names[0]); if (f.get() == nullptr) { LOG(ERROR) << __func__ << ": failed to obtain function " << module_names[0]; return tvm::runtime::Module(); diff --git a/apps/hexagon_launcher/launcher_core.h b/apps/hexagon_launcher/launcher_core.h index da0dfcbbd5a6..ae9e4108cd57 100644 --- a/apps/hexagon_launcher/launcher_core.h +++ b/apps/hexagon_launcher/launcher_core.h @@ -90,7 +90,7 @@ struct Model { static tvm::Device device() { return tvm::Device{static_cast(kDLHexagon), 0}; } static tvm::Device external() { return tvm::Device{static_cast(kDLCPU), 0}; } - tvm::runtime::PackedFunc run; + tvm::ffi::Function run; }; struct ExecutionSession { @@ -123,9 +123,8 @@ void reset_device_api(); tvm::runtime::Module load_module(const std::string& file_name); -const tvm::runtime::PackedFunc get_runtime_func(const std::string& name); -const tvm::runtime::PackedFunc get_module_func(tvm::runtime::Module module, - const std::string& name); +const tvm::ffi::Function get_runtime_func(const std::string& name); +const tvm::ffi::Function get_module_func(tvm::runtime::Module module, const std::string& name); tvm::runtime::Module create_aot_executor(tvm::runtime::Module factory_module, tvm::Device device); tvm::runtime::Module create_graph_executor(const std::string& graph_json, diff --git a/apps/hexagon_launcher/launcher_hexagon.cc b/apps/hexagon_launcher/launcher_hexagon.cc index 63659cb5044d..bd1df4aa62ad 100644 --- a/apps/hexagon_launcher/launcher_hexagon.cc +++ b/apps/hexagon_launcher/launcher_hexagon.cc @@ -47,7 +47,7 @@ static AEEResult error_too_small(const std::string& func_name, const std::string int __QAIC_HEADER(launcher_rpc_open)(const char* uri, remote_handle64* handle) { *handle = 0; // Just use any value. reset_device_api(); - static const tvm::runtime::PackedFunc acq_res = + static const tvm::ffi::Function acq_res = get_runtime_func("device_api.hexagon.acquire_resources"); acq_res(); return AEE_SUCCESS; @@ -55,7 +55,7 @@ int __QAIC_HEADER(launcher_rpc_open)(const char* uri, remote_handle64* handle) { int __QAIC_HEADER(launcher_rpc_close)(remote_handle64 handle) { // Comment to stop clang-format from single-lining this function. - static const tvm::runtime::PackedFunc rel_res = + static const tvm::ffi::Function rel_res = get_runtime_func("device_api.hexagon.release_resources"); rel_res(); return AEE_SUCCESS; @@ -104,8 +104,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_num_inputs)(remote_handle64 handle, int return AEE_EBADSTATE; } - tvm::runtime::PackedFunc get_num_inputs = - get_module_func(TheModel->model_executor, "get_num_inputs"); + tvm::ffi::Function get_num_inputs = get_module_func(TheModel->model_executor, "get_num_inputs"); *num_inputs = get_num_inputs(); return AEE_SUCCESS; } @@ -140,7 +139,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_set_input)(remote_handle64 handle, int inpu auto input = tvm::runtime::NDArray::FromDLPack(&managed); - tvm::runtime::PackedFunc set_input = get_module_func(TheModel->model_executor, "set_input"); + tvm::ffi::Function set_input = get_module_func(TheModel->model_executor, "set_input"); set_input(input_idx, input); return AEE_SUCCESS; @@ -152,8 +151,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_num_outputs)(remote_handle64 handle, in return AEE_EBADSTATE; } - tvm::runtime::PackedFunc get_num_outputs = - get_module_func(TheModel->model_executor, "get_num_outputs"); + tvm::ffi::Function get_num_outputs = get_module_func(TheModel->model_executor, "get_num_outputs"); *num_outputs = get_num_outputs(); return AEE_SUCCESS; } @@ -173,7 +171,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out return AEE_EBADPARM; } - tvm::runtime::PackedFunc get_output = get_module_func(TheModel->model_executor, "get_output"); + tvm::ffi::Function get_output = get_module_func(TheModel->model_executor, "get_output"); tvm::runtime::NDArray output = get_output(output_idx); std::vector shape_vec{output->shape, output->shape + output->ndim}; diff --git a/apps/ios_rpc/tvmrpc/RPCServer.mm b/apps/ios_rpc/tvmrpc/RPCServer.mm index 284a2cfcd9ee..3dc2fb0c192a 100644 --- a/apps/ios_rpc/tvmrpc/RPCServer.mm +++ b/apps/ios_rpc/tvmrpc/RPCServer.mm @@ -51,7 +51,7 @@ * 2: need to write * 0: shutdown */ -using FEventHandler = PackedFunc; +using FEventHandler = ffi::Function; /*! * \brief Create a server event handler. @@ -68,7 +68,7 @@ FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string << "You are using tvm_runtime module built without RPC support. " << "Please rebuild it with USE_RPC flag."; - PackedFunc writer_func([outputStream](TVMArgs args, TVMRetValue* rv) { + ffi::Function writer_func([outputStream](ffi::PackedArgs args, ffi::Any* rv) { TVMByteArray* data = args[0].ptr(); int64_t nbytes = [outputStream write:reinterpret_cast(data->data) maxLength:data->size]; diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index a14bb50e0b2c..243e4819d025 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -51,14 +51,15 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - static const std::string base_ = NSTemporaryDirectory().UTF8String; - const auto path = args[0].cast(); - *rv = base_ + "/" + path; -}); +TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + static const std::string base_ = NSTemporaryDirectory().UTF8String; + const auto path = args[0].cast(); + *rv = base_ + "/" + path; + }); TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto name = args[0].cast(); std::string fmt = GetFileFormat(name, ""); NSString* base; @@ -109,7 +110,7 @@ void Init(const std::string& name) { // Add UnsignedDSOLoader plugin in global registry TVM_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto n = make_object(); n->Init(args[0]); *rv = CreateModuleFromLibrary(n); diff --git a/docs/arch/index.rst b/docs/arch/index.rst index c13a518a6263..f6880351cf48 100644 --- a/docs/arch/index.rst +++ b/docs/arch/index.rst @@ -198,7 +198,7 @@ tvm/runtime The runtime serves as the foundation of the TVM stack. It provides the mechanism to load and execute compiled artifacts. The runtime defines a stable standard set of C APIs to interface with frontend languages such as Python and Rust. -`runtime::Object` is one of the primary data structures in TVM runtime besides the `runtime::PackedFunc`. +`runtime::Object` is one of the primary data structures in TVM runtime besides the `ffi::Function`. It is a reference-counted base class with a type index to support runtime type checking and downcasting. The object system allows the developer to introduce new data structures to the runtime, such as Array, Map, and new IR data structures. diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index 90e02a4f0ce7..f797039ee386 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -54,7 +54,7 @@ The following code block provides an example in C++ #include - void MyAdd(TVMArgs args, TVMRetValue* rv) { + void MyAdd(ffi::PackedArgs args, ffi::Any* rv) { // automatically convert arguments to desired type. int a = args[0].cast(); int b = args[1].cast(); @@ -71,8 +71,8 @@ The following code block provides an example in C++ In the above codeblock, we defined a PackedFunc MyAdd. It takes two arguments : ``args`` represents input arguments and ``rv`` represents return value. The function is type-erased, which means that the function signature does not restrict which input type to pass in or type to return. -Under the hood, when we call a PackedFunc, it packs the input arguments to TVMArgs on stack, -and gets the result back via TVMRetValue. +Under the hood, when we call a PackedFunc, it packs the input arguments to ffi::PackedArgs on stack, +and gets the result back via ffi::Any. Thanks to template tricks in C++, we can call a PackedFunc just like a normal function. Because of its type-erased nature, we can call a PackedFunc from dynamic languages like python, without additional glue code for each new type function created. The following example registers PackedFunc in C++ and calls from python. @@ -91,7 +91,7 @@ The following example registers PackedFunc in C++ and calls from python. # prints 3 print(myadd(1, 2)) -Most of the magic of PackedFunc lies in ``TVMArgs`` and ``TVMRetValue`` structure. +Most of the magic of PackedFunc lies in ``ffi::PackedArgs`` and ``ffi::Any`` structure. We restrict a list of possible types which can be passed. Here are the common ones: @@ -111,7 +111,7 @@ we can pass functions from python (as PackedFunc) to C++. .. code:: c TVM_REGISTER_GLOBAL("callhello") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { PackedFunc f = args[0]; f("hello world"); }); diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 708495b58199..bdcca7b73ead 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -127,7 +127,7 @@ class FunctionObjImpl : public FunctionObj { /*! \brief The type of derived object class */ using TSelf = FunctionObjImpl; /*! - * \brief Derived object class for constructing PackedFuncObj. + * \brief Derived object class for constructing ffi::FunctionObj. * \param callable The type-erased callable object. */ explicit FunctionObjImpl(TCallable callable) : callable_(callable) { @@ -292,7 +292,7 @@ class PackedArgs { /*! * \brief ffi::Function is a type-erased function. - * The arguments are passed by packed format. + * The arguments are passed by "packed format" via AnyView */ class Function : public ObjectRef { public: @@ -300,7 +300,7 @@ class Function : public ObjectRef { Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) /*! * \brief Constructing a packed function from a callable type - * whose signature is consistent with `PackedFunc` + * whose signature is consistent with `ffi::Function` * \param packed_call The packed function signature * \note legacy purpose, should change to Function::FromPacked for mostfuture use. */ @@ -310,7 +310,7 @@ class Function : public ObjectRef { } /*! * \brief Constructing a packed function from a callable type - * whose signature is consistent with `PackedFunc` + * whose signature is consistent with `ffi::Function` * \param packed_call The packed function signature */ template @@ -462,7 +462,7 @@ class Function : public ObjectRef { * \param callable the internal container of packed function. */ template - static Function FromUnpacked(TCallable callable) { + static Function FromTyped(TCallable callable) { using FuncInfo = details::FunctionInfo; auto call_packed = [callable](const AnyView* args, int32_t num_args, Any* rv) mutable -> void { details::unpack_call( @@ -477,7 +477,7 @@ class Function : public ObjectRef { * \param name optional name attacked to the function. */ template - static Function FromUnpacked(TCallable callable, std::string name) { + static Function FromTyped(TCallable callable, std::string name) { using FuncInfo = details::FunctionInfo; auto call_packed = [callable, name](const AnyView* args, int32_t num_args, Any* rv) mutable -> void { @@ -540,7 +540,7 @@ class Function : public ObjectRef { private: /*! * \brief Constructing a packed function from a callable type - * whose signature is consistent with `PackedFunc` + * whose signature is consistent with `ffi::Function` * \param packed_call The packed function signature */ template @@ -560,16 +560,16 @@ class TypedFunction; /*! * \anchor TypedFunctionAnchor - * \brief A PackedFunc wrapper to provide typed function signature. - * It is backed by a PackedFunc internally. + * \brief A ffi::Function wrapper to provide typed function signature. + * It is backed by a ffi::Function internally. * * TypedFunction enables compile time type checking. * TypedFunction works with the runtime system: - * - It can be passed as an argument of PackedFunc. - * - It can be assigned to TVMRetValue. - * - It can be directly converted to a type-erased PackedFunc. + * - It can be passed as an argument of ffi::Function. + * - It can be assigned to ffi::Any. + * - It can be directly converted to a type-erased ffi::Function. * - * Developers should prefer TypedFunction over PackedFunc in C++ code + * Developers should prefer TypedFunction over ffi::Function in C++ code * as it enables compile time checking. * We can construct a TypedFunction from a lambda function * with the same signature. @@ -584,8 +584,8 @@ class TypedFunction; * TypedFunction ftyped(addone); * // invoke the function. * int y = ftyped(1); - * // Can be directly converted to PackedFunc - * PackedFunc packed = ftype; + * // Can be directly converted to ffi::Function + * ffi::Function packed = ftype; * \endcode * \tparam R The return value of the function. * \tparam Args The argument signature of the function. @@ -623,7 +623,7 @@ class TypedFunction { template >::value>::type> TypedFunction(FLambda typed_lambda, std::string name) { // NOLINT(*) - packed_ = Function::FromUnpacked(typed_lambda, name); + packed_ = Function::FromTyped(typed_lambda, name); } /*! * \brief construct from a lambda function with the same signature. @@ -646,7 +646,7 @@ class TypedFunction { template >::value>::type> TypedFunction(const FLambda& typed_lambda) { // NOLINT(*) - packed_ = Function::FromUnpacked(typed_lambda); + packed_ = Function::FromTyped(typed_lambda); } /*! * \brief copy assignment operator from typed lambda @@ -668,11 +668,11 @@ class TypedFunction { std::is_convertible>::value>::type> TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) - packed_ = Function::FromUnpacked(typed_lambda); + packed_ = Function::FromTyped(typed_lambda); return *this; } /*! - * \brief copy assignment operator from PackedFunc. + * \brief copy assignment operator from ffi::Function. * \param packed The packed function. * \returns reference to self. */ @@ -698,16 +698,16 @@ class TypedFunction { } } /*! - * \brief convert to PackedFunc - * \return the internal PackedFunc + * \brief convert to ffi::Function + * \return the internal ffi::Function */ operator Function() const { return packed(); } /*! - * \return reference the internal PackedFunc + * \return reference the internal ffi::Function */ const Function& packed() const& { return packed_; } /*! - * \return r-value reference the internal PackedFunc + * \return r-value reference the internal ffi::Function */ constexpr Function&& packed() && { return std::move(packed_); } /*! \return Whether the packed function is nullptr */ @@ -797,7 +797,7 @@ class Function::Registry { */ template Registry& set_body_typed(FLambda f) { - return Register(Function::FromUnpacked(f, name_)); + return Register(Function::FromTyped(f, name_)); } /*! @@ -838,14 +838,14 @@ class Function::Registry { // call method pointer return (target.*f)(params...); }; - return Register(ffi::Function::FromUnpacked(fwrap, name_)); + return Register(ffi::Function::FromTyped(fwrap, name_)); } if constexpr (std::is_base_of_v) { auto fwrap = [f](const T* target, Args... params) -> R { // call method pointer return (const_cast(target)->*f)(params...); }; - return Register(ffi::Function::FromUnpacked(fwrap, name_)); + return Register(ffi::Function::FromTyped(fwrap, name_)); } return *this; } @@ -859,14 +859,14 @@ class Function::Registry { // call method pointer return (target.*f)(params...); }; - return Register(ffi::Function::FromUnpacked(fwrap, name_)); + return Register(ffi::Function::FromTyped(fwrap, name_)); } if constexpr (std::is_base_of_v) { auto fwrap = [f](const T* target, Args... params) -> R { // call method pointer return (target->*f)(params...); }; - return Register(ffi::Function::FromUnpacked(fwrap, name_)); + return Register(ffi::Function::FromTyped(fwrap, name_)); } return *this; } diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index f47a253a5872..3e7f9be140c7 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -73,7 +73,7 @@ struct FuncFunctorImpl { static constexpr size_t num_args = sizeof...(Args); // MSVC is not that friendly to in-template nested bool evaluation #ifndef _MSC_VER - /*! \brief Whether this function can be converted to ffi::Function via FromUnpacked */ + /*! \brief Whether this function can be converted to ffi::Function via FromTyped */ static constexpr bool unpacked_supported = (ArgSupported && ...) && (RetSupported); #endif @@ -108,7 +108,7 @@ struct FunctionInfo : FuncFunctorImpl {}; template struct FunctionInfo : FuncFunctorImpl {}; -/*! \brief Using static function to output TypedPackedFunc signature */ +/*! \brief Using static function to output typed function signature */ typedef std::string (*FGetFuncSignature)(); /*! diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h index a4a363fa17cd..b300436feee6 100644 --- a/ffi/include/tvm/ffi/rvalue_ref.h +++ b/ffi/include/tvm/ffi/rvalue_ref.h @@ -53,7 +53,7 @@ namespace ffi { * \code * * void Example() { - * auto append = Function::FromUnpacked([](RValueRef> ref, int val) -> Array { + * auto append = Function::FromTyped([](RValueRef> ref, int val) -> Array { * Array arr = *std::move(ref); * assert(arr.unique()); * arr.push_back(val); diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc index 563505b74c5f..0ca2034aa219 100644 --- a/ffi/src/ffi/container.cc +++ b/ffi/src/ffi/container.cc @@ -88,7 +88,7 @@ class MapForwardIterFunctor { TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor") .set_body_typed([](const ffi::MapObj* n) -> ffi::Function { - return ffi::Function::FromUnpacked(MapForwardIterFunctor(n->begin(), n->end())); + return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end())); }); } // namespace ffi diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc index ed10ea59c9bc..594d3e01fbd0 100644 --- a/ffi/src/ffi/function.cc +++ b/ffi/src/ffi/function.cc @@ -279,7 +279,7 @@ TVM_FFI_REGISTER_GLOBAL("ffi.FunctionListGlobalNamesFunctor").set_body_typed([]( return names[i]; } }; - return tvm::ffi::Function::FromUnpacked(return_functor); + return tvm::ffi::Function::FromTyped(return_functor); }); TVM_FFI_REGISTER_GLOBAL("ffi.String").set_body_typed([](tvm::ffi::String val) -> tvm::ffi::String { diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc index bb0b062c328a..cb42f32c6cfe 100644 --- a/ffi/tests/cpp/test_array.cc +++ b/ffi/tests/cpp/test_array.cc @@ -185,8 +185,7 @@ TEST(Array, InsertEraseRange) { } TEST(Array, FuncArrayAnyArg) { - Function fadd_one = - Function::FromUnpacked([](Array a) -> Any { return a[0].cast() + 1; }); + Function fadd_one = Function::FromTyped([](Array a) -> Any { return a[0].cast() + 1; }); EXPECT_EQ(fadd_one(Array{1}).cast(), 2); } diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc index fbdc580f3b2b..526e1ad03e96 100644 --- a/ffi/tests/cpp/test_function.cc +++ b/ffi/tests/cpp/test_function.cc @@ -72,9 +72,9 @@ TEST(Func, PackedArgs) { EXPECT_EQ(data[2].cast()->value, 12); } -TEST(Func, FromUnpacked) { +TEST(Func, FromTyped) { // try decution - Function fadd1 = Function::FromUnpacked([](const int32_t& a) -> int { return a + 1; }); + Function fadd1 = Function::FromTyped([](const int32_t& a) -> int { return a + 1; }); int b = fadd1(1).cast(); EXPECT_EQ(b, 2); @@ -109,14 +109,14 @@ TEST(Func, FromUnpacked) { ::tvm::ffi::Error); // try decution - Function fpass_and_return = Function::FromUnpacked( + Function fpass_and_return = Function::FromTyped( [](TInt x, int value, AnyView z) -> Function { EXPECT_EQ(x.use_count(), 2); EXPECT_EQ(x->value, value); if (auto opt = z.as()) { EXPECT_EQ(value, *opt); } - return Function::FromUnpacked([value](int x) -> int { return x + value; }); + return Function::FromTyped([value](int x) -> int { return x + value; }); }, "fpass_and_return"); TInt a(11); @@ -139,18 +139,18 @@ TEST(Func, FromUnpacked) { ::tvm::ffi::Error); Function fconcact = - Function::FromUnpacked([](const String& a, const String& b) -> String { return a + b; }); + Function::FromTyped([](const String& a, const String& b) -> String { return a + b; }); EXPECT_EQ(fconcact("abc", "def").cast(), "abcdef"); } TEST(Func, PassReturnAny) { - Function fadd_one = Function::FromUnpacked([](Any a) -> Any { return a.cast() + 1; }); + Function fadd_one = Function::FromTyped([](Any a) -> Any { return a.cast() + 1; }); EXPECT_EQ(fadd_one(1).cast(), 2); } TEST(Func, Global) { Function::SetGlobal("testing.add1", - Function::FromUnpacked([](const int32_t& a) -> int { return a + 1; })); + Function::FromTyped([](const int32_t& a) -> int { return a + 1; })); auto fadd1 = Function::GetGlobalRequired("testing.add1"); int b = fadd1(1).cast(); EXPECT_EQ(b, 2); @@ -199,7 +199,7 @@ TEST(Func, TypedFunctionAsAnyView) { TEST(Func, ObjectRefWithFallbackTraits) { // test cases to test automatic type conversion via ObjectRefWithFallbackTraits // through TPrimExpr - Function freturn_primexpr = Function::FromUnpacked([](TPrimExpr a) -> TPrimExpr { return a; }); + Function freturn_primexpr = Function::FromTyped([](TPrimExpr a) -> TPrimExpr { return a; }); auto result_int = freturn_primexpr(1).cast(); EXPECT_EQ(result_int->dtype, "int64"); diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc index 1c43230bbc1f..bd0b58b7c46e 100644 --- a/ffi/tests/cpp/test_map.cc +++ b/ffi/tests/cpp/test_map.cc @@ -243,9 +243,9 @@ TEST(Map, AnyConvertCheck) { ::tvm::ffi::Error); } -TEST(Map, PackedFuncGetItem) { - Function f = Function::FromUnpacked([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, - "map_get_item"); +TEST(Map, ffi::FunctionGetItem) { + Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, + "map_get_item"); Map map{{"x", 1}, {"y", 2}}; Any k("x"); Any v = f(map, k); diff --git a/ffi/tests/cpp/test_rvalue_ref.cc b/ffi/tests/cpp/test_rvalue_ref.cc index ac81208d48ba..7cbd5c627b55 100644 --- a/ffi/tests/cpp/test_rvalue_ref.cc +++ b/ffi/tests/cpp/test_rvalue_ref.cc @@ -31,7 +31,7 @@ using namespace tvm::ffi::testing; TEST(RValueRef, Basic) { auto append = - Function::FromUnpacked([](RValueRef> ref, int val, bool is_unique) -> Array { + Function::FromTyped([](RValueRef> ref, int val, bool is_unique) -> Array { Array arr = *std::move(ref); EXPECT_EQ(arr.unique(), is_unique); arr.push_back(val); @@ -48,7 +48,7 @@ TEST(RValueRef, Basic) { TEST(RValueRef, ParamChecking) { // try decution - Function fadd1 = Function::FromUnpacked([](TInt a) -> int64_t { return a->value + 1; }); + Function fadd1 = Function::FromTyped([](TInt a) -> int64_t { return a->value + 1; }); // convert that triggers error EXPECT_THROW( @@ -65,7 +65,7 @@ TEST(RValueRef, ParamChecking) { }, ::tvm::ffi::Error); - Function fadd2 = Function::FromUnpacked([](RValueRef> a) -> int { + Function fadd2 = Function::FromTyped([](RValueRef> a) -> int { Array arr = *std::move(a); return arr[0] + 1; }); @@ -86,7 +86,7 @@ TEST(RValueRef, ParamChecking) { }, ::tvm::ffi::Error); // triggered a rvalue based conversion - Function func3 = Function::FromUnpacked([](RValueRef a) -> String { + Function func3 = Function::FromTyped([](RValueRef a) -> String { TPrimExpr expr = *std::move(a); return expr->dtype; }); diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc index 42c8e6aacc2b..79eeb488643d 100644 --- a/ffi/tests/cpp/test_tuple.cc +++ b/ffi/tests/cpp/test_tuple.cc @@ -84,9 +84,9 @@ TEST(Tuple, AnyConvert) { EXPECT_EQ(tuple2.get<1>()->value, 2); } -TEST(Tuple, FromUnpacked) { +TEST(Tuple, FromTyped) { // try decution - Function fadd1 = Function::FromUnpacked([](const Tuple& a) -> int { + Function fadd1 = Function::FromTyped([](const Tuple& a) -> int { return a.get<0>() + static_cast(a.get<1>()->value); }); int b = fadd1(Tuple(1, 2)).cast(); diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc index 94cbcd491a6e..ee49ac75d15f 100644 --- a/ffi/tests/cpp/test_variant.cc +++ b/ffi/tests/cpp/test_variant.cc @@ -72,9 +72,9 @@ TEST(Variant, ObjectPtrHashEqual) { EXPECT_TRUE(!ObjectPtrEqual()(v0, v2)); } -TEST(Variant, FromUnpacked) { +TEST(Variant, FromTyped) { // try decution - Function fadd1 = Function::FromUnpacked([](const Variant& a) -> int64_t { + Function fadd1 = Function::FromTyped([](const Variant& a) -> int64_t { if (auto opt_int = a.as()) { return opt_int.value() + 1; } else { @@ -100,7 +100,7 @@ TEST(Variant, FromUnpacked) { }, ::tvm::ffi::Error); - Function fadd2 = Function::FromUnpacked([](const Array>& a) -> int64_t { + Function fadd2 = Function::FromTyped([](const Array>& a) -> int64_t { if (auto opt_int = a[0].as()) { return opt_int.value() + 1; } else { diff --git a/golang/src/gotvm.cc b/golang/src/gotvm.cc index 7bfaf51dbaaa..d8919dafbfcb 100644 --- a/golang/src/gotvm.cc +++ b/golang/src/gotvm.cc @@ -168,7 +168,7 @@ void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) { extern int goTVMCallback(void*, void*, int, void*, void*); /*! - * \brief _TVMCallback is the TVM runtime callback function for PackedFunction system. + * \brief _TVMCallback is the TVM runtime callback function for ffi::Functiontion system. * * \param args is an array of TVMValue * \param type_codes is an array of int diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index eef32b4773f0..6e004415f8e9 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -139,8 +139,6 @@ class AttrFieldInfo : public ObjectRef { */ class BaseAttrsNode : public Object { public: - using TVMArgs = runtime::TVMArgs; - using TVMRetValue = runtime::TVMRetValue; /*! \brief virtual destructor */ virtual ~BaseAttrsNode() {} // visit function @@ -176,7 +174,8 @@ class BaseAttrsNode : public Object { * \param allow_unknown Whether allow additional unknown fields. * \note This function throws when the required field is not present. */ - TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0; + TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs, + bool allow_unknown = false) = 0; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -213,7 +212,7 @@ class DictAttrsNode : public BaseAttrsNode { // implementations void VisitAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final; - void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; + void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; // type info @@ -450,7 +449,8 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { // Namespace containing detail implementations namespace detail { -using runtime::TVMArgValue; + +using tvm::ffi::AnyView; // helper entry that does nothing in set_default/bound/describe calls. struct AttrNopEntry { @@ -878,7 +878,7 @@ class AttrsNode : public BaseAttrsNode { self()->_tvm_VisitAttrs(vis); } - void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final { + void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final { ICHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; @@ -899,7 +899,7 @@ class AttrsNode : public BaseAttrsNode { hit_count = vis.hit_count_; } else { // construct a map then do lookup. - std::unordered_map kwargs; + std::unordered_map kwargs; for (int i = 0; i < args.size(); i += 2) { kwargs[args[i].cast()] = args[i + 1]; } @@ -959,8 +959,8 @@ class AttrsNode : public BaseAttrsNode { template inline void BaseAttrsNode::InitBySeq(Args&&... args) { - runtime::PackedFunc pf( - [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); }); + ffi::Function pf( + [this](const ffi::PackedArgs& args, ffi::Any* rv) { this->InitByPackedArgs(args); }); pf(std::forward(args)...); } diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index add96d713bf8..6527643fbf61 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -33,8 +33,6 @@ namespace tvm { -using tvm::runtime::TypedPackedFunc; - /*! \brief The diagnostic level, controls the printing of the message. */ enum class DiagnosticLevel : int { kBug = 10, @@ -165,7 +163,7 @@ class DiagnosticContext; */ class DiagnosticRendererNode : public Object { public: - TypedPackedFunc renderer; + ffi::TypedFunction renderer; // override attr visitor void VisitAttrs(AttrVisitor* v) {} @@ -176,9 +174,9 @@ class DiagnosticRendererNode : public Object { class DiagnosticRenderer : public ObjectRef { public: - TVM_DLL DiagnosticRenderer(TypedPackedFunc render); + TVM_DLL DiagnosticRenderer(ffi::TypedFunction render); TVM_DLL DiagnosticRenderer() - : DiagnosticRenderer(TypedPackedFunc()) {} + : DiagnosticRenderer(ffi::TypedFunction()) {} void Render(const DiagnosticContext& ctx); diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index c44e102ccadd..52fab116360c 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -34,7 +34,7 @@ namespace tvm { /*! * \brief A serializable function backed by TVM's global environment. * - * This is a wrapper to enable serializable global PackedFunc. + * This is a wrapper to enable serializable global ffi::Function. * An EnvFunc is saved by its name in the global registry * under the assumption that the same function is registered during load. * \sa EnvFunc @@ -44,7 +44,7 @@ class EnvFuncNode : public Object { /*! \brief Unique name of the global function */ String name; /*! \brief The internal packed function */ - runtime::PackedFunc func; + ffi::Function func; /*! \brief constructor */ EnvFuncNode() {} @@ -82,7 +82,7 @@ class EnvFunc : public ObjectRef { * \returns The return value. */ template - runtime::TVMRetValue operator()(Args&&... args) const { + ffi::Any operator()(Args&&... args) const { const EnvFuncNode* n = operator->(); ICHECK(n != nullptr); return n->func(std::forward(args)...); diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index cee94d37a5c0..e19e3f3af124 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -50,16 +50,16 @@ enum class CallingConv : int { */ kDefault = 0, /*! - * \brief PackedFunc that exposes a CPackedFunc signature. + * \brief ffi::Function that exposes a Cffi::Function signature. * - * - Calling by PackedFunc calling convention. - * - Implementation: Expose a function with the CPackedFunc signature. + * - Calling by ffi::Function calling convention. + * - Implementation: Expose a function with the Cffi::Function signature. */ kCPackedFunc = 1, /*! * \brief Device kernel launch * - * - Call by PackedFunc calling convention. + * - Call by ffi::Function calling convention. * - Implementation: defined by device runtime(e.g. runtime/cuda) */ kDeviceKernelLaunch = 2, diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index a37485700544..8eaa62a98120 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -269,7 +269,7 @@ class OpRegEntry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpAttrMap - TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel); + TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel); }; /*! diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index cb8bd5ed9a68..76883feda76c 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -118,7 +118,7 @@ class BuilderNode : public runtime::Object { * \param build_inputs The inputs to be built. * \return The build results. */ - using FBuild = runtime::TypedPackedFunc(const Array&)>; + using FBuild = ffi::TypedFunction(const Array&)>; static constexpr const char* _type_key = "meta_schedule.Builder"; TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object); diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index db0f896d77ed..48a27340bff1 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -86,12 +86,12 @@ class PyCostModelNode : public CostModelNode { * \brief Load the cost model from given file location. * \param path The file path. */ - using FLoad = runtime::TypedPackedFunc; + using FLoad = ffi::TypedFunction; /*! * \brief Save the cost model to given file location. * \param path The file path. */ - using FSave = runtime::TypedPackedFunc; + using FSave = ffi::TypedFunction; /*! * \brief Update the cost model given running results. * \param context The tuning context. @@ -99,21 +99,21 @@ class PyCostModelNode : public CostModelNode { * \param results The running results of the measure candidates. * \return Whether cost model was updated successfully. */ - using FUpdate = runtime::TypedPackedFunc&, - const Array&)>; + using FUpdate = ffi::TypedFunction&, + const Array&)>; /*! * \brief Predict the running results of given measure candidates. * \param context The tuning context. * \param candidates The measure candidates. * \param p_addr The address to save the estimated running results. */ - using FPredict = runtime::TypedPackedFunc&, - void* p_addr)>; + using FPredict = + ffi::TypedFunction&, void* p_addr)>; /*! * \brief Get the cost model as string with name. * \return The string representation of the cost model. */ - using FAsString = runtime::TypedPackedFunc; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `Load` function. */ FLoad f_load; diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 744e4cfd54c3..45c4a241e29d 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -301,30 +301,30 @@ class PyDatabaseNode : public DatabaseNode { * \param mod The IRModule to be searched for. * \return Whether the database has the given workload. */ - using FHasWorkload = runtime::TypedPackedFunc; + using FHasWorkload = ffi::TypedFunction; /*! * \brief The function type of `CommitWorkload` method. * \param mod The IRModule to be searched for or added. * \return The workload corresponding to the given IRModule. */ - using FCommitWorkload = runtime::TypedPackedFunc; + using FCommitWorkload = ffi::TypedFunction; /*! * \brief The function type of `CommitTuningRecord` method. * \param record The tuning record to be added. */ - using FCommitTuningRecord = runtime::TypedPackedFunc; + using FCommitTuningRecord = ffi::TypedFunction; /*! * \brief The function type of `GetTopK` method. * \param workload The workload to be searched for. * \param top_k The number of top records to be returned. * \return An array of top K tuning records for the given workload. */ - using FGetTopK = runtime::TypedPackedFunc(const Workload&, int)>; + using FGetTopK = ffi::TypedFunction(const Workload&, int)>; /*! * \brief The function type of `GetAllTuningRecords` method. * \return An Array of all the tuning records in the database. */ - using FGetAllTuningRecords = runtime::TypedPackedFunc()>; + using FGetAllTuningRecords = ffi::TypedFunction()>; /*! * \brief The function type of `QueryTuningRecord` method. * \param mod The IRModule to be searched for. @@ -332,8 +332,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The best record of the given workload; NullOpt if not found. */ - using FQueryTuningRecord = runtime::TypedPackedFunc( - const IRModule&, const Target&, const String&)>; + using FQueryTuningRecord = + ffi::TypedFunction(const IRModule&, const Target&, const String&)>; /*! * \brief The function type of `QuerySchedule` method. * \param mod The IRModule to be searched for. @@ -341,8 +341,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; NullOpt if not found. */ - using FQuerySchedule = runtime::TypedPackedFunc( - const IRModule&, const Target&, const String&)>; + using FQuerySchedule = + ffi::TypedFunction(const IRModule&, const Target&, const String&)>; /*! * \brief The function type of `QueryIRModule` method. * \param mod The IRModule to be searched for. @@ -351,12 +351,12 @@ class PyDatabaseNode : public DatabaseNode { * \return The IRModule in the best IRModule of the given workload; NullOpt if not found. */ using FQueryIRModule = - runtime::TypedPackedFunc(const IRModule&, const Target&, const String&)>; + ffi::TypedFunction(const IRModule&, const Target&, const String&)>; /*! * \brief The function type of `Size` method. * \return The size of the database. */ - using FSize = runtime::TypedPackedFunc; + using FSize = ffi::TypedFunction; /*! \brief The packed function to the `HasWorkload` function. */ FHasWorkload f_has_workload; @@ -378,7 +378,7 @@ class PyDatabaseNode : public DatabaseNode { FSize f_size; void VisitAttrs(tvm::AttrVisitor* v) { - // PackedFuncs are all not visited, because the reflection system doesn't take care of them, + // ffi::Functions are all not visited, because the reflection system doesn't take care of them, // so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. // `f_has_workload` is not visited @@ -472,8 +472,8 @@ class Database : public runtime::ObjectRef { * and returns a boolean indicating if the schedule is successful. * \param mod_eq_name A string to specify the module equality testing and hashing method. */ - TVM_DLL static Database ScheduleFnDatabase( - runtime::TypedPackedFunc schedule_fn, String mod_eq_name = "structural"); + TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction schedule_fn, + String mod_eq_name = "structural"); /*! * \brief Create a default database that uses JSON file for tuning records. * \param path_workload The path to the workload table. diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 4165e5efe0fd..3e01faccaf28 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -63,13 +63,13 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { * \param candidates The measure candidates to extract features from. * \return The feature ndarray extracted. */ - using FExtractFrom = runtime::TypedPackedFunc( + using FExtractFrom = ffi::TypedFunction( const TuneContext& context, const Array& candidates)>; /*! * \brief Get the feature extractor as string with name. * \return The string of the feature extractor. */ - using FAsString = runtime::TypedPackedFunc; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `ExtractFrom` function. */ FExtractFrom f_extract_from; diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index 30d1c2cd3ee0..10356b6f5fb0 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -74,17 +74,16 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { * \param results The runner results by running the built measure candidates. * \return Whether the measure callback was successfully applied. */ - using FApply = - runtime::TypedPackedFunc& measure_candidates, // - const Array& builds, // - const Array& results)>; + using FApply = ffi::TypedFunction& measure_candidates, // + const Array& builds, // + const Array& results)>; /*! * \brief Get the measure callback function as string with name. * \return The string of the measure callback function. */ - using FAsString = runtime::TypedPackedFunc; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `Apply` function. */ FApply f_apply; diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 08a8248dfdbc..f8bf69180db5 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -78,24 +78,24 @@ class Mutator : public runtime::ObjectRef { * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. */ - using FInitializeWithTuneContext = runtime::TypedPackedFunc; + using FInitializeWithTuneContext = ffi::TypedFunction; /*! * \brief Apply the mutator function to the given trace. * \param trace The given trace for mutation. * \return None if mutator failed, otherwise return the mutated trace. */ - using FApply = runtime::TypedPackedFunc( + using FApply = ffi::TypedFunction( const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; /*! * \brief Clone the mutator. * \return The cloned mutator. */ - using FClone = runtime::TypedPackedFunc; + using FClone = ffi::TypedFunction; /*! * \brief Get the mutator as string with name. * \return The string of the mutator. */ - using FAsString = runtime::TypedPackedFunc; + using FAsString = ffi::TypedFunction; /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */ TVM_DLL static Mutator MutateTileSize(); /*! diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 47b6b80ea43e..5a2b96caf81f 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -75,23 +75,23 @@ class Postproc : public runtime::ObjectRef { * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. */ - using FInitializeWithTuneContext = runtime::TypedPackedFunc; + using FInitializeWithTuneContext = ffi::TypedFunction; /*! * \brief Apply a postprocessor to the given schedule. * \param sch The schedule to be post processed. * \return Whether the postprocessor was successfully applied. */ - using FApply = runtime::TypedPackedFunc; + using FApply = ffi::TypedFunction; /*! * \brief Clone the postprocessor. * \return The cloned postprocessor. */ - using FClone = runtime::TypedPackedFunc; + using FClone = ffi::TypedFunction; /*! * \brief Get the postprocessor function as string with name. * \return The string of the postprocessor function. */ - using FAsString = runtime::TypedPackedFunc; + using FAsString = ffi::TypedFunction; /*! * \brief Create a postprocessor with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. diff --git a/include/tvm/meta_schedule/profiler.h b/include/tvm/meta_schedule/profiler.h index 0f6572cca98b..91b7bfc45c09 100644 --- a/include/tvm/meta_schedule/profiler.h +++ b/include/tvm/meta_schedule/profiler.h @@ -47,8 +47,8 @@ class ScopedTimer { private: friend class Profiler; - explicit ScopedTimer(runtime::TypedPackedFunc deferred) : deferred_(deferred) {} - runtime::TypedPackedFunc deferred_; + explicit ScopedTimer(ffi::TypedFunction deferred) : deferred_(deferred) {} + ffi::TypedFunction deferred_; }; /*! \brief A generic profiler */ @@ -57,7 +57,7 @@ class ProfilerNode : public runtime::Object { /*! \brief The segments that are already profiled */ std::unordered_map stats_sec; /*! \brief Counter for the total time used */ - runtime::PackedFunc total_timer; + ffi::Function total_timer; void VisitAttrs(tvm::AttrVisitor* v) { // `stats_sec` is not visited. diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index c09572836931..0335f81cc16c 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -110,12 +110,12 @@ class RunnerFutureNode : public runtime::Object { * \brief The function type to check whether the runner has finished. * \return Whether the runner's output is ready. */ - using FDone = runtime::TypedPackedFunc; + using FDone = ffi::TypedFunction; /*! * \brief The function type to fetch runner output if it is ready. * \return The runner's output. */ - using FResult = runtime::TypedPackedFunc; + using FResult = ffi::TypedFunction; /*! \brief The packed function to check whether the runner has finished. */ FDone f_done; @@ -176,7 +176,7 @@ class RunnerNode : public runtime::Object { * \return The runner futures. * \sa RunnerFuture */ - using FRun = runtime::TypedPackedFunc(Array)>; + using FRun = ffi::TypedFunction(Array)>; /*! \brief Default destructor */ virtual ~RunnerNode() = default; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index b4da978b56c2..974254afc1b8 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -80,7 +80,7 @@ class ScheduleRule : public runtime::ObjectRef { * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. */ - using FInitializeWithTuneContext = runtime::TypedPackedFunc; + using FInitializeWithTuneContext = ffi::TypedFunction; /*! * \brief The function type of `Apply` method. * \param sch The schedule to be modified. @@ -88,17 +88,17 @@ class ScheduleRule : public runtime::ObjectRef { * \return The list of schedules generated by applying the schedule rule. */ using FApply = - runtime::TypedPackedFunc(const tir::Schedule&, const tir::BlockRV&)>; + ffi::TypedFunction(const tir::Schedule&, const tir::BlockRV&)>; /*! * \brief Get the schedule rule as string with name. * \return The string of the schedule rule. */ - using FAsString = runtime::TypedPackedFunc; + using FAsString = ffi::TypedFunction; /*! * \brief The function type of `Clone` method. * \return The cloned schedule rule. */ - using FClone = runtime::TypedPackedFunc; + using FClone = ffi::TypedFunction; /*! * \brief Create a rule that applies customized rules registered using block attribute * `schedule_rule`. The rule will be dispatched according to target keys. @@ -160,7 +160,7 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write, - Optional filter_fn = NullOpt); + Optional filter_fn = NullOpt); /*! * \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic. diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index aeef1bff306c..ca7ee7ec8407 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -143,31 +143,31 @@ class SearchStrategy : public runtime::ObjectRef { * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. */ - using FInitializeWithTuneContext = runtime::TypedPackedFunc; + using FInitializeWithTuneContext = ffi::TypedFunction; /*! * \brief The function type of `PreTuning` method. */ - using FPreTuning = runtime::TypedPackedFunc&, - const Optional&, const Optional&)>; + using FPreTuning = + ffi::TypedFunction&, + const Optional&, const Optional&)>; /*! \brief The function type of `PostTuning` method. */ - using FPostTuning = runtime::TypedPackedFunc; + using FPostTuning = ffi::TypedFunction; /*! * \brief The function type of `GenerateMeasureCandidates` method. * \return The measure candidates generated, nullptr if finished. */ - using FGenerateMeasureCandidates = runtime::TypedPackedFunc>()>; + using FGenerateMeasureCandidates = ffi::TypedFunction>()>; /*! * \brief The function type of `NotifyRunnerResults` method. * \param results The measurement results from the runner. */ using FNotifyRunnerResults = - runtime::TypedPackedFunc&, const Array&)>; + ffi::TypedFunction&, const Array&)>; /*! * \brief The function type of `Clone` method. * \return The cloned search strategy. */ - using FClone = runtime::TypedPackedFunc; + using FClone = ffi::TypedFunction; /*! * \brief Create a search strategy with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 650320d1e21c..b626b3e7739f 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -125,18 +125,18 @@ class SpaceGenerator : public runtime::ObjectRef { * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. */ - using FInitializeWithTuneContext = runtime::TypedPackedFunc; + using FInitializeWithTuneContext = ffi::TypedFunction; /*! * \brief The function type of `GenerateDesignSpace` method. * \param mod The module used for design space generation. * \return The generated design spaces, i.e., schedules. */ - using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; + using FGenerateDesignSpace = ffi::TypedFunction(const IRModule&)>; /*! * \brief The function type of `Clone` method. * \return The cloned space generator. */ - using FClone = runtime::TypedPackedFunc; + using FClone = ffi::TypedFunction; protected: SpaceGenerator() = default; @@ -167,7 +167,7 @@ class SpaceGenerator : public runtime::ObjectRef { * \param postprocs The postprocessors. * \param mutator_probs The probability of using certain mutator. */ - TVM_DLL static SpaceGenerator ScheduleFn(PackedFunc schedule_fn, + TVM_DLL static SpaceGenerator ScheduleFn(ffi::Function schedule_fn, Optional> sch_rules, Optional> postprocs, Optional> mutator_probs); @@ -192,7 +192,7 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mutator_probs The probability of using certain mutator. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter, + TVM_DLL static SpaceGenerator PostOrderApply(ffi::Function f_block_filter, Optional> sch_rules, Optional> postprocs, Optional> mutator_probs); diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 8cc3595d68d8..e75059116dc2 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -128,7 +128,7 @@ class TaskRecord : public runtime::ObjectRef { class TaskSchedulerNode : public runtime::Object { public: /*! \brief The tuning task's logging function. */ - PackedFunc logger; + ffi::Function logger; /*! \brief Records for each task */ Array tasks_; /*! \brief The list of measure callbacks of the scheduler. */ @@ -212,23 +212,23 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { * \brief The function type of `NextTaskId` method. * \return The next task id. */ - using FNextTaskId = runtime::TypedPackedFunc; + using FNextTaskId = ffi::TypedFunction; /*! * \brief The function type of `JoinRunningTask` method. * \param task_id The task id to be joined. */ - using FJoinRunningTask = runtime::TypedPackedFunc(int)>; + using FJoinRunningTask = ffi::TypedFunction(int)>; /*! \brief The function type of `Tune` method. */ - using FTune = runtime::TypedPackedFunc tasks, // - Array task_weights, // - int max_trials_global, // - int max_trials_per_task, // - int num_trials_per_iter, // - Builder builder, // - Runner runner, // - Array measure_callbacks, // - Optional database, // - Optional cost_model)>; + using FTune = ffi::TypedFunction tasks, // + Array task_weights, // + int max_trials_global, // + int max_trials_per_task, // + int num_trials_per_iter, // + Builder builder, // + Runner runner, // + Array measure_callbacks, // + Optional database, // + Optional cost_model)>; /*! \brief The packed function to the `NextTaskId` function. */ FNextTaskId f_next_task_id; @@ -266,7 +266,7 @@ class TaskScheduler : public runtime::ObjectRef { * \param logger The tuning task's logging function. * \return The task scheduler created. */ - TVM_DLL static TaskScheduler RoundRobin(PackedFunc logger); + TVM_DLL static TaskScheduler RoundRobin(ffi::Function logger); /*! * \brief Create a task scheduler that fetches tasks in a gradient based fashion. * \param logger The tuning task's logging function. @@ -275,7 +275,7 @@ class TaskScheduler : public runtime::ObjectRef { * \param seed The random seed. * \return The task scheduler created. */ - TVM_DLL static TaskScheduler GradientBased(PackedFunc logger, double alpha, int window_size, + TVM_DLL static TaskScheduler GradientBased(ffi::Function logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed); /*! * \brief Create a task scheduler with customized methods on the python-side. @@ -286,7 +286,7 @@ class TaskScheduler : public runtime::ObjectRef { * \return The task scheduler created. */ TVM_DLL static TaskScheduler PyTaskScheduler( - PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, + ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode); }; diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 15f3cba30b95..9eacf499f405 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -62,7 +62,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The random state. */ TRandState rand_state; /*! \brief The tuning task's logging function. t*/ - PackedFunc logger; + ffi::Function logger; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); @@ -109,7 +109,7 @@ class TuneContext : public runtime::ObjectRef { TVM_DLL explicit TuneContext(Optional mod, Optional target, Optional space_generator, Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, PackedFunc logger); + int num_threads, TRandState rand_state, ffi::Function logger); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); }; diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 909d0ad0d4f9..4e075c7e56da 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -55,7 +55,7 @@ class AttrRegistryMapContainerMap { * \param key The key to the map * \return the const reference to the content value. */ - const runtime::TVMRetValue& operator[](const KeyType& key) const { + const ffi::Any& operator[](const KeyType& key) const { ICHECK(key.defined()); const uint32_t idx = key->AttrRegistryIndex(); ICHECK(idx < data_.size() && data_[idx].second != 0) @@ -88,7 +88,7 @@ class AttrRegistryMapContainerMap { /*! \brief The name of the attr field */ String attr_name_; /*! \brief The internal data. */ - std::vector> data_; + std::vector> data_; /*! \brief The constructor */ AttrRegistryMapContainerMap() = default; template diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index ad4fb1e1c27a..12598cb156c2 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -49,18 +49,18 @@ namespace tvm { +using ffi::Any; +using ffi::AnyView; +using ffi::Object; +using ffi::ObjectPtr; +using ffi::ObjectPtrEqual; +using ffi::ObjectPtrHash; +using ffi::ObjectRef; +using ffi::PackedArgs; +using ffi::TypeIndex; using runtime::Downcast; using runtime::GetRef; using runtime::make_object; -using runtime::Object; -using runtime::ObjectPtr; -using runtime::ObjectPtrEqual; -using runtime::ObjectPtrHash; -using runtime::ObjectRef; -using runtime::PackedFunc; -using runtime::TVMArgs; -using runtime::TVMRetValue; -using runtime::TypeIndex; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 0ee09e70f474..0938f2c56ad2 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -156,7 +156,7 @@ class ReflectionVTable { * \param kwargs the arguments in format key1, value1, ..., key_n, value_n. * \return The created object. */ - TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs); + TVM_DLL ObjectRef CreateObject(const std::string& type_key, const ffi::PackedArgs& kwargs); /*! * \brief Create an object by giving kwargs about its fields. * @@ -172,7 +172,7 @@ class ReflectionVTable { * \return The corresponding attribute value. * \note This function will throw an exception if the object does not contain the field. */ - TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const; + TVM_DLL ffi::Any GetAttr(Object* self, const String& attr_name) const; /*! * \brief List all the fields in the object. diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 48cc35fcb886..85a7d0eee356 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -214,7 +214,7 @@ enum class BaseCheckResult { * - (b) We automatically insert match_cast at function boundary, so * we can erase (int)->int argument as (object)->int. * The input shape/type mismatch will be detected by runtime checks at function boundary. - * This behavior is also consistent with the PackedFunc behavior. + * This behavior is also consistent with the ffi::Function behavior. * * \note This level means there is no problem about static known information. * It is OK for the checker to do best effort and return this value. diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index 15d1c9c5fbda..fd2fa72a2410 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -70,8 +70,7 @@ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, */ TVM_DLL Function RewriteBindings( const PatternContext& ctx, - runtime::TypedPackedFunc(Map, Map)> rewriter, - Function f); + ffi::TypedFunction(Map, Map)> rewriter, Function f); /** * \brief Rewrite a function with the given pattern and the rewriter function. @@ -97,7 +96,7 @@ TVM_DLL Function RewriteBindings( * \return The updated function, if any updates were applied. */ TVM_DLL Function RewriteCall(const DFPattern& pattern, - runtime::TypedPackedFunc)> rewriter, + ffi::TypedFunction)> rewriter, Function func); } // namespace relax diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index c5103719d028..2cee3bca631b 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -118,7 +118,7 @@ class ExecBuilderNode : public Object { */ template vm::Instruction::Arg ConvertConstant(T value) { - TVMRetValue rv; + ffi::Any rv; rv = value; return ConvertConstant_(rv); } @@ -151,7 +151,7 @@ class ExecBuilderNode : public Object { * \param obj The constant value to be emitted * \return An Arg that represents the result of constant argument. */ - vm::Instruction::Arg ConvertConstant_(TVMRetValue obj); + vm::Instruction::Arg ConvertConstant_(ffi::Any obj); /*! * \brief A helper function to check if an executable is legal by checking if registers are used diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 0c215f023e28..bd9c59da3acb 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -59,8 +59,7 @@ enum OpPatternKind { * \param call The call expression to be derived. * \param ctx The builder context. */ -using FInferStructInfo = - runtime::TypedPackedFunc; +using FInferStructInfo = ffi::TypedFunction; /*! * \brief Packed function implementation for operators. The relax operator will be lowered to @@ -88,7 +87,7 @@ using FCallPacked = String; * \param call The call to be normalized. It is provided by-value, to * avoid copies for the common case where the call is already normalized. */ -using FNormalize = runtime::TypedPackedFunc; +using FNormalize = ffi::TypedFunction; /*! * \brief The function type of a validation function. @@ -107,7 +106,7 @@ using FNormalize = runtime::TypedPackedFunc; +using FValidate = ffi::TypedFunction; /*! \brief The function type of a legalization function. * @@ -123,7 +122,7 @@ using FValidate = runtime::TypedPackedFunc; * \param bb The BlockBuilder context. * \param call The call to be legalized. */ -using FLegalize = runtime::TypedPackedFunc; +using FLegalize = ffi::TypedFunction; /*! \brief The function type of a function to lower the runtime builtin. * @@ -132,7 +131,7 @@ using FLegalize = runtime::TypedPackedFunc; +using FLowerBuiltin = ffi::TypedFunction; /*! * \brief Gradient for a specific op. @@ -143,7 +142,7 @@ using FLowerBuiltin = runtime::TypedPackedFunc( +using FPrimalGradient = ffi::TypedFunction( const Var& orig_var, const Call& orig_call, const Var& output_grad, const BlockBuilder& ctx)>; } // namespace relax diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h index 1b29cb03582d..6a00487d69a0 100644 --- a/include/tvm/relax/tir_pattern.h +++ b/include/tvm/relax/tir_pattern.h @@ -69,7 +69,7 @@ class MatchResult : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode); }; -using FCodegen = runtime::TypedPackedFunc(Array match_results)>; +using FCodegen = ffi::TypedFunction(Array match_results)>; } // namespace relax } // namespace tvm #endif // TVM_RELAX_TIR_PATTERN_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 2da2ba53c701..98aa2673c23b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -246,7 +246,7 @@ TVM_DLL Pass FoldConstant(); * showing up in the database. * \return The Pass. */ -TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_warning = false); +TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_warning = false); /*! * \brief Propagate virtual device information. @@ -383,7 +383,7 @@ class FusionPatternNode : public Object { * It should have signature * bool(const PatternCheckContext& context) */ - Optional check; + Optional check; /*! * \brief The function to get attributes for fused function @@ -391,7 +391,7 @@ class FusionPatternNode : public Object { * It should have signature * Map(const Map& context) */ - Optional attrs_getter; + Optional attrs_getter; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); @@ -408,7 +408,7 @@ class FusionPatternNode : public Object { class FusionPattern : public ObjectRef { public: FusionPattern(String name, DFPattern pattern, Map annotation_patterns, - Optional check, Optional attrs_getter); + Optional check, Optional attrs_getter); FusionPattern(String name, DFPattern pattern) : FusionPattern(name, pattern, {}, NullOpt, NullOpt) {} diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h index bcbfad2c7ac9..c18a8cfb54a7 100644 --- a/include/tvm/relax/tuning_api.h +++ b/include/tvm/relax/tuning_api.h @@ -35,8 +35,8 @@ namespace relax { /*! \brief Helper function to unpack arguments in the array as parameters for the given packed * function. */ -TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, - const Array& args) { +TVM_ALWAYS_INLINE ffi::Any CallPackedWithArgsInArray(const ffi::Function f, + const Array& args) { size_t num_args = args.size(); std::vector packed_args(num_args); for (size_t i = 0; i < num_args; ++i) { @@ -68,14 +68,14 @@ class ChoiceNode : public runtime::Object { } /*! \brief Getter for constr_func. */ - const runtime::PackedFunc GetConstrFunc() { + const ffi::Function GetConstrFunc() { const auto constr_func = tvm::ffi::Function::GetGlobal(constr_func_key); ICHECK(constr_func.has_value()) << "constr_func_key is not registered: " << constr_func_key; return *std::move(constr_func); } /*! \brief Getter for transform_func. */ - const runtime::PackedFunc GetTransformFunc() { + const ffi::Function GetTransformFunc() { auto transform_func = tvm::ffi::Function::GetGlobal(transform_func_key); ICHECK(transform_func.has_value()) << "transform_func_key is not registered: " << transform_func_key; diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index 8fde5948f993..eb8d7270b137 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -105,7 +105,7 @@ TVM_DLL int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr); * specific C APIs. * * \note We only register the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). In most cases we should use the PackedFunc FFI. + * cannot trap back into python). In most cases we should use the ffi::Function FFI. * * \param name The name of the symbol * \param ptr The symbol address. diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 8238c7f148a2..b802dbc22839 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -180,7 +180,7 @@ typedef enum { kTVMNDArrayHandle = 13U, kTVMObjectRValueRefArg = 14U, kTVMArgBool = 15U, - // Extension codes for other frameworks to integrate TVM PackedFunc. + // Extension codes for other frameworks to integrate TVM ffi::Function. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. // Open an issue at the repo if you need a section of code. diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index c9c85b7dbfed..c7aeb4e284ad 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -65,7 +65,7 @@ class DiscoWorker { /*! \brief Get the worker instance on the current thread */ TVM_DLL static DiscoWorker* ThreadLocal(); /*! \brief Set the specific register to a specific value */ - void SetRegister(int reg_id, TVMArgValue value); + void SetRegister(int reg_id, ffi::AnyView value); /*! \brief The id of the worker.*/ int worker_id; @@ -92,7 +92,7 @@ class DiscoWorker { */ DiscoChannel* channel; /*! \brief The registers in the worker */ - std::vector register_file; + std::vector register_file; struct Impl; friend struct DiscoWorker::Impl; diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 32f148853073..fb21f79882ad 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -51,7 +51,7 @@ * * **Control plane.** The controler broadcasts commands to all the workers as control signals. * For example, the control may ask all workers to load a library or call a function respectively. - * Common control signals include: shutdown, retrievel a global PackedFunc, call packed function, + * Common control signals include: shutdown, retrievel a global ffi::Function, call packed function, * etc. The controler is assumed to keep a message channel to each worker to implement the broadcast * behavior, and the message channel may vary depends on usecases. * @@ -67,7 +67,7 @@ * * **Channel.** Disco channel is a bi-directional communication channel between the controler and * workers for exchanging control signals. It is no different from a generic RPC channel, but - * adopts TVM's PackedFunc calling convention to support polymorphic and variadic arguments. + * adopts TVM's ffi::Function calling convention to support polymorphic and variadic arguments. */ #ifndef TVM_RUNTIME_DISCO_SESSION_H_ #define TVM_RUNTIME_DISCO_SESSION_H_ @@ -138,13 +138,13 @@ class DRefObj : public Object { * \param worker_id The id of the worker to be fetched from. * \return The value of the register. */ - inline TVMRetValue DebugGetFromRemote(int worker_id); + inline ffi::Any DebugGetFromRemote(int worker_id); /*! * \brief Copy from the NDArray provided to a remote worker. * \param worker_id The id of the worker to be copied to. * \param source The NDArray to be copied. */ - inline void DebugCopyFrom(int worker_id, TVMArgValue source); + inline void DebugCopyFrom(int worker_id, ffi::AnyView source); static constexpr const char* _type_key = "runtime.disco.DRef"; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef; @@ -169,13 +169,13 @@ class DRef : public ObjectRef { /*! * \brief A Disco interactive session. It allows users to interact with the Disco command queue with - * various PackedFunc calling convention. + * various ffi::Function calling convention. */ class SessionObj : public Object { public: virtual ~SessionObj() = default; /*! - * \brief Call a PackedFunc on workers providing variadic arguments. + * \brief Call a ffi::Function on workers providing variadic arguments. * \tparam Args In the variadic arguments, the supported types include: * - integers and floating point numbers; * - DataType; @@ -184,7 +184,7 @@ class SessionObj : public Object { * - DRef. * Examples of unsupported types: * - NDArray, DLTensor; - * - TVM Objects, including PackedFunc, Module and String; + * - TVM Objects, including ffi::Function, Module and String; * \param func The function to be called. * \param args The variadic arguments. * \return The return value of function call @@ -197,7 +197,7 @@ class SessionObj : public Object { * The second element must be 0, which will later be updated by the session to return reg_id * The thirtd element is the function to be called. */ - TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0; + TVM_DLL virtual DRef CallWithPacked(const ffi::PackedArgs& args) = 0; /*! \brief Get the number of workers in the session. */ TVM_DLL virtual int64_t GetNumWorkers() = 0; /*! \brief Get a global functions on workers. */ @@ -236,7 +236,7 @@ class SessionObj : public Object { * \param worker_id The id of the worker to be fetched from. * \return The value of the register. */ - TVM_DLL virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0; + TVM_DLL virtual ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) = 0; /*! * \brief Set the value of a register on a remote worker. * \param reg_id The id of the register to be set. @@ -273,7 +273,7 @@ class Session : public ObjectRef { * \param num_workers The number of workers. * \param num_groups The number of worker groups. * \param process_pool_creator The name of a global function that takes `num_workers` as an input, - * and returns a PackedFunc, which takes an integer `worker_id` as the input and returns None. + * and returns a ffi::Function, which takes an integer `worker_id` as the input and returns None. * When `worker-id` is 0, it shuts down the process pool; Otherwise, it retursn a tuple * (read_fd, writefd) used to communicate with the corresponding worker. * \param entrypoint The entrypoint of DiscoWorker main worker function. @@ -294,13 +294,13 @@ class DiscoChannel { public: virtual ~DiscoChannel() = default; /*! \brief Send a packed sequence to the receiver */ - virtual void Send(const TVMArgs& args) = 0; + virtual void Send(const ffi::PackedArgs& args) = 0; /*! \brief Receive a packed sequence from worker */ - virtual TVMArgs Recv() = 0; + virtual ffi::PackedArgs Recv() = 0; /*! \brief Reply a packed sequence to the sender */ - virtual void Reply(const TVMArgs& args) = 0; + virtual void Reply(const ffi::PackedArgs& args) = 0; /*! \brief Receive a reply from the worker */ - virtual TVMArgs RecvReply() = 0; + virtual ffi::PackedArgs RecvReply() = 0; }; /*! @@ -326,11 +326,11 @@ DRefObj::~DRefObj() { } } -TVMRetValue DRefObj::DebugGetFromRemote(int worker_id) { +ffi::Any DRefObj::DebugGetFromRemote(int worker_id) { return Downcast(this->session)->DebugGetFromRemote(this->reg_id, worker_id); } -void DRefObj::DebugCopyFrom(int worker_id, TVMArgValue value) { +void DRefObj::DebugCopyFrom(int worker_id, ffi::AnyView value) { return Downcast(this->session)->DebugSetRegister(this->reg_id, value, worker_id); } diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 69e8d4283ede..37ab906dd422 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -42,8 +42,6 @@ namespace tvm { namespace runtime { -using PackedFunc = ffi::Function; - /*! * \brief Property of runtime module * We classify the property of runtime module into the following categories. @@ -89,10 +87,10 @@ class Module : public ObjectRef { * \param name The name of the function. * \param query_imports Whether also query dependency modules. * \return The result function. - * This function will return PackedFunc(nullptr) if function do not exist. + * This function will return ffi::Function(nullptr) if function do not exist. * \note Implemented in packed_func.cc */ - inline PackedFunc GetFunction(const String& name, bool query_imports = false); + inline ffi::Function GetFunction(const String& name, bool query_imports = false); // The following functions requires link with runtime. /*! * \brief Import another module into this module. @@ -151,9 +149,9 @@ class TVM_DLL ModuleNode : public Object { */ virtual const char* type_key() const = 0; /*! - * \brief Get a PackedFunc from module. + * \brief Get a ffi::Function from module. * - * The PackedFunc may not be fully initialized, + * The ffi::Function may not be fully initialized, * there might still be first time running overhead when * executing the function on certain devices. * For benchmarking, use prepare to eliminate @@ -161,13 +159,13 @@ class TVM_DLL ModuleNode : public Object { * \param name the name of the function. * \param sptr_to_self The ObjectPtr that points to this module node. * - * \return PackedFunc(nullptr) when it is not available. + * \return ffi::Function(nullptr) when it is not available. * * \note The function will always remain valid. * If the function need resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - virtual PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) = 0; + virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) = 0; /*! * \brief Save the module to file. * \param file_name The file to be saved to. @@ -199,10 +197,10 @@ class TVM_DLL ModuleNode : public Object { * \param name The name of the function. * \param query_imports Whether also query dependency modules. * \return The result function. - * This function will return PackedFunc(nullptr) if function do not exist. + * This function will return ffi::Function(nullptr) if function do not exist. * \note Implemented in packed_func.cc */ - PackedFunc GetFunction(const String& name, bool query_imports = false); + ffi::Function GetFunction(const String& name, bool query_imports = false); /*! * \brief Import another module into this module. * \param other The module to be imported. @@ -218,7 +216,7 @@ class TVM_DLL ModuleNode : public Object { * \param name name of the function. * \return The corresponding function. */ - const PackedFunc* GetFuncFromEnv(const String& name); + const ffi::Function* GetFuncFromEnv(const String& name); /*! \brief Clear all imports of the module. */ void ClearImports() { imports_.clear(); } @@ -269,7 +267,7 @@ class TVM_DLL ModuleNode : public Object { private: /*! \brief Cache used by GetImport */ - std::unordered_map> import_cache_; + std::unordered_map> import_cache_; std::mutex mutex_; }; @@ -281,13 +279,13 @@ class TVM_DLL ModuleNode : public Object { TVM_DLL bool RuntimeEnabled(const String& target); // implementation of Module::GetFunction -inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { +inline ffi::Function Module::GetFunction(const String& name, bool query_imports) { return (*this)->GetFunction(name, query_imports); } /*! \brief namespace for constant symbols */ namespace symbol { -/*! \brief A PackedFunc that retrieves exported metadata. */ +/*! \brief A ffi::Function that retrieves exported metadata. */ constexpr const char* tvm_get_c_metadata = "get_c_metadata"; /*! \brief Global variable to store module context. */ constexpr const char* tvm_module_ctx = "__tvm_module_ctx"; @@ -303,7 +301,7 @@ constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier constexpr const char* tvm_module_main = "__tvm_main__"; /*! \brief Prefix for parameter symbols emitted into the main program. */ constexpr const char* tvm_param_prefix = "__tvm_param__"; -/*! \brief A PackedFunc that looks up linked parameters by storage_id. */ +/*! \brief A ffi::Function that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; /*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */ constexpr const char* tvm_entrypoint_suffix = "run"; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index d9e1ac25177d..1c4dfb39e247 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -51,7 +51,6 @@ using tvm::ffi::GetRef; enum TypeIndex : int32_t { // Standard static index assignments, // Frontends can take benefit of these constants. - /*! \brief runtime::Module. */ kRuntimeModule = TVMFFITypeIndex::kTVMFFIModule, /*! \brief runtime::NDArray. */ @@ -60,7 +59,7 @@ enum TypeIndex : int32_t { kRuntimeShapeTuple = TVMFFITypeIndex::kTVMFFIShape, // Extra builtin static index here kCustomStaticIndex = TVMFFITypeIndex::kTVMFFIStaticObjectEnd, - /*! \brief runtime::PackedFunc. */ + /*! \brief ffi::Function. */ kRuntimePackedFunc = kCustomStaticIndex + 1, /*! \brief runtime::DRef for disco distributed runtime */ kRuntimeDiscoDRef = kCustomStaticIndex + 2, diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 3609987d5585..235bdcf3e32f 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -42,7 +42,7 @@ using ffi::Any; using ffi::AnyView; /*! - * \brief Utility function to convert legacy TVMArgValue to AnyView + * \brief Utility function to convert legacy ffi::AnyView to AnyView * \note This routine is not fastest, but serves purpose to do transition of ABI. */ inline TVMFFIAny LegacyTVMArgValueToFFIAny(TVMValue value, int type_code) { @@ -132,7 +132,7 @@ inline TVMFFIAny LegacyTVMArgValueToFFIAny(TVMValue value, int type_code) { } /*! - * \brief Utility function to convert legacy TVMArgValue to AnyView + * \brief Utility function to convert legacy ffi::AnyView to AnyView * \note This routine is not fastest, but serves purpose to do transition of ABI. */ inline AnyView LegacyTVMArgValueToAnyView(TVMValue value, int type_code) { @@ -140,7 +140,7 @@ inline AnyView LegacyTVMArgValueToAnyView(TVMValue value, int type_code) { } /*! - * \brief Utility function to convert legacy TVMArgValue to Any + * \brief Utility function to convert legacy ffi::AnyView to Any * \note This routine is not fastest, but serves purpose to do transition of ABI. */ inline Any MoveLegacyTVMArgValueToAny(TVMValue value, int type_code) { @@ -249,7 +249,7 @@ inline void MoveAnyToLegacyTVMValue(Any&& src, TVMValue* value, int* type_code) } /*! - * \brief Translate legacy TVMArgs to PackedArgs + * \brief Translate legacy ffi::PackedArgs to PackedArgs * \param value The TVMValue array * \param type_code The type code array * \param num_args The number of arguments @@ -263,7 +263,7 @@ inline void LegacyTVMArgsToPackedArgs(const TVMValue* value, const int* type_cod } /*! - * \brief Translate legacy TVMArgs to PackedArgs + * \brief Translate legacy ffi::PackedArgs to PackedArgs * \param args The AnyView array * \param num_args The number of arguments * \param value The TVMValue array @@ -276,18 +276,6 @@ inline void PackedArgsToLegacyTVMArgs(const AnyView* args, int num_args, TVMValu } } -// redirect to ffi::PackedArgs -using TVMArgs = ffi::PackedArgs; -// redirect to ffi::AnyView and ffi::Any for ArgValue and RetValue -using TVMArgValue = ffi::AnyView; -using TVMRetValue = ffi::Any; - -// redirect to ffi::Function -using PackedFunc = ffi::Function; - -template -using TypedPackedFunc = ffi::TypedFunction; - /*! * \brief Convert argument type code to string. * \param type_code The input type code. @@ -341,7 +329,7 @@ struct ModuleVTableEntryHelper {}; template struct ModuleVTableEntryHelper { using MemFnType = R (T::*)(Args...) const; - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, args.data(), args.size(), rv); @@ -351,7 +339,7 @@ struct ModuleVTableEntryHelper { template struct ModuleVTableEntryHelper { using MemFnType = R (T::*)(Args...); - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, args.data(), args.size(), rv); @@ -361,7 +349,7 @@ struct ModuleVTableEntryHelper { template struct ModuleVTableEntryHelper { using MemFnType = void (T::*)(Args...) const; - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, args.data(), args.size(), rv); @@ -371,7 +359,7 @@ struct ModuleVTableEntryHelper { template struct ModuleVTableEntryHelper { using MemFnType = void (T::*)(Args...); - static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args) { + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, args.data(), args.size(), rv); @@ -379,12 +367,12 @@ struct ModuleVTableEntryHelper { }; } // namespace details -#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ - const char* type_key() const final { return TypeKey; } \ - PackedFunc GetFunction(const String& _name, const ObjectPtr& _self) override { \ +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* type_key() const final { return TypeKey; } \ + ffi::Function GetFunction(const String& _name, const ObjectPtr& _self) override { \ using SelfPtr = std::remove_cv_t; -#define TVM_MODULE_VTABLE_END() \ - return PackedFunc(nullptr); \ +#define TVM_MODULE_VTABLE_END() \ + return ffi::Function(nullptr); \ } #define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \ { \ @@ -400,15 +388,15 @@ struct ModuleVTableEntryHelper { Helper::Call(rv, self, MemFunc, args); \ }); \ } -#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, MemFunc) \ - if (_name == Name) { \ - return PackedFunc([_self](ffi::PackedArgs args, Any* rv) -> void { \ - (static_cast(_self.get())->*(MemFunc))(args, rv); \ - }); \ +#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, MemFunc) \ + if (_name == Name) { \ + return ffi::Function([_self](ffi::PackedArgs args, Any* rv) -> void { \ + (static_cast(_self.get())->*(MemFunc))(args, rv); \ + }); \ } /*! - * \brief Export typed function as a PackedFunc + * \brief Export typed function as a ffi::Function * that can be loaded by LibraryModule. * * \param ExportName The symbol name to be exported. @@ -416,7 +404,7 @@ struct ModuleVTableEntryHelper { * \note ExportName and Function must be different, * see code examples below. * - * \sa TypedPackedFunc + * \sa ffi::TypedFunction * * \code * @@ -455,7 +443,6 @@ struct ModuleVTableEntryHelper { TVM_FFI_SAFE_CALL_END(); \ } \ } - } // namespace runtime // NOLINT(*) using ffi::Any; using ffi::AnyView; diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index a86ffa8f94d5..9d6623d9ad95 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -154,7 +154,7 @@ Timer DefaultTimer(Device dev); namespace profiling { /*! \brief Wrapper for `Device` because `Device` is not passable across the - * PackedFunc interface. + * ffi::Function interface. */ struct DeviceWrapperNode : public Object { /*! The device */ @@ -515,7 +515,7 @@ String ShapeString(const std::vector& shape, DLDataType dtype); * Example usage: * \code{.cpp} * // Use PAPI to measure the number of floating point operations. - * PackedFunc profiler = ProfileModule( + * ffi::Function profiler = ProfileModule( * mod, "main", kDLCPU, 0, {CreatePAPIMetricCollector({{kDLCPU, 0}, {"PAPI_FP_OPS"}})}); * Report r = profiler(arg1, arg2, arg); * std::cout << r << std::endl; @@ -531,12 +531,12 @@ String ShapeString(const std::vector& shape, DLDataType dtype); * than 0 so that cache effects are consistent. * \param collectors List of different * ways to collect metrics. See MetricCollector. - * \returns A PackedFunc which takes the same arguments as the `mod[func_name]` + * \returns A ffi::Function which takes the same arguments as the `mod[func_name]` * and returns performance metrics as a `Map` where * values can be `CountNode`, `DurationNode`, `PercentNode`. */ -PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, - int warmup_iters, Array collectors); +ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, + int warmup_iters, Array collectors); /*! * \brief Wrap a timer function to measure the time cost of a given packed function. @@ -585,10 +585,10 @@ PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, i * evaluator. * \return f_timer A timer function. */ -PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms, - int limit_zero_time_iterations, int cooldown_interval_ms, - int repeats_to_cooldown, int cache_flush_bytes = 0, - PackedFunc f_preproc = nullptr); +ffi::Function WrapTimeEvaluator(ffi::Function f, Device dev, int number, int repeat, + int min_repeat_ms, int limit_zero_time_iterations, + int cooldown_interval_ms, int repeats_to_cooldown, + int cache_flush_bytes = 0, ffi::Function f_preproc = nullptr); } // namespace profiling } // namespace runtime diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index e6ee664b5a93..c5c124dd6fb8 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -83,7 +83,7 @@ class WrappedPythonObject { * \brief Register a function globally. * \code * TVM_REGISTER_GLOBAL("MyPrint") - * .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + * .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { * }); * \endcode */ diff --git a/include/tvm/runtime/relax_vm/executable.h b/include/tvm/runtime/relax_vm/executable.h index 7028a6232971..dc9d87025382 100644 --- a/include/tvm/runtime/relax_vm/executable.h +++ b/include/tvm/runtime/relax_vm/executable.h @@ -154,7 +154,7 @@ class VMExecutable : public runtime::ModuleNode { /*! \brief A map from globals (as strings) to their index in the function map. */ std::unordered_map func_map; /*! \brief The global constant pool. */ - std::vector constants; + std::vector constants; /*! \brief The offset of instruction. */ std::vector instr_offset; /*! \brief The byte data of instruction. */ diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index 67439c102a82..ce69548d7016 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -74,7 +74,7 @@ class VMClosureObj : public Object { * as the first argument. The rest of arguments follows * the same arguments as the normal function call. */ - PackedFunc impl; + ffi::Function impl; static constexpr const char* _type_key = "relax.vm.Closure"; TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, Object); @@ -83,18 +83,18 @@ class VMClosureObj : public Object { /*! \brief reference to closure. */ class VMClosure : public ObjectRef { public: - VMClosure(String func_name, PackedFunc impl); + VMClosure(String func_name, ffi::Function impl); TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, ObjectRef, VMClosureObj); /*! - * \brief Create another PackedFunc with last arguments already bound to last_args. + * \brief Create another ffi::Function with last arguments already bound to last_args. * * This is a helper function to create captured closures. - * \param func The input func, can be a VMClosure or PackedFunc. + * \param func The input func, can be a VMClosure or ffi::Function. * \param last_args The arguments to bound to in the end of the function. * \note The new function takes in arguments and append the last_args in the end. */ - static PackedFunc BindLastArgs(PackedFunc func, std::vector last_args); + static ffi::Function BindLastArgs(ffi::Function func, std::vector last_args); }; /*! @@ -149,7 +149,7 @@ class VirtualMachine : public runtime::ModuleNode { */ virtual VMClosure GetClosure(const String& func_name) = 0; /*! - * \brief Invoke closure or packed function using PackedFunc convention. + * \brief Invoke closure or packed function using ffi::Function convention. * \param closure_or_packedfunc A VM closure or a packed_func. * \param args The input arguments. * \param rv The return value. @@ -164,7 +164,7 @@ class VirtualMachine : public runtime::ModuleNode { * * bool instrument(func, func_symbol, before_run, args...) * - * - func: Union[VMClosure, PackedFunc], the function object. + * - func: Union[VMClosure, ffi::Function], the function object. * - func_symbol: string, the symbol of the function. * - before_run: bool, whether it is before or after call. * - ret_value: Only valid in after run, otherwise it is null. @@ -175,7 +175,7 @@ class VirtualMachine : public runtime::ModuleNode { * * \param instrument The instrument function. */ - virtual void SetInstrument(PackedFunc instrument) = 0; + virtual void SetInstrument(ffi::Function instrument) = 0; /*! * \brief Get or create a VM extension. Once created, the extension will be stored in the VM @@ -209,7 +209,7 @@ class VirtualMachine : public runtime::ModuleNode { * \brief Helper function for vm closure functions to get the context ptr * \param arg The argument value. */ - static VirtualMachine* GetContextPtr(TVMArgValue arg) { + static VirtualMachine* GetContextPtr(ffi::AnyView arg) { return static_cast(arg.cast()); } diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 8785c224650d..e50c67a37664 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -64,7 +64,7 @@ namespace ir_builder { class IRBuilderFrameNode : public runtime::Object { public: /*! \brief A list of callbacks used when exiting the frame. */ - std::vector> callbacks; + std::vector> callbacks; void VisitAttrs(tvm::AttrVisitor* v) { // `callbacks` is not visited. @@ -90,7 +90,7 @@ class IRBuilderFrameNode : public runtime::Object { * \brief Add a callback method invoked when exiting the RAII scope. * \param callback The callback to be added. */ - void AddCallback(runtime::TypedPackedFunc callback); + void AddCallback(ffi::TypedFunction callback); }; /*! diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 0b69047a1cb7..171ac019dd03 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -236,7 +236,7 @@ class ForFrameNode : public TIRFrameNode { * \param loop_body The loop body * \return A stmt, the loop nest */ - using FMakeForLoop = runtime::TypedPackedFunc loop_vars, Array loop_extents, tvm::tir::Stmt loop_body)>; /*! \brief The loop variable. */ Array vars; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index e316088c0664..db94064d538c 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -227,7 +227,7 @@ class IRDocsifierNode : public Object { * \param root The root of the AST. * \param is_var A function that returns true if the given object is considered a variable. */ - void SetCommonPrefix(const ObjectRef& root, runtime::TypedPackedFunc is_var); + void SetCommonPrefix(const ObjectRef& root, ffi::TypedFunction is_var); /*! * \brief Transform the input object into TDoc. * \param obj The object to be transformed. diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 5ce105ea0963..40bb245e72f3 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -63,7 +63,7 @@ class IRDocsifierFunctor { template R operator()(const String& token, TObjectRef obj, Args... args) const { uint32_t type_index = obj.defined() ? obj->type_index() : 0; - const runtime::PackedFunc* pf = nullptr; + const ffi::Function* pf = nullptr; if ((pf = LookupDispatchTable(token, type_index)) != nullptr) { return (*pf)(obj, args...).template cast(); } @@ -91,12 +91,12 @@ class IRDocsifierFunctor { * This takes a type-erased packed function as input. It should be used * through FFI boundary, for example, registering dispatch function from Python. */ - TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) { - std::vector* table = &dispatch_table_[token]; + TSelf& set_dispatch(String token, uint32_t type_index, ffi::Function f) { + std::vector* table = &dispatch_table_[token]; if (table->size() <= type_index) { table->resize(type_index + 1, nullptr); } - runtime::PackedFunc& slot = (*table)[type_index]; + ffi::Function& slot = (*table)[type_index]; if (slot != nullptr) { ICHECK(false) << "Dispatch for type is already registered: " << runtime::Object::TypeIndex2Key(type_index); @@ -105,7 +105,7 @@ class IRDocsifierFunctor { return *this; } - TSelf& set_fallback(runtime::PackedFunc f) { + TSelf& set_fallback(ffi::Function f) { ICHECK(!dispatch_fallback_.has_value()) << "Fallback is already defined"; dispatch_fallback_ = f; return *this; @@ -122,13 +122,13 @@ class IRDocsifierFunctor { typename = std::enable_if_t::value>> TSelf& set_dispatch(String token, TCallable f) { return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(), - runtime::TypedPackedFunc(f)); + ffi::TypedFunction(f)); } template ::value>> TSelf& set_fallback(TCallable f) { - runtime::PackedFunc func = runtime::TypedPackedFunc(f); + ffi::Function func = ffi::TypedFunction(f); return set_fallback(func); } @@ -141,7 +141,7 @@ class IRDocsifierFunctor { * those function should be removed before that language runtime shuts down. */ void remove_dispatch(String token, uint32_t type_index) { - std::vector* table = &dispatch_table_[token]; + std::vector* table = &dispatch_table_[token]; if (table->size() <= type_index) { return; } @@ -155,16 +155,16 @@ class IRDocsifierFunctor { * \param type_index The TVM object type index. * \return Returns the functor if the lookup succeeds, nullptr otherwise. */ - const runtime::PackedFunc* LookupDispatchTable(const String& token, uint32_t type_index) const { + const ffi::Function* LookupDispatchTable(const String& token, uint32_t type_index) const { auto it = dispatch_table_.find(token); if (it == dispatch_table_.end()) { return nullptr; } - const std::vector& tab = it->second; + const std::vector& tab = it->second; if (type_index >= tab.size()) { return nullptr; } - const PackedFunc* f = &tab[type_index]; + const ffi::Function* f = &tab[type_index]; if (f->defined()) { return f; } else { @@ -175,7 +175,7 @@ class IRDocsifierFunctor { /*! * \brief Look up the fallback to be used if no handler is registered */ - const runtime::PackedFunc* LookupFallback() const { + const ffi::Function* LookupFallback() const { if (dispatch_fallback_.has_value()) { return &*dispatch_fallback_; } else { @@ -187,10 +187,10 @@ class IRDocsifierFunctor { * This type alias and the following free functions are created to reduce the binary bloat * from template and also hide implementation details from this header */ - using DispatchTable = std::unordered_map>; + using DispatchTable = std::unordered_map>; /*! \brief The dispatch table. */ DispatchTable dispatch_table_; - std::optional dispatch_fallback_; + std::optional dispatch_fallback_; }; } // namespace printer diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 0490e14e22a6..faa58f84870a 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -35,9 +35,9 @@ namespace tvm { /*! \brief namespace for target translation and codegen. */ namespace codegen { // use packed function from runtime. -using runtime::PackedFunc; -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::Function; +using ffi::PackedArgs; /*! * \brief Build a module from array of lowered function. diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index d770c56b49ea..f652424800dc 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -49,7 +49,7 @@ using TargetFeatures = Map; * \return The transformed Target JSON object. */ using TargetJSON = Map; -using FTVMTargetParser = runtime::TypedPackedFunc; +using FTVMTargetParser = ffi::TypedFunction; namespace detail { template @@ -71,7 +71,7 @@ class TargetKindNode : public Object { /*! \brief Default keys of the target */ Array default_keys; /*! \brief Function used to preprocess on target creation */ - PackedFunc preprocessor; + ffi::Function preprocessor; /*! \brief Function used to parse a JSON target during creation */ FTVMTargetParser target_parser; @@ -253,7 +253,7 @@ class TargetKindRegEntry { * \param value The value to be set * \param plevel The priority level */ - TVM_DLL void UpdateAttr(const String& key, TVMRetValue value, int plevel); + TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel); template friend class AttrRegistry; friend class TargetKind; @@ -332,7 +332,7 @@ template inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const String& attr_name, const ValueType& value, int plevel) { ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; - runtime::TVMRetValue rv; + ffi::Any rv; rv = value; UpdateAttr(attr_name, rv, plevel); return *this; @@ -351,7 +351,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector inline TargetKindRegEntry& TargetKindRegEntry::set_attrs_preprocessor(FLambda f) { LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead"; - kind_->preprocessor = ffi::Function::FromUnpacked(std::move(f)); + kind_->preprocessor = ffi::Function::FromTyped(std::move(f)); return *this; } diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 6f7ce9de2bf6..2035a511c1bb 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -359,7 +359,7 @@ TVM_DLL const Op& tvm_stack_make_array(); * return_type tvm_call_packed(name, TVMFFIAny* args) { * TVMFFIAny result; * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); + * const ffi::Function* f = env->GetFuncFromEnv(name); * (*f)(args, args, len(args), &result); * // return type can be int, float, handle. * return cast(return_type, result); @@ -383,7 +383,7 @@ TVM_DLL const Op& tvm_call_cpacked(); * * return_type tvm_call_trace_packed(name, TVMFFIAny* args) { * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); + * const ffi::Function* f = env->GetFuncFromEnv(name); * (*f)(args, args, len(args)); * // return type can be int, float, handle. * return cast(return_type, result); @@ -420,10 +420,10 @@ TVM_DLL const Op& tvm_thread_invariant(); * int begin, * int end) { * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], + * const ffi::Function* f = env->GetFuncFromEnv(name); + * f->CallPacked(ffi::PackedArgs(value_stack[begin:end], * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); + * ffi::Any(value_stack + end, tcode_stack + end)); * // return type can be int, float, handle. * return cast(return_type, load_return_from(tcode_stack + end)) * } @@ -439,8 +439,8 @@ TVM_DLL const Op& tvm_call_packed_lowered(); * int begin, * int end, * void* self) { - * fname(TVMArgs(value_stack[begin:end], tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); + * fname(ffi::PackedArgs(value_stack[begin:end], tcode_stack[begin:end]), + * ffi::Any(value_stack + end, tcode_stack + end)); * } */ TVM_DLL const Op& tvm_call_cpacked_lowered(); @@ -456,10 +456,10 @@ TVM_DLL const Op& tvm_call_cpacked_lowered(); * int begin, * int end) { * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], + * const ffi::Function* f = env->GetFuncFromEnv(name); + * f->CallPacked(ffi::PackedArgs(value_stack[begin:end], * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); + * ffi::Any(value_stack + end, tcode_stack + end)); * // return type can be int, float, handle. * return cast(return_type, load_return_from(tcode_stack + end)) * } diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 2d3b2747195a..f85c0ed706ef 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -288,7 +288,7 @@ namespace attr { * Here n = len(arg), m = len(work_size) = len(launch_params)-1. * * The list of kernel launch params indicates which additional - * parameters will be provided to the PackedFunc by the calling + * parameters will be provided to the ffi::Function by the calling * scope. * * - "threadIdx.x", "threadIdx.y", "threadIdx.z" diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 340d953ccf2f..319b5193e8f6 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -191,7 +191,7 @@ class IndexMap : public ObjectRef { * \param inverse_index_map The optional pre-defined inverse index map * \return The created index map */ - static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func, + static IndexMap FromFunc(int ndim, ffi::TypedFunction(Array)> func, Optional inverse_index_map = NullOpt); /*! \brief Generate the inverse mapping. diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index b2a644f9546e..59d4cbbcd507 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -49,12 +49,12 @@ using TVectorizable = bool; /*! * \brief The intrinsic lowering function for given op. */ -using FLowerIntrinsic = runtime::TypedPackedFunc; +using FLowerIntrinsic = ffi::TypedFunction; /*! * \brief The legalization function for given tir op. */ -using FLegalize = runtime::TypedPackedFunc; +using FLegalize = ffi::TypedFunction; /*! * \brief The operator's name in TVMScript printer diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index 5ceb3a034eaf..fe054865b738 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -42,7 +42,7 @@ class Schedule; * \param decision Decisions made on the instruction * \return The functor returns an array of output random variables */ -using FInstructionApply = runtime::TypedPackedFunc( +using FInstructionApply = ffi::TypedFunction( Schedule sch, const Array& inputs, const Array& attrs, const Any& decision)>; /*! @@ -54,8 +54,8 @@ using FInstructionApply = runtime::TypedPackedFunc( * \return A string representing the python api call */ using FInstructionAsPython = - runtime::TypedPackedFunc& inputs, const Array& attrs, - const Any& decision, const Array& outputs)>; + ffi::TypedFunction& inputs, const Array& attrs, + const Any& decision, const Array& outputs)>; /*! * \brief Type of the functor that serialize its attributes to JSON @@ -63,7 +63,7 @@ using FInstructionAsPython = * \return An array, serialized attributes * \note This functor is nullable */ -using FInstructionAttrsAsJSON = runtime::TypedPackedFunc attrs)>; +using FInstructionAttrsAsJSON = ffi::TypedFunction attrs)>; /*! * \brief Type of the functor that deserialize its attributes from JSON @@ -71,7 +71,7 @@ using FInstructionAttrsAsJSON = runtime::TypedPackedFunc at * \return An array, deserialized attributes * \note This functor is nullable */ -using FInstructionAttrsFromJSON = runtime::TypedPackedFunc(ObjectRef json_attrs)>; +using FInstructionAttrsFromJSON = ffi::TypedFunction(ObjectRef json_attrs)>; /*! * \brief Kind of an instruction, e.g. Split, Reorder, etc. diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index 5d920db2a1a0..79a2f8b2a08e 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -37,8 +37,8 @@ class Trace; * \return A new decision */ using FTraceDecisionProvider = - runtime::TypedPackedFunc& inputs, - const Array& attrs, const Any& decision)>; + ffi::TypedFunction& inputs, + const Array& attrs, const Any& decision)>; /*! * \brief An execution trace of a scheduling program diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 44f7af37eb39..9ce49610ee34 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -336,8 +336,7 @@ class StmtExprMutator : public StmtMutator, public ExprMutator { * If it is not null, preorder/postorder will only be called * when the IRNode's type key is in the list. */ -TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder, - const runtime::PackedFunc& postorder, +TVM_DLL Stmt IRTransform(Stmt stmt, const ffi::Function& preorder, const ffi::Function& postorder, Optional> only_enable = NullOpt); /*! diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 17b85585b08d..b80d4456c0be 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -713,7 +713,7 @@ TVM_DLL Pass AnnotateEntryFunc(); * \brief Filter PrimFuncs with a given condition. * \return The pass. */ -TVM_DLL Pass Filter(runtime::TypedPackedFunc fcond); +TVM_DLL Pass Filter(ffi::TypedFunction fcond); /*! * \brief Pass to rewrite global to shared memory copy on CUDA with asyncronous copy. diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index fe3954059441..e54169ea2934 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -90,7 +90,7 @@ inline Array make_extern(const Array>& out_shapes, /*! * \brief This function is used to create a DLTensor structure on the stack to - * be able to pass a symbolic buffer as arguments to TVM PackedFunc + * be able to pass a symbolic buffer as arguments to TVM ffi::Function * * \param buf The buffer to pack * @@ -117,10 +117,10 @@ inline PrimExpr pack_buffer(Buffer buf) { } /*! - * \brief Construct an Expr representing the invocation of a PackedFunc + * \brief Construct an Expr representing the invocation of a ffi::Function * - * \param args An array containing the registered name of the PackedFunc followed - * by the arguments to pass to the PackedFunc when called. The first element of the + * \param args An array containing the registered name of the ffi::Function followed + * by the arguments to pass to the ffi::Function when called. The first element of the * array must be a constant string expression. * * \return An expression representing the invocation diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py b/python/tvm/contrib/msc/plugin/codegen/sources.py index 3806dabd0e1e..b507d7b82557 100644 --- a/python/tvm/contrib/msc/plugin/codegen/sources.py +++ b/python/tvm/contrib/msc/plugin/codegen/sources.py @@ -487,24 +487,24 @@ class TVMUtils { } } - static void AttrFromArg(const TVMArgValue& arg, std::string& target) { - target = arg.operator std::string(); + static void AttrFromArg(const ffi::AnyView& arg, std::string& target) { + target = arg.cast(); } - static void AttrFromArg(const TVMArgValue& arg, bool& target) { target = arg; } + static void AttrFromArg(const ffi::AnyView& arg, bool& target) { target = arg; } - static void AttrFromArg(const TVMArgValue& arg, int& target) { target = arg; } + static void AttrFromArg(const ffi::AnyView& arg, int& target) { target = arg; } - static void AttrFromArg(const TVMArgValue& arg, size_t& target) { target = int(arg); } + static void AttrFromArg(const ffi::AnyView& arg, size_t& target) { target = int(arg); } - static void AttrFromArg(const TVMArgValue& arg, long& target) { target = int64_t(arg); } + static void AttrFromArg(const ffi::AnyView& arg, long& target) { target = int64_t(arg); } - static void AttrFromArg(const TVMArgValue& arg, float& target) { target = double(arg); } + static void AttrFromArg(const ffi::AnyView& arg, float& target) { target = double(arg); } - static void AttrFromArg(const TVMArgValue& arg, double& target) { target = arg; } + static void AttrFromArg(const ffi::AnyView& arg, double& target) { target = arg; } template - static void AttrFromArgs(const TVMArgs& args, size_t start, size_t num, std::vector& target) { + static void AttrFromArgs(const ffi::PackedArgs& args, size_t start, size_t num, std::vector& target) { for (size_t i = 0; i < num; i++) { AttrFromArg(args[start + i], target[i]); } diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 2097423cbed3..50b2c595b33f 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -98,7 +98,7 @@ def __call__(self, *args): @property def func(self): - return _ffi_api.EnvFuncGetPackedFunc(self) # type: ignore # pylint: disable=no-member + return _ffi_api.EnvFuncGetFunction(self) # type: ignore # pylint: disable=no-member @staticmethod def get(name): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 38d3e384c475..f40c84bf7a6c 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -335,8 +335,8 @@ def MakePackedAPI(): Prior to this pass, the PrimFunc may have Buffer arguments defined in the `PrimFuncNode::buffer_map`. This pass consumes the - `buffer_map`, using it to generate `TVMArgs` and `TVMRetValue*` - arguments that implement the `PackedFunc` API. + `buffer_map`, using it to generate arguments that implement + the packed based TVM FFI API. For static shapes, the `BufferNode::shape`, `BufferNode::strides`, and `BufferNode::elem_offset` member variables are used to diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 1edbaed08b03..602a198a2bf6 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -268,94 +268,96 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - auto self = std::make_shared(); - auto f = [self](std::string name) -> PackedFunc { - if (name == "const_int_bound") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = self->const_int_bound(args[0].cast()); - }); - } else if (name == "modular_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = self->modular_set(args[0].cast()); - }); - } else if (name == "const_int_bound_update") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - self->const_int_bound.Update(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - } else if (name == "Simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - if (args.size() == 1) { - *ret = self->Simplify(args[0].cast()); - } else if (args.size() == 2) { - *ret = self->Simplify(args[0].cast(), args[1].cast()); - } else { - LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; +TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + using ffi::Function; + using ffi::TypedFunction; + auto self = std::make_shared(); + auto f = [self](std::string name) -> ffi::Function { + if (name == "const_int_bound") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->const_int_bound(args[0].cast()); + }); + } else if (name == "modular_set") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->modular_set(args[0].cast()); + }); + } else if (name == "const_int_bound_update") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + self->const_int_bound.Update(args[0].cast(), args[1].cast(), + args[2].cast()); + }); + } else if (name == "Simplify") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 1) { + *ret = self->Simplify(args[0].cast()); + } else if (args.size() == 2) { + *ret = self->Simplify(args[0].cast(), args[1].cast()); + } else { + LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; + } + }); + } else if (name == "rewrite_simplify") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->rewrite_simplify(args[0].cast()); + }); + } else if (name == "get_rewrite_simplify_stats") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + self->rewrite_simplify.ResetStatsCounters(); + }); + } else if (name == "canonical_simplify") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->canonical_simplify(args[0].cast()); + }); + } else if (name == "int_set") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->int_set(args[0].cast(), args[1].cast>()); + }); + } else if (name == "bind") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + if (auto opt_range = args[1].as()) { + self->Bind(args[0].cast(), opt_range.value()); + } else { + self->Bind(args[0].cast(), args[1].cast()); + } + }); + } else if (name == "can_prove") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + int strength = args[1].cast(); + *ret = self->CanProve(args[0].cast(), static_cast(strength)); + }); + } else if (name == "enter_constraint_context") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + // can't use make_shared due to noexcept(false) decl in destructor, + // see https://stackoverflow.com/a/43907314 + auto ctx = std::shared_ptr>( + new With(self.get(), args[0].cast())); + auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; + *ret = ffi::Function::FromPacked(fexit); + }); + } else if (name == "can_prove_equal") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); + }); + } else if (name == "get_enabled_extensions") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); + }); + } else if (name == "set_enabled_extensions") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + int64_t flags = args[0].cast(); + self->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); } - }); - } else if (name == "rewrite_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = self->rewrite_simplify(args[0].cast()); - }); - } else if (name == "get_rewrite_simplify_stats") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = self->rewrite_simplify.GetStatsCounters(); - }); - } else if (name == "reset_rewrite_simplify_stats") { - return PackedFunc( - [self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); }); - } else if (name == "canonical_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = self->canonical_simplify(args[0].cast()); - }); - } else if (name == "int_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = self->int_set(args[0].cast(), args[1].cast>()); - }); - } else if (name == "bind") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - if (auto opt_range = args[1].as()) { - self->Bind(args[0].cast(), opt_range.value()); - } else { - self->Bind(args[0].cast(), args[1].cast()); - } - }); - } else if (name == "can_prove") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - int strength = args[1].cast(); - *ret = self->CanProve(args[0].cast(), static_cast(strength)); - }); - } else if (name == "enter_constraint_context") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr>( - new With(self.get(), args[0].cast())); - auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; - *ret = ffi::Function::FromPacked(fexit); - }); - } else if (name == "can_prove_equal") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); - }); - } else if (name == "get_enabled_extensions") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); - }); - } else if (name == "set_enabled_extensions") { - return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { - int64_t flags = args[0].cast(); - self->rewrite_simplify.SetEnabledExtensions( - static_cast(flags)); - }); - } - return PackedFunc(); - }; - *ret = TypedPackedFunc(f); -}); + return ffi::Function(); + }; + *ret = ffi::TypedFunction(f); + }); } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 5d3bc074bcde..8c314992ab49 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -204,7 +204,7 @@ TVM_REGISTER_GLOBAL("arith.IntGroupBounds") TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange); TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK(args.size() == 1 || args.size() == 2); auto bounds = args[0].cast(); if (args.size() == 1) { diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 3ad1cdbf5a4d..fb6250a778ef 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -455,7 +455,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol } TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveLinearEquations(args[0].cast()); } else if (args.size() == 3) { diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 949ee97cc564..0e5e6d485e74 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -536,7 +536,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { IntConstraints problem; PartialSolvedInequalities ret_ineq; if (args.size() == 1) { @@ -554,7 +554,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") }); TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveInequalitiesToRange(args[0].cast()); } else if (args.size() == 3) { @@ -569,7 +569,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") }); TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveInequalitiesDeskewRange(args[0].cast()); } else if (args.size() == 3) { diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 58219501415c..1131b82172a0 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -68,8 +68,8 @@ const TensorRTTransConfig ParseConfig(const String& config_str) { } using FRewriteTensorRT = - runtime::TypedPackedFunc& new_calls, const String& config)>; + ffi::TypedFunction& new_calls, const String& config)>; const Array BroadcastShape(const Array& src_shape, const Array& out_shape) { diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 08a62c53bf0a..e1d3c9960f6d 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -39,7 +39,7 @@ void TVMPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { // args to meta_attr stack_.comment("convert args to meta attrs method") .func_def(attr_name + "_from_args", "const " + attr_name) - .func_arg("args", "TVMArgs") + .func_arg("args", "ffi::PackedArgs") .func_arg("pos", "size_t&"); } @@ -62,7 +62,7 @@ void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { // args to meta_attr stack_.comment("convert args to meta attrs method") .func_def(attr_name + "_from_args", "const " + attr_name) - .func_arg("args", "TVMArgs") + .func_arg("args", "ffi::PackedArgs") .func_arg("pos", "size_t&") .func_start() .declare(attr_name, "meta_attr"); @@ -240,7 +240,7 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { device_cond = device_cond + "TVMUtils::OnDevice(" + plugin->inputs[i]->name + ", " + device_type + ")" + (i == plugin->inputs.size() - 1 ? "" : " && "); } - stack_.func_def(func_name).func_arg("args", "TVMArgs").func_arg("ret", "TVMRetValue*"); + stack_.func_def(func_name).func_arg("args", "ffi::PackedArgs").func_arg("ret", "ffi::Any*"); stack_.func_start().comment("define tensors"); for (size_t i = 0; i < plugin->inputs.size(); i++) { stack_.assign(plugin->inputs[i]->name, DocUtils::ToIndex("args", i), "DLTensor*"); diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 2ebca45c884d..fd87c2bc8e0c 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -53,7 +53,7 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { return attrs; } -void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { +void DictAttrsNode::InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { String key = args[i].cast(); ffi::AnyView val = args[i + 1]; diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 0e58a6847c6a..ec11f2c04f6c 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -100,14 +100,14 @@ TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode); void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->renderer(ctx); } TVM_DLL DiagnosticRenderer::DiagnosticRenderer( - TypedPackedFunc renderer) { + ffi::TypedFunction renderer) { auto n = make_object(); n->renderer = renderer; data_ = std::move(n); } TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer") - .set_body_typed([](TypedPackedFunc renderer) { + .set_body_typed([](ffi::TypedFunction renderer) { return DiagnosticRenderer(renderer); }); @@ -177,15 +177,15 @@ static const char* OVERRIDE_RENDERER = "diagnostics.OverrideRenderer"; DiagnosticRenderer GetRenderer() { auto override_pf = tvm::ffi::Function::GetGlobal(OVERRIDE_RENDERER); - tvm::runtime::TypedPackedFunc pf; + tvm::ffi::TypedFunction pf; if (override_pf) { - pf = tvm::runtime::TypedPackedFunc(*override_pf); + pf = tvm::ffi::TypedFunction(*override_pf); } else { auto default_pf = tvm::ffi::Function::GetGlobal(DEFAULT_RENDERER); ICHECK(default_pf.has_value()) << "Can not find registered function for " << DEFAULT_RENDERER << "." << std::endl << "Either this is an internal error or the default function was overloaded incorrectly."; - pf = tvm::runtime::TypedPackedFunc(*default_pf); + pf = tvm::ffi::TypedFunction(*default_pf); } return Downcast(pf()); } diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index a94223668bdd..9713f88f7ddd 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -26,9 +26,9 @@ namespace tvm { -using runtime::PackedFunc; -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::Function; +using ffi::PackedArgs; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -49,13 +49,13 @@ EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); -TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { EnvFunc env = args[0].cast(); ICHECK_GE(args.size(), 1); env->func.CallPacked(args.Slice(1), rv); }); -TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) { +TVM_REGISTER_GLOBAL("ir.EnvFuncGetFunction").set_body_typed([](const EnvFunc& n) { return n->func; }); diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index 6701308fbfb7..ad66f2944891 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -39,19 +39,17 @@ namespace instrument { class BasePassInstrumentNode : public PassInstrumentNode { public: /*! \brief Callback to run when entering PassContext. */ - runtime::TypedPackedFunc enter_pass_ctx_callback; + ffi::TypedFunction enter_pass_ctx_callback; /*! \brief Callback to run when exiting PassContext. */ - runtime::TypedPackedFunc exit_pass_ctx_callback; + ffi::TypedFunction exit_pass_ctx_callback; /*! \brief Callback determines whether to run a pass or not. */ - runtime::TypedPackedFunc should_run_callback; + ffi::TypedFunction should_run_callback; /*! \brief Callback to run before a pass. */ - runtime::TypedPackedFunc - run_before_pass_callback; + ffi::TypedFunction run_before_pass_callback; /*! \brief Callback to run after a pass. */ - runtime::TypedPackedFunc - run_after_pass_callback; + ffi::TypedFunction run_after_pass_callback; /*! \brief Instrument when entering PassContext. */ void EnterPassContext() const final; @@ -109,26 +107,23 @@ class BasePassInstrument : public PassInstrument { * \param run_after_pass_callback Callback to call after a pass run. */ TVM_DLL BasePassInstrument( - String name, runtime::TypedPackedFunc enter_pass_ctx_callback, - runtime::TypedPackedFunc exit_pass_ctx_callback, - runtime::TypedPackedFunc - should_run_callback, - runtime::TypedPackedFunc + String name, ffi::TypedFunction enter_pass_ctx_callback, + ffi::TypedFunction exit_pass_ctx_callback, + ffi::TypedFunction should_run_callback, + ffi::TypedFunction run_before_pass_callback, - runtime::TypedPackedFunc + ffi::TypedFunction run_after_pass_callback); TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode); }; BasePassInstrument::BasePassInstrument( - String name, runtime::TypedPackedFunc enter_pass_ctx_callback, - runtime::TypedPackedFunc exit_pass_ctx_callback, - runtime::TypedPackedFunc should_run_callback, - runtime::TypedPackedFunc - run_before_pass_callback, - runtime::TypedPackedFunc - run_after_pass_callback) { + String name, ffi::TypedFunction enter_pass_ctx_callback, + ffi::TypedFunction exit_pass_ctx_callback, + ffi::TypedFunction should_run_callback, + ffi::TypedFunction run_before_pass_callback, + ffi::TypedFunction run_after_pass_callback) { auto pi = make_object(); pi->name = std::move(name); @@ -182,13 +177,11 @@ TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); TVM_REGISTER_GLOBAL("instrument.PassInstrument") .set_body_typed( - [](String name, runtime::TypedPackedFunc enter_pass_ctx, - runtime::TypedPackedFunc exit_pass_ctx, - runtime::TypedPackedFunc should_run, - runtime::TypedPackedFunc - run_before_pass, - runtime::TypedPackedFunc - run_after_pass) { + [](String name, ffi::TypedFunction enter_pass_ctx, + ffi::TypedFunction exit_pass_ctx, + ffi::TypedFunction should_run, + ffi::TypedFunction run_before_pass, + ffi::TypedFunction run_after_pass) { return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, run_before_pass, run_after_pass); }); diff --git a/src/ir/op.cc b/src/ir/op.cc index d69b558a117c..70f7528e5e76 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -33,9 +33,9 @@ namespace tvm { -using runtime::PackedFunc; -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::Function; +using ffi::PackedArgs; using tir::FLowerIntrinsic; using OpRegistry = AttrRegistry; @@ -70,7 +70,7 @@ void OpRegEntry::reset_attr(const std::string& attr_name) { OpRegistry::Global()->ResetAttr(attr_name, op_); } -void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { +void OpRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { OpRegistry::Global()->UpdateAttr(key, op_, value, plevel); } @@ -81,9 +81,9 @@ TVM_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { TVM_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); -TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> TVMRetValue { - auto op_map = Op::GetAttrMap(attr_name); - TVMRetValue rv; +TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> ffi::Any { + auto op_map = Op::GetAttrMap(attr_name); + ffi::Any rv; if (op_map.count(op)) { rv = op_map[op]; } @@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) - }); TVM_REGISTER_GLOBAL("ir.OpSetAttr") - .set_body_typed([](Op op, String attr_name, runtime::TVMArgValue value, int plevel) { + .set_body_typed([](Op op, String attr_name, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attr(attr_name, value, plevel); }); @@ -134,7 +134,7 @@ TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) }); TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") - .set_body_typed([](String op_name, String attr_key, runtime::TVMArgValue value, int plevel) { + .set_body_typed([](String op_name, String attr_key, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); // enable resgiteration and override of certain properties if (attr_key == "num_inputs" && plevel > 128) { @@ -147,13 +147,13 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") }); TVM_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic") - .set_body_typed([](String name, PackedFunc f, String target, int plevel) { + .set_body_typed([](String name, ffi::Function f, String target, int plevel) { tvm::OpRegEntry::RegisterOrGet(name).set_attr(target + ".FLowerIntrinsic", f, plevel); }); ObjectPtr CreateOp(const std::string& name) { - // Hack use TVMRetValue as exchange + // Hack use ffi::Any as exchange auto op = Op::Get(name); ICHECK(op.defined()) << "Cannot find op \'" << name << '\''; return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(op); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 0756a86f52a2..0730faa7b4d7 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -43,8 +43,8 @@ namespace tvm { namespace transform { using tvm::ReprPrinter; -using tvm::runtime::TVMArgs; -using tvm::runtime::TVMRetValue; +using tvm::ffi::Any; +using tvm::ffi::PackedArgs; TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool); @@ -538,7 +538,7 @@ TVM_REGISTER_GLOBAL("transform.PassInfo") return PassInfo(opt_level, name, required, traceable); }); -TVM_REGISTER_GLOBAL("transform.Info").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.Info").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { Pass pass = args[0].cast(); *ret = pass->Info(); }); @@ -565,7 +565,7 @@ TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_GLOBAL("transform.MakeModulePass") .set_body_typed( - [](runtime::TypedPackedFunc, PassContext)> pass_func, + [](ffi::TypedFunction, PassContext)> pass_func, PassInfo pass_info) { auto wrapped_pass_func = [pass_func](IRModule mod, PassContext ctx) { return pass_func(ffi::RValueRef(std::move(mod)), ctx); @@ -586,15 +586,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_GLOBAL("transform.Sequential").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto passes = args[0].cast>(); - int opt_level = args[1].cast(); - std::string name = args[2].cast(); - auto required = args[3].cast>(); - bool traceable = args[4].cast(); - PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); - *ret = Sequential(passes, pass_info); -}); +TVM_REGISTER_GLOBAL("transform.Sequential") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto passes = args[0].cast>(); + int opt_level = args[1].cast(); + std::string name = args[2].cast(); + auto required = args[3].cast>(); + bool traceable = args[4].cast(); + PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); + *ret = Sequential(passes, pass_info); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index ed6d87f2448b..a1a1351812e2 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -25,7 +25,7 @@ class ScheduleFnDatabaseNode : public DatabaseNode { public: explicit ScheduleFnDatabaseNode(String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} - runtime::TypedPackedFunc schedule_fn; + ffi::TypedFunction schedule_fn; void VisitAttrs(AttrVisitor* v) { // `schedule_fn` is not visited. @@ -91,7 +91,7 @@ class ScheduleFnDatabaseNode : public DatabaseNode { } }; -Database Database::ScheduleFnDatabase(runtime::TypedPackedFunc schedule_fn, +Database Database::ScheduleFnDatabase(ffi::TypedFunction schedule_fn, String mod_eq_name) { ObjectPtr n = make_object(mod_eq_name); n->schedule_fn = std::move(schedule_fn); diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index ecbb1c50cea7..b7f261519fd2 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -75,11 +75,11 @@ Profiler::Profiler() { data_ = n; } -PackedFunc ProfilerTimedScope(String name) { +ffi::Function ProfilerTimedScope(String name) { if (Optional opt_profiler = Profiler::Current()) { - return TypedPackedFunc([profiler = opt_profiler.value(), // - tik = std::chrono::high_resolution_clock::now(), // - name = std::move(name)]() { + return ffi::TypedFunction([profiler = opt_profiler.value(), // + tik = std::chrono::high_resolution_clock::now(), // + name = std::move(name)]() { auto tok = std::chrono::high_resolution_clock::now(); double duration = std::chrono::duration_cast(tok - tik).count() / 1e9; diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index b89dc824c5d1..15962baa927a 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -50,7 +50,7 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { } std::ostringstream os; os << "Unknown schedule rule \"" << ann.value() << "\" for target keys \"" << keys - << "\". Checked PackedFuncs:"; + << "\". Checked ffi::Functions:"; for (const String& key : keys) { os << "\n " << GetCustomRuleName(ann.value(), key); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index b7e7cd09644e..79cff3bad738 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -397,7 +397,7 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write, - Optional filter_fn) { + Optional filter_fn) { auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); node->filter_fn_ = filter_fn; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 374da793e8b7..b46eac23ad7e 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -216,9 +216,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode { /*! \brief All available async pipeline stages. */ std::vector stages; /*! \brief The logging function */ - PackedFunc logger; + ffi::Function logger; /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */ - Optional filter_fn_; + Optional filter_fn_; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("structure", &structure); diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 1bde99869ed2..da2178c736a1 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -30,7 +30,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { /*! * \brief Optional block names to target. If not specified all blocks will have spaces generated. */ - runtime::PackedFunc f_block_filter_ = nullptr; + ffi::Function f_block_filter_ = nullptr; /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; @@ -103,7 +103,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter, +SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, Optional> sch_rules, Optional> postprocs, Optional> mutator_probs) { diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 1c1b5078a50c..89a02876f3d9 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -27,7 +27,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The schedule function. */ - runtime::PackedFunc schedule_fn_; + ffi::Function schedule_fn_; void VisitAttrs(tvm::AttrVisitor* v) { SpaceGeneratorNode::VisitAttrs(v); @@ -45,7 +45,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { /*rand_state=*/ForkSeed(&this->rand_state_), /*debug_mode=*/0, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); - runtime::TVMRetValue rv; + ffi::Any rv; rv = this->schedule_fn_(sch); if (rv == nullptr) { return {sch}; @@ -84,7 +84,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::ScheduleFn(PackedFunc schedule_fn, +SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, Optional> sch_rules, Optional> postprocs, Optional> mutator_probs) { diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 5b261eec32a4..c750067ace9f 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -134,7 +134,7 @@ class GradientBasedNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::GradientBased(PackedFunc logger, double alpha, int window_size, +TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed) { ObjectPtr n = make_object(); n->logger = logger; diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index d09f2c2ba791..d7c6f37e121d 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -55,7 +55,7 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(PackedFunc logger) { +TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { ObjectPtr n = make_object(); n->logger = logger; n->task_id = -1; diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index f17cfd55cf08..ca5c6e4988a3 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -99,7 +99,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r ICHECK_EQ(self->runner_futures.value().size(), results.size()); int n = results.size(); std::string name = self->ctx->task_name.value(); - const PackedFunc& logger = self->ctx->logger; + const ffi::Function& logger = self->ctx->logger; for (int i = 0; i < n; ++i) { const BuilderResult& builder_result = self->builder_results.value()[i]; const MeasureCandidate& candidate = self->measure_candidates.value()[i]; @@ -322,7 +322,7 @@ void TaskSchedulerNode::PrintTuningStatistics() { } TaskScheduler TaskScheduler::PyTaskScheduler( - PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, + ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune) { CHECK(f_next_task_id != nullptr) << "ValueError: next_task_id is not defined"; ObjectPtr n = make_object(); diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index e9cb17eeee55..275f8d124cd1 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -26,7 +26,7 @@ namespace meta_schedule { TuneContext::TuneContext(Optional mod, Optional target, Optional space_generator, Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, PackedFunc logger) { + int num_threads, TRandState rand_state, ffi::Function logger) { CHECK(rand_state == -1 || rand_state >= 0) << "ValueError: Invalid random state: " << rand_state; ObjectPtr n = make_object(); n->mod = mod; @@ -67,7 +67,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") .set_body_typed([](Optional mod, Optional target, Optional space_generator, Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, PackedFunc logger) -> TuneContext { + int num_threads, TRandState rand_state, + ffi::Function logger) -> TuneContext { return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads, rand_state, logger); }); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index adf1334385ee..de777e305919 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -83,7 +83,7 @@ class PyLogMessage { // FATAL not included }; - explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger, Level logging_level) + explicit PyLogMessage(const char* filename, int lineno, ffi::Function logger, Level logging_level) : filename_(filename), lineno_(lineno), logger_(logger), logging_level_(logging_level) {} TVM_NO_INLINE ~PyLogMessage() { @@ -115,7 +115,7 @@ class PyLogMessage { const char* filename_; int lineno_; std::ostringstream stream_; - PackedFunc logger_; + ffi::Function logger_; Level logging_level_; }; @@ -150,7 +150,7 @@ inline void print_interactive_table(const String& data) { * \param lineno The line number. * \param logging_func The logging function. */ -inline void clear_logging(const char* file, int lineno, PackedFunc logging_func) { +inline void clear_logging(const char* file, int lineno, ffi::Function logging_func) { if (const char* env_p = std::getenv("TVM_META_SCHEDULE_CLEAR_SCREEN")) { if (std::string(env_p) == "1") { if (logging_func.defined() && using_ipython()) { @@ -569,7 +569,7 @@ inline double Sum(const Array& arr) { class BlockCollector : public tir::StmtVisitor { public: static Array Collect(const tir::Schedule& sch, - const runtime::PackedFunc f_block_filter = nullptr) { // + const ffi::Function f_block_filter = nullptr) { // return BlockCollector(sch, f_block_filter).Run(); } @@ -603,8 +603,7 @@ class BlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch, - const runtime::PackedFunc f_block_filter = nullptr) + explicit BlockCollector(const tir::Schedule& sch, const ffi::Function f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { @@ -628,7 +627,7 @@ class BlockCollector : public tir::StmtVisitor { /*! \brief The schedule to be collected */ const tir::Schedule& sch_; /*! \brief An optional packed func that allows only certain blocks to be collected. */ - const runtime::PackedFunc f_block_filter_; + const ffi::Function f_block_filter_; /*! \brief The set of func name and block name pair */ std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index e150549bb937..9ec39e9f6aae 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -93,7 +93,7 @@ class AttrRegistry { * \param plevel The support level. */ void UpdateAttr(const String& attr_name, const KeyType& key, Any value, int plevel) { - using runtime::TVMRetValue; + using ffi::Any; auto& op_map = attrs_[attr_name]; if (op_map == nullptr) { op_map.reset(new AttrRegistryMapContainerMap()); @@ -104,7 +104,7 @@ class AttrRegistry { if (op_map->data_.size() <= index) { op_map->data_.resize(index + 1, std::make_pair(Any(), 0)); } - std::pair& p = op_map->data_[index]; + std::pair& p = op_map->data_[index]; ICHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName() << " is already registered with same plevel=" << plevel; ICHECK(value != nullptr) << "Registered packed_func is Null for " << attr_name @@ -126,7 +126,7 @@ class AttrRegistry { } uint32_t index = key->AttrRegistryIndex(); if (op_map->data_.size() > index) { - op_map->data_[index] = std::make_pair(TVMRetValue(), 0); + op_map->data_[index] = std::make_pair(ffi::Any(), 0); } } diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 067f4abb2032..cf9f0dd3bd6e 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -28,17 +28,17 @@ namespace tvm { -using runtime::PackedFunc; -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::Function; +using ffi::PackedArgs; // Attr getter. class AttrGetter : public AttrVisitor { public: const String& skey; - TVMRetValue* ret; + ffi::Any* ret; - AttrGetter(const String& skey, TVMRetValue* ret) : skey(skey), ret(ret) {} + AttrGetter(const String& skey, ffi::Any* ret) : skey(skey), ret(ret) {} bool found_ref_object{false}; @@ -95,8 +95,8 @@ class AttrGetter : public AttrVisitor { } }; -runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const String& field_name) const { - runtime::TVMRetValue ret; +ffi::Any ReflectionVTable::GetAttr(Object* self, const String& field_name) const { + ffi::Any ret; AttrGetter getter(field_name, &ret); bool success; @@ -179,7 +179,7 @@ ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key class NodeAttrSetter : public AttrVisitor { public: std::string type_key; - std::unordered_map attrs; + std::unordered_map attrs; void Visit(const char* key, double* value) final { *value = GetAttr(key).cast(); } void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).cast(); } @@ -204,18 +204,18 @@ class NodeAttrSetter : public AttrVisitor { } private: - runtime::TVMArgValue GetAttr(const char* key) { + ffi::AnyView GetAttr(const char* key) { auto it = attrs.find(key); if (it == attrs.end()) { LOG(FATAL) << type_key << ": require field " << key; } - runtime::TVMArgValue v = it->second; + ffi::AnyView v = it->second; attrs.erase(it); return v; } }; -void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) { +void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const ffi::PackedArgs& args) { NodeAttrSetter setter; setter.type_key = n->GetTypeKey(); ICHECK_EQ(args.size() % 2, 0); @@ -234,7 +234,8 @@ void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs } } -ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) { +ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, + const ffi::PackedArgs& kwargs) { ObjectPtr n = this->CreateInitObject(type_key); if (n->IsInstance()) { static_cast(n.get())->InitByPackedArgs(kwargs); @@ -246,7 +247,7 @@ ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMA ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const Map& kwargs) { - // Redirect to the TVMArgs version + // Redirect to the ffi::PackedArgs version // It is not the most efficient way, but CreateObject is not meant to be used // in a fast code-path and is mainly reserved as a flexible API for frontends. std::vector packed_args(kwargs.size() * 2); @@ -262,18 +263,18 @@ ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, } // Expose to FFI APIs. -void NodeGetAttr(TVMArgs args, TVMRetValue* ret) { +void NodeGetAttr(ffi::PackedArgs args, ffi::Any* ret) { Object* self = const_cast(args[0].cast()); *ret = ReflectionVTable::Global()->GetAttr(self, args[1].cast()); } -void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { +void NodeListAttrNames(ffi::PackedArgs args, ffi::Any* ret) { Object* self = const_cast(args[0].cast()); auto names = std::make_shared>(ReflectionVTable::Global()->ListAttrNames(self)); - *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) { + *ret = ffi::Function([names](ffi::PackedArgs args, ffi::Any* rv) { int64_t i = args[0].cast(); if (i == -1) { *rv = static_cast(names->size()); @@ -286,7 +287,7 @@ void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { // API function to make node. // args format: // key1, value1, ..., key_n, value_n -void MakeNode(const TVMArgs& args, TVMRetValue* rv) { +void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { auto type_key = args[0].cast(); *rv = ReflectionVTable::Global()->CreateObject(type_key, args.Slice(1)); } diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index 85f013d98693..7bc5cafc0a9d 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -330,9 +330,9 @@ void ExecBuilderNode::Formalize() { TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); TVM_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ExecBuilder builder = args[0].cast(); - TVMRetValue rt; + ffi::Any rt; rt = args[1]; *ret = builder->ConvertConstant(rt).data(); }); diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 88efad86cfdc..661c43842db4 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -370,13 +370,13 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") class PatternContextRewriterNode : public PatternMatchingRewriterNode { public: PatternContext pattern; - TypedPackedFunc(Map, Map)> rewriter_func; + ffi::TypedFunction(Map, Map)> rewriter_func; RewriteSpec RewriteBindings(const Array& bindings) const override; void VisitAttrs(AttrVisitor* visitor) { visitor->Visit("pattern", &pattern); - PackedFunc untyped_func = rewriter_func; + ffi::Function untyped_func = rewriter_func; visitor->Visit("rewriter_func", &untyped_func); } @@ -405,7 +405,7 @@ class PatternContextRewriter : public PatternMatchingRewriter { public: PatternContextRewriter( PatternContext pattern, - TypedPackedFunc(Map, Map)> rewriter_func); + ffi::TypedFunction(Map, Map)> rewriter_func); TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, PatternMatchingRewriter, PatternContextRewriterNode); @@ -432,7 +432,7 @@ RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bi PatternContextRewriter::PatternContextRewriter( PatternContext pattern, - TypedPackedFunc(Map, Map)> rewriter_func) { + ffi::TypedFunction(Map, Map)> rewriter_func) { auto node = make_object(); node->pattern = std::move(pattern); node->rewriter_func = std::move(rewriter_func); @@ -441,7 +441,8 @@ PatternContextRewriter::PatternContextRewriter( Function RewriteBindings( const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter, Function func) { + ffi::TypedFunction(Map, Map)> rewriter, + Function func) { // return BlockPatternRewriter::Run(ctx, rewriter, func); return Downcast(PatternContextRewriter(ctx, rewriter)(func)); } diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 514116c5cadf..da1614f50b47 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -195,7 +195,7 @@ TVM_REGISTER_NODE_TYPE(PatternMatchingRewriterNode); TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") .set_body_typed([](DFPattern pattern, - TypedPackedFunc(Expr, Map)> func) { + ffi::TypedFunction(Expr, Map)> func) { return PatternMatchingRewriter::FromPattern(pattern, func); }); @@ -261,12 +261,12 @@ Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") .set_body_typed([](DFPattern pattern, - TypedPackedFunc(Expr, Map)> func) { + ffi::TypedFunction(Expr, Map)> func) { return ExprPatternRewriter(pattern, func); }); ExprPatternRewriter::ExprPatternRewriter( - DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, Optional> additional_bindings, Map new_subroutines) { auto node = make_object(); node->pattern = std::move(pattern); @@ -605,12 +605,12 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( TVM_REGISTER_GLOBAL("relax.dpl.TupleRewriter") .set_body_typed([](Array patterns, - TypedPackedFunc(Expr, Map)> func) { + ffi::TypedFunction(Expr, Map)> func) { return TupleRewriter(patterns, func); }); TupleRewriter::TupleRewriter(Array patterns, - TypedPackedFunc(Expr, Map)> func, + ffi::TypedFunction(Expr, Map)> func, Optional> additional_bindings, Map new_subroutines) { auto node = make_object(); @@ -622,7 +622,7 @@ TupleRewriter::TupleRewriter(Array patterns, } PatternMatchingRewriter PatternMatchingRewriter::FromPattern( - DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, Optional> additional_bindings, Map new_subroutines) { if (auto or_pattern = pattern.as()) { auto new_additional_bindings = additional_bindings.value_or({}); @@ -749,7 +749,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { DFPattern top_pattern = make_pattern(func_pattern->body->body); - TypedPackedFunc(Expr, Map)> rewriter_func = + ffi::TypedFunction(Expr, Map)> rewriter_func = [param_wildcards = std::move(param_wildcards), orig_func_replacement = std::move(func_replacement)]( Expr expr, Map matches) -> Optional { @@ -1069,7 +1069,7 @@ tvm::transform::PassInfo PatternMatchingRewriterNode::Info() const { } Function RewriteCall(const DFPattern& pat, - TypedPackedFunc)> rewriter, Function func) { + ffi::TypedFunction)> rewriter, Function func) { return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index a4b2ea3c0f08..1176f1eaee7e 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -145,7 +145,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons auto attr_name = kv.first; auto attr_value = kv.second; if (Op::HasAttrMap(attr_name)) { - auto op_map = Op::GetAttrMap(attr_name); + auto op_map = Op::GetAttrMap(attr_name); if (op_map.count(op)) { matches &= StructuralEqual()(attr_value, op_map[op]); } else { @@ -157,7 +157,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } } else if (auto* op = expr.as()) { matches = true; - // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this + // TODO(mbrookhart): When OpNode Attrs move from ffi::Any to the Object system, remove this // and replace the whole thing with a Visitor-based approach ReflectionVTable* reflection = ReflectionVTable::Global(); auto attrs_node = const_cast(op->attrs.get()); diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h index 53f934982c59..4eec98373d0c 100644 --- a/src/relax/ir/dataflow_rewriter.h +++ b/src/relax/ir/dataflow_rewriter.h @@ -66,7 +66,7 @@ class PatternMatchingRewriterNode : public tvm::transform::PassNode { class PatternMatchingRewriter : public tvm::transform::Pass { public: static PatternMatchingRewriter FromPattern( - DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, Optional> additional_bindings = NullOpt, Map new_subroutines = {}); @@ -81,7 +81,7 @@ class PatternMatchingRewriter : public tvm::transform::Pass { class ExprPatternRewriterNode : public PatternMatchingRewriterNode { public: DFPattern pattern; - TypedPackedFunc(Expr, Map)> func; + ffi::TypedFunction(Expr, Map)> func; Optional> additional_bindings; Map new_subroutines; @@ -91,7 +91,7 @@ class ExprPatternRewriterNode : public PatternMatchingRewriterNode { void VisitAttrs(AttrVisitor* visitor) { visitor->Visit("pattern", &pattern); - PackedFunc untyped_func = func; + ffi::Function untyped_func = func; visitor->Visit("func", &untyped_func); } @@ -102,7 +102,7 @@ class ExprPatternRewriterNode : public PatternMatchingRewriterNode { class ExprPatternRewriter : public PatternMatchingRewriter { public: ExprPatternRewriter(DFPattern pattern, - TypedPackedFunc(Expr, Map)> func, + ffi::TypedFunction(Expr, Map)> func, Optional> additional_bindings = NullOpt, Map new_subroutines = {}); @@ -136,7 +136,7 @@ class OrRewriter : public PatternMatchingRewriter { class TupleRewriterNode : public PatternMatchingRewriterNode { public: Array patterns; - TypedPackedFunc(Expr, Map)> func; + ffi::TypedFunction(Expr, Map)> func; Optional> additional_bindings; Map new_subroutines; @@ -144,7 +144,7 @@ class TupleRewriterNode : public PatternMatchingRewriterNode { void VisitAttrs(AttrVisitor* visitor) { visitor->Visit("patterns", &patterns); - PackedFunc untyped_func = func; + ffi::Function untyped_func = func; visitor->Visit("func", &untyped_func); } @@ -169,7 +169,7 @@ class TupleRewriterNode : public PatternMatchingRewriterNode { class TupleRewriter : public PatternMatchingRewriter { public: TupleRewriter(Array patterns, - TypedPackedFunc(Expr, Map)> func, + ffi::TypedFunction(Expr, Map)> func, Optional> additional_bindings = NullOpt, Map new_subroutines = {}); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 3ee403a25cda..a450919decff 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -326,9 +326,10 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { - PostOrderVisit(expr, [f](const Expr& n) { f(n); }); -}); +TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit") + .set_body_typed([](Expr expr, ffi::Function f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); + }); // ================== // ExprMutatorBase diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 1da1c29b2057..4a36bf214884 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -36,64 +36,64 @@ class PyExprVisitorNode : public Object, public ExprVisitor { public: /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ - PackedFunc f_visit_expr{nullptr}; + ffi::Function f_visit_expr{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ - PackedFunc f_visit_constant_{nullptr}; + ffi::Function f_visit_constant_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ - PackedFunc f_visit_tuple_{nullptr}; + ffi::Function f_visit_tuple_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ - PackedFunc f_visit_var_{nullptr}; + ffi::Function f_visit_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ - PackedFunc f_visit_dataflow_var_{nullptr}; + ffi::Function f_visit_dataflow_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ - PackedFunc f_visit_shape_expr_{nullptr}; + ffi::Function f_visit_shape_expr_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ - PackedFunc f_visit_extern_func_{nullptr}; + ffi::Function f_visit_extern_func_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ - PackedFunc f_visit_global_var_{nullptr}; + ffi::Function f_visit_global_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ - PackedFunc f_visit_function_{nullptr}; + ffi::Function f_visit_function_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ - PackedFunc f_visit_call_{nullptr}; + ffi::Function f_visit_call_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ - PackedFunc f_visit_seq_expr_{nullptr}; + ffi::Function f_visit_seq_expr_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ - PackedFunc f_visit_if_{nullptr}; + ffi::Function f_visit_if_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ - PackedFunc f_visit_op_{nullptr}; + ffi::Function f_visit_op_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ - PackedFunc f_visit_tuple_getitem_{nullptr}; + ffi::Function f_visit_tuple_getitem_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ - PackedFunc f_visit_prim_value_{nullptr}; + ffi::Function f_visit_prim_value_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ - PackedFunc f_visit_string_imm_{nullptr}; + ffi::Function f_visit_string_imm_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ - PackedFunc f_visit_data_type_imm_{nullptr}; + ffi::Function f_visit_data_type_imm_{nullptr}; /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ - PackedFunc f_visit_binding{nullptr}; + ffi::Function f_visit_binding{nullptr}; /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` * function. */ - PackedFunc f_visit_var_binding_{nullptr}; + ffi::Function f_visit_var_binding_{nullptr}; /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` * function. */ - PackedFunc f_visit_match_cast_{nullptr}; + ffi::Function f_visit_match_cast_{nullptr}; /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` * function. */ - PackedFunc f_visit_binding_block{nullptr}; + ffi::Function f_visit_binding_block{nullptr}; /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` * function. */ - PackedFunc f_visit_binding_block_{nullptr}; + ffi::Function f_visit_binding_block_{nullptr}; /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` * function. */ - PackedFunc f_visit_dataflow_block_{nullptr}; + ffi::Function f_visit_dataflow_block_{nullptr}; /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ - PackedFunc f_visit_var_def{nullptr}; + ffi::Function f_visit_var_def{nullptr}; /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ - PackedFunc f_visit_var_def_{nullptr}; + ffi::Function f_visit_var_def_{nullptr}; /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ - PackedFunc f_visit_dataflow_var_def_{nullptr}; + ffi::Function f_visit_dataflow_var_def_{nullptr}; /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ - PackedFunc f_visit_span{nullptr}; + ffi::Function f_visit_span{nullptr}; void VisitExpr(const Expr& expr) { if (f_visit_expr != nullptr) { @@ -212,16 +212,19 @@ class PyExprVisitor : public ObjectRef { * \return The PyVisitor created. */ TVM_DLL static PyExprVisitor MakePyExprVisitor( - PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, - PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, PackedFunc f_visit_shape_expr_, - PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, - PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, - PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, - PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, - PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, - PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, - PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, - PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + ffi::Function f_visit_expr, ffi::Function f_visit_constant_, ffi::Function f_visit_tuple_, + ffi::Function f_visit_var_, ffi::Function f_visit_dataflow_var_, + ffi::Function f_visit_shape_expr_, ffi::Function f_visit_extern_func_, + ffi::Function f_visit_global_var_, ffi::Function f_visit_function_, + ffi::Function f_visit_call_, ffi::Function f_visit_seq_expr_, ffi::Function f_visit_if_, + ffi::Function f_visit_op_, ffi::Function f_visit_tuple_getitem_, + ffi::Function f_visit_prim_value_, ffi::Function f_visit_string_imm_, + ffi::Function f_visit_data_type_imm_, ffi::Function f_visit_binding, + ffi::Function f_visit_var_binding_, ffi::Function f_visit_match_cast_, + ffi::Function f_visit_binding_block, ffi::Function f_visit_binding_block_, + ffi::Function f_visit_dataflow_block_, ffi::Function f_visit_var_def, + ffi::Function f_visit_var_def_, ffi::Function f_visit_dataflow_var_def_, + ffi::Function f_visit_span) { ObjectPtr n = make_object(); n->f_visit_expr = f_visit_expr; n->f_visit_binding = f_visit_binding; @@ -266,64 +269,64 @@ class PyExprMutatorNode : public Object, public ExprMutator { public: /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ - PackedFunc f_visit_expr{nullptr}; + ffi::Function f_visit_expr{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ - PackedFunc f_visit_constant_{nullptr}; + ffi::Function f_visit_constant_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ - PackedFunc f_visit_tuple_{nullptr}; + ffi::Function f_visit_tuple_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ - PackedFunc f_visit_var_{nullptr}; + ffi::Function f_visit_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ - PackedFunc f_visit_dataflow_var_{nullptr}; + ffi::Function f_visit_dataflow_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ - PackedFunc f_visit_shape_expr_{nullptr}; + ffi::Function f_visit_shape_expr_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ - PackedFunc f_visit_extern_func_{nullptr}; + ffi::Function f_visit_extern_func_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ - PackedFunc f_visit_global_var_{nullptr}; + ffi::Function f_visit_global_var_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ - PackedFunc f_visit_function_{nullptr}; + ffi::Function f_visit_function_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ - PackedFunc f_visit_call_{nullptr}; + ffi::Function f_visit_call_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ - PackedFunc f_visit_seq_expr_{nullptr}; + ffi::Function f_visit_seq_expr_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ - PackedFunc f_visit_if_{nullptr}; + ffi::Function f_visit_if_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ - PackedFunc f_visit_op_{nullptr}; + ffi::Function f_visit_op_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ - PackedFunc f_visit_tuple_getitem_{nullptr}; + ffi::Function f_visit_tuple_getitem_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ - PackedFunc f_visit_prim_value_{nullptr}; + ffi::Function f_visit_prim_value_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ - PackedFunc f_visit_string_imm_{nullptr}; + ffi::Function f_visit_string_imm_{nullptr}; /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ - PackedFunc f_visit_data_type_imm_{nullptr}; + ffi::Function f_visit_data_type_imm_{nullptr}; /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ - PackedFunc f_visit_binding{nullptr}; + ffi::Function f_visit_binding{nullptr}; /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` * function. */ - PackedFunc f_visit_var_binding_{nullptr}; + ffi::Function f_visit_var_binding_{nullptr}; /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` * function. */ - PackedFunc f_visit_match_cast_{nullptr}; + ffi::Function f_visit_match_cast_{nullptr}; /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` * function. */ - PackedFunc f_visit_binding_block{nullptr}; + ffi::Function f_visit_binding_block{nullptr}; /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` * function. */ - PackedFunc f_visit_binding_block_{nullptr}; + ffi::Function f_visit_binding_block_{nullptr}; /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` * function. */ - PackedFunc f_visit_dataflow_block_{nullptr}; + ffi::Function f_visit_dataflow_block_{nullptr}; /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ - PackedFunc f_visit_var_def{nullptr}; + ffi::Function f_visit_var_def{nullptr}; /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ - PackedFunc f_visit_var_def_{nullptr}; + ffi::Function f_visit_var_def_{nullptr}; /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ - PackedFunc f_visit_dataflow_var_def_{nullptr}; + ffi::Function f_visit_dataflow_var_def_{nullptr}; /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ - PackedFunc f_visit_span{nullptr}; + ffi::Function f_visit_span{nullptr}; Expr VisitExpr(const Expr& expr) { if (f_visit_expr != nullptr) { @@ -490,17 +493,19 @@ class PyExprMutator : public ObjectRef { * \return The PyExprMutator created. */ TVM_DLL static PyExprMutator MakePyExprMutator( - BlockBuilder builder_, PackedFunc f_visit_expr, PackedFunc f_visit_constant_, - PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, - PackedFunc f_visit_shape_expr_, PackedFunc f_visit_extern_func_, - PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, - PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, - PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, - PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, - PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, - PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, - PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, - PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + BlockBuilder builder_, ffi::Function f_visit_expr, ffi::Function f_visit_constant_, + ffi::Function f_visit_tuple_, ffi::Function f_visit_var_, ffi::Function f_visit_dataflow_var_, + ffi::Function f_visit_shape_expr_, ffi::Function f_visit_extern_func_, + ffi::Function f_visit_global_var_, ffi::Function f_visit_function_, + ffi::Function f_visit_call_, ffi::Function f_visit_seq_expr_, ffi::Function f_visit_if_, + ffi::Function f_visit_op_, ffi::Function f_visit_tuple_getitem_, + ffi::Function f_visit_prim_value_, ffi::Function f_visit_string_imm_, + ffi::Function f_visit_data_type_imm_, ffi::Function f_visit_binding, + ffi::Function f_visit_var_binding_, ffi::Function f_visit_match_cast_, + ffi::Function f_visit_binding_block, ffi::Function f_visit_binding_block_, + ffi::Function f_visit_dataflow_block_, ffi::Function f_visit_var_def, + ffi::Function f_visit_var_def_, ffi::Function f_visit_dataflow_var_def_, + ffi::Function f_visit_span) { ObjectPtr n = make_object(); n->builder_ = builder_; n->f_visit_expr = f_visit_expr; diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index bf924240721d..d79d8b3fd50d 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -165,8 +165,7 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") .set_body_typed( - [](runtime::TypedPackedFunc, IRModule, PassContext)> - pass_func, + [](ffi::TypedFunction, IRModule, PassContext)> pass_func, PassInfo pass_info) { auto wrapped_pass_func = [pass_func](Function func, IRModule mod, PassContext ctx) { return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); @@ -385,15 +384,15 @@ Pass CreateDataflowBlockPass( TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") - .set_body_typed([](runtime::TypedPackedFunc, - IRModule, PassContext)> - pass_func, - PassInfo pass_info) { - auto wrapped_pass_func = [pass_func](DataflowBlock func, IRModule mod, PassContext ctx) { - return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); - }; - return DataflowBlockPass(wrapped_pass_func, pass_info); - }); + .set_body_typed( + [](ffi::TypedFunction, IRModule, PassContext)> + pass_func, + PassInfo pass_info) { + auto wrapped_pass_func = [pass_func](DataflowBlock func, IRModule mod, PassContext ctx) { + return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); + }; + return DataflowBlockPass(wrapped_pass_func, pass_info); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index b1ac176a28e4..79749cb41693 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -39,19 +39,19 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns( +std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); - TypedPackedFunc is_compile_time = [compile_time_lookup](Expr arg) -> bool { + ffi::TypedFunction is_compile_time = [compile_time_lookup](Expr arg) -> bool { if (auto as_var = arg.as()) { return compile_time_lookup.count(as_var.value()); } else { return false; } }; - TypedPackedFunc is_runtime = [is_compile_time](Expr arg) -> bool { + ffi::TypedFunction is_runtime = [is_compile_time](Expr arg) -> bool { return !is_compile_time(arg); }; diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 6c0329a33d06..cc55eaff0721 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -40,7 +40,7 @@ namespace relax { using runtime::Map; -using FCheck = runtime::TypedPackedFunc, Array, Map)>; +using FCheck = ffi::TypedFunction, Array, Map)>; /*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose batch sizes are compatible are combined. @@ -117,7 +117,7 @@ Patterns CreatePatterns(const BranchInfo& branch_info) { } /*! \brief Create a rewriter for the given parallel matmul branches. */ -runtime::TypedPackedFunc(Map, Map)> GetRewriter( +ffi::TypedFunction(Map, Map)> GetRewriter( const Patterns& patterns, const BranchInfo& branch_info, FCheck check) { auto batch_dims_compatible = [](size_t rhs_dim, const std::vector& indices, const std::vector>& rhs_shapes) { diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 0898af1d7636..b41ad9ea29c4 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -343,7 +343,7 @@ DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block, namespace transform { Pass ConvertLayout(Map> desired_layouts) { - runtime::TypedPackedFunc pass_func = + ffi::TypedFunction pass_func = [=](DataflowBlock df_block, IRModule m, PassContext pc) { return Downcast(ConvertLayoutPass(df_block, desired_layouts)); }; diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index e20f9c59b28b..134eca557264 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -40,7 +40,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns( +std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index af38645139d6..2cce9c8d7c26 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -99,7 +99,7 @@ class ConstantFolder : public ExprMutator { * \brief Get a cached build version of func * \return The cached func, nullopt if func cannot be built. */ - Optional GetCachedBuild(tir::PrimFunc func) { + Optional GetCachedBuild(tir::PrimFunc func) { // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once // would be helpful for future cases where PrimFunc recursively call into each other Target eval_cpu_target{"llvm"}; @@ -108,7 +108,7 @@ class ConstantFolder : public ExprMutator { if (it != func_build_cache_.end()) { return it->second; } - Optional build_func = NullOpt; + Optional build_func = NullOpt; try { // Not all the primfunc can be directly built via llvm, for example, if a function is @@ -145,7 +145,7 @@ class ConstantFolder : public ExprMutator { Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array arr_args, runtime::ShapeTuple shape, DataType ret_type) { // obtain function from the cache. - Optional func = GetCachedBuild(tir_func); + Optional func = GetCachedBuild(tir_func); if (!func) return NullOpt; // here the vector size has an additional + 1 because we need to put ret_tensor at the end @@ -165,7 +165,7 @@ class ConstantFolder : public ExprMutator { // set return value packed_args[arg_offset++] = ret_tensor; - TVMRetValue ret; + ffi::Any ret; // invoke func.value().CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &ret); return Constant(ret_tensor); @@ -196,7 +196,7 @@ class ConstantFolder : public ExprMutator { using ExprMutator::VisitExpr_; // TODO(@sunggg): - // Next PR will support fold with PackedFunc and MatchCast + // Next PR will support fold with ffi::Function and MatchCast // Until then, DecomposeOps() should be applied after // this pass to fold `tensor_to_shape` op. Expr VisitExpr_(const CallNode* call) final { @@ -280,8 +280,8 @@ class ConstantFolder : public ExprMutator { return ShapeExpr(shape_values); } } else if (op->name == "relax.shape_to_tensor") { - // Special handling for "relax.shape_to_tensor" since it is implemented in PackedFunc. - // TODO(sunggg): revisit this when we extend ConstantFolding to fold PackedFunc. + // Special handling for "relax.shape_to_tensor" since it is implemented in ffi::Function. + // TODO(sunggg): revisit this when we extend ConstantFolding to fold ffi::Function. Expr arg = post_call->args[0]; ShapeExpr shape = Downcast(arg); Array values = shape->values; @@ -313,7 +313,7 @@ class ConstantFolder : public ExprMutator { } // cache for function build, via structural equality - std::unordered_map, StructuralHash, StructuralEqual> + std::unordered_map, StructuralHash, StructuralEqual> func_build_cache_; }; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index a1b651c3eb2d..4819cefb9ac3 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1055,8 +1055,8 @@ class PatternBasedPartitioner : ExprVisitor { using GroupMap = OperatorFusor::GroupMap; using PatternCheckContext = transform::PatternCheckContext; using ExprVisitor::VisitExpr_; - using FCheckMatch = runtime::TypedPackedFunc; - using FAttrsGetter = runtime::TypedPackedFunc(const Map&)>; + using FCheckMatch = ffi::TypedFunction; + using FAttrsGetter = ffi::TypedFunction(const Map&)>; static GroupMap Run(String pattern_name, DFPattern pattern, Map annotation_patterns, FCheckMatch check, Expr expr, @@ -1383,8 +1383,8 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, namespace transform { FusionPattern::FusionPattern(String name, DFPattern pattern, - Map annotation_patterns, Optional check, - Optional attrs_getter) { + Map annotation_patterns, + Optional check, Optional attrs_getter) { ObjectPtr n = make_object(); n->name = std::move(name); n->pattern = std::move(pattern); @@ -1397,7 +1397,7 @@ FusionPattern::FusionPattern(String name, DFPattern pattern, TVM_REGISTER_NODE_TYPE(FusionPatternNode); TVM_REGISTER_GLOBAL("relax.transform.FusionPattern") .set_body_typed([](String name, DFPattern pattern, Map annotation_patterns, - Optional check, Optional attrs_getter) { + Optional check, Optional attrs_getter) { return FusionPattern(name, pattern, annotation_patterns, check, attrs_getter); }); diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index 52483998081d..3440afb7211d 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -73,7 +73,7 @@ using VarDTypeMap = std::unordered_map; // Call is a call node, out_dtype is the expected output_dtype using FInferMixedPrecision = - runtime::TypedPackedFunc; + ffi::TypedFunction; Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index 951fe92cb8ac..d8666cc431da 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -149,7 +149,7 @@ using VarLayoutMap = Map; * \param desired_layouts The desired layouts of the operator. * \param var_layout_map The layout of the variables. */ -using FRelaxInferLayout = runtime::TypedPackedFunc>& desired_layouts, const VarLayoutMap& var_layout_map)>; diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index e5653c5da93d..b66132154aca 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -59,7 +59,7 @@ bool KnowAllShapeValues(const StructInfo& sinfo) { class LegalizeMutator : public ExprMutator { public: - explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap, + explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap, bool enable_warning) : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { if (cmap) { @@ -377,7 +377,7 @@ class LegalizeMutator : public ExprMutator { /*! \brief The context IRModule. */ IRModule mod_; /*! \brief The customized legalization function map. */ - Map cmap_; + Map cmap_; /*! \brief If VDevice annotations produced at least one PrimFunc with a Target attr*/ bool generated_tir_with_target_attr_{false}; /*! @@ -389,7 +389,7 @@ class LegalizeMutator : public ExprMutator { namespace transform { -Pass LegalizeOps(Optional> cmap, bool enable_warning) { +Pass LegalizeOps(Optional> cmap, bool enable_warning) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index c0123a43abc0..875d26ea47a1 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -189,7 +189,7 @@ Pass MetaScheduleTuneIRMod(Map params, String work_dir Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { Target target = Target::Current(false); - runtime::TypedPackedFunc pass_func = + ffi::TypedFunction pass_func = [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_global, NullOpt) .TuneTIR(f, ctx); diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index 5712ec6d2b81..a2023a068aa2 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -40,7 +40,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns() { +std::tuple)>> CreatePatterns() { // TODO(Lunderberg): Allow pattern-matching to handle a flexible // number of arguments, each of which matches the same type of // pattern. diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index 4b043787d6c7..28480a2296f3 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -40,7 +40,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns() { +std::tuple)>> CreatePatterns() { auto pat_lhs = WildcardPattern(); auto pat_weights = WildcardPattern(); diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index eefcf3ba1b64..062ac97a35f7 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -39,7 +39,7 @@ namespace relax { namespace { class ParamStructInfoMutator : public ExprMutator { public: - explicit ParamStructInfoMutator(TypedPackedFunc(Var)> sinfo_func) + explicit ParamStructInfoMutator(ffi::TypedFunction(Var)> sinfo_func) : sinfo_func_(sinfo_func) {} using ExprMutator::VisitExpr_; @@ -64,12 +64,12 @@ class ParamStructInfoMutator : public ExprMutator { return ExprMutator::VisitExpr_(func.get()); } - TypedPackedFunc(Var)> sinfo_func_; + ffi::TypedFunction(Var)> sinfo_func_; }; } // namespace namespace transform { -Pass UpdateParamStructInfo(TypedPackedFunc(Var)> sinfo_func) { +Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_func) { auto pass_func = [=](IRModule mod, PassContext pc) { ParamStructInfoMutator mutator(sinfo_func); diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 0482f4ab7009..b76687522d69 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -22,6 +22,7 @@ * \brief Device specific implementations */ #include +#include #include #include #include @@ -508,7 +509,8 @@ int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, TVMFunctionHandle* func) { API_BEGIN(); - PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); + tvm::ffi::Function pf = + ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); if (pf != nullptr) { tvm::ffi::Any ret = pf; TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(ret)); @@ -606,7 +608,7 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) { API_BEGIN(); ICHECK_EQ(num_ret, 1); - TVMRetValue* rv = static_cast(ret); + tvm::ffi::Any* rv = static_cast(ret); *rv = LegacyTVMArgValueToAnyView(value[0], type_code[0]); API_END(); } @@ -615,17 +617,18 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked TVMFunctionHandle* out) { API_BEGIN(); if (fin == nullptr) { - tvm::runtime::TVMRetValue ret; - ret = tvm::ffi::Function::FromPacked([func, resource_handle](TVMArgs args, TVMRetValue* rv) { - // run ABI translation - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); - int ret = func(values.data(), type_codes.data(), args.size(), rv, resource_handle); - if (ret != 0) { - TVMThrowLastError(); - } - }); + tvm::ffi::Any ret; + ret = tvm::ffi::Function::FromPacked( + [func, resource_handle](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + // run ABI translation + std::vector values(args.size()); + std::vector type_codes(args.size()); + PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); + int ret = func(values.data(), type_codes.data(), args.size(), rv, resource_handle); + if (ret != 0) { + TVMThrowLastError(); + } + }); TVMValue val; int type_code; MoveAnyToLegacyTVMValue(std::move(ret), &val, &type_code); @@ -634,18 +637,19 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); - tvm::runtime::TVMRetValue ret; - ret = PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { - // run ABI translation - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); - int ret = func(values.data(), type_codes.data(), args.size(), rv, rpack.get()); - - if (ret != 0) { - TVMThrowLastError(); - } - }); + tvm::ffi::Any ret; + ret = + tvm::ffi::Function::FromPacked([func, rpack](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { + // run ABI translation + std::vector values(args.size()); + std::vector type_codes(args.size()); + PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); + int ret = func(values.data(), type_codes.data(), args.size(), rv, rpack.get()); + + if (ret != 0) { + TVMThrowLastError(); + } + }); TVMValue val; val.v_handle = nullptr; int type_code; @@ -771,7 +775,7 @@ int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream // set device api TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { DLDevice dev; dev.device_type = static_cast(args[0].cast()); dev.device_id = args[1].cast(); @@ -779,22 +783,23 @@ TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) }); // set device api -TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - DLDevice dev; - dev.device_type = static_cast(args[0].cast()); - dev.device_id = args[1].cast(); - - DeviceAttrKind kind = static_cast(args[2].cast()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); - if (api != nullptr) { - api->GetAttr(dev, kind, ret); - } else { - *ret = 0; - } - } else { - DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); - } -}); +TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr") + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + DLDevice dev; + dev.device_type = static_cast(args[0].cast()); + dev.device_id = args[1].cast(); + + DeviceAttrKind kind = static_cast(args[2].cast()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); + if (api != nullptr) { + api->GetAttr(dev, kind, ret); + } else { + *ret = 0; + } + } else { + DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); + } + }); TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream); diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index e202129f6ae7..2536847726c8 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -65,7 +65,7 @@ class ConstLoaderModuleNode : public ModuleNode { } } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")"; // Initialize and memoize the module. // Usually, we have some warmup runs. The module initialization should be @@ -76,7 +76,7 @@ class ConstLoaderModuleNode : public ModuleNode { } if (name == "get_const_var_ndarray") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { Map ret_map; for (const auto& kv : const_var_ndarray_) { ret_map.Set(kv.first, kv.second); @@ -90,10 +90,10 @@ class ConstLoaderModuleNode : public ModuleNode { // symobl lookup overhead should be minimal. ICHECK(!this->imports().empty()); for (Module it : this->imports()) { - PackedFunc pf = it.GetFunction(name); + ffi::Function pf = it.GetFunction(name); if (pf != nullptr) return pf; } - return PackedFunc(nullptr); + return ffi::Function(nullptr); } const char* type_key() const final { return "const_loader"; } @@ -133,7 +133,7 @@ class ConstLoaderModuleNode : public ModuleNode { * found module accordingly by passing the needed constants into it. */ void InitSubModule(const std::string& symbol) { - PackedFunc init(nullptr); + ffi::Function init(nullptr); for (Module it : this->imports()) { // Get the initialization function from the imported modules. std::string init_name = "__init_" + symbol; diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc index 80da3e71eb78..72225f39954f 100644 --- a/src/runtime/contrib/amx/amx_config.cc +++ b/src/runtime/contrib/amx/amx_config.cc @@ -76,20 +76,21 @@ void init_tile_config(__tilecfg_u* dst, uint16_t cols, uint8_t rows) { _tile_loadconfig(dst->a); } -TVM_REGISTER_GLOBAL("runtime.amx_tileconfig").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - int rows = args[0].cast(); - int cols = args[1].cast(); - LOG(INFO) << "rows: " << rows << ", cols:" << cols; - // -----------Config for AMX tile resgister---------------------- - __tilecfg_u cfg; - init_tile_config(&cfg, cols, rows); - - *rv = 1; - return; -}); +TVM_REGISTER_GLOBAL("runtime.amx_tileconfig") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + int rows = args[0].cast(); + int cols = args[1].cast(); + LOG(INFO) << "rows: " << rows << ", cols:" << cols; + // -----------Config for AMX tile resgister---------------------- + __tilecfg_u cfg; + init_tile_config(&cfg, cols, rows); + + *rv = 1; + return; + }); // register a global packed function in c++,to init the system for AMX config -TVM_REGISTER_GLOBAL("runtime.amx_init").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.amx_init").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // -----------Detect and request for AMX control---------------------- uint64_t bitmask = 0; int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index edb4d7e0afae..155e1f05f197 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -123,18 +123,19 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CblasSgemmOp()); - else - CallGemm(args, ret, CblasDgemmOp()); -}); + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CblasSgemmOp()); + else + CallGemm(args, ret, CblasDgemmOp()); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { @@ -145,7 +146,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") }); TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { diff --git a/src/runtime/contrib/cblas/dnnl_blas.cc b/src/runtime/contrib/cblas/dnnl_blas.cc index a742d9493336..18840eb55db1 100644 --- a/src/runtime/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/contrib/cblas/dnnl_blas.cc @@ -46,10 +46,11 @@ struct DNNLSgemmOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.matmul").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CallGemm(args, ret, DNNLSgemmOp()); -}); +TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.matmul") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CallGemm(args, ret, DNNLSgemmOp()); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index d211d7c7ff69..14b74d4736fc 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -34,8 +34,8 @@ namespace tvm { namespace contrib { -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::PackedArgs; using runtime::TypeMatch; inline int ColumnStride(const DLTensor* tensor) { @@ -73,7 +73,7 @@ inline int ColumnCount(const DLTensor* tensor, bool trans, int batch_offset = 0) // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. template -inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { +inline void CallGemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -112,7 +112,7 @@ inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. template -inline void CallU8S8S32Gemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { +inline void CallU8S8S32Gemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -181,7 +181,7 @@ inline int BatchCount3D(DLTensor* tensor) { return tensor->shape[0]; } inline int RowCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 2 : 1]; } inline int ColumnCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 2]; } template -inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { +inline void CallBatchGemm(ffi::PackedArgs args, ffi::Any* ret, TBatchGemmOp op) { using DType = typename TBatchGemmOp::TDatatype; auto A = args[0].cast(); auto B = args[1].cast(); diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index 56a45b3f819a..f98df0c6d624 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -154,19 +154,20 @@ struct MKLDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); +TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, MKLSgemmOp()); - else - CallGemm(args, ret, MKLDgemmOp()); -}); + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, MKLSgemmOp()); + else + CallGemm(args, ret, MKLDgemmOp()); + }); // integer matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -177,7 +178,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") }); TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { @@ -188,7 +189,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") }); TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul_iterative") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index a29230b0d857..b3f7e846e0ec 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -19,7 +19,7 @@ /*! * \brief CoreML runtime that can run coreml model - * containing only tvm PackedFunc. + * containing only tvm ffi::Function. * \file coreml_runtime.h */ #ifndef TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ @@ -93,7 +93,7 @@ class CoreMLModel { * \brief CoreML runtime. * * This runtime can be accessed in various language via - * TVM runtime PackedFunc API. + * TVM runtime ffi::Function API. */ class CoreMLRuntime : public ModuleNode { public: @@ -103,7 +103,7 @@ class CoreMLRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 272e871ead0b..7b2733c4312e 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -128,22 +128,25 @@ model_ = std::unique_ptr(new CoreMLModel(url)); } -PackedFunc CoreMLRuntime::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { +ffi::Function CoreMLRuntime::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "invoke" || name == "run") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { model_->Invoke(); }); + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { model_->Invoke(); }); } else if (name == "set_input") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { const auto& input_name = args[0].operator std::string(); model_->SetInput(input_name, args[1]); }); } else if (name == "get_output") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = model_->GetOutput(args[0]); }); + return ffi::Function( + [this](ffi::PackedArgs args, ffi::Any* rv) { *rv = model_->GetOutput(args[0]); }); } else if (name == "get_num_outputs") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = model_->GetNumOutputs(); }); + return ffi::Function( + [this](ffi::PackedArgs args, ffi::Any* rv) { *rv = model_->GetNumOutputs(); }); } else if (name == symbol_) { // Return the packedfunc which executes the subgraph. - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { MLModelDescription* model_desc = [model_->model_ modelDescription]; NSString* metadata = [model_desc metadata][MLModelDescriptionKey]; NSData* data = [metadata dataUsingEncoding:NSUTF8StringEncoding]; @@ -179,7 +182,7 @@ *rv = out; }); } else { - return PackedFunc(); + return ffi::Function(); } } @@ -189,9 +192,10 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = CoreMLRuntimeCreate(args[0], args[1]); -}); +TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = CoreMLRuntimeCreate(args[0], args[1]); + }); void CoreMLRuntime::SaveToBinary(dmlc::Stream* stream) { NSURL* url = model_->url_; diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 4907bc9a308a..e3222e3adc40 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -305,7 +305,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, cublasLtMatrixLayoutDestroy(C_desc); } -inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl, cudaStream_t stream) { +inline void CallLtIgemm(ffi::PackedArgs args, ffi::Any* ret, cublasLtHandle_t hdl, + cudaStream_t stream) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -377,7 +378,7 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl, cu } #endif -inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { +inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t hdl) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -435,7 +436,7 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { beta_ptr, C_data, cuda_out_type, ColumnStride(C), cuda_out_type, algo)); } -inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { +inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t hdl) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -514,7 +515,7 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); @@ -539,7 +540,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") #if CUDART_VERSION >= 10010 TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); @@ -557,7 +558,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") #endif // CUDART_VERSION >= 10010 TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 79ed696a8778..c9c6cf85c6ba 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -48,12 +48,12 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since CublasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function - // and directly call cuBLAS on the inputs from TVMArgs. + // and directly call cuBLAS on the inputs from ffi::PackedArgs. if (this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); @@ -64,7 +64,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { const char* type_key() const override { return "cublas_json"; } // May be overridden - void Run(TVMArgs args) { + void Run(ffi::PackedArgs args) { auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index be5ebdfbbf00..52c69c81cf08 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -63,7 +63,7 @@ void ConvolutionBackwardData(int mode, int format, int algo, int dims, int group void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], const int dilation[], const int dy_dim[], const int w_dim[], const int dx_dim[], const std::string& data_dtype, - const std::string& conv_dtype, bool verbose, TVMRetValue* ret) { + const std::string& conv_dtype, bool verbose, ffi::Any* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); const int full_dims = dims + 2; std::vector dy_dim_int64(full_dims); @@ -140,7 +140,7 @@ void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int gro void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], const int dilation[], const int dy_dim[], const int x_dim[], const int dw_dim[], const std::string& data_dtype, - const std::string& conv_dtype, bool verbose, TVMRetValue* ret) { + const std::string& conv_dtype, bool verbose, ffi::Any* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); const int full_dims = dims + 2; std::vector x_dim_int64(full_dims); @@ -186,7 +186,7 @@ void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], c } TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); int algo = args[2].cast(); @@ -207,7 +207,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); int* pad = static_cast(args[2].cast()); @@ -226,7 +226,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); int algo = args[2].cast(); @@ -247,7 +247,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); int* pad = static_cast(args[2].cast()); diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 1c2e815110de..87e6121e74c7 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -108,7 +108,7 @@ void ConvolutionBiasActivationForward(int mode, int format, int algo, int dims, void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], const std::string& data_dtype, const std::string& conv_dtype, bool verbose, - TVMRetValue* ret) { + ffi::Any* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); const int full_dims = dims + 2; std::vector x_dim_int64(full_dims); @@ -154,7 +154,7 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid } TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); int algo = args[2].cast(); @@ -175,7 +175,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); int algo = args[2].cast(); @@ -199,7 +199,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); int algo = args[2].cast(); @@ -220,7 +220,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); int* pad = static_cast(args[2].cast()); diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 3811a0788fa9..08909a3150c2 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -150,7 +150,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { int mode = CUDNN_CROSS_CORRELATION; // find best algo - TVMRetValue best_algo; + ffi::Any best_algo; tvm::contrib::FindAlgo(format, dims, groups, padding.data(), strides.data(), dilation.data(), input_dims.data(), kernel_dims.data(), output_dims.data(), conv_dtype, diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index cd8e0a9220b6..902b61532353 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -117,7 +117,7 @@ void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int g void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], const std::string& data_dtype, const std::string& conv_dtype, bool verbose, - runtime::TVMRetValue* ret); + ffi::Any* ret); void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[], const int stride[], const int dilation[], const DLTensor* x, diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index 141a84f670a5..c2b3ac3db84c 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -31,7 +31,7 @@ namespace contrib { using namespace runtime; -void softmax_impl(cudnnSoftmaxAlgorithm_t alg, TVMArgs args, TVMRetValue* ret) { +void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* ret) { auto x = args[0].cast(); auto y = args[1].cast(); int axis = args[2].cast(); @@ -78,12 +78,12 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, TVMArgs args, TVMRetValue* ret) { } TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret); }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); }); diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 7fca0d6b26a5..3e73b19116ee 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -348,38 +348,39 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ } // DNNL Conv2d single OP -TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto input = args[0].cast(); - auto weights = args[1].cast(); - auto output = args[2].cast(); - int p_Ph0_ = args[3].cast(), p_Pw0_ = args[4].cast(), p_Ph1_ = args[5].cast(), - p_Pw1_ = args[6].cast(), p_Sh_ = args[7].cast(), p_Sw_ = args[8].cast(), - p_G_ = args[9].cast(); - bool channel_last = args[10].cast(); - bool pre_cast = args[11].cast(); - bool post_cast = args[12].cast(); - - int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2], - p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2], - p_Kw_ = weights->shape[3]; - - if (channel_last) { - p_N_ = input->shape[0]; - p_H_ = input->shape[1]; - p_W_ = input->shape[2]; - p_C_ = input->shape[3]; - p_O_ = output->shape[3]; - p_Kh_ = weights->shape[0]; - p_Kw_ = weights->shape[1]; - } - - std::vector bias(p_O_, 0); - primitive_attr attr; - return dnnl_conv2d_common(static_cast(input->data), static_cast(weights->data), - bias.data(), static_cast(output->data), p_N_, p_C_, p_H_, p_W_, - p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, - attr, channel_last, pre_cast, post_cast); -}); +TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + auto weights = args[1].cast(); + auto output = args[2].cast(); + int p_Ph0_ = args[3].cast(), p_Pw0_ = args[4].cast(), p_Ph1_ = args[5].cast(), + p_Pw1_ = args[6].cast(), p_Sh_ = args[7].cast(), p_Sw_ = args[8].cast(), + p_G_ = args[9].cast(); + bool channel_last = args[10].cast(); + bool pre_cast = args[11].cast(); + bool post_cast = args[12].cast(); + + int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2], + p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2], + p_Kw_ = weights->shape[3]; + + if (channel_last) { + p_N_ = input->shape[0]; + p_H_ = input->shape[1]; + p_W_ = input->shape[2]; + p_C_ = input->shape[3]; + p_O_ = output->shape[3]; + p_Kh_ = weights->shape[0]; + p_Kw_ = weights->shape[1]; + } + + std::vector bias(p_O_, 0); + primitive_attr attr; + return dnnl_conv2d_common( + static_cast(input->data), static_cast(weights->data), bias.data(), + static_cast(output->data), p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, + p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, channel_last, pre_cast, post_cast); + }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 323d00f4a646..b06b17c17d8e 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -72,7 +72,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable code"; } /* Thread safe implementation of Run. Keep runtime instance immutable */ - void Run(const TVMArgs& args) const { + void Run(const ffi::PackedArgs& args) const { auto arg_data_provider = makeIODataProvider(args); auto mem_solver = tensor_registry_.MakeSolver(arg_data_provider); // Execute primitives one by one @@ -99,9 +99,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /* Override GetFunction to reimplement Run method */ - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { if (this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) @@ -115,7 +115,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /* Same as makeInitDataProvider but in case of InputOutput return real DLTensor */ - TensorRegistry::DLTensorProvider makeIODataProvider(const TVMArgs& args) const { + TensorRegistry::DLTensorProvider makeIODataProvider(const ffi::PackedArgs& args) const { std::map io_map; // eid to dl tensor map for (size_t i = 0; i < run_arg_eid_.size(); i++) { io_map[run_arg_eid_[i]] = args[i].cast(); diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index a0c11fea2299..2a2462786327 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -69,7 +69,7 @@ Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { } TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = EdgeTPURuntimeCreate(args[0], args[1]); }); } // namespace runtime diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.h b/src/runtime/contrib/edgetpu/edgetpu_runtime.h index 341062f1c492..2d5d10652fd0 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.h +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.h @@ -19,7 +19,7 @@ /*! * \brief EdgeTPU runtime that can run tflite model compiled - * for EdgeTPU containing only tvm PackedFunc. + * for EdgeTPU containing only tvm ffi::Function. * \file edgetpu_runtime.h */ #ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_ @@ -40,7 +40,7 @@ namespace runtime { * \brief EdgeTPU runtime. * * This runtime can be accessed in various languages via - * the TVM runtime PackedFunc API. + * the TVM runtime ffi::Function API. */ class EdgeTPURuntime : public TFLiteRuntime { public: diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index a6d4b0448287..c85f15cc743a 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -272,7 +272,7 @@ void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, hipblasLtMatrixLayoutDestroy(C_desc); } -inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) { +inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -330,7 +330,7 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) { beta_ptr, C_data, hip_out_type, ColumnStride(C), hip_out_type, algo)); } -inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) { +inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -408,7 +408,7 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, hipblasHandle_t hdl) // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); @@ -431,7 +431,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") }); TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index ac3e8b87d059..2cd1223bc654 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -46,12 +46,12 @@ class HipblasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since HipblasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function - // and directly call hipBLAS on the inputs from TVMArgs. + // and directly call hipBLAS on the inputs from ffi::PackedArgs. if (this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); @@ -62,7 +62,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { const char* type_key() const override { return "hipblas_json"; } // May be overridden - void Run(TVMArgs args) { + void Run(ffi::PackedArgs args) { auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(); static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_rocm_stream"); hipStream_t stream = static_cast(func().cast()); diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index e60132affdcd..3f42e109f839 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -95,15 +95,15 @@ class JSONRuntimeBase : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); } else if (name == "get_const_vars") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_names_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->const_names_; }); } else if (this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; // Bind argument tensors to data entries. @@ -116,9 +116,9 @@ class JSONRuntimeBase : public ModuleNode { // NOTE: the current debug convention is not very compatible with // the FFI convention, consider clean up if (!this->CanDebug()) { - return PackedFunc(nullptr); + return ffi::Function(nullptr); } - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; // Bind argument tensors to data entries. @@ -138,7 +138,7 @@ class JSONRuntimeBase : public ModuleNode { }); } else if ("__init_" + this->symbol_name_ == name) { // The function to initialize constant tensors. - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK_EQ(args.size(), 1U); std::lock_guard guard(this->initialize_mutex_); if (!this->initialized_) { @@ -148,7 +148,7 @@ class JSONRuntimeBase : public ModuleNode { *rv = 0; }); } else { - return PackedFunc(nullptr); + return ffi::Function(nullptr); } } @@ -199,7 +199,7 @@ class JSONRuntimeBase : public ModuleNode { * * \param args The packed args. */ - void SetInputOutputBuffers(const TVMArgs& args) { + void SetInputOutputBuffers(const ffi::PackedArgs& args) { ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) << "Found mismatch in the number of provided data entryies and required."; diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index bd309a948f96..19eec4a0a026 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -35,7 +35,7 @@ namespace miopen { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const int mode = args[0].cast(); const int dtype = args[1].cast(); const int pad_h = args[2].cast(); @@ -149,7 +149,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") }); TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const int mode = args[0].cast(); const int dtype = args[1].cast(); const int pad_h = args[2].cast(); diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc index 897372524dd0..021d0387defb 100644 --- a/src/runtime/contrib/miopen/softmax.cc +++ b/src/runtime/contrib/miopen/softmax.cc @@ -32,7 +32,7 @@ namespace miopen { using namespace runtime; -void softmax_impl(TVMArgs args, TVMRetValue* ret, miopenSoftmaxAlgorithm_t alg) { +void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, miopenSoftmaxAlgorithm_t alg) { auto x = args[0].cast(); auto y = args[1].cast(); int axis = args[2].cast(); @@ -80,12 +80,12 @@ void softmax_impl(TVMArgs args, TVMRetValue* ret, miopenSoftmaxAlgorithm_t alg) } TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); }); TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index c8788950a261..4200477b2713 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -25,7 +25,7 @@ using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto buf = args[0].cast(); auto img = args[1].cast(); // copy to temp @@ -58,7 +58,7 @@ }); TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto img = args[0].cast(); auto buf = args[1].cast(); id mtlbuf = (__bridge id)(buf->data); @@ -76,86 +76,88 @@ buf -> dtype, nullptr); }); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - // MPS-NHWC - auto data = args[0].cast(); - auto weight = args[1].cast(); - auto output = args[2].cast(); - int pad = args[3].cast(); - int stride = args[4].cast(); - - ICHECK_EQ(data->ndim, 4); - ICHECK_EQ(weight->ndim, 4); - ICHECK_EQ(output->ndim, 4); - ICHECK(output->strides == nullptr); - ICHECK(weight->strides == nullptr); - ICHECK(data->strides == nullptr); - - ICHECK_EQ(data->shape[0], 1); - ICHECK_EQ(output->shape[0], 1); - - int oCh = weight->shape[0]; - int kH = weight->shape[1]; - int kW = weight->shape[2]; - int iCh = weight->shape[3]; - - const auto f_buf2img = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.buffer2img"); - const auto f_img2buf = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.img2buffer"); - // Get Metal device API - MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); - id dev = entry_ptr->metal_api->GetDevice(data->device); - id queue = entry_ptr->metal_api->GetCommandQueue(data->device); - id cb = [queue commandBuffer]; - // data to MPSImage - DLTensor tmp_in; - (*f_buf2img)(data, &tmp_in); - MPSImage* tempA = (__bridge MPSImage*)tmp_in.data; - // weight to temp memory - id bufB = (__bridge id)(weight->data); - id tempB = rt->GetTempBuffer(weight->device, [bufB length]); - entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, - [bufB length], weight -> device, weight -> device, - tmp_in.dtype, nullptr); - float* ptr_w = (float*)[tempB contents]; - // output to MPSImage - DLTensor tmp_out; - (*f_buf2img)(output, &tmp_out); - MPSImage* tempC = (__bridge MPSImage*)tmp_out.data; - // conv desc - - MPSCNNConvolutionDescriptor* conv_desc = - [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW - kernelHeight:kH - inputFeatureChannels:iCh - outputFeatureChannels:oCh]; - [conv_desc setStrideInPixelsX:stride]; - [conv_desc setStrideInPixelsY:stride]; - - MPSCNNConvolution* conv = [[MPSCNNConvolution alloc] initWithDevice:dev - convolutionDescriptor:conv_desc - kernelWeights:ptr_w - biasTerms:nil - flags:MPSCNNConvolutionFlagsNone]; - if (pad == 0) { - conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeSame]; - } else if (pad == 1) { - conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeValidOnly]; - } - [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; - - [cb commit]; - id encoder = [cb blitCommandEncoder]; - [encoder synchronizeResource:tempC.texture]; - [encoder endEncoding]; - [cb waitUntilCompleted]; - - (*f_img2buf)(&tmp_out, output); -}); +TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + // MPS-NHWC + auto data = args[0].cast(); + auto weight = args[1].cast(); + auto output = args[2].cast(); + int pad = args[3].cast(); + int stride = args[4].cast(); + + ICHECK_EQ(data->ndim, 4); + ICHECK_EQ(weight->ndim, 4); + ICHECK_EQ(output->ndim, 4); + ICHECK(output->strides == nullptr); + ICHECK(weight->strides == nullptr); + ICHECK(data->strides == nullptr); + + ICHECK_EQ(data->shape[0], 1); + ICHECK_EQ(output->shape[0], 1); + + int oCh = weight->shape[0]; + int kH = weight->shape[1]; + int kW = weight->shape[2]; + int iCh = weight->shape[3]; + + const auto f_buf2img = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.buffer2img"); + const auto f_img2buf = tvm::ffi::Function::GetGlobal("tvm.contrib.mps.img2buffer"); + // Get Metal device API + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); + id dev = entry_ptr->metal_api->GetDevice(data->device); + id queue = entry_ptr->metal_api->GetCommandQueue(data->device); + id cb = [queue commandBuffer]; + // data to MPSImage + DLTensor tmp_in; + (*f_buf2img)(data, &tmp_in); + MPSImage* tempA = (__bridge MPSImage*)tmp_in.data; + // weight to temp memory + id bufB = (__bridge id)(weight->data); + id tempB = rt->GetTempBuffer(weight->device, [bufB length]); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, + [bufB length], weight -> device, weight -> device, + tmp_in.dtype, nullptr); + float* ptr_w = (float*)[tempB contents]; + // output to MPSImage + DLTensor tmp_out; + (*f_buf2img)(output, &tmp_out); + MPSImage* tempC = (__bridge MPSImage*)tmp_out.data; + // conv desc + + MPSCNNConvolutionDescriptor* conv_desc = + [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iCh + outputFeatureChannels:oCh]; + [conv_desc setStrideInPixelsX:stride]; + [conv_desc setStrideInPixelsY:stride]; + + MPSCNNConvolution* conv = + [[MPSCNNConvolution alloc] initWithDevice:dev + convolutionDescriptor:conv_desc + kernelWeights:ptr_w + biasTerms:nil + flags:MPSCNNConvolutionFlagsNone]; + if (pad == 0) { + conv.padding = [MPSNNDefaultPadding + paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | MPSNNPaddingMethodSizeSame]; + } else if (pad == 1) { + conv.padding = [MPSNNDefaultPadding + paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | MPSNNPaddingMethodSizeValidOnly]; + } + [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; + + [cb commit]; + id encoder = [cb blitCommandEncoder]; + [encoder synchronizeResource:tempC.texture]; + [encoder endEncoding]; + [cb waitUntilCompleted]; + + (*f_img2buf)(&tmp_out, output); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index a87f437563e1..77eb6dd03dd3 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -24,72 +24,75 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto A = args[0].cast(); - auto B = args[1].cast(); - auto C = args[2].cast(); - bool transa = args[3].cast(); - bool transb = args[4].cast(); - // call gemm for simple compact code. - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); - // Get Metal device API - MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); - // ICHECK_EQ(A->device, B->device); - // ICHECK_EQ(A->device, C->device); - id dev = entry_ptr->metal_api->GetDevice(A->device); - id queue = entry_ptr->metal_api->GetCommandQueue(A->device); - id cb = [queue commandBuffer]; - NSUInteger M = A->shape[0 + (transa ? 1 : 0)]; - NSUInteger N = B->shape[1 - (transb ? 1 : 0)]; - NSUInteger K = B->shape[0 + (transb ? 1 : 0)]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + auto B = args[1].cast(); + auto C = args[2].cast(); + bool transa = args[3].cast(); + bool transb = args[4].cast(); + // call gemm for simple compact code. + ICHECK_EQ(A->ndim, 2); + ICHECK_EQ(B->ndim, 2); + ICHECK_EQ(C->ndim, 2); + ICHECK(C->strides == nullptr); + ICHECK(B->strides == nullptr); + ICHECK(A->strides == nullptr); + ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); + // Get Metal device API + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + // ICHECK_EQ(A->device, B->device); + // ICHECK_EQ(A->device, C->device); + id dev = entry_ptr->metal_api->GetDevice(A->device); + id queue = entry_ptr->metal_api->GetCommandQueue(A->device); + id cb = [queue commandBuffer]; + NSUInteger M = A->shape[0 + (transa ? 1 : 0)]; + NSUInteger N = B->shape[1 - (transb ? 1 : 0)]; + NSUInteger K = B->shape[0 + (transb ? 1 : 0)]; - ICHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); - // mps a - MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); - MPSMatrixDescriptor* descA = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:M - columns:K - rowBytes:K * sizeof(MPSDataTypeFloat32) - dataType:MPSDataTypeFloat32]; - id bufA = (__bridge id)(A->data); - MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; - // mps b - MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:K - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; - id bufB = (__bridge id)(B->data); - MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; - // mps c - MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:M - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; - id bufC = (__bridge id)(C->data); - MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; - // kernel + ICHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); + // mps a + MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); + MPSMatrixDescriptor* descA = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:K + rowBytes:K * sizeof(MPSDataTypeFloat32) + dataType:MPSDataTypeFloat32]; + id bufA = (__bridge id)(A->data); + MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; + // mps b + MPSMatrixDescriptor* descB = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:K + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; + id bufB = (__bridge id)(B->data); + MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; + // mps c + MPSMatrixDescriptor* descC = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; + id bufC = (__bridge id)(C->data); + MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; + // kernel - MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; - MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev - transposeLeft:transa - transposeRight:transb - resultRows:M - resultColumns:N - interiorColumns:K - alpha:1.0f - beta:0.0f]; - ICHECK(sgemm != nil); - [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; - [cb commit]; -}); + MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; + MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev + transposeLeft:transa + transposeRight:transb + resultRows:M + resultColumns:N + interiorColumns:K + alpha:1.0f + beta:0.0f]; + ICHECK(sgemm != nil); + [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; + [cb commit]; + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index 44b88022a7ce..01cfb385c7f5 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -211,12 +211,12 @@ class MarvellHardwareModuleNode : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) { + virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) { if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); } else if (name == "register_cb") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { struct ml_dpdk_cb* a = static_cast(args[0].value().v_handle); memcpy(&dpdk_cb_, a, sizeof(struct ml_dpdk_cb)); device_handle = args[1].value().v_handle; @@ -224,22 +224,22 @@ class MarvellHardwareModuleNode : public ModuleNode { use_dpdk_cb = true; }); } else if (name == "get_const_vars") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array{}; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = Array{}; }); } else if (this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { RunInference(args); *rv = 0; }); } else if ("__init_" + this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { run_arg.device = device_handle; run_arg.model_id = model_id; load_and_initialize_model(); *rv = 0; }); } - return PackedFunc(nullptr); + return ffi::Function(nullptr); } virtual void SaveToBinary(dmlc::Stream* stream) { @@ -295,7 +295,7 @@ class MarvellHardwareModuleNode : public ModuleNode { struct run_args run_arg; static bool use_dpdk_cb; - void RunInference_TVMC(TVMArgs args) { + void RunInference_TVMC(ffi::PackedArgs args) { float* i_d_buf_float; float* o_d_buf_float; const DLTensor* tensor; @@ -371,7 +371,7 @@ class MarvellHardwareModuleNode : public ModuleNode { } } - void RunInference_DPDK(TVMArgs args) { + void RunInference_DPDK(ffi::PackedArgs args) { const DLTensor* tensor[64]; for (int in = 0; in < num_inputs_; in++) { @@ -404,7 +404,7 @@ class MarvellHardwareModuleNode : public ModuleNode { run_arg.o_q_buf, tensor); } - void RunInference(TVMArgs args) { + void RunInference(ffi::PackedArgs args) { if (use_dpdk_cb) RunInference_DPDK(args); else diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index 312cf7ab3ce7..186cc3b3a859 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -69,20 +69,20 @@ class MarvellSimulatorModuleNode : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) { + virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) { if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); } else if (name == "get_const_vars") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array{}; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = Array{}; }); } else if (this->symbol_name_ == name) { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { Run(args); *rv = 0; }); } - return PackedFunc(nullptr); + return ffi::Function(nullptr); } virtual void SaveToBinary(dmlc::Stream* stream) { @@ -123,7 +123,7 @@ class MarvellSimulatorModuleNode : public ModuleNode { size_t num_inputs_; size_t num_outputs_; - void Run(TVMArgs args) { + void Run(ffi::PackedArgs args) { ICHECK_EQ(args.size(), num_inputs_ + num_outputs_) << "Marvell-Compiler-ERROR-Internal::Mismatch in number of input & number of output args " "to subgraph"; diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc index d993072aa2ee..b062a50dccb5 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc @@ -61,7 +61,7 @@ static void WriteBinToDisk(const std::string& bin_file, const std::string& bin_c for (auto byte : byte_array) file_out << byte; } -static void ReadInputsAndGenerateInputBin(TVMArgs args, const std::string& input_json, +static void ReadInputsAndGenerateInputBin(ffi::PackedArgs args, const std::string& input_json, const std::string& input_bin, const std::string& bin_directory, size_t num_inputs) { std::ofstream file_out; @@ -104,7 +104,7 @@ static void RunInferenceOnMlModel(const std::string& symbol_name, const std::str (*run_sim)(command, sim_directory); } -static void ReadOutputsAndUpdateRuntime(TVMArgs args, size_t num_inputs, +static void ReadOutputsAndUpdateRuntime(ffi::PackedArgs args, size_t num_inputs, const std::string& out_bin_prefix) { for (int out = num_inputs; out < args.size(); out++) { const DLTensor* outTensor; @@ -141,14 +141,15 @@ static void ReadOutputsAndUpdateRuntime(TVMArgs args, size_t num_inputs, } } -static void CleanUp(TVMArgs args, const std::string& bin_file, const std::string& input_json, - const std::string& input_bin, const std::string& out_bin_prefix, - size_t num_outputs) { +static void CleanUp(ffi::PackedArgs args, const std::string& bin_file, + const std::string& input_json, const std::string& input_bin, + const std::string& out_bin_prefix, size_t num_outputs) { const auto clean_up = tvm::ffi::Function::GetGlobal("tvm.mrvl.CleanUpSim"); (*clean_up)(bin_file, input_json, input_bin, out_bin_prefix, num_outputs); } -void tvm::runtime::contrib::mrvl::RunMarvellSimulator(TVMArgs args, const std::string& symbol_name, +void tvm::runtime::contrib::mrvl::RunMarvellSimulator(ffi::PackedArgs args, + const std::string& symbol_name, const std::string& bin_code, size_t num_inputs, size_t num_outputs) { // check $PATH for the presence of MRVL dependent tools/scripts diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h index 4670487ed1c4..c5242a4e7bde 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h @@ -35,7 +35,7 @@ namespace runtime { namespace contrib { namespace mrvl { -void RunMarvellSimulator(tvm::runtime::TVMArgs args, const std::string& symbol_name, +void RunMarvellSimulator(tvm::ffi::PackedArgs args, const std::string& symbol_name, const std::string& bin_code, size_t num_inputs, size_t num_outputs); } } // namespace contrib diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index ebc60762c060..ed4e1a3fad38 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -70,7 +70,7 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { } TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); int64_t low = args[0].cast(); int64_t high = args[1].cast(); @@ -104,7 +104,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") }); TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); double low = args[0].cast(); double high = args[1].cast(); @@ -113,7 +113,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") }); TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); double loc = args[0].cast(); double scale = args[1].cast(); @@ -122,14 +122,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") }); TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); auto out = args[0].cast(); entry->random_engine.RandomFill(out); }); TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) -> void { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) -> void { const auto curand = tvm::ffi::Function::GetGlobal("runtime.contrib.curand.RandomFill"); auto out = args[0].cast(); if (curand.has_value() && out->device.device_type == DLDeviceType::kDLCUDA) { diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index bf3f5b998a0f..88c6071e1efd 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -66,7 +66,7 @@ typedef dmlc::ThreadLocalStore RocBlasThreadStore; // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); @@ -104,7 +104,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") }); TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 1efaca220f60..f413af696661 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -78,7 +78,7 @@ struct float16 { // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto sort_num = args[1].cast(); auto output = args[2].cast(); @@ -216,93 +216,94 @@ void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto input = args[0].cast(); - auto output = args[1].cast(); - int32_t axis = args[2].cast(); - bool is_ascend = args[3].cast(); - if (axis < 0) { - axis = input->ndim + axis; - } - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; - - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); - - if (data_dtype == "float32") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float64") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + auto output = args[1].cast(); + int32_t axis = args[2].cast(); + bool is_ascend = args[3].cast(); + if (axis < 0) { + axis = input->ndim + axis; + } + ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; + + auto data_dtype = DLDataTypeToString(input->dtype); + auto out_dtype = DLDataTypeToString(output->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - } else if (data_dtype == "float16") { - if (out_dtype == "float16") { - argsort<__fp16, __fp16>(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } + } else if (data_dtype == "float16") { + if (out_dtype == "float16") { + argsort<__fp16, __fp16>(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } #endif - } else if (data_dtype == "int32") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int64") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float16") { - if (out_dtype == "int32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "int64") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float32") { - argsort(input, output, axis, is_ascend); - } else if (out_dtype == "float64") { - argsort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; - } -}); + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } + }); // Sort implemented C library sort. // Return sorted tensor. @@ -311,41 +312,42 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body_packed([](TVMArgs args, // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto input = args[0].cast(); - auto output = args[1].cast(); - int32_t axis = args[2].cast(); - bool is_ascend = args[3].cast(); - if (axis < 0) { - axis = input->ndim + axis; - } - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; +TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + auto output = args[1].cast(); + int32_t axis = args[2].cast(); + bool is_ascend = args[3].cast(); + if (axis < 0) { + axis = input->ndim + axis; + } + ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); + auto data_dtype = DLDataTypeToString(input->dtype); + auto out_dtype = DLDataTypeToString(output->dtype); - ICHECK_EQ(data_dtype, out_dtype); + ICHECK_EQ(data_dtype, out_dtype); - if (data_dtype == "float32") { - sort(input, output, axis, is_ascend); - } else if (data_dtype == "float64") { - sort(input, output, axis, is_ascend); + if (data_dtype == "float32") { + sort(input, output, axis, is_ascend); + } else if (data_dtype == "float64") { + sort(input, output, axis, is_ascend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) - } else if (data_dtype == "float16") { - sort<__fp16>(input, output, axis, is_ascend); + } else if (data_dtype == "float16") { + sort<__fp16>(input, output, axis, is_ascend); #endif - } else if (data_dtype == "int32") { - sort(input, output, axis, is_ascend); - } else if (data_dtype == "int64") { - sort(input, output, axis, is_ascend); - } else if (data_dtype == "float16") { - sort(input, output, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; - } -}); + } else if (data_dtype == "int32") { + sort(input, output, axis, is_ascend); + } else if (data_dtype == "int64") { + sort(input, output, axis, is_ascend); + } else if (data_dtype == "float16") { + sort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } + }); template void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, @@ -440,124 +442,126 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto input = args[0].cast(); - DLTensor* values_out = nullptr; - DLTensor* indices_out = nullptr; - int k = args[args.size() - 4].cast(); - int axis = args[args.size() - 3].cast(); - std::string ret_type = args[args.size() - 2].cast(); - bool is_ascend = args[args.size() - 1].cast(); - if (ret_type == "both") { - values_out = args[1].cast(); - indices_out = args[2].cast(); - } else if (ret_type == "values") { - values_out = args[1].cast(); - } else if (ret_type == "indices") { - indices_out = args[1].cast(); - } else { - LOG(FATAL) << "Unsupported ret type: " << ret_type; - } - if (axis < 0) { - axis = input->ndim + axis; - } - ICHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; - - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = (indices_out == nullptr) ? "int64" : DLDataTypeToString(indices_out->dtype); - - if (data_dtype == "float32") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float64") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "uint8") { - if (out_dtype == "uint8") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int8") { - if (out_dtype == "int8") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int32") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "int64") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else if (data_dtype == "float16") { - if (out_dtype == "int32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "int64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float32") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else if (out_dtype == "float64") { - topk(input, values_out, indices_out, k, axis, is_ascend); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; - } - } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; - } -}); +TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto input = args[0].cast(); + DLTensor* values_out = nullptr; + DLTensor* indices_out = nullptr; + int k = args[args.size() - 4].cast(); + int axis = args[args.size() - 3].cast(); + std::string ret_type = args[args.size() - 2].cast(); + bool is_ascend = args[args.size() - 1].cast(); + if (ret_type == "both") { + values_out = args[1].cast(); + indices_out = args[2].cast(); + } else if (ret_type == "values") { + values_out = args[1].cast(); + } else if (ret_type == "indices") { + indices_out = args[1].cast(); + } else { + LOG(FATAL) << "Unsupported ret type: " << ret_type; + } + if (axis < 0) { + axis = input->ndim + axis; + } + ICHECK(axis >= 0 && axis < input->ndim) + << "Axis out of boundary for input ndim " << input->ndim; + + auto data_dtype = DLDataTypeToString(input->dtype); + auto out_dtype = (indices_out == nullptr) ? "int64" : DLDataTypeToString(indices_out->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "uint8") { + if (out_dtype == "uint8") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int8") { + if (out_dtype == "int8") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 09669ac370f2..990475069574 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -150,28 +150,30 @@ NDArray TFLiteRuntime::GetOutput(int index) const { return ret; } -PackedFunc TFLiteRuntime::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { +ffi::Function TFLiteRuntime::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { int in_idx = args[0].cast(); ICHECK_GE(in_idx, 0); this->SetInput(in_idx, args[1].cast()); }); } else if (name == "get_output") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->GetOutput(args[0].cast()); }); } else if (name == "invoke") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { this->Invoke(); }); } else if (name == "set_num_threads") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { int num_threads = args[0].cast(); CHECK_GE(num_threads, 1); this->SetNumThreads(num_threads); }); } else { - return PackedFunc(); + return ffi::Function(); } } @@ -181,9 +183,10 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = TFLiteRuntimeCreate(args[0].cast(), args[1].cast()); -}); +TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = TFLiteRuntimeCreate(args[0].cast(), args[1].cast()); + }); TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); } // namespace runtime diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index eeba3e0a0e79..6557fa07975e 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -19,7 +19,7 @@ /*! * \brief Tflite runtime that can run tflite model - * containing only tvm PackedFunc. + * containing only tvm ffi::Function. * \file tflite_runtime.h */ #ifndef TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ @@ -43,7 +43,7 @@ namespace runtime { * \brief Tflite runtime. * * This runtime can be accessed in various language via - * TVM runtime PackedFunc API. + * TVM runtime ffi::Function API. */ class TFLiteRuntime : public ModuleNode { public: @@ -53,7 +53,7 @@ class TFLiteRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 22a6cf425d99..aa1befaeef32 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -232,7 +232,8 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") +.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_GE(args.num_args, 4); auto input = args[0].cast(); auto values_out = args[1].cast(); @@ -279,7 +280,7 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* } TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_GE(args.num_args, 5); auto keys_in = args[0].cast(); auto values_in = args[1].cast(); @@ -394,7 +395,7 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor } TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") -.set_body_packed([](TVMArgs args, TVMRetValue* ret) { +.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4); auto data = args[0].cast(); auto output = args[1].cast(); diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 867f2622648b..df2271e64732 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -52,7 +52,7 @@ namespace runtime { class CPUDeviceAPI final : public DeviceAPI { public: void SetDevice(Device dev) final {} - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final { if (kind == kExist) { *rv = 1; } @@ -150,7 +150,7 @@ void CPUDeviceAPI::FreeWorkspace(Device dev, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, data); } -TVM_REGISTER_GLOBAL("device_api.cpu").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.cpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CPUDeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 79f54494ce8d..1dc928e77801 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -38,7 +38,7 @@ namespace runtime { class CUDADeviceAPI final : public DeviceAPI { public: void SetDevice(Device dev) final { CUDA_CALL(cudaSetDevice(dev.device_id)); } - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final { int value = 0; switch (kind) { case kExist: { @@ -286,12 +286,12 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {} CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.cuda").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CUDADeviceAPI::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CUDADeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index f54aefe8c4eb..db01b76cb531 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -70,7 +70,7 @@ class CUDAModuleNode : public runtime::ModuleNode { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const String& file_name, const String& format) final { std::string fmt = GetFileFormat(file_name, format); @@ -166,7 +166,7 @@ class CUDAWrappedFunc { launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { + void operator()(ffi::PackedArgs args, ffi::Any* rv, void** void_args) const { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); ThreadWorkLoad wl = launch_param_config_.Extract(args); @@ -227,7 +227,7 @@ class CUDAPrepGlobalBarrier { std::fill(pcache_.begin(), pcache_.end(), 0); } - void operator()(const TVMArgs& args, TVMRetValue* rv) const { + void operator()(const ffi::PackedArgs& args, ffi::Any* rv) const { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); if (pcache_[device_id] == 0) { @@ -246,14 +246,15 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -PackedFunc CUDAModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { +ffi::Function CUDAModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; if (name == symbol::tvm_prepare_global_barrier) { - return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self)); + return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self)); } auto it = fmap_.find(name); - if (it == fmap_.end()) return PackedFunc(); + if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 23eddac2be38..ae7c057be0cc 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -32,7 +32,7 @@ typedef dmlc::ThreadLocalStore L2FlushStore; L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } -TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; L2Flush::ThreadLocal()->Flush(stream); diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 2346f26bd0c4..81f10190e00b 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -77,7 +77,7 @@ void BcastSessionObj::InitCCL(String ccl, IntTuple device_ids) { void BcastSessionObj::SyncWorker(int worker_id) { BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kSyncWorker, worker_id); - TVMArgs args = this->RecvReplyPacked(worker_id); + ffi::PackedArgs args = this->RecvReplyPacked(worker_id); ICHECK_EQ(args.size(), 2); DiscoAction action = static_cast(args[0].cast()); int ret_worker_id = args[1].cast(); @@ -85,7 +85,7 @@ void BcastSessionObj::SyncWorker(int worker_id) { ICHECK_EQ(ret_worker_id, worker_id); } -DRef BcastSessionObj::CallWithPacked(const TVMArgs& args) { +DRef BcastSessionObj::CallWithPacked(const ffi::PackedArgs& args) { // NOTE: this action is not safe unless we know args is not // used else where in this case it is oK AnyView* args_vec = const_cast(args.data()); diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h index e71fcd77cfcb..bfb1ca24b565 100644 --- a/src/runtime/disco/bcast_session.h +++ b/src/runtime/disco/bcast_session.h @@ -30,7 +30,7 @@ namespace runtime { /*! * \brief A Disco interactive session. It allows users to interact with the Disco command queue with - * various PackedFunc calling convention. + * various ffi::Function calling convention. */ class BcastSessionObj : public SessionObj { public: @@ -42,14 +42,14 @@ class BcastSessionObj : public SessionObj { void SyncWorker(int worker_id) override; void Shutdown() override; void InitCCL(String ccl, IntTuple device_ids) override; - TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0; + ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0; void DebugSetRegister(int64_t reg_id, AnyView value, int worker_id) override = 0; protected: /*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */ void DeallocReg(int reg_id) override; /*! \brief Call packed function on each worker using a packed sequence */ - DRef CallWithPacked(const TVMArgs& args) override; + DRef CallWithPacked(const ffi::PackedArgs& args) override; /*! \brief Allocate a register id, either from `free_regs_` or by incrementing `reg_count_` */ virtual int AllocateReg(); /*! @@ -59,12 +59,12 @@ class BcastSessionObj : public SessionObj { */ virtual void AppendHostNDArray(const NDArray& host_array); /*! - * \brief Broadcast a command to all workers via TVM's PackedFunc calling convention. + * \brief Broadcast a command to all workers via TVM's ffi::Function calling convention. * As part of the calling convention, The first argument in the packed sequence must be * the action, and the second argument must be the register id. - * \param TVMArgs The input arguments in TVM's PackedFunc calling convention + * \param ffi::PackedArgs The input arguments in TVM's ffi::Function calling convention */ - virtual void BroadcastPacked(const TVMArgs& args) = 0; + virtual void BroadcastPacked(const ffi::PackedArgs& args) = 0; /*! * \brief Send a packed sequence to a worker. This function is usually called by the controler to @@ -73,7 +73,7 @@ class BcastSessionObj : public SessionObj { * \param worker_id The worker id to send the packed sequence to. * \param args The packed sequence to send. */ - virtual void SendPacked(int worker_id, const TVMArgs& args) = 0; + virtual void SendPacked(int worker_id, const ffi::PackedArgs& args) = 0; /*! * \brief Receive a packed sequence from a worker. This function is usually called by the @@ -82,7 +82,7 @@ class BcastSessionObj : public SessionObj { * with the controler. Receiving from other workers may not be supported. * \return The packed sequence received. */ - virtual TVMArgs RecvReplyPacked(int worker_id) = 0; + virtual ffi::PackedArgs RecvReplyPacked(int worker_id) = 0; /*! \brief A side channel to communicate with worker-0 */ WorkerZeroData worker_zero_data_; diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 889a152f6108..a58c840ea325 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -51,12 +51,12 @@ Module LoadVMModule(std::string path, Device device) { static DSOLibraryCache cache; Module dso_mod = cache.Open(path); device = UseDefaultDeviceIfNone(device); - PackedFunc vm_load_executable = dso_mod.GetFunction("vm_load_executable"); + ffi::Function vm_load_executable = dso_mod.GetFunction("vm_load_executable"); CHECK(vm_load_executable != nullptr) << "ValueError: File `" << path << "` is not built by RelaxVM, because `vm_load_executable` does not exist"; auto mod = vm_load_executable().cast(); - PackedFunc vm_initialization = mod.GetFunction("vm_initialization"); + ffi::Function vm_initialization = mod.GetFunction("vm_initialization"); CHECK(vm_initialization != nullptr) << "ValueError: File `" << path << "` is not built by RelaxVM, because `vm_initialization` does not exist"; @@ -70,7 +70,7 @@ NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device) { return NDArray::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } -PackedFunc GetCCLFunc(const char* name) { +ffi::Function GetCCLFunc(const char* name) { std::string ccl = DiscoWorker::ThreadLocal()->ccl; std::string pf_name = "runtime.disco." + ccl + "." + name; const auto pf = tvm::ffi::Function::GetGlobal(pf_name); diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index 486b9d2ae12d..7f98feacd83b 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -36,7 +36,7 @@ TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() { void DiscoWorker::SetRegister(int reg_id, AnyView value) { ICHECK(0 <= reg_id && reg_id < static_cast(register_file.size())); - TVMRetValue& rv = register_file.at(reg_id); + ffi::Any& rv = register_file.at(reg_id); if (rv.type_index() == ffi::TypeIndex::kTVMFFINDArray && value.type_index() == ffi::TypeIndex::kTVMFFINDArray) { NDArray dst = rv.cast(); @@ -52,7 +52,7 @@ struct DiscoWorker::Impl { ThreadLocalDiscoWorker::Get()->worker = self; using namespace tvm; while (true) { - TVMArgs args = self->channel->Recv(); + ffi::PackedArgs args = self->channel->Recv(); DiscoAction action = static_cast(args[0].cast()); int64_t reg_id = args[1].cast(); switch (action) { @@ -71,7 +71,7 @@ struct DiscoWorker::Impl { case DiscoAction::kCallPacked: { int func_reg_id = args[2].cast(); CHECK_LT(func_reg_id, self->register_file.size()); - PackedFunc func = GetReg(self, func_reg_id).cast(); + ffi::Function func = GetReg(self, func_reg_id).cast(); CHECK(func.defined()); CallPacked(self, reg_id, func, args.Slice(3)); break; @@ -147,7 +147,7 @@ struct DiscoWorker::Impl { static void DebugGetFromRemote(DiscoWorker* self, int reg_id, int worker_id) { if (worker_id == self->worker_id) { - TVMRetValue rv = GetReg(self, reg_id); + ffi::Any rv = GetReg(self, reg_id); if (rv.as()) { rv = DiscoDebugObject::Wrap(rv); } @@ -167,8 +167,8 @@ struct DiscoWorker::Impl { } } - static void CallPacked(DiscoWorker* self, int64_t ret_reg_id, PackedFunc func, - const TVMArgs& args) { + static void CallPacked(DiscoWorker* self, int64_t ret_reg_id, ffi::Function func, + const ffi::PackedArgs& args) { // NOTE: this action is not safe unless we know args is not // used else where in this case it is oK AnyView* args_vec = const_cast(args.data()); @@ -179,12 +179,12 @@ struct DiscoWorker::Impl { args_vec[i] = GetReg(self, dref->reg_id); } } - TVMRetValue rv; + ffi::Any rv; func.CallPacked(ffi::PackedArgs(args_vec, args.size()), &rv); GetReg(self, ret_reg_id) = std::move(rv); } - static TVMRetValue& GetReg(DiscoWorker* self, int64_t reg_id) { + static ffi::Any& GetReg(DiscoWorker* self, int64_t reg_id) { if (reg_id >= static_cast(self->register_file.size())) { self->register_file.resize(reg_id + 1); } diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index 4a0f0380946d..9c25d4abb68e 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -42,10 +42,10 @@ class DiscoSocketChannel : public DiscoChannel { DiscoSocketChannel(DiscoSocketChannel&& other) = delete; DiscoSocketChannel(const DiscoSocketChannel& other) = delete; - void Send(const TVMArgs& args) { message_queue_.Send(args); } - TVMArgs Recv() { return message_queue_.Recv(); } - void Reply(const TVMArgs& args) { message_queue_.Send(args); } - TVMArgs RecvReply() { return message_queue_.Recv(); } + void Send(const ffi::PackedArgs& args) { message_queue_.Send(args); } + ffi::PackedArgs Recv() { return message_queue_.Recv(); } + void Reply(const ffi::PackedArgs& args) { message_queue_.Send(args); } + ffi::PackedArgs RecvReply() { return message_queue_.Recv(); } private: TCPSocket socket_; @@ -96,7 +96,7 @@ class SocketSessionObj : public BcastSessionObj { int64_t GetNumWorkers() final { return num_nodes_ * num_workers_per_node_; } - TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) final { + ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) final { int node_id = worker_id / num_workers_per_node_; if (node_id == 0) { return local_session_->DebugGetFromRemote(reg_id, worker_id); @@ -105,10 +105,10 @@ class SocketSessionObj : public BcastSessionObj { ffi::PackedArgs::Fill(packed_args, static_cast(DiscoSocketAction::kSend), worker_id, static_cast(DiscoAction::kDebugGetFromRemote), reg_id, worker_id); remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args, 5)); - TVMArgs args = this->RecvReplyPacked(worker_id); + ffi::PackedArgs args = this->RecvReplyPacked(worker_id); ICHECK_EQ(args.size(), 2); ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugGetFromRemote); - TVMRetValue result; + ffi::Any result; result = args[1]; return result; } @@ -131,14 +131,14 @@ class SocketSessionObj : public BcastSessionObj { value); remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args, 6)); } - TVMRetValue result; - TVMArgs args = this->RecvReplyPacked(worker_id); + ffi::Any result; + ffi::PackedArgs args = this->RecvReplyPacked(worker_id); ICHECK_EQ(args.size(), 1); ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugSetRegister); } } - void BroadcastPacked(const TVMArgs& args) final { + void BroadcastPacked(const ffi::PackedArgs& args) final { local_session_->BroadcastPacked(args); std::vector packed_args(args.size() + 2); ffi::PackedArgs::Fill(packed_args.data(), static_cast(DiscoSocketAction::kSend), -1); @@ -148,7 +148,7 @@ class SocketSessionObj : public BcastSessionObj { } } - void SendPacked(int worker_id, const TVMArgs& args) final { + void SendPacked(int worker_id, const ffi::PackedArgs& args) final { int node_id = worker_id / num_workers_per_node_; if (node_id == 0) { local_session_->SendPacked(worker_id, args); @@ -161,7 +161,7 @@ class SocketSessionObj : public BcastSessionObj { remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args.data(), packed_args.size())); } - TVMArgs RecvReplyPacked(int worker_id) final { + ffi::PackedArgs RecvReplyPacked(int worker_id) final { int node_id = worker_id / num_workers_per_node_; if (node_id == 0) { return local_session_->RecvReplyPacked(worker_id); @@ -220,7 +220,7 @@ class RemoteSocketSession { << ", errno = " << Socket::GetLastErrorCode(); } channel_ = std::make_unique(socket_); - TVMArgs metadata = channel_->Recv(); + ffi::PackedArgs metadata = channel_->Recv(); ICHECK_EQ(metadata.size(), 4); num_nodes_ = metadata[0].cast(); num_workers_per_node_ = metadata[1].cast(); @@ -232,7 +232,7 @@ class RemoteSocketSession { void MainLoop() { while (true) { - TVMArgs args = channel_->Recv(); + ffi::PackedArgs args = channel_->Recv(); DiscoSocketAction action = static_cast(args[0].cast()); int worker_id = args[1].cast(); int local_worker_id = worker_id - node_id_ * num_workers_per_node_; diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index 19d2b80f40e4..7a79f0c392ef 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -147,8 +147,8 @@ class ShardLoaderObj : public Object { const ParamRecord* param; ShardInfo shard_info; }; - /*! \brief The PackedFuncs being used during sharding */ - std::unordered_map shard_funcs_; + /*! \brief The ffi::Functions being used during sharding */ + std::unordered_map shard_funcs_; /*! \brief The metadata loaded from `ndarray-cache.json` */ NDArrayCacheMetadata metadata_; /*! \brief Sharding information for each weight */ @@ -179,7 +179,8 @@ TVM_REGISTER_OBJECT_TYPE(ShardLoaderObj); ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata, std::string shard_info, Module mod) { if (shard_info.empty() && mod.defined()) { - if (PackedFunc get_shard_info = mod->GetFunction("get_shard_info"); get_shard_info != nullptr) { + if (ffi::Function get_shard_info = mod->GetFunction("get_shard_info"); + get_shard_info != nullptr) { shard_info = get_shard_info().cast(); } } @@ -196,7 +197,8 @@ ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std: ShardInfo& shard_info = shards[name]; for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) { const std::string& name = shard_func.name; - if (PackedFunc f = mod.defined() ? mod->GetFunction(name, true) : nullptr; f != nullptr) { + if (ffi::Function f = mod.defined() ? mod->GetFunction(name, true) : nullptr; + f != nullptr) { n->shard_funcs_[name] = f; } else if (const auto f = tvm::ffi::Function::GetGlobal(name)) { n->shard_funcs_[name] = *f; @@ -214,7 +216,7 @@ NDArray ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const NDArray& param) const { Device device = param->device; NDArray o = NDArray::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device); - PackedFunc f = this->shard_funcs_.at(shard_func.name); + ffi::Function f = this->shard_funcs_.at(shard_func.name); int n = static_cast(shard_func.params.size()); std::vector packed_args(n + 2); const DLTensor* w_in = param.operator->(); diff --git a/src/runtime/disco/message_queue.h b/src/runtime/disco/message_queue.h index efd3b1735bf3..6b3600acbb97 100644 --- a/src/runtime/disco/message_queue.h +++ b/src/runtime/disco/message_queue.h @@ -36,7 +36,7 @@ class DiscoStreamMessageQueue : private dmlc::Stream, ~DiscoStreamMessageQueue() = default; - void Send(const TVMArgs& args) { + void Send(const ffi::PackedArgs& args) { // Run legacy ABI translation. std::vector values(args.size()); std::vector type_codes(args.size()); @@ -46,7 +46,7 @@ class DiscoStreamMessageQueue : private dmlc::Stream, CommitSendAndNotifyEnqueue(); } - TVMArgs Recv() { + ffi::PackedArgs Recv() { bool is_implicit_shutdown = DequeueNextPacket(); AnyView* packed_args = nullptr; int num_args = 0; diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index e0eb0215e254..4265bd21c43d 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -48,10 +48,10 @@ class DiscoProcessChannel final : public DiscoChannel { DiscoProcessChannel(DiscoProcessChannel&& other) = delete; DiscoProcessChannel(const DiscoProcessChannel& other) = delete; - void Send(const TVMArgs& args) { controler_to_worker_.Send(args); } - TVMArgs Recv() { return controler_to_worker_.Recv(); } - void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); } - TVMArgs RecvReply() { return worker_to_controler_.Recv(); } + void Send(const ffi::PackedArgs& args) { controler_to_worker_.Send(args); } + ffi::PackedArgs Recv() { return controler_to_worker_.Recv(); } + void Reply(const ffi::PackedArgs& args) { worker_to_controler_.Send(args); } + ffi::PackedArgs RecvReply() { return worker_to_controler_.Recv(); } support::Pipe controller_to_worker_pipe_; support::Pipe worker_to_controller_pipe_; @@ -61,7 +61,7 @@ class DiscoProcessChannel final : public DiscoChannel { class ProcessSessionObj final : public BcastSessionObj { public: - explicit ProcessSessionObj(int num_workers, int num_groups, PackedFunc process_pool) + explicit ProcessSessionObj(int num_workers, int num_groups, ffi::Function process_pool) : process_pool_(process_pool), worker_0_( std::make_unique(0, num_workers, num_groups, &worker_zero_data_)) { @@ -94,7 +94,7 @@ class ProcessSessionObj final : public BcastSessionObj { int64_t GetNumWorkers() { return workers_.size() + 1; } - TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { + ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) { if (worker_id == 0) { this->SyncWorker(worker_id); return worker_0_->worker->register_file.at(reg_id); @@ -105,10 +105,10 @@ class ProcessSessionObj final : public BcastSessionObj { worker_id); workers_[worker_id - 1]->Send(ffi::PackedArgs(packed_args, 3)); } - TVMArgs args = this->RecvReplyPacked(worker_id); + ffi::PackedArgs args = this->RecvReplyPacked(worker_id); ICHECK_EQ(args.size(), 2); ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugGetFromRemote); - TVMRetValue result; + ffi::Any result; result = args[1]; return result; } @@ -130,20 +130,20 @@ class ProcessSessionObj final : public BcastSessionObj { worker_id, value); SendPacked(worker_id, ffi::PackedArgs(packed_args, 4)); } - TVMRetValue result; - TVMArgs args = this->RecvReplyPacked(worker_id); + ffi::Any result; + ffi::PackedArgs args = this->RecvReplyPacked(worker_id); ICHECK_EQ(args.size(), 1); ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugSetRegister); } - void BroadcastPacked(const TVMArgs& args) final { + void BroadcastPacked(const ffi::PackedArgs& args) final { worker_0_->channel->Send(args); for (std::unique_ptr& channel : workers_) { channel->Send(args); } } - void SendPacked(int worker_id, const TVMArgs& args) final { + void SendPacked(int worker_id, const ffi::PackedArgs& args) final { if (worker_id == 0) { worker_0_->channel->Send(args); } else { @@ -151,7 +151,7 @@ class ProcessSessionObj final : public BcastSessionObj { } } - TVMArgs RecvReplyPacked(int worker_id) final { + ffi::PackedArgs RecvReplyPacked(int worker_id) final { if (worker_id == 0) { return worker_0_->channel->RecvReply(); } @@ -165,7 +165,7 @@ class ProcessSessionObj final : public BcastSessionObj { return workers_.at(worker_id - 1).get(); } - PackedFunc process_pool_; + ffi::Function process_pool_; std::unique_ptr worker_0_; std::vector> workers_; @@ -183,7 +183,7 @@ Session Session::ProcessSession(int num_workers, int num_group, String process_p const auto pf = tvm::ffi::Function::GetGlobal(process_pool_creator); CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator << " in the registry. Please check if it is registered."; - auto process_pool = (*pf)(num_workers, num_group, entrypoint).cast(); + auto process_pool = (*pf)(num_workers, num_group, entrypoint).cast(); auto n = make_object(num_workers, num_group, process_pool); return Session(n); } diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index 6023deaf9796..13fda2fecde4 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -93,18 +93,18 @@ struct DiscoProtocol { struct DiscoDebugObject : public Object { public: /*! \brief The data to be serialized */ - TVMRetValue data; + ffi::Any data; /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug extension. */ - static ObjectRef Wrap(const TVMRetValue& data) { + static ObjectRef Wrap(const ffi::Any& data) { ObjectPtr n = make_object(); n->data = data; return ObjectRef(n); } /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug extension. */ - static ObjectRef Wrap(const TVMArgValue& data) { - TVMRetValue rv; + static ObjectRef Wrap(const ffi::AnyView& data) { + ffi::Any rv; rv = data; return Wrap(std::move(rv)); } diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index ae94f24fbe7c..467888c65768 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -25,7 +25,7 @@ namespace tvm { namespace runtime { struct SessionObj::FFI { - static DRef CallWithPacked(Session sess, const TVMArgs& args) { + static DRef CallWithPacked(Session sess, const ffi::PackedArgs& args) { return sess->CallWithPacked(args); } }; @@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker").set_body_method(&SessionO TVM_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") // .set_body_method(&SessionObj::InitCCL); TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Session self = args[0].cast(); *rv = SessionObj::FFI::CallWithPacked(self, args.Slice(1)); }); diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index b9239daf9596..f40fae007e50 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -40,7 +40,7 @@ namespace runtime { class DiscoThreadedMessageQueue : private dmlc::Stream, private DiscoProtocol { public: - void Send(const TVMArgs& args) { + void Send(const ffi::PackedArgs& args) { // Run legacy ABI translation. std::vector values(args.size()); std::vector type_codes(args.size()); @@ -50,7 +50,7 @@ class DiscoThreadedMessageQueue : private dmlc::Stream, CommitSendAndNotifyEnqueue(); } - TVMArgs Recv() { + ffi::PackedArgs Recv() { DequeueNextPacket(); TVMValue* values = nullptr; int* type_codes = nullptr; @@ -132,10 +132,10 @@ class DiscoThreadedMessageQueue : private dmlc::Stream, class DiscoThreadChannel final : public DiscoChannel { public: - void Send(const TVMArgs& args) { controler_to_worker_.Send(args); } - TVMArgs Recv() { return controler_to_worker_.Recv(); } - void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); } - TVMArgs RecvReply() { return worker_to_controler_.Recv(); } + void Send(const ffi::PackedArgs& args) { controler_to_worker_.Send(args); } + ffi::PackedArgs Recv() { return controler_to_worker_.Recv(); } + void Reply(const ffi::PackedArgs& args) { worker_to_controler_.Send(args); } + ffi::PackedArgs RecvReply() { return worker_to_controler_.Recv(); } DiscoThreadedMessageQueue controler_to_worker_; DiscoThreadedMessageQueue worker_to_controler_; @@ -165,7 +165,7 @@ class ThreadedSessionObj final : public BcastSessionObj { int64_t GetNumWorkers() { return workers_.size(); } - TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { + ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) { this->SyncWorker(worker_id); return this->workers_.at(worker_id).worker->register_file.at(reg_id); } @@ -175,17 +175,17 @@ class ThreadedSessionObj final : public BcastSessionObj { this->workers_.at(worker_id).worker->SetRegister(reg_id, value); } - void BroadcastPacked(const TVMArgs& args) final { + void BroadcastPacked(const ffi::PackedArgs& args) final { for (const DiscoWorkerThread& worker : this->workers_) { worker.channel->Send(args); } } - void SendPacked(int worker_id, const TVMArgs& args) final { + void SendPacked(int worker_id, const ffi::PackedArgs& args) final { this->workers_.at(worker_id).channel->Send(args); } - TVMArgs RecvReplyPacked(int worker_id) final { + ffi::PackedArgs RecvReplyPacked(int worker_id) final { return this->workers_.at(worker_id).channel->RecvReply(); } diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index 69be065210c2..c959e39e1d39 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -90,7 +90,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ObjectPtr n = CreateDSOLibraryObject(args[0].cast()); *rv = CreateModuleFromLibrary(n); }); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 1ebe83b5a4a9..40294324018b 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -43,7 +43,7 @@ HexagonDeviceAPI* HexagonDeviceAPI::Global() { return inst; } -void HexagonDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { +void HexagonDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { if (kind == kExist) { *rv = 1; } @@ -191,7 +191,7 @@ void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void } TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto dst = args[0].cast(); auto src = args[1].cast(); int size = args[2].cast(); @@ -210,7 +210,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); void* dst = args[1].cast(); void* src = args[2].cast(); @@ -227,7 +227,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); int inflight = args[1].cast(); ICHECK(inflight >= 0); @@ -236,21 +236,21 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); *rv = static_cast(0); }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); *rv = static_cast(0); }); TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); int32_t dtype_code_hint = args[2].cast(); @@ -275,7 +275,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") }); TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); auto scope = args[2].cast(); @@ -292,24 +292,24 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd") }); TVM_REGISTER_GLOBAL("device_api.hexagon.acquire_resources") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); api->AcquireResources(); }); TVM_REGISTER_GLOBAL("device_api.hexagon.release_resources") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); api->ReleaseResources(); }); TVM_REGISTER_GLOBAL("device_api.hexagon.vtcm_device_bytes") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); *rv = static_cast(api->VtcmPool()->VtcmDeviceBytes()); }); -TVM_REGISTER_GLOBAL("device_api.hexagon").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.hexagon").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = HexagonDeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/hexagon/hexagon_device_api.h b/src/runtime/hexagon/hexagon_device_api.h index c4e87a957ade..e77e681dd434 100644 --- a/src/runtime/hexagon/hexagon_device_api.h +++ b/src/runtime/hexagon/hexagon_device_api.h @@ -97,7 +97,7 @@ class HexagonDeviceAPI final : public DeviceAPI { void SetDevice(Device dev) final{}; //! \brief Return the queried Hexagon device attribute. - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; //! \brief Currently unimplemented interface to synchronize a device stream. void StreamSync(Device dev, TVMStreamHandle stream) final {} diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 26297247fcbc..6ed2a2757f68 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -42,8 +42,8 @@ HexagonModuleNode::HexagonModuleNode(std::string data, std::string fmt, std::string bc_str) : data_(data), fmt_(fmt), fmap_(fmap), asm_(asm_str), obj_(obj_str), ir_(ir_str), bc_(bc_str) {} -PackedFunc HexagonModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Function HexagonModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { LOG(FATAL) << "HexagonModuleNode::GetFunction is not implemented."; } diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index 0abe175e907c..b8a830bc7c29 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -59,7 +59,7 @@ class HexagonModuleNode : public runtime::ModuleNode { HexagonModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str); - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; String GetSource(const String& format) override; const char* type_key() const final { return "hexagon"; } /*! \brief Get the property of the runtime module .*/ diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index dc1fa9090984..265e5bb12e57 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -110,7 +110,7 @@ class HexagonTransportChannel : public RPCChannel { }; TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() >= 4) << args.size() << " is less than 4"; auto session_name = args[0].cast(); diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index 97b648bfa338..78d65fb8deeb 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -243,15 +243,15 @@ tvm::runtime::hexagon::HexagonRPCServer* get_hexagon_rpc_server( } } // namespace -const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) { +const tvm::ffi::Function get_runtime_func(const std::string& name) { if (const auto pf = tvm::ffi::Function::GetGlobal(name)) { return *pf; } - return tvm::runtime::PackedFunc(); + return tvm::ffi::Function(); } void reset_device_api() { - const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon"); + const tvm::ffi::Function api = get_runtime_func("device_api.hexagon"); // Registering device_api.cpu as device_api.hexagon since we use hexagon as sub-target of LLVM. tvm::ffi::Function::SetGlobal("device_api.cpu", api, true); } @@ -330,14 +330,14 @@ __attribute__((weak)) void _Parse_fde_instr() {} } TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") - .set_body_packed([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); *rv = CreateModuleFromLibrary(n); }); TVM_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") - .set_body_packed([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto profiling_mode = args[0].cast(); auto out_file = args[1].cast(); if (profiling_mode.compare("lwp") == 0) { @@ -355,7 +355,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { } TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") - .set_body_packed([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto file_name = args[0].cast(); auto data = args[1].cast(); SaveBinaryToFile(file_name, data); diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index b0d31bf59b2d..a98abe634e8b 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -333,14 +333,14 @@ __attribute__((weak)) void _Parse_fde_instr() {} } TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") - .set_body_packed([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); *rv = CreateModuleFromLibrary(n); }); TVM_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") - .set_body_packed([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto profiling_mode = args[0].cast(); auto out_file = args[1].cast(); if (profiling_mode.compare("lwp") == 0) { @@ -358,7 +358,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { } TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") - .set_body_packed([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto file_name = args[0].cast(); auto data = args[1].cast(); SaveBinaryToFile(file_name, data); diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index 8c02f5546681..7366371b491a 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -1371,7 +1371,7 @@ std::optional SimulatorRPCChannel::to_nullptr(const detail::Mayb } TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() >= 4) << args.size() << " is less than 4"; auto session_name = args[0].cast(); diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 00c5d719965f..f580a6d667f1 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -38,7 +38,7 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib, PackedFuncWrapper wrapper) + explicit LibraryModuleNode(ObjectPtr lib, FFIFunctionWrapper wrapper) : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } @@ -48,7 +48,7 @@ class LibraryModuleNode final : public ModuleNode { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; }; - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { TVMFFISafeCallType faddr; if (name == runtime::symbol::tvm_module_main) { const char* entry_name = @@ -59,16 +59,16 @@ class LibraryModuleNode final : public ModuleNode { } else { faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); } - if (faddr == nullptr) return PackedFunc(); + if (faddr == nullptr) return ffi::Function(); return packed_func_wrapper_(faddr, sptr_to_self); } private: ObjectPtr lib_; - PackedFuncWrapper packed_func_wrapper_; + FFIFunctionWrapper packed_func_wrapper_; }; -PackedFunc WrapPackedFunc(TVMFFISafeCallType faddr, const ObjectPtr& sptr_to_self) { +ffi::Function WrapFFIFunction(TVMFFISafeCallType faddr, const ObjectPtr& sptr_to_self) { return ffi::Function::FromPacked([faddr, sptr_to_self](ffi::PackedArgs args, ffi::Any* rv) { ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), @@ -114,7 +114,7 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { * \param dso_ctx_addr the output dso module */ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, - PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module, + FFIFunctionWrapper packed_func_wrapper, runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); uint64_t nbytes = 0; @@ -180,7 +180,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, } } -Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper packed_func_wrapper) { +Module CreateModuleFromLibrary(ObjectPtr lib, FFIFunctionWrapper packed_func_wrapper) { InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); auto n = make_object(lib, packed_func_wrapper); // Load the imported modules diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index c380227e15f9..ccc0b3193b87 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -71,7 +71,7 @@ class Library : public Object { * \param faddr The function address * \param mptr The module pointer node. */ -PackedFunc WrapPackedFunc(TVMFFISafeCallType faddr, const ObjectPtr& mptr); +ffi::Function WrapFFIFunction(TVMFFISafeCallType faddr, const ObjectPtr& mptr); /*! * \brief Utility to initialize conext function symbols during startup @@ -94,8 +94,8 @@ class ModuleInternal { * \param mptr The module pointer node. * \return Packed function that wraps the invocation of the function at faddr. */ -using PackedFuncWrapper = - std::function& mptr)>; +using FFIFunctionWrapper = + std::function& mptr)>; /*! \brief Return a library object interface over dynamic shared * libraries in Windows and Linux providing support for @@ -110,7 +110,7 @@ ObjectPtr CreateDSOLibraryObject(std::string library_path); * * \param lib The library. * \param wrapper Optional function used to wrap a TVMBackendPackedCFunc, - * by default WrapPackedFunc is used. + * by default WrapFFIFunction is used. * \param symbol_prefix Optional symbol prefix that can be used to search alternative symbols. * * \return The corresponding loaded module. @@ -118,7 +118,8 @@ ObjectPtr CreateDSOLibraryObject(std::string library_path); * \note This function can create multiple linked modules * by parsing the binary blob section of the library. */ -Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper wrapper = WrapPackedFunc); +Module CreateModuleFromLibrary(ObjectPtr lib, + FFIFunctionWrapper wrapper = WrapFFIFunction); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index d68dd0b2cd3b..ab383732ea8c 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -162,7 +162,7 @@ class MetalWorkspace final : public DeviceAPI { } // override device API void SetDevice(Device dev) final; - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(Device dev, void* ptr) final; TVMStreamHandle CreateStream(Device dev) final; diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index ead9c1804a77..83f2c38a2bd5 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -41,7 +41,7 @@ return inst; } -void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { +void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { AUTORELEASEPOOL { size_t index = static_cast(dev.device_id); if (kind == kExist) { @@ -362,7 +362,7 @@ int GetWarpSize(id dev) { MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.metal").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.metal").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = MetalWorkspace::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index a9a7faaefa3c..cc25fd8b0daf 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -57,7 +57,7 @@ int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const String& file_name, const String& format) final { LOG(FATAL) << "Do not support save to file, use save to binary and export instead"; @@ -187,7 +187,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na scache_[dev_id] = m->GetPipelineState(dev_id, func_name); } // invoke the function with void arguments - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { + void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) const { AUTORELEASEPOOL { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; @@ -258,14 +258,15 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons LaunchParamConfig launch_param_config_; }; -PackedFunc MetalModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { - PackedFunc pf; +ffi::Function MetalModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { + ffi::Function f; AUTORELEASEPOOL { ICHECK_EQ(sptr_to_self.get(), this); ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) { - pf = PackedFunc(); + f = ffi::Function(); return; } const FunctionInfo& info = it->second; diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 7b8e7765027a..5ba2248f7627 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -57,9 +57,9 @@ void ModuleNode::Import(Module other) { this->imports_.emplace_back(std::move(other)); } -PackedFunc ModuleNode::GetFunction(const String& name, bool query_imports) { +ffi::Function ModuleNode::GetFunction(const String& name, bool query_imports) { ModuleNode* self = this; - PackedFunc pf = self->GetFunction(name, GetObjectPtr(this)); + ffi::Function pf = self->GetFunction(name, GetObjectPtr(this)); if (pf != nullptr) return pf; if (query_imports) { for (Module& m : self->imports_) { @@ -101,11 +101,11 @@ String ModuleNode::GetSource(const String& format) { LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource"; } -const PackedFunc* ModuleNode::GetFuncFromEnv(const String& name) { +const ffi::Function* ModuleNode::GetFuncFromEnv(const String& name) { std::lock_guard lock(mutex_); auto it = import_cache_.find(name); if (it != import_cache_.end()) return it->second.get(); - PackedFunc pf; + ffi::Function pf; for (Module& m : this->imports_) { pf = m.GetFunction(name, true); if (pf != nullptr) break; @@ -117,10 +117,10 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const String& name) { << " If this involves ops from a contrib library like" << " cuDNN, ensure TVM was built with the relevant" << " library."; - import_cache_.insert(std::make_pair(name, std::make_shared(*f))); + import_cache_.insert(std::make_pair(name, std::make_shared(*f))); return import_cache_.at(name).get(); } else { - import_cache_.insert(std::make_pair(name, std::make_shared(pf))); + import_cache_.insert(std::make_pair(name, std::make_shared(pf))); return import_cache_.at(name).get(); } } diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 94ab736f5ed5..6f2a9e610363 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -347,7 +347,7 @@ class OpenCLWorkspace : public DeviceAPI { cl_device_id GetCLDeviceID(int device_id); // override device API void SetDevice(Device dev) final; - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; void* AllocDataSpace(Device dev, size_t size, size_t alignment, DLDataType type_hint) final; void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, Optional mem_scope = NullOpt) final; @@ -479,7 +479,7 @@ class OpenCLModuleNodeBase : public ModuleNode { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; // Initialize the programs virtual void Init() = 0; @@ -509,7 +509,7 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap, std::string source) : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; // Return true if OpenCL program for the requested function and device was created bool IsProgramCreated(const std::string& func_name, int device_id); void SaveToFile(const String& file_name, const String& format) final; diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 18c265697b8e..a436ede61bc9 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -132,7 +132,7 @@ cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) { void OpenCLWorkspace::SetDevice(Device dev) { GetThreadEntry()->device.device_id = dev.device_id; } -void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { +void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { this->Init(); size_t index = static_cast(dev.device_id); if (kind == kExist) { @@ -761,7 +761,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic } TVM_REGISTER_GLOBAL("device_api.opencl.alloc_nd") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); int32_t dtype_code_hint = args[2].cast(); @@ -788,21 +788,22 @@ TVM_REGISTER_GLOBAL("device_api.opencl.alloc_nd") String("global.texture")); }); -TVM_REGISTER_GLOBAL("device_api.opencl.free_nd").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - int32_t device_type = args[0].cast(); - int32_t device_id = args[1].cast(); - auto scope = args[2].cast(); - CHECK(scope.find("texture") != std::string::npos); - void* data = args[3].cast(); - OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); - Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - ptr->FreeDataSpace(dev, data); - *rv = static_cast(0); -}); +TVM_REGISTER_GLOBAL("device_api.opencl.free_nd") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + int32_t device_type = args[0].cast(); + int32_t device_id = args[1].cast(); + auto scope = args[2].cast(); + CHECK(scope.find("texture") != std::string::npos); + void* data = args[3].cast(); + OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + ptr->FreeDataSpace(dev, data); + *rv = static_cast(0); + }); -TVM_REGISTER_GLOBAL("device_api.opencl").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.opencl").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = OpenCLWorkspace::Global(); *rv = static_cast(ptr); }); @@ -892,10 +893,11 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { } }; -TVM_REGISTER_GLOBAL("DeviceAllocator.opencl").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - Allocator* alloc = new OpenCLPooledAllocator(); - *rv = static_cast(alloc); -}); +TVM_REGISTER_GLOBAL("DeviceAllocator.opencl") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + Allocator* alloc = new OpenCLPooledAllocator(); + *rv = static_cast(alloc); + }); } // namespace cl size_t OpenCLTimerNode::count_timer_execs = 0; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 3b0dafc243a6..90cdcb48bf96 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -50,7 +50,7 @@ class OpenCLWrappedFunc { launch_param_config_.Init(arg_size.size(), launch_param_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { + void operator()(ffi::PackedArgs args, ffi::Any* rv, void** void_args) const { ICHECK(w_->devices.size() > 0) << "No OpenCL device"; cl::OpenCLThreadEntry* t = w_->GetThreadEntry(); // get the kernel from thread local kernel table. @@ -134,12 +134,12 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } -PackedFunc OpenCLModuleNodeBase::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); - if (it == fmap_.end()) return PackedFunc(); + if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; OpenCLWrappedFunc f; std::vector arg_size(info.arg_types.size()); @@ -345,15 +345,15 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { return data; } -PackedFunc OpenCLModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Function OpenCLModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); if (name == "opencl.GetPreCompiledPrograms") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->GetPreCompiledPrograms(); }); } else if (name == "opencl.SetPreCompiledPrograms") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { this->SetPreCompiledPrograms(args[0].cast()); }); } diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 7cce38cde959..ec000524fa00 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -19,7 +19,7 @@ /*! * \file pack_args.h - * \brief Utility to pack TVMArgs to other type-erased fution calling convention. + * \brief Utility to pack ffi::PackedArgs to other type-erased fution calling convention. * * Two type erased function signatures are supported. * - cuda_style(void** args, int num_args); @@ -62,39 +62,39 @@ union ArgUnion64 { /*! * \brief Create a packed function from void addr types. * - * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args) + * \param f with signiture (ffi::PackedArgs args, ffi::Any* rv, void* void_args) * \param arg_types The arguments type information. * \tparam F the function type * * \return The wrapped packed function. */ template -inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types); +inline ffi::Function PackFuncVoidAddr(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function only packs buffer arguments. * - * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args) + * \param f with signiture (ffi::PackedArgs args, ffi::Any* rv, ArgUnion* pack_args) * \param arg_types The arguments type information. * \tparam F the function type * * \return The wrapped packed function. */ template -inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types); +inline ffi::Function PackFuncNonBufferArg(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function that takes a packed arguments. * * This procedure ensures inserts padding to ensure proper alignment of struct fields * per C struct convention * - * \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes) + * \param f with signature (ffi::PackedArgs args, ffi::Any* rv, void* pack_args, size_t nbytes) * \param arg_types The arguments that wish to get from * \tparam F the function type * * \return The wrapped packed function. */ template -inline PackedFunc PackFuncPackedArgAligned(F f, const std::vector& arg_types); +inline ffi::Function PackFuncPackedArgAligned(F f, const std::vector& arg_types); /*! * \brief Extract number of buffer argument from the argument types. * \param arg_types The argument types. @@ -150,9 +150,9 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { } template -inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& codes) { +inline ffi::Function PackFuncVoidAddr_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); - auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { + auto ret = [f, codes, num_args](ffi::PackedArgs args, ffi::Any* ret) { TempArray addr_(num_args); TempArray holder_(num_args); void** addr = addr_.data(); @@ -187,13 +187,14 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code } f(args, ret, addr); }; - return PackedFunc(ret); + return ffi::Function(ret); } template -inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector& codes) { +inline ffi::Function PackFuncNonBufferArg_(F f, int base, + const std::vector& codes) { int num_args = static_cast(codes.size()); - auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { + auto ret = [f, codes, base, num_args](ffi::PackedArgs args, ffi::Any* ret) { TempArray holder_(num_args); ArgUnion64* holder = holder_.data(); // NOTE: we need the real address of the args.data for some addr translation @@ -229,13 +230,13 @@ inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector -inline PackedFunc PackFuncPackedArgAligned_(F f, const std::vector& codes) { +inline ffi::Function PackFuncPackedArgAligned_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); - auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { + auto ret = [f, codes, num_args](ffi::PackedArgs args, ffi::Any* ret) { TempArray pack_(num_args); int32_t* pack = reinterpret_cast(pack_.data()); int32_t* ptr = pack; @@ -292,12 +293,12 @@ inline PackedFunc PackFuncPackedArgAligned_(F f, const std::vector -inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types) { +inline ffi::Function PackFuncVoidAddr(F f, const std::vector& arg_types) { std::vector codes(arg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { codes[i] = detail::GetArgConvertCode(arg_types[i]); @@ -328,7 +329,7 @@ inline size_t NumBufferArgs(const std::vector& arg_types) { } template -inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types) { +inline ffi::Function PackFuncNonBufferArg(F f, const std::vector& arg_types) { size_t num_buffer = NumBufferArgs(arg_types); std::vector codes; for (size_t i = num_buffer; i < arg_types.size(); ++i) { @@ -345,7 +346,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_t } template -inline PackedFunc PackFuncPackedArgAligned(F f, const std::vector& arg_types) { +inline ffi::Function PackFuncPackedArgAligned(F f, const std::vector& arg_types) { std::vector codes; for (size_t i = 0; i < arg_types.size(); ++i) { codes.push_back(detail::GetArgConvertCode(arg_types[i])); diff --git a/src/runtime/packed_func.cc b/src/runtime/packed_func.cc index 75a29e4398c7..63ec7bbc7d47 100644 --- a/src/runtime/packed_func.cc +++ b/src/runtime/packed_func.cc @@ -18,7 +18,7 @@ */ /* * \file src/runtime/packed_func.cc - * \brief Implementation of non-inlinable PackedFunc pieces. + * \brief Implementation of non-inlinable ffi::Function pieces. */ #include #include @@ -26,7 +26,7 @@ namespace tvm { namespace runtime { -TVM_REGISTER_OBJECT_TYPE(PackedFuncObj); +TVM_REGISTER_OBJECT_TYPE(ffi::FunctionObj); } // namespace runtime } // namespace tvm diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 08fd1fb11a47..c073056fb320 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -798,11 +798,11 @@ TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device return DeviceWrapper(dev); }); -PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, - int warmup_iters, Array collectors) { +ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, + int warmup_iters, Array collectors) { // Module::GetFunction is not const, so this lambda has to be mutable - return PackedFunc::FromPacked([=](const AnyView* args, int32_t num_args, Any* ret) mutable { - PackedFunc f = mod.GetFunction(func_name); + return ffi::Function::FromPacked([=](const AnyView* args, int32_t num_args, Any* ret) mutable { + ffi::Function f = mod.GetFunction(func_name); CHECK(f.defined()) << "There is no function called \"" << func_name << "\" in the module"; Device dev{static_cast(device_type), device_id}; @@ -844,11 +844,11 @@ PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, i } TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") - .set_body_typed)>([](Module mod, String func_name, - int device_type, int device_id, - int warmup_iters, - Array collectors) { + .set_body_typed)>([](Module mod, String func_name, + int device_type, int device_id, + int warmup_iters, + Array collectors) { if (mod->type_key() == std::string("rpc")) { LOG(FATAL) << "Profiling a module over RPC is not yet supported"; // because we can't send @@ -859,9 +859,10 @@ TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") } }); -PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, int min_repeat_ms, - int limit_zero_time_iterations, int cooldown_interval_ms, - int repeats_to_cooldown, int cache_flush_bytes, PackedFunc f_preproc) { +ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int repeat, + int min_repeat_ms, int limit_zero_time_iterations, + int cooldown_interval_ms, int repeats_to_cooldown, + int cache_flush_bytes, ffi::Function f_preproc) { ICHECK(pf != nullptr); auto ftimer = [pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, diff --git a/src/runtime/regex.cc b/src/runtime/regex.cc index 0a3215832e36..8b4df9e69395 100644 --- a/src/runtime/regex.cc +++ b/src/runtime/regex.cc @@ -33,7 +33,7 @@ bool regex_match(const std::string& match_against, const std::string& regex_patt const auto regex_match_func = tvm::ffi::Function::GetGlobal("tvm.runtime.regex_match"); CHECK(regex_match_func.has_value()) << "RuntimeError: " - << "The PackedFunc 'tvm.runtime.regex_match' has not been registered. " + << "The ffi::Function 'tvm.runtime.regex_match' has not been registered. " << "This can occur if the TVM Python library has not yet been imported."; return (*regex_match_func)(regex_pattern, match_against).cast(); } diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 5ea1d7677a6f..1045dc20ef0d 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -44,7 +44,7 @@ namespace runtime { * we need for specific low-level handling(e.g. signal checking). * * We only stores the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). Always consider use the PackedFunc FFI when possible + * cannot trap back into python). Always consider use the ffi::Function FFI when possible * in other cases. */ class EnvCAPIRegistry { diff --git a/src/runtime/relax_vm/attn_backend.cc b/src/runtime/relax_vm/attn_backend.cc index 0b94d541c2dd..09f2d2c736fc 100644 --- a/src/runtime/relax_vm/attn_backend.cc +++ b/src/runtime/relax_vm/attn_backend.cc @@ -33,13 +33,13 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, String backend_name = Downcast(args[0]); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - PackedFunc attn_func = Downcast(args[1]); + ffi::Function attn_func = Downcast(args[1]); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - PackedFunc attn_func = Downcast(args[1]); - PackedFunc plan_func = Downcast(args[2]); + ffi::Function attn_func = Downcast(args[1]); + ffi::Function plan_func = Downcast(args[2]); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -55,13 +55,13 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array arg String backend_name = Downcast(args[0]); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - PackedFunc attn_func = Downcast(args[1]); + ffi::Function attn_func = Downcast(args[1]); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - PackedFunc attn_func = Downcast(args[1]); - PackedFunc plan_func = Downcast(args[2]); + ffi::Function attn_func = Downcast(args[1]); + ffi::Function plan_func = Downcast(args[2]); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -76,13 +76,13 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, A String backend_name = Downcast(args[0]); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - PackedFunc attn_func = Downcast(args[1]); + ffi::Function attn_func = Downcast(args[1]); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - PackedFunc attn_func = Downcast(args[1]); - PackedFunc plan_func = Downcast(args[2]); + ffi::Function attn_func = Downcast(args[1]); + ffi::Function plan_func = Downcast(args[2]); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -98,7 +98,7 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< String backend_name = Downcast(args[0]); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - PackedFunc attn_func = Downcast(args[1]); + ffi::Function attn_func = Downcast(args[1]); return std::make_unique(std::move(attn_func), attn_kind); } LOG(FATAL) << "Cannot reach here"; @@ -113,7 +113,7 @@ std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Arra String backend_name = Downcast(args[0]); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - PackedFunc attn_func = Downcast(args[1]); + ffi::Function attn_func = Downcast(args[1]); return std::make_unique(std::move(attn_func), attn_kind); } LOG(FATAL) << "Cannot reach here"; diff --git a/src/runtime/relax_vm/attn_backend.h b/src/runtime/relax_vm/attn_backend.h index 10142460a6ed..4d3a3ce9d832 100644 --- a/src/runtime/relax_vm/attn_backend.h +++ b/src/runtime/relax_vm/attn_backend.h @@ -49,13 +49,14 @@ enum class AttnBackendKind : int { /*! \brief The base class of attention backends. */ class AttnBackendFunc { public: - explicit AttnBackendFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + explicit AttnBackendFunc(ffi::Function attn_func, AttnKind attn_kind, + AttnBackendKind backend_kind) : attn_func_(std::move(attn_func)), attn_kind(attn_kind), backend_kind(backend_kind) {} virtual ~AttnBackendFunc() = default; protected: - PackedFunc attn_func_; + ffi::Function attn_func_; public: AttnKind attn_kind; @@ -65,7 +66,8 @@ class AttnBackendFunc { /*! \brief The paged prefill attention function base class. */ class PagedPrefillFunc : public AttnBackendFunc { public: - explicit PagedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + explicit PagedPrefillFunc(ffi::Function attn_func, AttnKind attn_kind, + AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} virtual void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, @@ -95,7 +97,7 @@ class PagedPrefillFunc : public AttnBackendFunc { /*! \brief The TIR-based paged prefill attention function class. */ class TIRPagedPrefillFunc : public PagedPrefillFunc { public: - explicit TIRPagedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind) + explicit TIRPagedPrefillFunc(ffi::Function attn_func, AttnKind attn_kind) : PagedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, @@ -120,7 +122,7 @@ class TIRPagedPrefillFunc : public PagedPrefillFunc { /*! \brief The FlashInfer-based paged prefill attention function class. */ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { public: - explicit FlashInferPagedPrefillFunc(PackedFunc attn_func, PackedFunc plan_func, + explicit FlashInferPagedPrefillFunc(ffi::Function attn_func, ffi::Function plan_func, AttnKind attn_kind) : PagedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), plan_func_(std::move(plan_func)) {} @@ -193,14 +195,15 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { } private: - PackedFunc plan_func_; + ffi::Function plan_func_; std::vector> cached_buffers_; }; /*! \brief The ragged prefill attention function base class. */ class RaggedPrefillFunc : public AttnBackendFunc { public: - explicit RaggedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + explicit RaggedPrefillFunc(ffi::Function attn_func, AttnKind attn_kind, + AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} virtual void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, @@ -222,7 +225,7 @@ class RaggedPrefillFunc : public AttnBackendFunc { /*! \brief The TIR-based ragged prefill attention function class. */ class TIRRaggedPrefillFunc : public RaggedPrefillFunc { public: - explicit TIRRaggedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind) + explicit TIRRaggedPrefillFunc(ffi::Function attn_func, AttnKind attn_kind) : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, @@ -239,7 +242,7 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc { /*! \brief The FlashInfer-based ragged prefill attention function class. */ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { public: - explicit FlashInferRaggedPrefillFunc(PackedFunc attn_func, PackedFunc plan_func, + explicit FlashInferRaggedPrefillFunc(ffi::Function attn_func, ffi::Function plan_func, AttnKind attn_kind) : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), plan_func_(std::move(plan_func)) {} @@ -282,7 +285,7 @@ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { } private: - PackedFunc plan_func_; + ffi::Function plan_func_; NDArray float_workspace_buffer_; NDArray int_workspace_buffer_; NDArray page_locked_int_workspace_buffer_; @@ -292,7 +295,8 @@ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { /*! \brief The paged decode attention function base class. */ class PagedDecodeFunc : public AttnBackendFunc { public: - explicit PagedDecodeFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + explicit PagedDecodeFunc(ffi::Function attn_func, AttnKind attn_kind, + AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} virtual void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, @@ -321,7 +325,7 @@ class PagedDecodeFunc : public AttnBackendFunc { /*! \brief The TIR-based paged decode attention function class. */ class TIRPagedDecodeFunc : public PagedDecodeFunc { public: - explicit TIRPagedDecodeFunc(PackedFunc attn_func, AttnKind attn_kind) + explicit TIRPagedDecodeFunc(ffi::Function attn_func, AttnKind attn_kind) : PagedDecodeFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, @@ -344,7 +348,8 @@ class TIRPagedDecodeFunc : public PagedDecodeFunc { /*! \brief The FlashInfer-based paged decode attention function class. */ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { public: - explicit FlashInferPagedDecodeFunc(PackedFunc attn_func, PackedFunc plan_func, AttnKind attn_kind) + explicit FlashInferPagedDecodeFunc(ffi::Function attn_func, ffi::Function plan_func, + AttnKind attn_kind) : PagedDecodeFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), plan_func_(std::move(plan_func)) {} @@ -387,14 +392,14 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { } private: - PackedFunc plan_func_; + ffi::Function plan_func_; std::vector> cached_buffers_; }; /*! \brief The paged prefill with tree mask attention function base class. */ class PagedPrefillTreeMaskFunc : public AttnBackendFunc { public: - explicit PagedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind, + explicit PagedPrefillTreeMaskFunc(ffi::Function attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} @@ -425,7 +430,7 @@ class PagedPrefillTreeMaskFunc : public AttnBackendFunc { /*! \brief The TIR-based paged prefill with tree mask attention function class. */ class TIRPagedPrefillTreeMaskFunc : public PagedPrefillTreeMaskFunc { public: - explicit TIRPagedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind) + explicit TIRPagedPrefillTreeMaskFunc(ffi::Function attn_func, AttnKind attn_kind) : PagedPrefillTreeMaskFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} void MHA(NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, NDArray page_indices, @@ -443,7 +448,7 @@ class TIRPagedPrefillTreeMaskFunc : public PagedPrefillTreeMaskFunc { /*! \brief The ragged prefill with tree mask function base class. */ class RaggedPrefillTreeMaskFunc : public AttnBackendFunc { public: - explicit RaggedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind, + explicit RaggedPrefillTreeMaskFunc(ffi::Function attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} @@ -473,7 +478,7 @@ class RaggedPrefillTreeMaskFunc : public AttnBackendFunc { /*! \brief The TIR-based ragged prefill with tree mask attention function class. */ class TIRRaggedPrefillTreeMaskFunc : public RaggedPrefillTreeMaskFunc { public: - explicit TIRRaggedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind) + explicit TIRRaggedPrefillTreeMaskFunc(ffi::Function attn_func, AttnKind attn_kind) : RaggedPrefillTreeMaskFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, @@ -489,44 +494,44 @@ class TIRRaggedPrefillTreeMaskFunc : public RaggedPrefillTreeMaskFunc { /*! * \brief Create a PagedPrefillFunc from the given arguments and the attention kind. - * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. - * \param attn_kind The attention kind of the function. - * \return The created PagedPrefillFunc pointer. + * \param args The arguments that contains the backend kind and the runtime attention + * ffi::Functions. \param attn_kind The attention kind of the function. \return The created + * PagedPrefillFunc pointer. */ std::unique_ptr ConvertPagedPrefillFunc(Array args, AttnKind attn_kind); /*! * \brief Create a PagedDecodeFunc from the given arguments and the attention kind. - * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. - * \param attn_kind The attention kind of the function. - * \return The created PagedDecodeFunc pointer. + * \param args The arguments that contains the backend kind and the runtime attention + * ffi::Functions. \param attn_kind The attention kind of the function. \return The created + * PagedDecodeFunc pointer. */ std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind); /*! * \brief Create a RaggedPrefillFunc from the given arguments and the attention kind. - * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. - * \param attn_kind The attention kind of the function. - * \return The created RaggedPrefillFunc pointer. + * \param args The arguments that contains the backend kind and the runtime attention + * ffi::Functions. \param attn_kind The attention kind of the function. \return The created + * RaggedPrefillFunc pointer. */ std::unique_ptr ConvertRaggedPrefillFunc(Array args, AttnKind attn_kind); /*! * \brief Create a PagedPrefillTreeMaskFunc from the given arguments and the attention kind. - * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. - * \param attn_kind The attention kind of the function. - * \return The created PagedPrefillTreeMaskFunc pointer. + * \param args The arguments that contains the backend kind and the runtime attention + * ffi::Functions. \param attn_kind The attention kind of the function. \return The created + * PagedPrefillTreeMaskFunc pointer. */ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, AttnKind attn_kind); /*! * \brief Create a RaggedPrefillTreeMaskFunc from the given arguments and the attention kind. - * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. - * \param attn_kind The attention kind of the function. - * \return The created RaggedPrefillTreeMaskFunc pointer. + * \param args The arguments that contains the backend kind and the runtime attention + * ffi::Functions. \param attn_kind The attention kind of the function. \return The created + * RaggedPrefillTreeMaskFunc pointer. */ std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, AttnKind attn_kind); diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 340792e63b7e..29adde567ad7 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -326,7 +326,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo */ void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { // a function that lazily get context for error reporting - bool is_func = arg.as() || arg.as(); + bool is_func = arg.as() || arg.as(); CHECK(is_func) << "TypeError: " << err_ctx.value_or("") << " expect a Function but get " << arg->GetTypeKey(); } @@ -571,7 +571,7 @@ extern "C" { /*! * \brief Backend function to get anylist item and set into Packed Func call arg stack. * - * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param anylist The handle to the anylist, backed by ffi::Any* * \param int The index. * \param args The args stack. * \param arg_offset The offset of argument. @@ -582,7 +582,7 @@ TVM_DLL int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMFFIAny* a /*! * \brief Backend function to get anylist item and set into Packed Func call arg stack. * - * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param anylist The handle to the anylist, backed by ffi::Any* * \param int The index. */ TVM_DLL int TVMBackendAnyListResetItem(void* anylist, int index); @@ -590,7 +590,7 @@ TVM_DLL int TVMBackendAnyListResetItem(void* anylist, int index); /*! * \brief Backend function to set anylist item by moving from packed func return. * - * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param anylist The handle to the anylist, backed by ffi::Any* * \param int The index. * \param args The args stack. * \param type_codes The type codes stack. diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index 60905468f6af..0c0c8eda493c 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -168,17 +168,19 @@ class CUDAGraphExtensionNode : public VMExtensionNode { packed_args[i] = tuple_args[i]; } - TVMRetValue capture_func_rv; + ffi::Any capture_func_rv; // Run the function without CUDA graph. This is a warm up step to do necessary initialization // of the CUDA module such as loading module data, setting kernel attributes. - vm->InvokeClosurePacked(capture_func, TVMArgs(packed_args.data(), nargs), &capture_func_rv); + vm->InvokeClosurePacked(capture_func, ffi::PackedArgs(packed_args.data(), nargs), + &capture_func_rv); // Run the graph in capture mode cudaGraph_t graph; { CUDACaptureStream capture_stream(&graph); - vm->InvokeClosurePacked(capture_func, TVMArgs(packed_args.data(), nargs), &capture_func_rv); + vm->InvokeClosurePacked(capture_func, ffi::PackedArgs(packed_args.data(), nargs), + &capture_func_rv); } CUDAGraphCapturedState entry; @@ -205,8 +207,8 @@ class CUDAGraphExtensionNode : public VMExtensionNode { if (auto it = alloc_cache_.find(entry_index); it != alloc_cache_.end()) { return it->second; } - TVMRetValue alloc_func_rv; - vm->InvokeClosurePacked(alloc_func, TVMArgs(nullptr, 0), &alloc_func_rv); + ffi::Any alloc_func_rv; + vm->InvokeClosurePacked(alloc_func, ffi::PackedArgs(nullptr, 0), &alloc_func_rv); ObjectRef alloc_result = alloc_func_rv.cast(); alloc_cache_[entry_index] = alloc_result; return alloc_result; @@ -240,7 +242,7 @@ class CUDAGraphExtension : public VMExtension { }; TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() == 5 || args.size() == 4); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); auto extension = vm->GetOrCreateExtension(); @@ -255,7 +257,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") }); TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK_EQ(args.size(), 3); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); auto extension = vm->GetOrCreateExtension(); diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index 5f2f9778d8e5..07d37bb4d3ea 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -322,7 +322,7 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { STREAM_CHECK(strm->Read(&constant_type, sizeof(constant_type)), "constant"); if (constant_type == ConstantType::kNDArray) { ndarray.Load(strm); - TVMRetValue cell; + ffi::Any cell; cell = ndarray; this->constants.push_back(cell); } else if (constant_type == ConstantType::kShapeTuple) { @@ -332,12 +332,12 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { for (size_t i = 0; i < size; ++i) { strm->Read(&(data[i])); } - TVMRetValue cell; + ffi::Any cell; cell = ShapeTuple(data); this->constants.push_back(cell); } else if (constant_type == ConstantType::kDLDataType) { strm->Read(&dtype); - TVMRetValue cell; + ffi::Any cell; cell = dtype; this->constants.push_back(cell); } else if (constant_type == ConstantType::kString) { @@ -347,19 +347,19 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { for (size_t i = 0; i < size; ++i) { strm->Read(&(data[i])); } - TVMRetValue cell; + ffi::Any cell; cell = String(std::string(data.begin(), data.end())); this->constants.push_back(cell); } else if (constant_type == ConstantType::kInt) { int64_t value; strm->Read(&value); - TVMRetValue cell; + ffi::Any cell; cell = value; this->constants.push_back(cell); } else if (constant_type == ConstantType::kFloat) { double value; strm->Read(&value); - TVMRetValue cell; + ffi::Any cell; cell = value; this->constants.push_back(cell); } else { diff --git a/src/runtime/relax_vm/hexagon/builtin.cc b/src/runtime/relax_vm/hexagon/builtin.cc index 586984dfc0d2..3cfa4db71744 100644 --- a/src/runtime/relax_vm/hexagon/builtin.cc +++ b/src/runtime/relax_vm/hexagon/builtin.cc @@ -33,7 +33,7 @@ namespace runtime { namespace relax_vm { TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") - .set_body_typed([](TVMArgValue vm_ptr, NDArray src_arr, NDArray dst_arr, int queue_id, + .set_body_typed([](ffi::AnyView vm_ptr, NDArray src_arr, NDArray dst_arr, int queue_id, bool bypass_cache) { const DLTensor* dptr = dst_arr.operator->(); const DLTensor* sptr = src_arr.operator->(); @@ -55,7 +55,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") }); TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_wait") - .set_body_typed([](TVMArgValue vm_ptr, int queue_id, int inflight_dma, bool bypass_cache, + .set_body_typed([](ffi::AnyView vm_ptr, int queue_id, int inflight_dma, bool bypass_cache, [[maybe_unused]] NDArray src_arr, [[maybe_unused]] NDArray dst_arr) { ICHECK(inflight_dma >= 0); tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight_dma); diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index 0a4ce1b41f2d..1af7cf78c944 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -38,7 +38,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence") TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence").set_body_method(&KVStateObj::ForkSequence); TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method(&KVStateObj::PopN); TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 3 || args.size() == 4) << "KVState BeginForward only accepts 3 or 4 arguments"; KVState kv_state = args[0].cast(); diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 797f19815fb6..44079b48d1c5 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -301,7 +301,7 @@ NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ShapeTuple shape) { } TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 1 || args.size() == 2) << "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got " << args.size() << "."; diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index 5d554528a673..fc6eb6bf360f 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -300,12 +300,12 @@ class ParamModuleNode : public runtime::ModuleNode { public: const char* type_key() const final { return "param_module"; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { if (name == "get_params") { auto params = params_; - return PackedFunc([params](TVMArgs args, TVMRetValue* rv) { *rv = params; }); + return ffi::Function([params](ffi::PackedArgs args, ffi::Any* rv) { *rv = params; }); } else { - return PackedFunc(); + return ffi::Function(); } } diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 73b1fc922784..c9fc851ea772 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -241,11 +241,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector tree_attn_mask_view_; std::vector tree_attn_mn_indptr_view_; - Optional f_transpose_append_mha_; - Optional f_transpose_append_mla_; - Optional f_transfer_kv_; - Optional f_transfer_kv_page_to_page_ = NullOpt; - PackedFunc f_compact_copy_; + Optional f_transpose_append_mha_; + Optional f_transpose_append_mla_; + Optional f_transfer_kv_; + Optional f_transfer_kv_page_to_page_ = NullOpt; + ffi::Function f_compact_copy_; std::unique_ptr f_attention_prefill_ragged_; std::unique_ptr f_attention_prefill_; std::unique_ptr f_attention_decode_; @@ -254,10 +254,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv_; std::unique_ptr f_attention_prefill_with_tree_mask_; std::unique_ptr f_mla_prefill_; - Array f_merge_inplace_; - PackedFunc f_split_rotary_; - PackedFunc f_copy_single_page_; - Optional f_debug_get_kv_; + Array f_merge_inplace_; + ffi::Function f_split_rotary_; + ffi::Function f_copy_single_page_; + Optional f_debug_get_kv_; /*! \brief The device this PagedKVCache runs on. */ Device device_; @@ -277,16 +277,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, Optional rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, Device device, - Optional f_transpose_append_mha, Optional f_transpose_append_mla, - PackedFunc f_compact_copy, std::unique_ptr f_attention_prefill_ragged, + Optional f_transpose_append_mha, + Optional f_transpose_append_mla, ffi::Function f_compact_copy, + std::unique_ptr f_attention_prefill_ragged, std::unique_ptr f_attention_prefill, std::unique_ptr f_attention_decode, std::unique_ptr f_attention_prefill_sliding_window, std::unique_ptr f_attention_decode_sliding_window, std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv, std::unique_ptr f_attention_prefill_with_tree_mask, - std::unique_ptr f_mla_prefill, Array f_merge_inplace, - PackedFunc f_split_rotary, PackedFunc f_copy_single_page, PackedFunc f_debug_get_kv) + std::unique_ptr f_mla_prefill, Array f_merge_inplace, + ffi::Function f_split_rotary, ffi::Function f_copy_single_page, ffi::Function f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), layer_id_begin_offset_(layer_id_begin_offset), @@ -2311,8 +2312,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") double rotary_theta = args[10].cast(); Optional rope_ext_factors = NullOpt; // args[11] NDArray init = args[12].cast(); - Optional f_transpose_append_mha = NullOpt; // args[13] - Optional f_transpose_append_mla = NullOpt; // args[14] + Optional f_transpose_append_mha = NullOpt; // args[13] + Optional f_transpose_append_mla = NullOpt; // args[14] std::unique_ptr f_attention_prefill_ragged = ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill = @@ -2329,17 +2330,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); std::unique_ptr f_mla_prefill = ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); - Array f_merge_inplace = args[23].cast>(); - PackedFunc f_split_rotary = args[24].cast(); - PackedFunc f_copy_single_page = args[25].cast(); - PackedFunc f_debug_get_kv = args[26].cast(); - PackedFunc f_compact_copy = args[27].cast(); + Array f_merge_inplace = args[23].cast>(); + ffi::Function f_split_rotary = args[24].cast(); + ffi::Function f_copy_single_page = args[25].cast(); + ffi::Function f_debug_get_kv = args[26].cast(); + ffi::Function f_compact_copy = args[27].cast(); if (auto opt_nd = args[11].as()) { rope_ext_factors = opt_nd.value(); } - auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { - if (auto opt_func = args[arg_idx].as()) { + auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { + if (auto opt_func = args[arg_idx].as()) { return opt_func.value(); } return NullOpt; @@ -2366,7 +2367,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") num_total_pages += reserved_num_seqs * 2; } // NOTE: We will remove this legacy construction after finishing the transition phase. - // Some `PackedFunc()` here are placeholders that will be filled. + // Some `ffi::Function()` here are placeholders that will be filled. ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, layer_id_end_offset, num_qo_heads, num_kv_heads, qk_head_dim, v_head_dim, attn_kinds_vec, reserved_num_seqs, num_total_pages, diff --git a/src/runtime/relax_vm/rnn_state.cc b/src/runtime/relax_vm/rnn_state.cc index bcc8fb74cf4a..9468a50d2071 100644 --- a/src/runtime/relax_vm/rnn_state.cc +++ b/src/runtime/relax_vm/rnn_state.cc @@ -138,7 +138,7 @@ class RNNStateImpObj : public RNNStateObj { * \note Each state data per layer may have different dtype and shape, so we use a * different function for each state data. */ - Array f_gets_; + Array f_gets_; /*! * \brief The function to set the state data to the storage. * The function signature is `f_set_(state, seq_slot_ids, history_slot_ids, data, max_history)`. @@ -149,16 +149,16 @@ class RNNStateImpObj : public RNNStateObj { * \note Each state data per layer may have different dtype and shape, so we use a * different function for each state data. */ - Array f_sets_; + Array f_sets_; public: /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ - explicit RNNStateImpObj(int64_t num_layers, // - int64_t reserved_num_seqs, // - int64_t max_history, // - DLDevice device, // - Array f_gets, // - Array f_sets, // + explicit RNNStateImpObj(int64_t num_layers, // + int64_t reserved_num_seqs, // + int64_t max_history, // + DLDevice device, // + Array f_gets, // + Array f_sets, // Array init_layer_value) : num_layers_(num_layers), reserved_num_seqs_(reserved_num_seqs), @@ -465,11 +465,11 @@ TVM_REGISTER_OBJECT_TYPE(RNNStateImpObj); //------------------------------------------------- TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_create") - .set_body_typed([](int64_t num_layers, // - int64_t reserved_num_seqs, // - int64_t max_history, // - Array f_gets, // - Array f_sets, // + .set_body_typed([](int64_t num_layers, // + int64_t reserved_num_seqs, // + int64_t max_history, // + Array f_gets, // + Array f_sets, // Array init_layer_value) { CHECK_GT(num_layers, 0) << "The number of layers should be greater than 0."; CHECK_GT(reserved_num_seqs, 0) diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 5f9440d7f718..c92509923b7e 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -39,7 +39,7 @@ namespace relax_vm { //--------------------------------------------- TVM_REGISTER_OBJECT_TYPE(VMClosureObj); -VMClosure::VMClosure(String func_name, PackedFunc impl) { +VMClosure::VMClosure(String func_name, ffi::Function impl) { auto ptr = make_object(); ptr->func_name = func_name; ptr->impl = std::move(impl); @@ -47,13 +47,13 @@ VMClosure::VMClosure(String func_name, PackedFunc impl) { } /*! - * \brief Create another PackedFunc with last arguments already bound to last_args. - * \param func The input func, can be a VMClosure or PackedFunc. + * \brief Create another ffi::Function with last arguments already bound to last_args. + * \param func The input func, can be a VMClosure or ffi::Function. * \param last_args The arguments to bound to in the end of the function. * \note The new function takes in arguments and append the last_args in the end. */ -PackedFunc VMClosure::BindLastArgs(PackedFunc func, std::vector last_args) { - return PackedFunc([func, last_args](TVMArgs args, TVMRetValue* rv) { +ffi::Function VMClosure::BindLastArgs(ffi::Function func, std::vector last_args) { + return ffi::Function([func, last_args](ffi::PackedArgs args, ffi::Any* rv) { std::vector packed_args(args.size() + last_args.size()); std::copy(args.data(), args.data() + args.size(), packed_args.data()); for (size_t i = 0; i < last_args.size(); ++i) { @@ -68,7 +68,7 @@ PackedFunc VMClosure::BindLastArgs(PackedFunc func, std::vector last_args) //----------------------------------------------------------- // Use the args after `starting_arg_idx` as a series of indices into `obj`, // indexing into nested Array and returning the final indexed object. -Any IndexIntoNestedObject(Any obj, TVMArgs args, int starting_arg_idx) { +Any IndexIntoNestedObject(Any obj, ffi::PackedArgs args, int starting_arg_idx) { for (int i = starting_arg_idx; i < args.size(); i++) { // the object must be an Array to be able to index into it if (!obj.as()) { @@ -110,7 +110,7 @@ Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { } } -TVMRetValue ConvertArgToDevice(AnyView input, Device dev, Allocator* alloc) { +ffi::Any ConvertArgToDevice(AnyView input, Device dev, Allocator* alloc) { // in terms of memory-behavior. // To be extra careful, we copy DLTensor. // The developer can still explicitly allocate NDArray @@ -130,7 +130,7 @@ TVMRetValue ConvertArgToDevice(AnyView input, Device dev, Allocator* alloc) { return ret; } -TVMRetValue ConvertRegToDevice(TVMRetValue input, Device dev, Allocator* alloc) { +ffi::Any ConvertRegToDevice(ffi::Any input, Device dev, Allocator* alloc) { Any ret; if (auto opt_obj = input.as()) { ret = ConvertObjectToDevice(opt_obj.value(), dev, alloc); @@ -146,7 +146,7 @@ TVMRetValue ConvertRegToDevice(TVMRetValue input, Device dev, Allocator* alloc) /*! * \brief The register type. */ -using RegType = TVMRetValue; +using RegType = ffi::Any; /*! * \brief A representation of a stack frame. @@ -162,7 +162,7 @@ struct VMFrame { std::vector register_file; /*! \brief Register in caller's frame to put return value */ RegName caller_return_register; - // The following fields are used for PackedFunc call within + // The following fields are used for ffi::Function call within // a single function scope. The space is reused across multiple // packed func calls to increase cache locality and avoid re-allocation /*! \brief Temporary argument value stack for packed func call. */ @@ -200,25 +200,25 @@ class VirtualMachineImpl : public VirtualMachine { VMClosure GetClosure(const String& func_name) final { return this->GetClosureInternal(func_name, false).value(); } - void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, - TVMRetValue* rv) final; - void SetInstrument(PackedFunc instrument) final { this->instrument_ = instrument; } + void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, ffi::PackedArgs args, + ffi::Any* rv) final; + void SetInstrument(ffi::Function instrument) final { this->instrument_ = instrument; } //--------------------------------------------------- // Functions in the vtable of Module //--------------------------------------------------- - void _Init(TVMArgs args, TVMRetValue* rv); - void _SaveClosure(TVMArgs args, TVMRetValue* rv); - void _InvokeClosure(TVMArgs args, TVMRetValue* rv); + void _Init(ffi::PackedArgs args, ffi::Any* rv); + void _SaveClosure(ffi::PackedArgs args, ffi::Any* rv); + void _InvokeClosure(ffi::PackedArgs args, ffi::Any* rv); void _InvokeClosureStateful(std::string func_name); - void _SetInstrument(TVMArgs args, TVMRetValue* rv); - void _GetOutputArity(TVMArgs args, TVMRetValue* rv); - void _GetOutput(TVMArgs args, TVMRetValue* rv); - void _SetInputWithoutParamModule(TVMArgs args, TVMRetValue* rv); - void _SetInputWithParamModule(TVMArgs args, TVMRetValue* rv); + void _SetInstrument(ffi::PackedArgs args, ffi::Any* rv); + void _GetOutputArity(ffi::PackedArgs args, ffi::Any* rv); + void _GetOutput(ffi::PackedArgs args, ffi::Any* rv); + void _SetInputWithoutParamModule(ffi::PackedArgs args, ffi::Any* rv); + void _SetInputWithParamModule(ffi::PackedArgs args, ffi::Any* rv); int _GetFunctionArity(std::string func_name); std::string _GetFunctionParamName(std::string func_name, int index); - PackedFunc _LookupFunction(const String& name); + ffi::Function _LookupFunction(const String& name); TVM_MODULE_VTABLE_BEGIN("relax.VirtualMachine"); TVM_MODULE_VTABLE_ENTRY_PACKED("vm_initialization", &VirtualMachineImpl::_Init); @@ -257,7 +257,7 @@ class VirtualMachineImpl : public VirtualMachine { * the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C * runtime. */ - void SetInput(std::string func_name, bool with_param_module, TVMArgs args); + void SetInput(std::string func_name, bool with_param_module, ffi::PackedArgs args); /*! * \brief Look up whether the VM has a function by the given name. @@ -285,7 +285,7 @@ class VirtualMachineImpl : public VirtualMachine { * \note This function is used by RPC server to help benchmarking. */ void SaveClosure(const String& func_name, const String& save_name, bool include_return, - TVMArgs args); + ffi::PackedArgs args); /*! * \brief Internal function to invoke a closure. * \param closure_or_packed The closure to be invoked. @@ -306,14 +306,14 @@ class VirtualMachineImpl : public VirtualMachine { /*! * \brief Get function by querying all of the current module's imports. * \param name The name of the function. - * \return The result function, can return PackedFunc(nullptr) if nothing is found. + * \return The result function, can return ffi::Function(nullptr) if nothing is found. */ - PackedFunc GetFuncFromImports(const String& name) { + ffi::Function GetFuncFromImports(const String& name) { for (auto& lib : this->imports_) { - PackedFunc func = lib->GetFunction(name, true); + ffi::Function func = lib->GetFunction(name, true); if (func.defined()) return func; } - return PackedFunc(nullptr); + return ffi::Function(nullptr); } /*! * \brief Initialize function pool. @@ -423,11 +423,11 @@ class VirtualMachineImpl : public VirtualMachine { /*! \brief The loaded executable. */ ObjectPtr exec_; /*! \brief The global constant pool */ - std::vector const_pool_; + std::vector const_pool_; /*! * \brief Function pool to cache functions in func_table */ - std::vector func_pool_; + std::vector func_pool_; //-------------------------------------------------------- // Executor interface support //-------------------------------------------------------- @@ -455,7 +455,7 @@ class VirtualMachineImpl : public VirtualMachine { /*! \brief The special return register. */ RegType return_value_; /*!\ brief instrument function. */ - PackedFunc instrument_ = nullptr; + ffi::Function instrument_ = nullptr; }; void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { @@ -503,7 +503,8 @@ RegType VirtualMachineImpl::LookupVMOutput(const std::string& func_name) { return outputs_[func_name]; } -void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, TVMArgs args) { +void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, + ffi::PackedArgs args) { const auto& m = exec_->func_map; if (m.find(func_name) != m.end()) { Index gf_idx = m.at(func_name); @@ -529,16 +530,16 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, //------------------------------------------ // Closure handling //------------------------------------------ -void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, - TVMRetValue* rv) { +void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, + ffi::PackedArgs args, ffi::Any* rv) { // run packed call if it is a packed func. - if (auto* packed = closure_or_packedfunc.as()) { + if (auto* packed = closure_or_packedfunc.as()) { packed->CallPacked(args.data(), args.size(), rv); return; } // run closure call. auto* clo = closure_or_packedfunc.as(); - ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc "; + ICHECK(clo != nullptr) << "Function expects a closure or ffi::Function "; std::vector packed_args(args.size() + 1); // per convention, ctx ptr must be VirtualMachine* casted to void. @@ -556,7 +557,7 @@ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedf RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_packed, const std::vector& args) { RegType ret; - auto* packed = closure_or_packed.as(); + auto* packed = closure_or_packed.as(); auto* clo = closure_or_packed.as(); int clo_offset = clo != nullptr ? 1 : 0; @@ -579,16 +580,16 @@ RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_pa } void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save_name, - bool include_return, TVMArgs args) { + bool include_return, ffi::PackedArgs args) { VMClosure clo = this->GetClosure(func_name); std::vector inputs(args.size()); for (int i = 0; i < args.size(); ++i) { inputs[i] = ConvertArgToDevice(args[i], this->devices[0], this->allocators[0]); } - PackedFunc impl = VMClosure::BindLastArgs(clo->impl, inputs); + ffi::Function impl = VMClosure::BindLastArgs(clo->impl, inputs); if (!include_return) { - impl = PackedFunc([impl](TVMArgs args, TVMRetValue* rv) { - TVMRetValue temp; + impl = ffi::Function([impl](ffi::PackedArgs args, ffi::Any* rv) { + ffi::Any temp; impl.CallPacked(args, &temp); }); } @@ -613,7 +614,7 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na if (finfo.kind == VMFuncInfo::FuncKind::kVMFunc) { // NOTE: should not capture strong ref to self and avoid cyclic ref. - auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) { + auto impl = ffi::Function([gf_idx](ffi::PackedArgs args, ffi::Any* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].cast()); @@ -627,17 +628,17 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na } else { ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) << "Cannot support closure with function kind " << static_cast(finfo.kind); - PackedFunc tir_func = GetFuncFromImports("__vmtir__" + finfo.name); + ffi::Function tir_func = GetFuncFromImports("__vmtir__" + finfo.name); ICHECK(tir_func != nullptr) << "Cannot find underlying compiled tir function of VMTIRFunc " << finfo.name; - auto impl = PackedFunc([this, finfo, tir_func](TVMArgs args, TVMRetValue* rv) { + auto impl = ffi::Function([this, finfo, tir_func](ffi::PackedArgs args, ffi::Any* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].cast()); ICHECK(ctx_ptr == this); ICHECK_EQ(args.size() - 1, finfo.num_args) << "Function " << finfo.name << " expects " << finfo.num_args << " arguments"; ICHECK_GE(finfo.register_file_size, finfo.num_args + 1); - std::vector reg_file(finfo.register_file_size); + std::vector reg_file(finfo.register_file_size); for (int64_t i = 0; i < finfo.num_args; ++i) { reg_file[i] = args[i + 1]; } @@ -703,14 +704,14 @@ void VirtualMachineImpl::InitFuncPool() { const VMFuncInfo& info = exec_->func_table[func_index]; if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { // only look through imports first - PackedFunc func = GetFuncFromImports(info.name); + ffi::Function func = GetFuncFromImports(info.name); if (!func.defined()) { const auto p_func = tvm::ffi::Function::GetGlobal(info.name); if (p_func.has_value()) func = *(p_func); } ICHECK(func.defined()) - << "Error: Cannot find PackedFunc " << info.name - << " in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in " + << "Error: Cannot find ffi::Function " << info.name + << " in either Relax VM kernel library, or in TVM runtime ffi::Function registry, or in " "global Relax functions of the VM executable"; func_pool_[func_index] = func; @@ -761,7 +762,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { } } ffi::PackedArgs args(call_args.data() + args_begin_offset, instr.num_args); - TVMRetValue ret; + ffi::Any ret; ICHECK_LT(static_cast(instr.func_idx), this->func_pool_.size()); @@ -858,7 +859,7 @@ ObjectPtr VirtualMachine::Create() { return make_object devices; std::vector alloc_types; @@ -872,13 +873,13 @@ void VirtualMachineImpl::_Init(TVMArgs args, TVMRetValue* rv) { this->Init(devices, alloc_types); } -void VirtualMachineImpl::_SaveClosure(TVMArgs args, TVMRetValue* rv) { +void VirtualMachineImpl::_SaveClosure(ffi::PackedArgs args, ffi::Any* rv) { ICHECK_GE(args.size(), 3); std::string func_name = args[0].cast(); this->SaveClosure(func_name, args[1].cast(), args[2].cast(), args.Slice(3)); } -void VirtualMachineImpl::_InvokeClosure(TVMArgs args, TVMRetValue* rv) { +void VirtualMachineImpl::_InvokeClosure(ffi::PackedArgs args, ffi::Any* rv) { this->InvokeClosurePacked(args[0].cast(), args.Slice(1), rv); } @@ -896,20 +897,20 @@ void VirtualMachineImpl::_InvokeClosureStateful(std::string func_name) { inputs_[func_name]); } -void VirtualMachineImpl::_SetInstrument(TVMArgs args, TVMRetValue* rv) { +void VirtualMachineImpl::_SetInstrument(ffi::PackedArgs args, ffi::Any* rv) { if (args[0].as()) { - this->SetInstrument(args[0].cast()); + this->SetInstrument(args[0].cast()); } else { String func_name = args[0].cast(); const auto factory = tvm::ffi::Function::GetGlobal(func_name); CHECK(factory.has_value()) << "Cannot find factory " << func_name; - TVMRetValue rv; + ffi::Any rv; factory->CallPacked(args.Slice(1), &rv); - this->SetInstrument(rv.cast()); + this->SetInstrument(rv.cast()); } } -void VirtualMachineImpl::_GetOutputArity(TVMArgs args, TVMRetValue* rv) { +void VirtualMachineImpl::_GetOutputArity(ffi::PackedArgs args, ffi::Any* rv) { std::string func_name = args[0].cast(); RegType out = LookupVMOutput(func_name); Any obj = IndexIntoNestedObject(out, args, 1); @@ -920,7 +921,7 @@ void VirtualMachineImpl::_GetOutputArity(TVMArgs args, TVMRetValue* rv) { } } -void VirtualMachineImpl::_GetOutput(TVMArgs args, TVMRetValue* rv) { +void VirtualMachineImpl::_GetOutput(ffi::PackedArgs args, ffi::Any* rv) { std::string func_name = args[0].cast(); RegType out = LookupVMOutput(func_name); Any obj = IndexIntoNestedObject(out, args, 1); @@ -932,12 +933,12 @@ void VirtualMachineImpl::_GetOutput(TVMArgs args, TVMRetValue* rv) { *rv = obj; } -void VirtualMachineImpl::_SetInputWithoutParamModule(TVMArgs args, TVMRetValue* rv) { +void VirtualMachineImpl::_SetInputWithoutParamModule(ffi::PackedArgs args, ffi::Any* rv) { std::string func_name = args[0].cast(); this->SetInput(func_name, false, args.Slice(1)); } -void VirtualMachineImpl::_SetInputWithParamModule(TVMArgs args, TVMRetValue* rv) { +void VirtualMachineImpl::_SetInputWithParamModule(ffi::PackedArgs args, ffi::Any* rv) { std::string func_name = args[0].cast(); this->SetInput(func_name, true, args.Slice(1)); } @@ -956,16 +957,16 @@ std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, int return vm_func.param_names[index]; } -PackedFunc VirtualMachineImpl::_LookupFunction(const String& name) { +ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { if (Optional opt = this->GetClosureInternal(name, true)) { - return PackedFunc( - [clo = opt.value(), _self = GetRef(this)](TVMArgs args, TVMRetValue* rv) -> void { - auto* self = const_cast(_self.as()); - ICHECK(self); - self->InvokeClosurePacked(clo, args, rv); - }); + return ffi::Function([clo = opt.value(), _self = GetRef(this)](ffi::PackedArgs args, + ffi::Any* rv) -> void { + auto* self = const_cast(_self.as()); + ICHECK(self); + self->InvokeClosurePacked(clo, args, rv); + }); } - return PackedFunc(nullptr); + return ffi::Function(nullptr); } //---------------------------------------------------------------- @@ -979,9 +980,9 @@ PackedFunc VirtualMachineImpl::_LookupFunction(const String& name) { */ class VirtualMachineProfiler : public VirtualMachineImpl { public: - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { if (name == "profile") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { std::string f_name = args[0].cast(); VMClosure clo = this->GetClosure(f_name); diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index afd3a55c22c1..67991717552e 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -37,7 +37,7 @@ namespace runtime { class ROCMDeviceAPI final : public DeviceAPI { public: void SetDevice(Device dev) final { ROCM_CALL(hipSetDevice(dev.device_id)); } - void GetAttr(Device device, DeviceAttrKind kind, TVMRetValue* rv) final { + void GetAttr(Device device, DeviceAttrKind kind, ffi::Any* rv) final { int value = 0; switch (kind) { case kExist: { @@ -251,12 +251,12 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.rocm").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.rocm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = ROCMDeviceAPI::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = ROCMDeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 96b5caa186bb..44c7483624e6 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -66,7 +66,7 @@ class ROCMModuleNode : public runtime::ModuleNode { int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const String& file_name, const String& format) final { std::string fmt = GetFileFormat(file_name, format); @@ -157,7 +157,8 @@ class ROCMWrappedFunc { launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { + void operator()(ffi::PackedArgs args, ffi::Any* rv, void* packed_args, + size_t packed_nbytes) const { int device_id; ROCM_CALL(hipGetDevice(&device_id)); if (fcache_[device_id] == nullptr) { @@ -190,11 +191,12 @@ class ROCMWrappedFunc { LaunchParamConfig launch_param_config_; }; -PackedFunc ROCMModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { +ffi::Function ROCMModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); - if (it == fmap_.end()) return PackedFunc(); + if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; ROCMWrappedFunc f; f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h index 114bc0a2e7bd..62af2d92a8ac 100644 --- a/src/runtime/rpc/rpc_channel.h +++ b/src/runtime/rpc/rpc_channel.h @@ -69,7 +69,7 @@ class CallbackChannel final : public RPCChannel { * \param frecv The recv function, takes an expected maximum size, and return * a byte array with the actual amount of data received. */ - explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) + explicit CallbackChannel(ffi::Function fsend, ffi::Function frecv) : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} ~CallbackChannel() {} @@ -90,8 +90,8 @@ class CallbackChannel final : public RPCChannel { size_t Recv(void* data, size_t size) final; private: - PackedFunc fsend_; - PackedFunc frecv_; + ffi::Function fsend_; + ffi::Function frecv_; }; } // namespace runtime diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 729cc10f0ba1..710965d07824 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -38,7 +38,7 @@ class RPCDeviceAPI final : public DeviceAPI { GetSess(dev)->GetDeviceAPI(remote_dev)->SetDevice(remote_dev); } - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final { auto remote_dev = RemoveRPCSessionMask(dev); GetSess(dev)->GetDeviceAPI(remote_dev)->GetAttr(remote_dev, kind, rv); } @@ -150,7 +150,7 @@ class RPCDeviceAPI final : public DeviceAPI { } }; -TVM_REGISTER_GLOBAL("device_api.rpc").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.rpc").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { static RPCDeviceAPI inst; DeviceAPI* ptr = &inst; *rv = static_cast(ptr); diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 47fa2e3b4420..23edfa9bb520 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -394,9 +394,9 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { /*! * \brief Receive incoming packed seq from the stream. * \return The received argments. - * \note The TVMArgs is available until we switchstate. + * \note The ffi::PackedArgs is available until we switchstate. */ - TVMArgs RecvPackedSeq() { + ffi::PackedArgs RecvPackedSeq() { TVMValue* values; int* tcodes; int num_args; @@ -425,7 +425,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { * \brief Return a packed sequence to the remote. * \param args The arguments. */ - void ReturnPackedSeq(TVMArgs args) { + void ReturnPackedSeq(ffi::PackedArgs args) { // Legacy ABI translation // TODO(tqchen): remove this once we have upgraded to new ABI TVMValue* values = this->ArenaAlloc(args.size()); @@ -440,7 +440,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { * \param setreturn The function to encode return. */ void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { - TVMArgs args = RecvPackedSeq(); + ffi::PackedArgs args = RecvPackedSeq(); if (code == RPCCode::kException) { // switch to the state before sending exception. this->SwitchToState(kRecvPacketNumBytes); @@ -483,8 +483,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { fcopyack(data_ptr, data_bytes); } else { char* temp_data = this->ArenaAlloc(data_bytes); - auto on_copy_complete = [this, elem_bytes, data_bytes, temp_data, fcopyack](RPCCode status, - TVMArgs args) { + auto on_copy_complete = [this, elem_bytes, data_bytes, temp_data, fcopyack]( + RPCCode status, ffi::PackedArgs args) { if (status == RPCCode::kException) { this->ReturnException(args[0].cast()); this->SwitchToState(kRecvPacketNumBytes); @@ -528,7 +528,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { dmlc::ByteSwap(temp_data, elem_bytes, data_bytes / elem_bytes); } - auto on_copy_complete = [this](RPCCode status, TVMArgs args) { + auto on_copy_complete = [this](RPCCode status, ffi::PackedArgs args) { if (status == RPCCode::kException) { this->ReturnException(args[0].cast()); this->SwitchToState(kRecvPacketNumBytes); @@ -548,11 +548,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { uint64_t call_handle; this->Read(&call_handle); - TVMArgs args = RecvPackedSeq(); + ffi::PackedArgs args = RecvPackedSeq(); this->SwitchToState(kWaitForAsyncCallback); GetServingSession()->AsyncCallFunc(reinterpret_cast(call_handle), args, - [this](RPCCode status, TVMArgs args) { + [this](RPCCode status, ffi::PackedArgs args) { if (status == RPCCode::kException) { this->ReturnException(args[0].cast()); } else { @@ -571,7 +571,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { client_protocol_ver.resize(len); this->Read(dmlc::BeginPtr(client_protocol_ver), len); - TVMArgs args = RecvPackedSeq(); + ffi::PackedArgs args = RecvPackedSeq(); try { ICHECK(serving_session_ == nullptr) << "Server has already been initialized"; @@ -595,7 +595,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { auto fconstructor = tvm::ffi::Function::GetGlobal(constructor_name); ICHECK(fconstructor.has_value()) << " Cannot find session constructor " << constructor_name; - TVMRetValue con_ret; + ffi::Any con_ret; try { fconstructor->CallPacked(constructor_args, &con_ret); @@ -622,20 +622,21 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } void HandleSyscallStreamSync() { - TVMArgs args = RecvPackedSeq(); + ffi::PackedArgs args = RecvPackedSeq(); try { auto dev = args[0].cast(); TVMStreamHandle handle = args[1].cast(); this->SwitchToState(kWaitForAsyncCallback); - GetServingSession()->AsyncStreamWait(dev, handle, [this](RPCCode status, TVMArgs args) { - if (status == RPCCode::kException) { - this->ReturnException(args[0].cast()); - } else { - this->ReturnVoid(); - } - this->SwitchToState(kRecvPacketNumBytes); - }); + GetServingSession()->AsyncStreamWait(dev, handle, + [this](RPCCode status, ffi::PackedArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args[0].cast()); + } else { + this->ReturnVoid(); + } + this->SwitchToState(kRecvPacketNumBytes); + }); } catch (const std::exception& e) { this->ReturnException(e.what()); this->SwitchToState(kRecvPacketNumBytes); @@ -645,9 +646,9 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // Handler for special syscalls that have a specific RPCCode. template void SysCallHandler(F f) { - TVMArgs args = RecvPackedSeq(); + ffi::PackedArgs args = RecvPackedSeq(); try { - TVMRetValue rv; + ffi::Any rv; f(GetServingSession(), args, &rv); AnyView packed_args[1]; packed_args[0] = rv; @@ -740,7 +741,7 @@ void RPCEndpoint::Init() { handler_ = std::make_shared(&reader_, &writer_, name_, &remote_key_, flush_writer); // Quick function to for syscall remote. - syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { + syscall_remote_ = ffi::Function([this](ffi::PackedArgs all_args, ffi::Any* rv) { std::lock_guard lock(mutex_); RPCCode code = static_cast(all_args[0].cast()); ffi::PackedArgs args = all_args.Slice(1); @@ -759,7 +760,7 @@ void RPCEndpoint::Init() { handler_->Write(code); handler_->SendPackedSeq(values, tcodes, args.size(), true); - code = HandleUntilReturnEvent(true, [rv](TVMArgs args) { + code = HandleUntilReturnEvent(true, [rv](ffi::PackedArgs args) { ICHECK_EQ(args.size(), 1); *rv = args[0]; }); @@ -778,7 +779,7 @@ void RPCEndpoint::Init() { */ std::shared_ptr RPCEndpoint::Create(std::unique_ptr channel, std::string name, std::string remote_key, - TypedPackedFunc fcleanup) { + ffi::TypedFunction fcleanup) { std::shared_ptr endpt = std::make_shared(); endpt->channel_ = std::move(channel); endpt->name_ = std::move(name); @@ -816,8 +817,8 @@ void RPCEndpoint::ServerLoop() { if (const auto f = tvm::ffi::Function::GetGlobal("tvm.rpc.server.start")) { (*f)(); } - TVMRetValue rv; - ICHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown); + ffi::Any rv; + ICHECK(HandleUntilReturnEvent(false, [](ffi::PackedArgs) {}) == RPCCode::kShutdown); if (const auto f = tvm::ffi::Function::GetGlobal("tvm.rpc.server.shutdown")) { (*f)(); } @@ -829,7 +830,7 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int even RPCCode code = RPCCode::kNone; if (in_bytes.length() != 0) { reader_.Write(in_bytes.c_str(), in_bytes.length()); - code = handler_->HandleNextEvent(false, true, [](TVMArgs) {}); + code = handler_->HandleNextEvent(false, true, [](ffi::PackedArgs) {}); } if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { writer_.ReadWithCallback( @@ -842,7 +843,7 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int even return 1; } -void RPCEndpoint::InitRemoteSession(TVMArgs args) { +void RPCEndpoint::InitRemoteSession(ffi::PackedArgs args) { std::lock_guard lock(mutex_); RPCCode code = RPCCode::kInitServer; std::string protocol_ver = kRPCProtocolVer; @@ -865,7 +866,7 @@ void RPCEndpoint::InitRemoteSession(TVMArgs args) { handler_->WriteArray(protocol_ver.data(), length); handler_->SendPackedSeq(values, tcodes, args.size(), true); - code = HandleUntilReturnEvent(true, [](TVMArgs args) {}); + code = HandleUntilReturnEvent(true, [](ffi::PackedArgs args) {}); ICHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); } @@ -914,7 +915,7 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) RPCReference::SendDLTensor(handler_, to); handler_->Write(nbytes); handler_->WriteArray(reinterpret_cast(from_bytes), nbytes); - ICHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kReturn); + ICHECK(HandleUntilReturnEvent(true, [](ffi::PackedArgs) {}) == RPCCode::kReturn); } void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) { @@ -933,29 +934,29 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes handler_->Write(code); RPCReference::SendDLTensor(handler_, from); handler_->Write(nbytes); - ICHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kCopyAck); + ICHECK(HandleUntilReturnEvent(true, [](ffi::PackedArgs) {}) == RPCCode::kCopyAck); handler_->ReadArray(reinterpret_cast(to_bytes), nbytes); handler_->FinishCopyAck(); } // SysCallEventHandler functions -void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCGetGlobalFunc(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto name = args[0].cast(); *rv = handler->GetFunction(name); } -void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCFreeHandle(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { void* handle = args[0].cast(); handler->FreeHandle(handle); } -void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevSetDevice(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); handler->GetDeviceAPI(dev)->SetDevice(dev); } -void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevGetAttr(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); DeviceAttrKind kind = static_cast(args[1].cast()); if (kind == kExist) { @@ -970,7 +971,7 @@ void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { } } -void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevAllocData(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); uint64_t nbytes = args[1].cast(); uint64_t alignment = args[2].cast(); @@ -979,7 +980,7 @@ void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { *rv = data; } -void RPCDevAllocDataWithScope(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevAllocDataWithScope(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto arr = args[0].cast(); Device dev = arr->device; int ndim = arr->ndim; @@ -990,13 +991,13 @@ void RPCDevAllocDataWithScope(RPCSession* handler, TVMArgs args, TVMRetValue* rv *rv = data; } -void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevFreeData(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); void* ptr = args[1].cast(); handler->GetDeviceAPI(dev)->FreeDataSpace(dev, ptr); } -void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCCopyAmongRemote(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto from = args[0].cast(); auto to = args[1].cast(); TVMStreamHandle stream = args[2].cast(); @@ -1011,25 +1012,25 @@ void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); } -void RPCDevCreateStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevCreateStream(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); void* data = handler->GetDeviceAPI(dev)->CreateStream(dev); *rv = data; } -void RPCDevFreeStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevFreeStream(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); TVMStreamHandle stream = args[1].cast(); handler->GetDeviceAPI(dev)->FreeStream(dev, stream); } -void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevSetStream(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); TVMStreamHandle stream = args[1].cast(); handler->GetDeviceAPI(dev)->SetStream(dev, stream); } -void RPCDevGetCurrentStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCDevGetCurrentStream(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto dev = args[0].cast(); *rv = handler->GetDeviceAPI(dev)->GetCurrentStream(dev); } @@ -1162,7 +1163,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { void SetDevice(Device dev) final { endpoint_->SysCallRemote(RPCCode::kDevSetDevice, dev); } - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final { if (dev.device_type == kDLCPU && kind == kExist) { // cpu always exists. *rv = 1; @@ -1241,7 +1242,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { if (rpc_func == nullptr) { rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault; } else { - CallFunc(rpc_func, ffi::PackedArgs(nullptr, 0), [this](TVMArgs args) { + CallFunc(rpc_func, ffi::PackedArgs(nullptr, 0), [this](ffi::PackedArgs args) { // Use args[1] as return value, args[0] is tcode // Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc rpc_chunk_max_size_bytes_ = args[1].cast(); diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 7da1340a7ef4..a420e6d92f41 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -128,7 +128,7 @@ class RPCEndpoint { * * \param session_constructor_args Optional sequence of the remote sesssion constructor. */ - void InitRemoteSession(TVMArgs session_constructor_args); + void InitRemoteSession(ffi::PackedArgs session_constructor_args); /*! * \brief Call into remote function @@ -168,7 +168,7 @@ class RPCEndpoint { * \return The returned remote value. */ template - inline TVMRetValue SysCallRemote(RPCCode fcode, Args&&... args); + inline ffi::Any SysCallRemote(RPCCode fcode, Args&&... args); /*! * \brief Create a RPC session with given channel. * \param channel The communication channel. @@ -180,7 +180,7 @@ class RPCEndpoint { */ static std::shared_ptr Create(std::unique_ptr channel, std::string name, std::string remote_key, - TypedPackedFunc fcleanup = nullptr); + ffi::TypedFunction fcleanup = nullptr); private: class EventHandler; @@ -199,13 +199,13 @@ class RPCEndpoint { // Event handler. std::shared_ptr handler_; // syscall remote with specified function code. - PackedFunc syscall_remote_; + ffi::Function syscall_remote_; // The name of the session. std::string name_; // The remote key std::string remote_key_; // Invoked when the RPC session is terminated - TypedPackedFunc fcleanup_; + ffi::TypedFunction fcleanup_; }; /*! @@ -217,7 +217,7 @@ std::shared_ptr CreateClientSession(std::shared_ptr end // implementation of inline functions template -inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { +inline ffi::Any RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { return syscall_remote_(static_cast(code), std::forward(args)...); } diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index a47d5cd16c1d..97d62cd586fc 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -31,13 +31,14 @@ namespace tvm { namespace runtime { -PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name, std::string remote_key) { - static PackedFunc frecv( - [](TVMArgs args, TVMRetValue* rv) { LOG(FATAL) << "Do not allow explicit receive"; }); +ffi::Function CreateEventDrivenServer(ffi::Function fsend, std::string name, + std::string remote_key) { + static ffi::Function frecv( + [](ffi::PackedArgs args, ffi::Any* rv) { LOG(FATAL) << "Do not allow explicit receive"; }); auto ch = std::make_unique(fsend, frecv); std::shared_ptr sess = RPCEndpoint::Create(std::move(ch), name, remote_key); - return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([sess](ffi::PackedArgs args, ffi::Any* rv) { int ret = sess->ServerAsyncIOEventHandler(args[0].cast(), args[1].cast()); *rv = ret; }); diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index d924be327da0..5761828876e1 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -43,7 +43,7 @@ RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) } } -void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) { +void LocalSession::EncodeReturn(ffi::Any rv, const FEncodeReturn& encode_return) { AnyView packed_args[3]; // NOTE: this is the place that we need to handle special RPC-related // ABI convention for return value passing that is built on top of Any FFI. diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h index 8625fe1f0430..4019552ebcd1 100644 --- a/src/runtime/rpc/rpc_local_session.h +++ b/src/runtime/rpc/rpc_local_session.h @@ -64,7 +64,7 @@ class LocalSession : public RPCSession { * \param rv The return value. * \param encode_return The encoding function. */ - void EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return); + void EncodeReturn(ffi::Any rv, const FEncodeReturn& encode_return); }; } // namespace runtime diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 8f9606100421..c50c92ee995a 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -45,7 +45,7 @@ namespace runtime { * \param handle A pointer valid on the remote end which should form the `data` field of the * underlying DLTensor. * \param template_tensor An empty DLTensor whose shape and dtype fields are used to fill the newly - * created array. Needed because it's difficult to pass a shape vector as a PackedFunc arg. + * created array. Needed because it's difficult to pass a shape vector as a ffi::Function arg. * \param dev Remote device used with this tensor. Must have non-zero RPCSessMask. * \param remote_ndarray_handle The handle returned by RPC server to identify the NDArray. */ @@ -74,13 +74,13 @@ NDArray NDArrayFromRemoteOpaqueHandle(std::shared_ptr sess, void* ha } /*! - * \brief A wrapped remote function as a PackedFunc. + * \brief A wrapped remote function as a ffi::Function. */ class RPCWrappedFunc : public Object { public: RPCWrappedFunc(void* handle, std::shared_ptr sess) : handle_(handle), sess_(sess) {} - void operator()(TVMArgs args, TVMRetValue* rv) const { + void operator()(ffi::PackedArgs args, ffi::Any* rv) const { std::vector packed_args(args.size()); std::vector> temp_dltensors; @@ -130,7 +130,7 @@ class RPCWrappedFunc : public Object { } } } - auto set_return = [this, rv](TVMArgs args) { this->WrapRemoteReturnToValue(args, rv); }; + auto set_return = [this, rv](ffi::PackedArgs args) { this->WrapRemoteReturnToValue(args, rv); }; sess_->CallFunc(handle_, ffi::PackedArgs(packed_args.data(), packed_args.size()), set_return); } @@ -149,9 +149,9 @@ class RPCWrappedFunc : public Object { std::shared_ptr sess_; // unwrap a remote value to the underlying handle. - void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const; + void* UnwrapRemoteValueToHandle(const ffi::AnyView& arg) const; // wrap a remote return via Set - void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const; + void WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) const; // remove a remote session mask Device RemoveSessMask(Device dev) const { @@ -185,9 +185,9 @@ class RPCModuleNode final : public ModuleNode { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { if (name == "CloseRPCConnection") { - return PackedFunc([this](TVMArgs, TVMRetValue*) { sess_->Shutdown(); }); + return ffi::Function([this](ffi::PackedArgs, ffi::Any*) { sess_->Shutdown(); }); } if (module_handle_ == nullptr) { @@ -203,10 +203,10 @@ class RPCModuleNode final : public ModuleNode { throw; } - PackedFunc GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat, - int min_repeat_ms, int limit_zero_time_iterations, - int cooldown_interval_ms, int repeats_to_cooldown, - int cache_flush_bytes, const std::string& f_preproc_name) { + ffi::Function GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat, + int min_repeat_ms, int limit_zero_time_iterations, + int cooldown_interval_ms, int repeats_to_cooldown, + int cache_flush_bytes, const std::string& f_preproc_name) { InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); // Remove session mask because we pass dev by parts. ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index()) @@ -249,10 +249,11 @@ class RPCModuleNode final : public ModuleNode { *func = WrapRemoteFunc(handle); } - PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { - if (handle == nullptr) return PackedFunc(); + ffi::Function WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { + if (handle == nullptr) return ffi::Function(); auto wf = std::make_shared(handle, sess_); - return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); + return ffi::Function( + [wf](ffi::PackedArgs args, ffi::Any* rv) { return wf->operator()(args, rv); }); } // The module handle @@ -260,15 +261,15 @@ class RPCModuleNode final : public ModuleNode { // The local channel std::shared_ptr sess_; // remote function to get time evaluator - TypedPackedFunc, std::string, int, int, int, int, int, int, int, int, - int, std::string)> + ffi::TypedFunction, std::string, int, int, int, int, int, int, int, + int, int, std::string)> remote_get_time_evaluator_; // remote function getter for modules. - TypedPackedFunc remote_mod_get_function_; + ffi::TypedFunction remote_mod_get_function_; // remote function getter for load module - TypedPackedFunc remote_load_module_; + ffi::TypedFunction remote_load_module_; // remote function getter for load module - TypedPackedFunc remote_import_module_; + ffi::TypedFunction remote_import_module_; }; void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const AnyView& arg) const { @@ -288,7 +289,7 @@ void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const AnyView& arg) const { } } -void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const { +void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) const { int tcode = args[0].cast(); // TODO(tqchen): move to RPC to new ABI if (tcode == kTVMNullptr) { @@ -298,7 +299,8 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) cons ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto wf = std::make_shared(handle, sess_); - *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); + *rv = ffi::Function( + [wf](ffi::PackedArgs args, ffi::Any* rv) { return wf->operator()(args, rv); }); } else if (tcode == kTVMModuleHandle) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); @@ -378,7 +380,7 @@ inline void CPUCacheFlushImpl(const char* addr, unsigned int len) { #endif } -inline void CPUCacheFlush(int begin_index, const TVMArgs& args) { +inline void CPUCacheFlush(int begin_index, const ffi::PackedArgs& args) { for (int i = begin_index; i < args.size(); i++) { CPUCacheFlushImpl(static_cast((args[i].cast()->data)), GetDataSize(*(args[i].cast()))); @@ -402,14 +404,14 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } else { - PackedFunc f_preproc; + ffi::Function f_preproc; if (!f_preproc_name.empty()) { auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); ICHECK(pf_preproc.has_value()) << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - PackedFunc pf = m.GetFunction(name, true); + ffi::Function pf = m.GetFunction(name, true); CHECK(pf != nullptr) << "Cannot find " << name << "` in the global registry"; return profiling::WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, @@ -418,7 +420,7 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") } else { auto pf = tvm::ffi::Function::GetGlobal(name); ICHECK(pf.has_value()) << "Cannot find " << name << " in the global function"; - PackedFunc f_preproc; + ffi::Function f_preproc; if (!f_preproc_name.empty()) { auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); ICHECK(pf_preproc.has_value()) @@ -432,7 +434,7 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") }); TVM_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { CPUCacheFlush(1, args); }); + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CPUCacheFlush(1, args); }); // server function registration. TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) { @@ -457,7 +459,7 @@ TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, M static_cast(parent.operator->())->ImportModule(child); }); -TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Module m = args[0].cast(); std::string tkey = m->type_key(); ICHECK_EQ(tkey, "rpc"); diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index f36b70f64162..25472de72777 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -108,11 +108,11 @@ Module CreatePipeClient(std::vector cmd) { auto endpt = RPCEndpoint::Create(std::make_unique(parent_read, parent_write, pid), "pipe", "pipe"); - endpt->InitRemoteSession(TVMArgs(nullptr, 0)); + endpt->InitRemoteSession(ffi::PackedArgs(nullptr, 0)); return CreateRPCSessionModule(CreateClientSession(endpt)); } -TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::vector cmd; for (int i = 0; i < args.size(); ++i) { cmd.push_back(args[i].cast()); diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index d7529115f14d..823fa232a953 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -35,24 +35,27 @@ std::string RPCGetPath(const std::string& name) { return (*f)(name).cast(); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - std::string file_name = RPCGetPath(args[0].cast()); - auto data = args[1].cast(); - SaveBinaryToFile(file_name, data); -}); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.download").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - std::string file_name = RPCGetPath(args[0].cast()); - std::string data; - LoadBinaryFromFile(file_name, &data); - LOG(INFO) << "Download " << file_name << "... nbytes=" << data.size(); - *rv = ffi::Bytes(data); -}); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.remove").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - std::string file_name = RPCGetPath(args[0].cast()); - RemoveFile(file_name); -}); +TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + std::string file_name = RPCGetPath(args[0].cast()); + auto data = args[1].cast(); + SaveBinaryToFile(file_name, data); + }); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.download") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + std::string file_name = RPCGetPath(args[0].cast()); + std::string data; + LoadBinaryFromFile(file_name, &data); + LOG(INFO) << "Download " << file_name << "... nbytes=" << data.size(); + *rv = ffi::Bytes(data); + }); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.remove") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + std::string file_name = RPCGetPath(args[0].cast()); + RemoveFile(file_name); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index e6f6852a0906..76e07e00fb2b 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -43,7 +43,7 @@ void RPCSession::AsyncCallFunc(PackedFuncHandle func, ffi::PackedArgs packed_arg FAsyncCallback callback) { try { this->CallFunc(func, packed_args, - [&callback](TVMArgs args) { callback(RPCCode::kReturn, args); }); + [&callback](ffi::PackedArgs args) { callback(RPCCode::kReturn, args); }); } catch (const std::exception& e) { this->SendException(callback, e.what()); } diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 3fcc1eb3dbc4..271e26dfd04e 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -47,7 +47,7 @@ namespace runtime { */ class RPCSession { public: - /*! \brief PackedFunc Handle in the remote. */ + /*! \brief ffi::Function Handle in the remote. */ using PackedFuncHandle = void*; /*! \brief Module handle in the remote. */ @@ -62,13 +62,13 @@ class RPCSession { * \param encode_args The arguments that we can encode the return values into. * * Encoding convention (as list of arguments): - * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. - * - PackedFunc/Module: [tcode: int, handle: void*] + * - str/float/int/byte: [tcode: int, value: TVMValue] value follows ffi::Function convention. + * - ffi::Function/Module: [tcode: int, handle: void*] * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] * DLTensor* contains the meta-data as well as handle into the remote data. * nd_handle can be used for deletion. */ - using FEncodeReturn = std::function; + using FEncodeReturn = std::function; /*! * \brief Callback to send an encoded return values via encode_args. @@ -76,7 +76,7 @@ class RPCSession { * \param status The return status, can be RPCCode::kReturn or RPCCode::kException. * \param encode_args The arguments that we can encode the return values into. */ - using FAsyncCallback = std::function; + using FAsyncCallback = std::function; /*! \brief Destructor.*/ virtual ~RPCSession() {} @@ -93,9 +93,9 @@ class RPCSession { * * Calling convention: * - * - type_code is follows the PackedFunc convention. - * - int/float/string/bytes follows the PackedFunc convention, all data are local. - * - PackedFunc/Module and future remote objects: pass remote handle instead. + * - type_code is follows the ffi::Function convention. + * - int/float/string/bytes follows the ffi::Function convention, all data are local. + * - ffi::Function/Module and future remote objects: pass remote handle instead. * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor * points to a remote data handle returned by the Device API. * The meta-data of the DLTensor sits on local. @@ -106,7 +106,7 @@ class RPCSession { * if they want to do inplace modify and forward. * * The callee need to store the return value into ret_value. - * - PackedFunc/Module are stored as void* + * - ffi::Function/Module are stored as void* * - NDArray is stored as local NDArray, whose data field is a remote handle. * Notably the NDArray's deleter won't delete remote handle. * It is up to the user of the RPCSession to such wrapping. diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 7b66725ce55f..286d143bad6c 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -65,7 +65,7 @@ class SockChannel final : public RPCChannel { }; std::shared_ptr RPCConnect(std::string url, int port, std::string key, - bool enable_logging, TVMArgs init_seq) { + bool enable_logging, ffi::PackedArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); @@ -108,7 +108,7 @@ std::shared_ptr RPCConnect(std::string url, int port, std::string k } Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging, - TVMArgs init_seq) { + ffi::PackedArgs init_seq) { auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq); return CreateRPCSessionModule(CreateClientSession(endpt)); } @@ -119,12 +119,12 @@ TVM_DLL void RPCServerLoop(int sockfd) { RPCEndpoint::Create(std::make_unique(sock), "SockServerLoop", "")->ServerLoop(); } -void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { +void RPCServerLoop(ffi::Function fsend, ffi::Function frecv) { RPCEndpoint::Create(std::make_unique(fsend, frecv), "SockServerLoop", "") ->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto url = args[0].cast(); int port = args[1].cast(); auto key = args[2].cast(); @@ -132,12 +132,11 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](TVMArgs args, TVMRetValue* *rv = RPCClientConnect(url, port, key, enable_logging, args.Slice(4)); }); -TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_int = args[0].as()) { RPCServerLoop(opt_int.value()); } else { - RPCServerLoop(args[0].cast(), - args[1].cast()); + RPCServerLoop(args[0].cast(), args[1].cast()); } }); diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index b9e63271cfa9..47633b326cba 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -48,9 +48,10 @@ class StaticLibraryNode final : public runtime::ModuleNode { const char* type_key() const final { return "static_library"; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { if (name == "get_func_names") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = func_names_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = func_names_; }); } else { return {}; } diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc index d27c2d4c03c6..30fca708b8e8 100644 --- a/src/runtime/system_library.cc +++ b/src/runtime/system_library.cc @@ -108,11 +108,11 @@ class SystemLibModuleRegistry { std::mutex mutex_; // we need to make sure each lib map have an unique // copy through out the entire lifetime of the process - // so the cached PackedFunc in the system do not get out dated. + // so the cached ffi::Function in the system do not get out dated. std::unordered_map lib_map_; }; -TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::string symbol_prefix = ""; if (args.size() != 0) { symbol_prefix = args[0].cast(); diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index de7f38ed280a..457f44799d7c 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -379,20 +379,21 @@ class ThreadPool { * \brief args[0] is the AffinityMode, args[1] is the number of threads. * args2 is a list of CPUs which is used to set the CPU affinity. */ -TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - threading::ThreadGroup::AffinityMode mode = - static_cast(args[0].cast()); - int nthreads = args[1].cast(); - std::vector cpus; - if (args.size() >= 3) { - auto cpu_array = args[2].cast>(); - for (auto cpu : cpu_array) { - ICHECK(IsNumber(cpu)) << "The CPU core information '" << cpu << "' is not a number."; - cpus.push_back(std::stoi(cpu)); - } - } - threading::Configure(mode, nthreads, cpus); -}); +TVM_REGISTER_GLOBAL("runtime.config_threadpool") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + threading::ThreadGroup::AffinityMode mode = + static_cast(args[0].cast()); + int nthreads = args[1].cast(); + std::vector cpus; + if (args.size() >= 3) { + auto cpu_array = args[2].cast>(); + for (auto cpu : cpu_array) { + ICHECK(IsNumber(cpu)) << "The CPU core information '" << cpu << "' is not a number."; + cpus.push_back(std::stoi(cpu)); + } + } + threading::Configure(mode, nthreads, cpus); + }); TVM_REGISTER_GLOBAL("runtime.NumThreads").set_body_typed([]() -> int32_t { return threading::NumThreads(); diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 972140489b90..049e6467d1fc 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -19,7 +19,7 @@ /*! * \file thread_storage_scope.h - * \brief Extract launch parameters configuration from TVMArgs. + * \brief Extract launch parameters configuration from ffi::PackedArgs. */ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ @@ -261,7 +261,7 @@ class LaunchParamConfig { } } // extract workload from arguments. - ThreadWorkLoad Extract(TVMArgs args) const { + ThreadWorkLoad Extract(ffi::PackedArgs args) const { ThreadWorkLoad w; std::fill(w.work_size, w.work_size + 6, 1); const TVMFFIAny* raw_args = reinterpret_cast(args.data()); diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 8ec8101d8305..fcb3e764bf86 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -91,7 +91,7 @@ int VulkanDeviceAPI::GetActiveDeviceID() { return active_device_id_per_thread.Ge VulkanDevice& VulkanDeviceAPI::GetActiveDevice() { return device(GetActiveDeviceID()); } -void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { +void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { size_t index = static_cast(dev.device_id); if (kind == kExist) { *rv = static_cast(index < devices_.size()); @@ -177,7 +177,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) } } -void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) { +void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, ffi::Any* rv) { size_t index = static_cast(dev.device_id); const auto& prop = device(index).device_properties; @@ -455,14 +455,14 @@ VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { return const_cast(const_cast(this)->device(device_id)); } -TVM_REGISTER_GLOBAL("device_api.vulkan").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.vulkan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = VulkanDeviceAPI::Global(); *rv = static_cast(ptr); }); TVM_REGISTER_GLOBAL("device_api.vulkan.get_target_property") .set_body_typed([](Device dev, const std::string& property) { - TVMRetValue rv; + ffi::Any rv; VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); return rv; }); diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index 35100ee62764..64ca0db701e8 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -44,7 +44,7 @@ class VulkanDeviceAPI final : public DeviceAPI { // Implement active device void SetDevice(Device dev) final; - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; // Implement memory management required by DeviceAPI void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; @@ -107,7 +107,7 @@ class VulkanDeviceAPI final : public DeviceAPI { * Returns the results of feature/property queries done during the * device initialization. */ - void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) final; + void GetTargetProperty(Device dev, const std::string& property, ffi::Any* rv) final; private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 15f41afe91eb..ab212c2eade4 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -42,7 +42,7 @@ void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, +void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) const { int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID(); auto& device = VulkanDeviceAPI::Global()->device(device_id); @@ -205,12 +205,12 @@ VulkanModuleNode::~VulkanModuleNode() { } } -PackedFunc VulkanModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Function VulkanModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); - if (it == fmap_.end()) return PackedFunc(); + if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index a983b3e70205..9b6f3703f34f 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -60,7 +60,7 @@ class VulkanWrappedFunc { size_t num_buffer_args, size_t num_pack_args, const std::vector& launch_param_tags); - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; + void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) const; private: // internal module @@ -94,7 +94,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args); diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index f295b48f5158..51df91238954 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -36,7 +36,7 @@ void IRBuilderFrameNode::ExitWithScope() { IRBuilder::Current()->frames.pop_back(); } -void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc callback) { +void IRBuilderFrameNode::AddCallback(ffi::TypedFunction callback) { if (IRBuilder::Current()->frames.empty()) { LOG(FATAL) << "ValueError: No frames in Builder to add callback"; } diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 1a99448711b3..f74e08848920 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -103,7 +103,7 @@ void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { } void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, - runtime::TypedPackedFunc is_var) { + ffi::TypedFunction is_var) { class Visitor : public AttrVisitor { public: inline void operator()(ObjectRef obj) { Visit("", &obj); } @@ -179,7 +179,7 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, std::unordered_set visited_; public: - runtime::TypedPackedFunc is_var; + ffi::TypedFunction is_var; std::unordered_map> common_prefix; }; Visitor visitor; diff --git a/src/support/array.h b/src/support/array.h index 057e1668b383..6fd30503f016 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -165,7 +165,7 @@ struct AsVectorImpl { template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { - TVMRetValue ret_value; + ffi::Any ret_value; ret_value = array; Array as_int_vec = ret_value.cast>(); @@ -180,7 +180,7 @@ struct AsVectorImpl { template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { - TVMRetValue ret_value; + ffi::Any ret_value; ret_value = array; Array as_int_vec = ret_value.cast>(); @@ -195,7 +195,7 @@ struct AsVectorImpl { template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { - TVMRetValue ret_value; + ffi::Any ret_value; ret_value = array; Array as_int_vec = ret_value.cast>(); @@ -228,7 +228,7 @@ struct AsArrayImpl { Array result; result.reserve(vec.size()); for (auto x : vec) { - TVMRetValue ret_value; + ffi::Any ret_value; ret_value = x; result.push_back(ret_value.cast()); } @@ -242,7 +242,7 @@ struct AsArrayImpl { Array result; result.reserve(vec.size()); for (auto x : vec) { - TVMRetValue ret_value; + ffi::Any ret_value; ret_value = x; result.push_back(ret_value.cast()); } @@ -256,7 +256,7 @@ struct AsArrayImpl { Array result; result.reserve(vec.size()); for (auto x : vec) { - TVMRetValue ret_value; + ffi::Any ret_value; ret_value = x; result.push_back(ret_value.cast()); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 9e727f06d4c2..4482272ced53 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -54,15 +54,15 @@ struct TestAttrs : public AttrsNode { TVM_REGISTER_NODE_TYPE(TestAttrs); TVM_REGISTER_GLOBAL("testing.test_wrap_callback") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { - PackedFunc pf = args[0].cast(); - *ret = runtime::TypedPackedFunc([pf]() { pf(); }); + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + ffi::Function pf = args[0].cast(); + *ret = ffi::TypedFunction([pf]() { pf(); }); }); TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { - PackedFunc pf = args[0].cast(); - auto result = runtime::TypedPackedFunc([pf]() { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + ffi::Function pf = args[0].cast(); + auto result = ffi::TypedFunction([pf]() { try { pf(); } catch (std::exception& err) { @@ -72,13 +72,12 @@ TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") }); TVM_REGISTER_GLOBAL("testing.test_check_eq_callback") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto msg = args[0].cast(); - *ret = runtime::TypedPackedFunc( - [msg](int x, int y) { CHECK_EQ(x, y) << msg; }); + *ret = ffi::TypedFunction([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); }); -TVM_REGISTER_GLOBAL("testing.device_test").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("testing.device_test").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto dev = args[0].cast(); int dtype = args[1].cast(); int did = args[2].cast(); @@ -87,13 +86,14 @@ TVM_REGISTER_GLOBAL("testing.device_test").set_body_packed([](TVMArgs args, TVMR *ret = dev; }); -TVM_REGISTER_GLOBAL("testing.identity_cpp").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py"); - ICHECK(identity_func.has_value()) - << "AttributeError: \"testing.identity_py\" is not registered. Please check " - "if the python module is properly loaded"; - *ret = (*identity_func)(args[0]); -}); +TVM_REGISTER_GLOBAL("testing.identity_cpp") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py"); + ICHECK(identity_func.has_value()) + << "AttributeError: \"testing.identity_py\" is not registered. Please check " + "if the python module is properly loaded"; + *ret = (*identity_func)(args[0]); + }); // in src/api_test.cc void ErrorTest(int x, int y) { @@ -113,19 +113,19 @@ class FrontendTestModuleNode : public runtime::ModuleNode { static constexpr const char* kAddFunctionName = "__add_function"; - virtual PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); private: - std::unordered_map functions_; + std::unordered_map functions_; }; constexpr const char* FrontendTestModuleNode::kAddFunctionName; -PackedFunc FrontendTestModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Function FrontendTestModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { if (name == kAddFunctionName) { - return runtime::TypedPackedFunc( - [this, sptr_to_self](std::string func_name, PackedFunc pf) { + return ffi::TypedFunction( + [this, sptr_to_self](std::string func_name, ffi::Function pf) { CHECK_NE(func_name, kAddFunctionName) << "func_name: cannot be special function " << kAddFunctionName; functions_[func_name] = pf; @@ -134,7 +134,7 @@ PackedFunc FrontendTestModuleNode::GetFunction(const String& name, auto it = functions_.find(name); if (it == functions_.end()) { - return PackedFunc(); + return ffi::Function(); } return it->second; @@ -197,10 +197,10 @@ TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") }); TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") - .set_body_typed([](Array> arr) -> ObjectRef { + .set_body_typed([](Array> arr) -> ObjectRef { for (auto item : arr) { CHECK(item.as() || item.as()) - << "Array should contain either PrimExpr or PackedFunc"; + << "Array should contain either PrimExpr or ffi::Function"; } return arr; }); @@ -254,7 +254,7 @@ class TestingEventLogger { std::vector entries_; }; -TVM_REGISTER_GLOBAL("testing.record_event").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("testing.record_event").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args.size() != 0 && args[0].as()) { TestingEventLogger::ThreadLocal()->Record(args[0].cast()); } else { @@ -262,7 +262,7 @@ TVM_REGISTER_GLOBAL("testing.record_event").set_body_packed([](TVMArgs args, TVM } }); -TVM_REGISTER_GLOBAL("testing.reset_events").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("testing.reset_events").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { TestingEventLogger::ThreadLocal()->Reset(); }); diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 11b40a67263b..79065d0024c5 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -23,27 +23,27 @@ namespace tvm { namespace datatype { -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::PackedArgs; TVM_REGISTER_GLOBAL("dtype.register_custom_type") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { datatype::Registry::Global()->Register(args[0].cast(), static_cast(args[1].cast())); }); TVM_REGISTER_GLOBAL("dtype.get_custom_type_code") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = datatype::Registry::Global()->GetTypeCode(args[0].cast()); }); TVM_REGISTER_GLOBAL("dtype.get_custom_type_name") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = Registry::Global()->GetTypeName(args[0].cast()); }); TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = Registry::Global()->GetTypeRegistered(args[0].cast()); }); diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 37a11d857699..1399fc083a08 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -107,7 +107,7 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") - .set_body_packed([](const TVMArgs& targs, TVMRetValue* rv) { + .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAArch64()); }); diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 4cd833d2da69..a1cff52beb2a 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -68,7 +68,7 @@ static inline int DetectROCMmaxThreadsPerBlock() { tvm_dev.device_id = 0; tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_dev, true); if (api != nullptr) { - TVMRetValue val; + ffi::Any val; api->GetAttr(tvm_dev, tvm::runtime::kExist, &val); if (val.cast() == 1) { tvm::runtime::DeviceAPI::Get(tvm_dev)->GetAttr(tvm_dev, tvm::runtime::kMaxThreadsPerBlock, @@ -359,7 +359,7 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") - .set_body_packed([](const TVMArgs& targs, TVMRetValue* rv) { + .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAMDGPU()); }); diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 0fd6673f49d0..3abebec2a36e 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -133,7 +133,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") - .set_body_packed([](const TVMArgs& targs, TVMRetValue* rv) { + .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenARM()); }); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 332650b7fe94..0a5223ae029b 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -506,7 +506,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { // There are two reasons why we create another function for compute_scope // - Make sure the generated compute function is clearly separately(though it can get inlined) - // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs. + // - Set noalias on all the pointer arguments, some of them are loaded from ffi::PackedArgs. // This is easier than set the alias scope manually. Array vargs = tir::UndefinedVars(op->body, {}); std::vector arg_values; @@ -1159,7 +1159,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") - .set_body_packed([](const TVMArgs& targs, TVMRetValue* rv) { + .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenCPU()); }); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 9b1025aff9d3..22708c61178a 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -592,7 +592,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon") - .set_body_packed([](const TVMArgs& targs, TVMRetValue* rv) { + .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenHexagon()); }); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 2e2f34ba8a84..fa035367408c 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -371,7 +371,7 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") - .set_body_packed([](const TVMArgs& targs, TVMRetValue* rv) { + .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenNVPTX()); }); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 563b7be2caa3..954d4e7efd56 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -133,7 +133,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") - .set_body_packed([](const TVMArgs& targs, TVMRetValue* rv) { + .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenX86_64()); }); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 9f1fae0e7b64..396b02063b34 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -84,9 +84,9 @@ namespace tvm { namespace codegen { -using runtime::PackedFunc; -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::Function; +using ffi::PackedArgs; class LLVMModuleNode final : public runtime::ModuleNode { public: @@ -94,7 +94,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { const char* type_key() const final { return "llvm"; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; /*! \brief Get the property of the runtime module .*/ // TODO(tvm-team): Make it serializable @@ -154,12 +154,13 @@ LLVMModuleNode::~LLVMModuleNode() { module_owning_ptr_.reset(); } -PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { +ffi::Function LLVMModuleNode::GetFunction(const String& name, + const ObjectPtr& sptr_to_self) { if (name == "__tvm_is_system_module") { bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); - return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); + return ffi::Function([flag](ffi::PackedArgs args, ffi::Any* rv) { *rv = flag; }); } else if (name == "__tvm_get_system_lib_prefix") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { auto* md = module_->getModuleFlag("tvm_system_lib_prefix"); if (md != nullptr) { *rv = llvm::cast(md)->getString().str(); @@ -168,15 +169,16 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtrfunction_names_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->function_names_; }); } else if (name == "get_symbol") { - return PackedFunc(nullptr); + return ffi::Function(nullptr); } else if (name == "get_const_vars") { - return PackedFunc(nullptr); + return ffi::Function(nullptr); } else if (name == "_get_target_string") { std::string target_string = LLVMTarget::GetTargetMetadata(*module_); - return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; }); + return ffi::Function( + [target_string](ffi::PackedArgs args, ffi::Any* rv) { *rv = target_string; }); } ICHECK(jit_engine_.size()) << "JIT engine type is missing"; if ((jit_engine_ == "mcjit") && (mcjit_ee_ == nullptr)) InitMCJIT(); @@ -195,8 +197,8 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtr(GetFunctionAddr(name, *llvm_target)); } - if (faddr == nullptr) return PackedFunc(); - return tvm::runtime::WrapPackedFunc(faddr, sptr_to_self); + if (faddr == nullptr) return ffi::Function(); + return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self); } namespace { diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 90be76663847..8d1ad91746b6 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -714,9 +714,9 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; - return PackedFunc(nullptr); + return ffi::Function(nullptr); } void SaveToBinary(dmlc::Stream* stream) final { diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 1a5de7aee484..ec3cadb8c8e4 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -41,9 +41,9 @@ namespace tvm { namespace codegen { -using runtime::PackedFunc; -using runtime::TVMArgs; -using runtime::TVMRetValue; +using ffi::Any; +using ffi::Function; +using ffi::PackedArgs; using runtime::FunctionInfo; using runtime::GetFileFormat; @@ -56,10 +56,10 @@ class SourceModuleNode : public runtime::ModuleNode { SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} const char* type_key() const final { return "source"; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; - return PackedFunc(); + return ffi::Function(); } String GetSource(const String& format) final { return code_; } @@ -84,22 +84,22 @@ class CSourceModuleNode : public runtime::ModuleNode { : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) {} const char* type_key() const final { return "c"; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { // Currently c-source module is used as demonstration purposes with binary metadata module // that expects get_symbol interface. When c-source module is used as external module, it // will only contain one function. However, when its used as an internal module (e.g., target // "c") it can have many functions. if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->func_names_[0]; }); } else if (name == "get_const_vars") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->const_vars_; }); } else if (name == "get_func_names") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); + return ffi::Function( + [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->func_names_; }); } else { - return PackedFunc(nullptr); + return ffi::Function(nullptr); } } @@ -202,10 +202,10 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { std::function fget_source) : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; - return PackedFunc(); + return ffi::Function(); } String GetSource(const String& format) final { diff --git a/src/target/target.cc b/src/target/target.cc index 360c56642b37..7a29ca2ef537 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -60,7 +60,7 @@ class TargetInternal { static ObjectPtr FromConfigString(const String& config_str); static ObjectPtr FromRawString(const String& target_str); static ObjectPtr FromConfig(Map config); - static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv); + static void ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv); static Target WithHost(const Target& target, const Target& target_host) { ObjectPtr n = make_object(*target.get()); n->host = target_host; @@ -756,7 +756,7 @@ Target Target::Current(bool allow_not_defined) { /********** Creation **********/ -void TargetInternal::ConstructorDispatcher(TVMArgs args, TVMRetValue* rv) { +void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { if (args.size() == 1) { const auto& arg = args[0]; if (auto opt_target = arg.as()) { @@ -923,7 +923,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // parse host if (config.count(kHost)) { - target->host = PackedFunc(ConstructorDispatcher)(config[kHost]).cast(); + target->host = ffi::Function(ConstructorDispatcher)(config[kHost]).cast(); config.erase(kHost); } else { target->host = NullOpt; @@ -987,7 +987,7 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, return output; } - TVMRetValue ret; + ffi::Any ret; api->GetAttr(device, runtime::kExist, &ret); bool device_exists = ret.cast(); if (!device_exists) { @@ -1000,7 +1000,7 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, for (const auto& kv : target->kind->key2vtype_) { const String& key = kv.first; - TVMRetValue ret; + ffi::Any ret; api->GetTargetProperty(device, key, &ret); output[key] = ret; } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index bea7791aad4e..2a8dad4162cf 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -78,7 +78,7 @@ TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_ return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name); } -void TargetKindRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { +void TargetKindRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel); } @@ -122,7 +122,7 @@ std::string ExtractStringWithPrefix(const std::string& str, const std::string& p * \param val The detected value * \return A boolean indicating if detection succeeds */ -static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, TVMRetValue* val) { +static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, ffi::Any* val) { using runtime::DeviceAPI; DeviceAPI* api = DeviceAPI::Get(device, true); // Check if compiled with the corresponding device api @@ -168,7 +168,7 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) { } else { // Use the compute version of the first CUDA GPU instead int archInt; - TVMRetValue version; + ffi::Any version; if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_50\" instead"; archInt = 50; @@ -196,7 +196,7 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) { } else { // Use the compute version of the first CUDA GPU instead int arch; - TVMRetValue version; + ffi::Any version; if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead"; arch = 50; @@ -222,7 +222,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { arch = ExtractStringWithPrefix(mcpu, "gfx"); ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu; } else { - TVMRetValue val; + ffi::Any val; if (const auto f_get_rocm_arch = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_get_arch")) { arch = (*f_get_rocm_arch)().cast(); } @@ -232,7 +232,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { // Before ROCm 3.5 we needed code object v2, starting // with 3.5 we need v3 (this argument disables v3) - TVMRetValue val; + ffi::Any val; int version; if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) { LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5"; @@ -447,9 +447,9 @@ TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break /********** Registry **********/ TVM_REGISTER_GLOBAL("target.TargetKindGetAttr") - .set_body_typed([](TargetKind kind, String attr_name) -> TVMRetValue { - auto target_attr_map = TargetKind::GetAttrMap(attr_name); - TVMRetValue rv; + .set_body_typed([](TargetKind kind, String attr_name) -> ffi::Any { + auto target_attr_map = TargetKind::GetAttrMap(attr_name); + ffi::Any rv; if (target_attr_map.count(kind)) { rv = target_attr_map[kind]; } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 7f0f90302e79..d087f845cc0f 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -784,7 +784,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list, return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { Array arg_list = args[0].cast>(); std::optional index_dtype_override{std::nullopt}; // Add conversion to make std::optional compatible with FFI. diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index d2d03132981c..654d3332c755 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -200,7 +200,7 @@ Array UndefinedVars(const PrimExpr& expr, const Array& args) { } TVM_REGISTER_GLOBAL("tir.analysis.UndefinedVars") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_stmt = args[0].as()) { *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); } else if (auto opt_expr = args[0].as()) { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index e1626c08a837..3b94c2ae757a 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -640,7 +640,7 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std TVM_REGISTER_NODE_TYPE(BufferNode); -TVM_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_EQ(args.size(), 11); auto buffer_type = args[8].cast(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index bec6c040856f..e89162239a63 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -127,8 +127,7 @@ Var Var::copy_with_dtype(DataType dtype) const { return Var(new_ptr); } -TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type, - Span span) { +TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, ffi::AnyView type, Span span) { if (type.as()) { return Var(name_hint, type.cast(), span); } else { diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index aed8361d04f1..ff948da01289 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -43,7 +43,7 @@ IndexMap::IndexMap(Array initial_indices, Array final_indices, data_ = std::move(n); } -IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func, +IndexMap IndexMap::FromFunc(int ndim, ffi::TypedFunction(Array)> func, Optional inverse_index_map) { Array initial_indices; initial_indices.reserve(ndim); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index cf09ea306e1c..23dba3ef7233 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -592,7 +592,7 @@ void PostOrderVisit(const ObjectRef& node, std::function class IRTransformer final : public StmtExprMutator { public: - IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, + IRTransformer(const ffi::Function& f_preorder, const ffi::Function& f_postorder, const std::unordered_set& only_enable) : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {} @@ -627,14 +627,14 @@ class IRTransformer final : public StmtExprMutator { return new_node; } // The functions - const runtime::PackedFunc& f_preorder_; - const runtime::PackedFunc& f_postorder_; + const ffi::Function& f_preorder_; + const ffi::Function& f_postorder_; // type indices enabled. const std::unordered_set& only_enable_; }; -Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, Optional> only_enable) { +Stmt IRTransform(Stmt ir_node, const ffi::Function& f_preorder, const ffi::Function& f_postorder, + Optional> only_enable) { std::unordered_set only_type_index; if (only_enable.defined()) { for (auto s : only_enable.value()) { @@ -894,11 +894,11 @@ PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); -TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { +TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); -TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { +TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); }); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 85eaa2aed42a..f724b6a74598 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -146,8 +146,7 @@ TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") .set_body_typed( - [](runtime::TypedPackedFunc, IRModule, PassContext)> - pass_func, + [](ffi::TypedFunction, IRModule, PassContext)> pass_func, PassInfo pass_info) { auto wrapped_pass_func = [pass_func](PrimFunc func, IRModule mod, PassContext ctx) { return pass_func(ffi::RValueRef(std::move(func)), mod, ctx); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 4c064185fd7d..838af436a6cd 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -1071,7 +1071,7 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace -TVM_REGISTER_GLOBAL("node._const").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (auto opt = args[0].as()) { *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); } else if (auto opt = args[0].as()) { @@ -1121,7 +1121,7 @@ TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); }) #define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir." #Node).set_body_packed([](TVMArgs args, TVMRetValue* ret) { \ + TVM_REGISTER_GLOBAL("tir." #Node).set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { \ bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ if (lhs_is_int) { \ diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 715a012e6548..9cf96bbd6b68 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -522,7 +522,7 @@ std::tuple, Array> GetReducerAndCombinerL * \return The list of the registered reducer-getter functions * \sa ReducerRegistry */ -std::vector(Array)>> GetReducerGetters(); +std::vector(Array)>> GetReducerGetters(); /*! * \brief Given the input identities and the combiner BufferStores of a reduction, extract the diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index 5f8af8418668..d85be933820c 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -685,7 +685,7 @@ bool FromIdentityCombiner(const Array& identities, const Array(Array)>& reducer_getter : + for (const ffi::TypedFunction(Array)>& reducer_getter : GetReducerGetters()) { Optional reducer = reducer_getter(identities); if (!reducer.defined()) { diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 29165750e8ec..9f24ee1a3e8b 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -176,7 +176,7 @@ struct UnpackedInstTraits { static TVM_ALWAYS_INLINE void _SetAttrs(AnyView* packed_args, const Array& attrs); template static TVM_ALWAYS_INLINE void _SetDecision(AnyView* packed_args, const Any& decision); - static TVM_ALWAYS_INLINE Array _ConvertOutputs(const TVMRetValue& rv); + static TVM_ALWAYS_INLINE Array _ConvertOutputs(const ffi::Any& rv); }; /*! @@ -316,7 +316,7 @@ Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, TTraits::template _SetInputs<1>(packed_args, inputs); TTraits::template _SetAttrs<1 + kNumInputs>(packed_args, attrs); TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(packed_args, decision); - PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { + ffi::Function pf([](const ffi::PackedArgs& args, ffi::Any* rv) -> void { constexpr size_t kNumArgs = details::NumArgs; ICHECK_EQ(args.size(), kNumArgs); ffi::details::unpack_call(std::make_index_sequence{}, nullptr, @@ -347,7 +347,7 @@ String UnpackedInstTraits::AsPython(const Array& inputs, const Arr TTraits::template _SetInputs<1>(packed_args, inputs); TTraits::template _SetAttrs<1 + kNumInputs>(packed_args, attrs); TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(packed_args, decision); - PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { + ffi::Function pf([](const ffi::PackedArgs& args, ffi::Any* rv) -> void { constexpr size_t kNumArgs = details::NumArgs; ICHECK_EQ(args.size(), kNumArgs); ffi::details::unpack_call(std::make_index_sequence{}, nullptr, @@ -396,7 +396,7 @@ TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetDecision(AnyView* packed } template -TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs(const TVMRetValue& rv) { +TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs(const ffi::Any& rv) { using method_type = decltype(TTraits::UnpackedApplyToSchedule); using return_type = details::ReturnType; constexpr int is_array = details::IsTVMArray; diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index c294f7092516..126832cc85fb 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -398,15 +398,15 @@ struct ReducerRegistry { })} {} static void RegisterReducer( - int n_buffers, TypedPackedFunc(Array, Array)> combiner_getter, - TypedPackedFunc(Array)> identity_getter) { + int n_buffers, ffi::TypedFunction(Array, Array)> combiner_getter, + ffi::TypedFunction(Array)> identity_getter) { ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( n_buffers, std::move(combiner_getter), std::move(identity_getter))); } - static TypedPackedFunc(Array)> CreateReducerGetter( - int n_buffers, TypedPackedFunc(Array, Array)> combiner_getter, - TypedPackedFunc(Array)> identity_getter) { + static ffi::TypedFunction(Array)> CreateReducerGetter( + int n_buffers, ffi::TypedFunction(Array, Array)> combiner_getter, + ffi::TypedFunction(Array)> identity_getter) { return [n_buffers, // combiner_getter = std::move(combiner_getter), // identity_getter = std::move(identity_getter) // @@ -429,10 +429,10 @@ struct ReducerRegistry { return &instance; } - std::vector(Array)>> reducer_getters; + std::vector(Array)>> reducer_getters; }; -std::vector(Array)>> GetReducerGetters() { +std::vector(Array)>> GetReducerGetters() { return ReducerRegistry::Global()->reducer_getters; } @@ -1345,7 +1345,8 @@ TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); /******** FFI ********/ TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer") - .set_body_typed([](int n_buffers, PackedFunc combiner_getter, PackedFunc identity_getter) { + .set_body_typed([](int n_buffers, ffi::Function combiner_getter, + ffi::Function identity_getter) { ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), std::move(identity_getter)); }); diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index cde7a6fc05e0..a666a1d5902e 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -294,9 +294,9 @@ Optional TraceNode::Pop() { void TraceNode::ApplyToSchedule( Schedule sch, bool remove_postproc, - runtime::TypedPackedFunc& inputs, // - const Array& attrs, // - const Any& decision)> + ffi::TypedFunction& inputs, // + const Array& attrs, // + const Any& decision)> decision_provider) const { std::unordered_map rv_map; for (const Instruction& inst : this->insts) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 0227ec514498..cc58f96b83fb 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -148,7 +148,7 @@ inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind } /*! - * \brief Get the type that is passed around TVM PackedFunc API. + * \brief Get the type that is passed around TVM ffi::Function API. * \param t The original type. * \return The corresponding API type. */ diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 45a3263a9cf1..6eb196d2520e 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -53,7 +53,7 @@ struct KernelInfo { Array launch_params; // Additional arguments which must be provided to the host-side - // PackedFunc. These may be in terms of the function's parameters + // ffi::Function. These may be in terms of the function's parameters // (e.g. a function that computes the average of `N` elements, and // which must be launched with `N` CUDA threads). Array launch_args; diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 212ccf6e5616..c1b8a2e83a45 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -40,7 +40,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - using FLowerGeneral = runtime::TypedPackedFunc; + using FLowerGeneral = ffi::TypedFunction; IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") : IRMutatorWithAnalyzer(analyzer) { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d241d43c19ac..7f8dc60460b4 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -183,7 +183,7 @@ Optional RequiresPackedAPI(const PrimFunc& func) { } } - // Internal function calls do not need the PackedFunc API + // Internal function calls do not need the ffi::Function API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol.defined()) { return NullOpt; diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 1cd9bb21a7e0..7efa23bc322d 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -98,7 +98,7 @@ transform::Pass AnnotateEntryFunc() { return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", {}); } -transform::Pass Filter(runtime::TypedPackedFunc fcond) { +transform::Pass Filter(ffi::TypedFunction fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (fcond(f)) { return f; diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 660cc241e512..6d6dc4edc5f6 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -32,19 +32,19 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_REGISTER_GLOBAL(OpName).set_body_packed([](TVMArgs args, TVMRetValue* rv) { \ - bool lhs_is_tensor = args[0].as().has_value(); \ - bool rhs_is_tensor = args[1].as().has_value(); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } \ +#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ + TVM_REGISTER_GLOBAL(OpName).set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { \ + bool lhs_is_tensor = args[0].as().has_value(); \ + bool rhs_is_tensor = args[1].as().has_value(); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } \ }); TOPI_REGISTER_BCAST_OP("topi.add", topi::add); @@ -73,7 +73,7 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); -TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = broadcast_to(args[0].cast(), args[1].cast>()); }); diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 8f9b9d953af3..e4ac103f14d6 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -355,7 +355,7 @@ Array InferEinsumShape(const std::string& subscripts, return einsum_builder.InferShape(); } -TVM_REGISTER_GLOBAL("topi.einsum").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.einsum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = einsum(args[0].cast(), args[1].cast>()); }); diff --git a/src/topi/elemwise.cc b/src/topi/elemwise.cc index ce7f7f4eb102..e3a3411a9c6c 100644 --- a/src/topi/elemwise.cc +++ b/src/topi/elemwise.cc @@ -31,139 +31,139 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.acos").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.acos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = acos(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.acosh").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.acosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = acosh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.asin").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.asin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = asin(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.asinh").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.asinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = asinh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.atanh").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.atanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = atanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.exp").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = exp(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_exp").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.fast_exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_exp(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.erf").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = erf(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_erf").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.fast_erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_erf(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.tan").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.tan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tan(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.cos").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.cos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cos(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.cosh").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.cosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cosh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sin").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sin(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sinh").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sinh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.tanh").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_tanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.atan").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.atan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = atan(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sigmoid").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sigmoid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sigmoid(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sqrt").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sqrt(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.rsqrt").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.rsqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = rsqrt(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.log").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log2").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.log2").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log2(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log10").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.log10").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log10(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.identity").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.identity").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = identity(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.negative").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.negative").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = negative(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.clip").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.clip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = clip(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.cast").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.cast").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cast(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.reinterpret").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.reinterpret").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reinterpret(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = elemwise_sum(args[0].cast>()); }); -TVM_REGISTER_GLOBAL("topi.sign").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sign").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sign(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.full").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.full").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = full(args[0].cast>(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.full_like").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.full_like").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = full_like(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.logical_not").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.logical_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = logical_not(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = bitwise_not(args[0].cast()); }); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 58ef57f5080f..e4c9ae5f60e1 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -45,75 +45,79 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_REGISTER_GLOBAL("topi.nn.relu").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = relu(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = leaky_relu(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = prelu(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pad").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.pad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = pad(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.space_to_batch_nd").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = space_to_batch_nd(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast()); -}); +TVM_REGISTER_GLOBAL("topi.nn.space_to_batch_nd") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = space_to_batch_nd(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = batch_to_space_nd(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast()); -}); +TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = batch_to_space_nd(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.nll_loss").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.nll_loss").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nll_loss(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast()); }); /* Ops from nn/dense.h */ -TVM_REGISTER_GLOBAL("topi.nn.dense").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::dense(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); /* Ops from nn/bias_add.h */ -TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::bias_add(args[0].cast(), args[1].cast(), args[2].cast()); }); /* Ops from nn/dilate.h */ -TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::dilate(args[0].cast(), args[1].cast>(), args[2].cast()); }); /* Ops from nn/flatten.h */ -TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::flatten(args[0].cast()); }); /* Ops from nn/mapping.h */ -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::scale_shift_nchw(args[0].cast(), args[1].cast(), - args[2].cast()); -}); - -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::scale_shift_nhwc(args[0].cast(), args[1].cast(), - args[2].cast()); -}); +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::scale_shift_nchw(args[0].cast(), args[1].cast(), + args[2].cast()); + }); + +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::scale_shift_nhwc(args[0].cast(), args[1].cast(), + args[2].cast()); + }); /* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool_grad(args[0].cast(), args[1].cast(), args[2].cast>(), args[3].cast>(), @@ -121,44 +125,47 @@ TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body_packed([](TVMArgs args, TVMRet args[6].cast(), args[7].cast(), args[8].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::global_pool(args[0].cast(), static_cast(args[1].cast()), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool1d").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::adaptive_pool1d(args[0].cast(), args[1].cast>(), - static_cast(args[2].cast()), - args[3].cast()); -}); - -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::adaptive_pool(args[0].cast(), args[1].cast>(), - static_cast(args[2].cast()), - args[3].cast()); -}); - -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::adaptive_pool3d(args[0].cast(), args[1].cast>(), - static_cast(args[2].cast()), - args[3].cast()); -}); - -TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool1d") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::adaptive_pool1d(args[0].cast(), args[1].cast>(), + static_cast(args[2].cast()), + args[3].cast()); + }); + +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::adaptive_pool(args[0].cast(), args[1].cast>(), + static_cast(args[2].cast()), + args[3].cast()); + }); + +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::adaptive_pool3d(args[0].cast(), args[1].cast>(), + static_cast(args[2].cast()), + args[3].cast()); + }); + +TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool1d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool2d").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool2d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool2d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool3d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), @@ -166,51 +173,53 @@ TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](TVMArgs args, TVMRetVal }); /* Ops from nn/softmax.h */ -TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::softmax(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::log_softmax(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::lrn(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast()); }); /* Ops from nn/bnn.h */ -TVM_REGISTER_GLOBAL("topi.nn.binarize_pack").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::binarize_pack(args[0].cast(), args[1].cast()); -}); +TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::binarize_pack(args[0].cast(), args[1].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::binary_dense(args[0].cast(), args[1].cast()); }); /* Ops from nn/layer_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::layer_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast>(), args[4].cast()); }); /* Ops from nn/group_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::group_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast>(), args[6].cast()); }); /* Ops from nn/instance_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::instance_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), - args[4].cast()); -}); +TVM_REGISTER_GLOBAL("topi.nn.instance_norm") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::instance_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast>(), + args[4].cast()); + }); /* Ops from nn/rms_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::rms_norm(args[0].cast(), args[1].cast(), args[2].cast>(), args[3].cast()); }); diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index b9daa1deea7b..e1720cc0b6b0 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -32,41 +32,41 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.sum").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::sum(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.min").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.min").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::min(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.max").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.max").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::max(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.argmin").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.argmin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::argmin(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.argmax").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.argmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::argmax(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.prod").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.prod").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::prod(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.all").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.all").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::all(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.any").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.any").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::any(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); }); diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 2ab39be95cfa..cf86242d8491 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -37,54 +37,55 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.expand_dims").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.expand_dims").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = expand_dims(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.transpose").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.transpose").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = transpose(args[0].cast(), args[1].cast>>()); }); -TVM_REGISTER_GLOBAL("topi.flip").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.flip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // pass empty seq_lengths tensor to reverse_sequence *rv = reverse_sequence(args[0].cast(), Tensor(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.reverse_sequence").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = - reverse_sequence(args[0].cast(), args[1].cast(), args[2].cast()); -}); +TVM_REGISTER_GLOBAL("topi.reverse_sequence") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = reverse_sequence(args[0].cast(), args[1].cast(), + args[2].cast()); + }); -TVM_REGISTER_GLOBAL("topi.reshape").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.reshape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reshape(args[0].cast(), args[1].cast>()); }); -TVM_REGISTER_GLOBAL("topi.sliding_window").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sliding_window").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sliding_window(args[0].cast(), args[1].cast(), args[2].cast>(), args[3].cast>()); }); -TVM_REGISTER_GLOBAL("topi.squeeze").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.squeeze").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = squeeze(args[0].cast(), ArrayOrInt(args[1])); }); -TVM_REGISTER_GLOBAL("topi.concatenate").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.concatenate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = concatenate(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.stack").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.stack").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = stack(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.shape").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.shape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = shape(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = ndarray_size(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args[1].as()) { *rv = split_n_sections(args[0].cast(), args[1].cast(), args[2].cast()); } else { @@ -93,12 +94,13 @@ TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](TVMArgs args, TVMRetValue* } }); -TVM_REGISTER_GLOBAL("topi.layout_transform").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = layout_transform(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast()); -}); +TVM_REGISTER_GLOBAL("topi.layout_transform") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = layout_transform(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast()); + }); -TVM_REGISTER_GLOBAL("topi.take").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args.size() == 4) { auto mode = args[3].cast(); int batch_dims = args[2].cast(); @@ -113,52 +115,52 @@ TVM_REGISTER_GLOBAL("topi.take").set_body_packed([](TVMArgs args, TVMRetValue* r } }); -TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { double pad_val = args[2].cast(); int axis = args[3].cast(); *rv = sequence_mask(args[0].cast(), args[1].cast(), pad_val, axis); }); -TVM_REGISTER_GLOBAL("topi.where").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.where").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = where(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.arange").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.arange").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = arange(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.meshgrid").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.meshgrid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = meshgrid(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.repeat").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.repeat").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = repeat(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.tile").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.tile").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tile(args[0].cast(), args[1].cast>()); }); -TVM_REGISTER_GLOBAL("topi.gather").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.gather").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = gather(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.gather_nd").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.gather_nd").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int batch_dims = args[2].cast(); *rv = gather_nd(args[0].cast(), args[1].cast(), batch_dims); }); -TVM_REGISTER_GLOBAL("topi.unravel_index").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.unravel_index").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = unravel_index(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.sparse_to_dense").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.sparse_to_dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sparse_to_dense(args[0].cast(), args[1].cast>(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.matmul").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { switch (args.size()) { case 2: *rv = matmul(args[0].cast(), args[1].cast()); @@ -175,7 +177,7 @@ TVM_REGISTER_GLOBAL("topi.matmul").set_body_packed([](TVMArgs args, TVMRetValue* } }); -TVM_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args.size() == 2) { *rv = tensordot(args[0].cast(), args[1].cast()); } else if (args.size() == 3) { @@ -187,7 +189,7 @@ TVM_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](TVMArgs args, TVMRetVal } }); -TVM_REGISTER_GLOBAL("topi.strided_slice").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.strided_slice").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Tensor x = args[0].cast(); Array begin = args[1].cast>(); Array end = args[2].cast>(); @@ -215,7 +217,7 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body_packed([](TVMArgs args, TVMRe }); TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { te::Tensor begin = args[1].cast(); te::Tensor end = args[2].cast(); te::Tensor strides = args[3].cast(); @@ -228,7 +230,7 @@ TVM_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice") return relax::dynamic_strided_slice(x, begin, end, strides, output_shape); }); -TVM_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int depth = args[3].cast(); int axis = args[4].cast(); DataType dtype = args[5].cast(); @@ -236,7 +238,7 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](TVMArgs args, TVMRetValue depth, axis, dtype); }); -TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int k1 = args[2].cast(); int k2 = args[3].cast(); bool super_diag_right_align = args[4].cast(); diff --git a/src/topi/utils.cc b/src/topi/utils.cc index a4084b570c04..c02744a4202d 100644 --- a/src/topi/utils.cc +++ b/src/topi/utils.cc @@ -28,19 +28,20 @@ namespace tvm { namespace topi { -TVM_REGISTER_GLOBAL("topi.utils.is_empty_shape").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::detail::is_empty_shape(args[0].cast>()); -}); +TVM_REGISTER_GLOBAL("topi.utils.is_empty_shape") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::detail::is_empty_shape(args[0].cast>()); + }); TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nchw(args[0].cast(), args[1].cast>(), args[2].cast(), args[3].cast()); }); TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nhwc(args[0].cast(), args[1].cast>(), args[2].cast(), args[3].cast()); diff --git a/src/topi/vision.cc b/src/topi/vision.cc index 62644e584f70..dca44bf86c3c 100644 --- a/src/topi/vision.cc +++ b/src/topi/vision.cc @@ -31,7 +31,7 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = vision::reorg(args[0].cast(), args[1].cast()); }); diff --git a/tests/cpp-runtime/hexagon/run_all_tests.cc b/tests/cpp-runtime/hexagon/run_all_tests.cc index 0f9c1cb7b5f5..fa2a4aa45895 100644 --- a/tests/cpp-runtime/hexagon/run_all_tests.cc +++ b/tests/cpp-runtime/hexagon/run_all_tests.cc @@ -38,30 +38,31 @@ namespace tvm { namespace runtime { namespace hexagon { -TVM_REGISTER_GLOBAL("hexagon.run_all_tests").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - // gtest args are passed into this packed func as a singular string - // split gtest args using delimiter and build argument vector - std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); - std::vector argv; +TVM_REGISTER_GLOBAL("hexagon.run_all_tests") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + // gtest args are passed into this packed func as a singular string + // split gtest args using delimiter and build argument vector + std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); + std::vector argv; - // add executable name - argv.push_back(const_cast("hexagon_run_all_tests")); + // add executable name + argv.push_back(const_cast("hexagon_run_all_tests")); - // add parsed arguments - for (int i = 0; i < parsed_args.size(); ++i) { - argv.push_back(const_cast(parsed_args[i].data())); - } + // add parsed arguments + for (int i = 0; i < parsed_args.size(); ++i) { + argv.push_back(const_cast(parsed_args[i].data())); + } - // end of parsed arguments - argv.push_back(nullptr); + // end of parsed arguments + argv.push_back(nullptr); - // set argument count - int argc = argv.size() - 1; + // set argument count + int argc = argv.size() - 1; - // initialize gtest with arguments and run - ::testing::InitGoogleTest(&argc, argv.data()); - *rv = RUN_ALL_TESTS(); -}); + // initialize gtest with arguments and run + ::testing::InitGoogleTest(&argc, argv.data()); + *rv = RUN_ALL_TESTS(); + }); } // namespace hexagon } // namespace runtime diff --git a/tests/cpp-runtime/hexagon/run_unit_tests.cc b/tests/cpp-runtime/hexagon/run_unit_tests.cc index 59059fc803d1..d9331db28bee 100644 --- a/tests/cpp-runtime/hexagon/run_unit_tests.cc +++ b/tests/cpp-runtime/hexagon/run_unit_tests.cc @@ -80,42 +80,43 @@ class GtestPrinter : public testing::EmptyTestEventListener { std::string GetOutput() { return gtest_out_.str(); } }; -TVM_REGISTER_GLOBAL("hexagon.run_unit_tests").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - // gtest args are passed into this packed func as a singular string - // split gtest args using delimiter and build argument vector - std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); - std::vector argv; - - // add executable name - argv.push_back(const_cast("hexagon_run_unit_tests")); - - // add parsed arguments - for (int i = 0; i < parsed_args.size(); ++i) { - argv.push_back(const_cast(parsed_args[i].data())); - } - - // end of parsed arguments - argv.push_back(nullptr); - - // set argument count - int argc = argv.size() - 1; - - // initialize gtest with arguments and run - ::testing::InitGoogleTest(&argc, argv.data()); - - // add printer to capture gtest output in a string - GtestPrinter* gprinter = new GtestPrinter(); - testing::TestEventListeners& listeners = testing::UnitTest::GetInstance()->listeners(); - listeners.Append(gprinter); - - int gtest_error_code = RUN_ALL_TESTS(); - std::string gtest_output = gprinter->GetOutput(); - std::stringstream gtest_error_code_and_output; - gtest_error_code_and_output << gtest_error_code << std::endl; - gtest_error_code_and_output << gtest_output; - *rv = gtest_error_code_and_output.str(); - delete gprinter; -}); +TVM_REGISTER_GLOBAL("hexagon.run_unit_tests") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + // gtest args are passed into this packed func as a singular string + // split gtest args using delimiter and build argument vector + std::vector parsed_args = tvm::support::Split(args[0].cast(), ' '); + std::vector argv; + + // add executable name + argv.push_back(const_cast("hexagon_run_unit_tests")); + + // add parsed arguments + for (int i = 0; i < parsed_args.size(); ++i) { + argv.push_back(const_cast(parsed_args[i].data())); + } + + // end of parsed arguments + argv.push_back(nullptr); + + // set argument count + int argc = argv.size() - 1; + + // initialize gtest with arguments and run + ::testing::InitGoogleTest(&argc, argv.data()); + + // add printer to capture gtest output in a string + GtestPrinter* gprinter = new GtestPrinter(); + testing::TestEventListeners& listeners = testing::UnitTest::GetInstance()->listeners(); + listeners.Append(gprinter); + + int gtest_error_code = RUN_ALL_TESTS(); + std::string gtest_output = gprinter->GetOutput(); + std::stringstream gtest_error_code_and_output; + gtest_error_code_and_output << gtest_error_code << std::endl; + gtest_error_code_and_output << gtest_output; + *rv = gtest_error_code_and_output.str(); + delete gprinter; + }); } // namespace hexagon } // namespace runtime diff --git a/tests/python/contrib/test_hexagon/README_RPC.md b/tests/python/contrib/test_hexagon/README_RPC.md index 652237d4f53c..28300dfdea4e 100644 --- a/tests/python/contrib/test_hexagon/README_RPC.md +++ b/tests/python/contrib/test_hexagon/README_RPC.md @@ -80,12 +80,12 @@ Which eventually jumps to the following line in C++, which creates a RPC client [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129) ```cpp -TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto url = args[0].cast(); int port = args[1].cast(); auto key = args[2].cast(); *rv = RPCClientConnect(url, port, key, - TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); + ffi::PackedArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); }); ``` @@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](TVMArgs args, TVMRetValue* ```cpp TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") - .set_body_packed([](TVMArgs args, TVMRetValue* rv) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto session_name = args[0].cast(); int remote_stack_size_bytes = args[1].cast(); HexagonTransportChannel* hexagon_channel = @@ -178,7 +178,7 @@ At first, it is not obvious where this `CopyDataFromTo` jumps to (initially I th [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L107](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L107) ```cpp -Module RPCClientConnect(std::string url, int port, std::string key, TVMArgs init_seq) { +Module RPCClientConnect(std::string url, int port, std::string key, ffi::PackedArgs init_seq) { auto endpt = RPCConnect(url, port, "client:" + key, init_seq); return CreateRPCSessionModule(CreateClientSession(endpt)); } @@ -228,7 +228,7 @@ The handler is passed to the following function [https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L909-L922](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L909-L922) ```cpp -void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { +void RPCCopyAmongRemote(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { auto from = args[0].cast(); auto to = args[1].cast(); ... diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index fd6265f6c089..e50e6c37d34c 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -20,7 +20,7 @@ /* * \file tvmjs_support.cc * \brief Support functions to be linked with wasm_runtime to provide - * PackedFunc callbacks in tvmjs. + * ffi::Function callbacks in tvmjs. * We do not need to link this file in standalone wasm. */ @@ -53,9 +53,9 @@ TVM_DLL void* TVMWasmAllocSpace(int size); TVM_DLL void TVMWasmFreeSpace(void* data); /*! - * \brief Create PackedFunc from a resource handle. + * \brief Create ffi::Function from a resource handle. * \param resource_handle The handle to the resource. - * \param out The output PackedFunc. + * \param out The output ffi::Function. * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer 3A * \return 0 if success. */ @@ -133,12 +133,12 @@ class AsyncLocalSession : public LocalSession { void AsyncCallFunc(PackedFuncHandle func, ffi::PackedArgs args, FAsyncCallback callback) final { auto it = async_func_set_.find(func); if (it != async_func_set_.end()) { - PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) { + ffi::Function packed_callback([callback, this](ffi::PackedArgs args, ffi::Any*) { int code = args[0].cast(); - TVMRetValue rv; + ffi::Any rv; rv = args[1]; if (code == static_cast(RPCCode::kReturn)) { - this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { + this->EncodeReturn(std::move(rv), [&](ffi::PackedArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); }); } else { @@ -157,13 +157,13 @@ class AsyncLocalSession : public LocalSession { } else if (func == get_time_eval_placeholder_.get()) { // special handle time evaluator. try { - PackedFunc retfunc = this->GetTimeEvaluator( + ffi::Function retfunc = this->GetTimeEvaluator( args[0].cast>(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast(), args[6].cast(), args[7].cast(), args[8].cast(), args[9].cast()); - TVMRetValue rv; + ffi::Any rv; rv = retfunc; - this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { + this->EncodeReturn(std::move(rv), [&](ffi::PackedArgs encoded_args) { const void* pf = encoded_args[0].as(); ICHECK(pf != nullptr); // mark as async. @@ -225,7 +225,7 @@ class AsyncLocalSession : public LocalSession { async_wait_ = tvm::ffi::Function::GetGlobal("__async.wasm.WebGPUWaitForTasks"); } CHECK(async_wait_.has_value()); - PackedFunc packed_callback([on_complete](TVMArgs args, TVMRetValue*) { + ffi::Function packed_callback([on_complete](ffi::PackedArgs args, ffi::Any*) { int code = args[0].cast(); on_complete(static_cast(code), args.Slice(1)); }); @@ -237,14 +237,14 @@ class AsyncLocalSession : public LocalSession { private: std::unordered_set async_func_set_; - std::unique_ptr get_time_eval_placeholder_ = std::make_unique(); - std::optional async_wait_; + std::unique_ptr get_time_eval_placeholder_ = std::make_unique(); + std::optional async_wait_; // time evaluator - PackedFunc GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, - int device_id, int number, int repeat, int min_repeat_ms, - int limit_zero_time_iterations, int cooldown_interval_ms, - int repeats_to_cooldown) { + ffi::Function GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, + int device_id, int number, int repeat, int min_repeat_ms, + int limit_zero_time_iterations, int cooldown_interval_ms, + int repeats_to_cooldown) { Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; @@ -265,29 +265,29 @@ class AsyncLocalSession : public LocalSession { } // time evaluator - PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, - int min_repeat_ms, int limit_zero_time_iterations, - int cooldown_interval_ms, int repeats_to_cooldown) { + ffi::Function WrapWasmTimeEvaluator(ffi::Function pf, Device dev, int number, int repeat, + int min_repeat_ms, int limit_zero_time_iterations, + int cooldown_interval_ms, int repeats_to_cooldown) { auto ftimer = [pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, - cooldown_interval_ms, repeats_to_cooldown](TVMArgs args, TVMRetValue* rv) { + cooldown_interval_ms, repeats_to_cooldown](ffi::PackedArgs args, ffi::Any* rv) { // the function is a async function. - PackedFunc on_complete = args[args.size() - 1].cast(); + ffi::Function on_complete = args[args.size() - 1].cast(); std::vector packed_args(args.data(), args.data() + args.size() - 1); auto finvoke = [pf, packed_args](int n) { - TVMRetValue temp; - TVMArgs invoke_args(packed_args.data(), packed_args.size()); + ffi::Any temp; + ffi::PackedArgs invoke_args(packed_args.data(), packed_args.size()); for (int i = 0; i < n; ++i) { pf.CallPacked(invoke_args, &temp); } }; auto time_exec = tvm::ffi::Function::GetGlobal("__async.wasm.TimeExecution"); CHECK(time_exec.has_value()) << "Cannot find wasm.GetTimer in the global function"; - (*time_exec)(TypedPackedFunc(finvoke), dev, number, repeat, min_repeat_ms, + (*time_exec)(ffi::TypedFunction(finvoke), dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, /*cache_flush_bytes=*/0, on_complete); }; - return PackedFunc(ftimer); + return ffi::Function(ftimer); } }; diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 1ffd0ed29cdc..b8ebadff4f5c 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -85,7 +85,7 @@ int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_ta int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { return 0; } -// --- Environment PackedFuncs for testing --- +// --- Environment ffi::Functions for testing --- namespace tvm { namespace runtime { namespace detail { @@ -107,40 +107,44 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = args[0]; }); -TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - (args[0].cast()).CallPacked(args.Slice(1), ret); +TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + (args[0].cast()).CallPacked(args.Slice(1), ret); }); -TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = args[0].cast(); }); -TVM_REGISTER_GLOBAL("testing.log_info_str").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - LOG(INFO) << args[0].cast(); -}); +TVM_REGISTER_GLOBAL("testing.log_info_str") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + LOG(INFO) << args[0].cast(); + }); -TVM_REGISTER_GLOBAL("testing.log_fatal_str").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - LOG(FATAL) << args[0].cast(); -}); +TVM_REGISTER_GLOBAL("testing.log_fatal_str") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + LOG(FATAL) << args[0].cast(); + }); TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; }); -TVM_REGISTER_GLOBAL("testing.wrap_callback").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - PackedFunc pf = args[0].cast(); - *ret = runtime::TypedPackedFunc([pf]() { pf(); }); -}); +TVM_REGISTER_GLOBAL("testing.wrap_callback") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + ffi::Function pf = args[0].cast(); + *ret = ffi::TypedFunction([pf]() { pf(); }); + }); // internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.object_use_count").set_body_packed([](TVMArgs args, TVMRetValue* ret) { - auto obj = args[0].cast(); - // subtract the current one because we always copy - // and get another value. - *ret = (obj.use_count() - 1); -}); +TVM_REGISTER_GLOBAL("testing.object_use_count") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto obj = args[0].cast(); + // subtract the current one because we always copy + // and get another value. + *ret = (obj.use_count() - 1); + }); void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) { if (format == "f32-to-bf16" && dtype == "float32") { @@ -167,7 +171,7 @@ TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStor // Concatenate n TVMArrays TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { std::vector data; for (int i = 0; i < args.size(); ++i) { // Get i-th TVMArray @@ -217,7 +221,7 @@ NDArray ConcatEmbeddings(const std::vector& embeddings) { // Concatenate n NDArrays TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") - .set_body_packed([](TVMArgs args, TVMRetValue* ret) { + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { std::vector embeddings; for (int i = 0; i < args.size(); ++i) { embeddings.push_back(args[i].cast()); diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 1aafc272c385..3d74d77f14ce 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -60,7 +60,7 @@ class WebGPUDeviceAPI : public DeviceAPI { WebGPUDeviceAPI() { auto fp = tvm::ffi::Function::GetGlobal("wasm.WebGPUDeviceAPI"); CHECK(fp.has_value()) << "Cannot find wasm.WebGPUContext in the env"; - auto getter = TypedPackedFunc(*fp); + auto getter = ffi::TypedFunction(*fp); alloc_space_ = getter("deviceAllocDataSpace"); free_space_ = getter("deviceFreeDataSpace"); copy_to_gpu_ = getter("deviceCopyToGPU"); @@ -69,7 +69,7 @@ class WebGPUDeviceAPI : public DeviceAPI { } void SetDevice(Device dev) final {} - void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { + void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final { if (kind == kExist) { *rv = 1; } @@ -137,12 +137,13 @@ class WebGPUDeviceAPI : public DeviceAPI { private: // NOTE: js return number as double. - TypedPackedFunc alloc_space_; - TypedPackedFunc free_space_; - TypedPackedFunc copy_to_gpu_; - TypedPackedFunc copy_from_gpu_; - TypedPackedFunc + ffi::TypedFunction alloc_space_; + ffi::TypedFunction free_space_; + ffi::TypedFunction copy_to_gpu_; + ffi::TypedFunction + copy_from_gpu_; + ffi::TypedFunction copy_within_gpu_; }; @@ -165,26 +166,26 @@ class WebGPUModuleNode final : public runtime::ModuleNode { const char* type_key() const final { return "webgpu"; } - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { // special function if (name == "webgpu.get_fmap") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { std::ostringstream os; dmlc::JSONWriter writer(&os); writer.Write(fmap_); *rv = os.str(); }); } else if (name == "webgpu.get_shader") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { auto name = args[0].cast(); auto it = smap_.find(name); ICHECK(it != smap_.end()) << "Cannot find code " << name; *rv = it->second; }); } else if (name == "webgpu.update_prebuild") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { auto name = args[0].cast(); - PackedFunc func = args[1].cast(); + ffi::Function func = args[1].cast(); prebuild_[name] = func; }); } @@ -203,7 +204,7 @@ class WebGPUModuleNode final : public runtime::ModuleNode { info.Save(&writer); return create_shader_(os.str(), it->second); } else { - return PackedFunc(nullptr); + return ffi::Function(nullptr); } } @@ -224,9 +225,9 @@ class WebGPUModuleNode final : public runtime::ModuleNode { // The source std::string source_; // prebuild_ functions - std::unordered_map prebuild_; + std::unordered_map prebuild_; // Callback to get the GPU function. - TypedPackedFunc create_shader_; + ffi::TypedFunction create_shader_; }; Module WebGPUModuleLoadBinary(void* strm) { @@ -242,7 +243,7 @@ Module WebGPUModuleLoadBinary(void* strm) { // for now webgpu is hosted via a vulkan module. TVM_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); -TVM_REGISTER_GLOBAL("device_api.webgpu").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.webgpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = WebGPUDeviceAPI::Global(); *rv = static_cast(ptr); });