diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst index 09867f88fa36..6015c4351076 100644 --- a/docs/arch/device_target_interactions.rst +++ b/docs/arch/device_target_interactions.rst @@ -153,7 +153,10 @@ then be registered with the following steps. #. Register the function to the tvm registry:: - TVM_FFI_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global); + TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("device_api.foo", FooDeviceAPI::Global); + }); .. _base.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h @@ -164,7 +167,7 @@ then be registered with the following steps. #. Add a case in ``DeviceName`` in `device_api.h`_ to convert from the enum value to a string representation. This string representation - should match the name given to ``TVM_FFI_REGISTER_GLOBAL``. + should match the name given to ``GlobalDef().def``. #. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of :py:class:`tvm.runtime.Device` for the new enum value. @@ -225,7 +228,10 @@ the same name as was used in the ``TVM_REGISTER_TARGET_KIND`` definition above. :: tvm::runtime::Module GeneratorFooCode(IRModule mod, Target target); - TVM_FFI_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode); + TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.foo", GeneratorFooCode); + }); The code generator takes two arguments. The first is the ``IRModule`` to compile, and the second is the ``Target`` that describes the device diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index c54ba18b0add..4bf3abceb0ca 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -376,8 +376,10 @@ Python when needed. return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } - TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant") - .set_body_typed(FoldConstant); + TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant); + }); } // namespace transform diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index 613c7d86e19e..8c2a9a0995f9 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -80,8 +80,10 @@ The following example registers PackedFunc in C++ and calls from python. .. code:: c // register a global packed function in c++ - TVM_FFI_REGISTER_GLOBAL("myadd") - .set_body_packed(MyAdd); + TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("myadd", MyAdd); + }); .. code:: python @@ -110,10 +112,12 @@ we can pass functions from python (as PackedFunc) to C++. .. code:: c - TVM_FFI_REGISTER_GLOBAL("callhello") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - PackedFunc f = args[0]; - f("hello world"); + TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("callhello", [](ffi::PackedArgs args, ffi::Any* rv) { + ffi::Function f = args[0].cast(); + f("hello world"); + }); }); .. code:: python diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 34d994b64cc5..5a30f25a7b5b 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -756,135 +756,6 @@ struct TypeTraits> : public TypeTraitsBase { TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo::Sig(); } }; -/*! \brief Registry for global function */ -class Function::Registry { - public: - /*! \brief constructor */ - explicit Registry(const char* name) : name_(name) {} - - /*! - * \brief Set body to be to use the packed convention. - * - * \tparam FLambda The signature of the function. - * \param f The body of the function. - */ - template - Registry& set_body_packed(FLambda f) { - return Register(ffi::Function::FromPacked(f)); - } - /*! - * \brief set the body of the function to the given function. - * Note that this will ignore default arg values and always require all arguments to be - * provided. - * - * \code - * - * int multiply(int x, int y) { - * return x * y; - * } - * - * TVM_FFI_REGISTER_GLOBAL("multiply") - * .set_body_typed(multiply); // will have type int(int, int) - * - * // will have type int(int, int) - * TVM_FFI_REGISTER_GLOBAL("sub") - * .set_body_typed([](int a, int b) -> int { return a - b; }); - * - * \endcode - * - * \param f The function to forward to. - * \tparam FLambda The signature of the function. - */ - template - Registry& set_body_typed(FLambda f) { - return Register(Function::FromTyped(f, name_)); - } - - /*! - * \brief set the body of the function to be the passed method pointer. - * Note that this will ignore default arg values and always require all arguments to be - * provided. - * - * \code - * - * // objectRef subclass: - * struct Example : ObjectRef { - * int DoThing(int x); - * } - * TVM_FFI_REGISTER_GLOBAL("Example_DoThing") - * .set_body_method(&Example::DoThing); // will have type int(self, int) - * - * // Object subclass: - * struct Example : Object { - * int DoThing(int x); - * } - * - * TVM_FFI_REGISTER_GLOBAL("Example_DoThing") - * .set_body_method(&Example::DoThing); // will have type int(self, int) - * - * \endcode - * - * \param f the method pointer to forward to. - * \tparam T the type containing the method (inferred). - * \tparam R the return type of the function (inferred). - * \tparam Args the argument types of the function (inferred). - */ - template - Registry& set_body_method(R (T::*f)(Args...)) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "T must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [f](T target, Args... params) -> R { - // call method pointer - return (target.*f)(std::forward(params)...); - }; - return Register(ffi::Function::FromTyped(fwrap, name_)); - } - if constexpr (std::is_base_of_v) { - auto fwrap = [f](const T* target, Args... params) -> R { - // call method pointer - return (const_cast(target)->*f)(std::forward(params)...); - }; - return Register(ffi::Function::FromTyped(fwrap, name_)); - } - return *this; - } - - template - Registry& set_body_method(R (T::*f)(Args...) const) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "T must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [f](const T target, Args... params) -> R { - // call method pointer - return (target.*f)(std::forward(params)...); - }; - return Register(ffi::Function::FromTyped(fwrap, name_)); - } - if constexpr (std::is_base_of_v) { - auto fwrap = [f](const T* target, Args... params) -> R { - // call method pointer - return (target->*f)(std::forward(params)...); - }; - return Register(ffi::Function::FromTyped(fwrap, name_)); - } - return *this; - } - - protected: - /*! - * \brief set the body of the function to be f - * \param f The body of the function. - */ - Registry& Register(Function f) { - Function::SetGlobal(name_, f); - return *this; - } - - /*! \brief name of the function */ - const char* name_; -}; - /*! * \brief helper function to get type index from key */ @@ -895,21 +766,6 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) { return type_index; } -#define TVM_FFI_FUNC_REG_VAR_DEF \ - TVM_FFI_ATTRIBUTE_UNUSED static inline ::tvm::ffi::Function::Registry& __##TVMFFIFuncReg - -/*! - * \brief Register a function globally. - * \code - * TVM_FFI_REGISTER_GLOBAL("MyAdd") - * .set_body_typed([](int a, int b) { - * return a + b; - * }); - * \endcode - */ -#define TVM_FFI_REGISTER_GLOBAL(OpName) \ - TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::ffi::Function::Registry(OpName) - /*! * \brief Export typed function as a SafeCallType symbol. * diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index c43543cc3863..cf467870c60c 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -134,8 +134,11 @@ class Timer : public ObjectRef { * }; * TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); * - * TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { - * return Timer(make_object()); + * TVM_FFI_STATIC_INIT_BLOCK({ + * namespace refl = tvm::ffi::reflection; + * refl::GlobalDef().def("profiling.timer.cpu", [](Device dev) { + * return Timer(make_object()); + * }); * }); * \endcode */ diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py b/python/tvm/contrib/msc/plugin/codegen/sources.py index b507d7b82557..a4e89ad7ecd2 100644 --- a/python/tvm/contrib/msc/plugin/codegen/sources.py +++ b/python/tvm/contrib/msc/plugin/codegen/sources.py @@ -684,6 +684,17 @@ class TVMUtils { return cuda_dev; } }; + +#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF(FuncName, Body) \ + TVM_FFI_STATIC_INIT_BLOCK({ \ + tvm::ffi::reflection::GlobalDef().def(FuncName, Body); \ + }) + +#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED(FuncName, Body) \ + TVM_FFI_STATIC_INIT_BLOCK({ \ + tvm::ffi::reflection::GlobalDef().def_packed(FuncName, Body); \ + }) + #endif // PLUGIN_SUPPORT_TVM """ @@ -1101,6 +1112,7 @@ def get_plugin_utils_h_code() -> str: #ifdef PLUGIN_SUPPORT_TVM #include +#include #include "tvm/../../src/contrib/msc/core/transform/layout_utils.h" #include "tvm/../../src/contrib/msc/core/utils.h" diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index ab416ef14176..1e42c862fee6 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2087,7 +2087,7 @@ def extern( out: OutType, ) -> OutType: """Invoke an extern function during runtime. The extern function must be registered with the " - TVM runtime using `TVM_FFI_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python). + TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_func` (Python). Parameters ---------- diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py index 71f96983ee18..88a49f3a63d9 100644 --- a/python/tvm/runtime/_ffi_api.py +++ b/python/tvm/runtime/_ffi_api.py @@ -17,6 +17,5 @@ """FFI APIs for tvm.runtime""" import tvm.ffi -# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "runtime" prefix. -# e.g. TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") +# Exports functions registered in runtime namespace. tvm.ffi._init_api("runtime", __name__) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 493dfceab59e..aef9ded9cc0d 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -24,7 +24,6 @@ # The implementations below are default ones when the corresponding # functions are not available in the runtime only mode. # They will be overriden via _init_api to the ones registered -# via TVM_FFI_REGISTER_GLOBAL in the compiler mode. def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" @@ -45,6 +44,5 @@ def LoadJSON(json_str): raise RuntimeError("Do not support object serialization in runtime only mode") -# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "node" prefix. -# e.g. TVM_FFI_REGISTER_GLOBAL("node.AsRepr") +# Exports functions registered in node namespace. tvm.ffi._init_api("node", __name__) diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 08cd3d7da6e0..a3861aabe75e 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -215,14 +215,12 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { stack_.func_end("infer_output"); // register funcs - stack_.func_call("TVM_FFI_REGISTER_GLOBAL") + stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF") .call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" + plugin->name)) - .method_call("set_body_typed") .call_arg("InferStructInfo" + plugin->name) .line() - .func_call("TVM_FFI_REGISTER_GLOBAL") + .func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF") .call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name)) - .method_call("set_body_typed") .call_arg("InferLayout" + plugin->name) .line(); } @@ -262,9 +260,8 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { CodeGenCompute(plugin, "cpu"); stack_.cond_end().func_end(); // register the compute - stack_.func_call("TVM_FFI_REGISTER_GLOBAL") + stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED") .call_arg(DocUtils::ToStr(plugin->name)) - .method_call("set_body") .call_arg(func_name) .line(); } diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index d7d50f8fa714..4da8b18fcb13 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -181,7 +181,8 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c static const Op& op = Op::Get("relax." OpRegName); \ return Call(op, {std::move(x)}, Attrs(), {}); \ } \ - TVM_FFI_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) + TVM_FFI_STATIC_INIT_BLOCK( \ + { tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName); }) /************ Utilities ************/ diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index ae36d45b3683..f612ec0598a9 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -42,7 +42,8 @@ namespace relax { static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {x1, x2}, Attrs(), {}); \ } \ - TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_STATIC_INIT_BLOCK( \ + { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(2) \ .add_argument("x1", "Tensor", "The first input tensor.") \ diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index a3d17be1359b..6808acdedf3a 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -256,7 +256,8 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs(attrs)); \ } \ - TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_STATIC_INIT_BLOCK( \ + { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index 331562454efe..e79ce1d4aeaa 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -50,7 +50,8 @@ namespace relax { static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ } \ - TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_STATIC_INIT_BLOCK( \ + { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 3eb817fd0f3b..06eb0284f9d0 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -193,19 +193,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -// set device api -TVM_FFI_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { - DLDevice dev; - dev.device_type = static_cast(args[0].cast()); - dev.device_id = args[1].cast(); - DeviceAPIManager::Get(dev)->SetDevice(dev); - }); - // set device api TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() + .def_packed(tvm::runtime::symbol::tvm_set_device, + [](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + DLDevice dev; + dev.device_type = static_cast(args[0].cast()); + dev.device_id = args[1].cast(); + DeviceAPIManager::Get(dev)->SetDevice(dev); + }) .def_packed("runtime.GetDeviceAttr", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { DLDevice dev; diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index f3f79c9ccc55..9e41bbd0deb5 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -331,64 +331,49 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("runtime.disco.compiled_ccl", []() -> String { return TVM_DISCO_CCL_NAME; }); + .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl", InitCCL) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker", InitCCLPerWorker) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce", + [](NDArray send, int kind, bool in_group, NDArray recv) { + CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; + nccl::AllReduce(send, static_cast(kind), in_group, recv); + }) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".allgather", + [](NDArray send, bool in_group, NDArray recv) { nccl::AllGather(send, in_group, recv); }) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0", BroadcastFromWorker0) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0", ScatterFromWorker0) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0", GatherToWorker0) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0", RecvFromWorker0) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group", SendToNextGroup) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group", RecvFromPrevGroup) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker", SendToWorker) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker", RecvFromWorker) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker", SyncWorker) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".test_send_to_next_group_recv_from_prev_group", + [](NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; + CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + int group_id = ctx->worker->worker_id / group_size; + if (group_id == 0) { + tvm::runtime::nccl::SendToNextGroup(buffer); + } else { + tvm::runtime::nccl::RecvFromPrevGroup(buffer); + } + }) + .def("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0", + [](NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; + CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + if (ctx->worker->worker_id == 2) { + tvm::runtime::nccl::SendToWorker(buffer, 0); + } else if (ctx->worker->worker_id == 0) { + tvm::runtime::nccl::RecvFromWorker(buffer, 2); + } + }); }); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") - .set_body_typed(InitCCLPerWorker); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce") - .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) { - CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; - nccl::AllReduce(send, static_cast(kind), in_group, recv); - }); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather") - .set_body_typed([](NDArray send, bool in_group, NDArray recv) { - nccl::AllGather(send, in_group, recv); - }); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0") - .set_body_typed(BroadcastFromWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0") - .set_body_typed(ScatterFromWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0") - .set_body_typed(GatherToWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0") - .set_body_typed(RecvFromWorker0); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group") - .set_body_typed(SendToNextGroup); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group") - .set_body_typed(RecvFromPrevGroup); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker") - .set_body_typed(SendToWorker); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker") - .set_body_typed(RecvFromWorker); -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker") - .set_body_typed(SyncWorker); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME - ".test_send_to_next_group_recv_from_prev_group") - .set_body_typed([](NDArray buffer) { - CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; - CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; - int group_size = ctx->worker->num_workers / ctx->worker->num_groups; - int group_id = ctx->worker->worker_id / group_size; - if (group_id == 0) { - tvm::runtime::nccl::SendToNextGroup(buffer); - } else { - tvm::runtime::nccl::RecvFromPrevGroup(buffer); - } - }); - -TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0") - .set_body_typed([](NDArray buffer) { - CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; - CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; - if (ctx->worker->worker_id == 2) { - tvm::runtime::nccl::SendToWorker(buffer, 0); - } else if (ctx->worker->worker_id == 0) { - tvm::runtime::nccl::RecvFromWorker(buffer, 2); - } - }); } // namespace nccl } // namespace runtime diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 998b468f4bc0..e8c8d62c9b23 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -760,107 +760,117 @@ TVM_FFI_STATIC_INIT_BLOCK({ #define TVM_TMP_STR(x) #x -#define TVM_FFI_REGISTER_GLOBAL_SIZE(Prefix, DType) \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64); - -TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float); -TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt); -TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int); - -#define TVM_FFI_REGISTER_GLOBAL_LANES(Prefix, Func) \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \ - TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64); - -#define TVM_FFI_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \ - TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \ - TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \ - TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \ - TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64); - -TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); -TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); -TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); +#define TVM_FFI_REFL_DEF_GLOBAL_SIZE(Prefix, DType) \ + def(Prefix TVM_TMP_STR(8), DType##8) \ + .def(Prefix TVM_TMP_STR(16), DType##16) \ + .def(Prefix TVM_TMP_STR(32), DType##32) \ + .def(Prefix TVM_TMP_STR(64), DType##64) + +#define TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix, Func) \ + def(Prefix TVM_TMP_STR(x4), Func##x4) \ + .def(Prefix TVM_TMP_STR(x8), Func##x8) \ + .def(Prefix TVM_TMP_STR(x16), Func##x16) \ + .def(Prefix TVM_TMP_STR(x32), Func##x32) \ + .def(Prefix TVM_TMP_STR(x64), Func##x64) + +#define TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES(Prefix, DType) \ + TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8) \ + .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16) \ + .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32) \ + .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64) TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.BFloat16", BFloat16); + refl::GlobalDef() + .def("script.ir_builder.tir.BFloat16", BFloat16) + .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.Float", Float) + .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt) + .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.Int", Int) + .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float) + .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt) + .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); // Float8 variants TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E3M4", Float8E3M4); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E3M4", Float8E3M4) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4", Float8E3M4); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4", Float8E3M4); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3", Float8E4M3); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E4M3", Float8E4M3) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3", Float8E4M3); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3", Float8E4M3); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E5M2", Float8E5M2); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E5M2", Float8E5M2) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); + refl::GlobalDef() + .def("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); // Float6 variants TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); + refl::GlobalDef() + .def("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); + refl::GlobalDef() + .def("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); // Float4 variant TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); + refl::GlobalDef() + .def("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); }); -TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index b1a1a4a7f50b..363494e0fda4 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -37,7 +37,7 @@ namespace datatype { * directly---see the TVM globals registered in the corresponding .cc file. * Currently, user should manually choose a type name and a type code, * ensuring that neither conflict with existing types. - * 2. Use TVM_FFI_REGISTER_GLOBAL to register the lowering functions needed to + * 2. Register the lowering functions needed to * lower the custom datatype. In general, these will look like: * For Casts: tvm.datatype.lower..Cast.. * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 03b22de8382e..9ced6f556cb0 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -255,7 +255,10 @@ PrimExpr thread_return(Span span) { return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span); } -TVM_FFI_REGISTER_GLOBAL("tir.thread_return").set_body_typed(thread_return); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.thread_return", thread_return); +}); // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { @@ -1158,54 +1161,22 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ - return (Func(a, b, span)); \ +#define DEF_MAKE_BINARY_OP(Node, Func) \ + def("tir." #Node, [](PrimExpr a, PrimExpr b, Span span) { return (Func(a, b, span)); }) + +#define DEF_MAKE_BIT_OP(Node, Func) \ + def_packed("tir." #Node, [](ffi::PackedArgs args, ffi::Any* ret) { \ + bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ + bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } else { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } \ }) -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { \ - bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ - bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } else { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } \ - }) - -REGISTER_MAKE_BINARY_OP(_OpAdd, add); -REGISTER_MAKE_BINARY_OP(_OpSub, sub); -REGISTER_MAKE_BINARY_OP(_OpMul, mul); -REGISTER_MAKE_BINARY_OP(_OpDiv, div); -REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); -REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv); -REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod); -REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); -REGISTER_MAKE_BINARY_OP(_OpLogAddExp, logaddexp); -REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); -REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); -REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod); -REGISTER_MAKE_BINARY_OP(_OpCeilDiv, ceildiv); -REGISTER_MAKE_BINARY_OP(_OpPow, pow); -REGISTER_MAKE_BINARY_OP(_OpMin, min); -REGISTER_MAKE_BINARY_OP(_OpMax, max); -REGISTER_MAKE_BINARY_OP(_OpEQ, equal); -REGISTER_MAKE_BINARY_OP(_OpNE, not_equal); -REGISTER_MAKE_BINARY_OP(_OpLT, less); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpLE, less_equal); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpGT, greater); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpGE, greater_equal); -REGISTER_MAKE_BINARY_OP(_OpAnd, logical_and); -REGISTER_MAKE_BINARY_OP(_OpOr, logical_or); -REGISTER_MAKE_BIT_OP(bitwise_and, bitwise_and); -REGISTER_MAKE_BIT_OP(bitwise_or, bitwise_or); -REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor); -REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*) -REGISTER_MAKE_BIT_OP(right_shift, right_shift); - TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() @@ -1213,7 +1184,36 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { return if_then_else(cond, true_value, false_value, span); }) - .def("tir.const_true", [](DataType t, Span span) { return const_true(t.lanes(), span); }); + .def("tir.const_true", [](DataType t, Span span) { return const_true(t.lanes(), span); }) + .DEF_MAKE_BINARY_OP(_OpAdd, add) + .DEF_MAKE_BINARY_OP(_OpSub, sub) + .DEF_MAKE_BINARY_OP(_OpMul, mul) + .DEF_MAKE_BINARY_OP(_OpDiv, div) + .DEF_MAKE_BINARY_OP(_OpMod, truncmod) + .DEF_MAKE_BINARY_OP(_OpIndexDiv, indexdiv) + .DEF_MAKE_BINARY_OP(_OpIndexMod, indexmod) + .DEF_MAKE_BINARY_OP(_OpFloorDiv, floordiv) + .DEF_MAKE_BINARY_OP(_OpLogAddExp, logaddexp) + .DEF_MAKE_BINARY_OP(_OpFloorMod, floormod) + .DEF_MAKE_BINARY_OP(_OpTruncDiv, truncdiv) + .DEF_MAKE_BINARY_OP(_OpTruncMod, truncmod) + .DEF_MAKE_BINARY_OP(_OpCeilDiv, ceildiv) + .DEF_MAKE_BINARY_OP(_OpPow, pow) + .DEF_MAKE_BINARY_OP(_OpMin, min) + .DEF_MAKE_BINARY_OP(_OpMax, max) + .DEF_MAKE_BINARY_OP(_OpEQ, equal) + .DEF_MAKE_BINARY_OP(_OpNE, not_equal) + .DEF_MAKE_BINARY_OP(_OpLT, less) // NOLINT(*) + .DEF_MAKE_BINARY_OP(_OpLE, less_equal) // NOLINT(*) + .DEF_MAKE_BINARY_OP(_OpGT, greater) // NOLINT(*) + .DEF_MAKE_BINARY_OP(_OpGE, greater_equal) + .DEF_MAKE_BINARY_OP(_OpAnd, logical_and) + .DEF_MAKE_BINARY_OP(_OpOr, logical_or) + .DEF_MAKE_BIT_OP(bitwise_and, bitwise_and) + .DEF_MAKE_BIT_OP(bitwise_or, bitwise_or) + .DEF_MAKE_BIT_OP(bitwise_xor, bitwise_xor) + .DEF_MAKE_BIT_OP(left_shift, left_shift) // NOLINT(*) + .DEF_MAKE_BIT_OP(right_shift, right_shift); }); PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) { diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 99b43b82e9fc..1ca901c6fbf5 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -32,52 +32,53 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_FFI_REGISTER_GLOBAL(OpName).set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { \ - bool lhs_is_tensor = args[0].as().has_value(); \ - bool rhs_is_tensor = args[1].as().has_value(); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } \ - }); - -TOPI_REGISTER_BCAST_OP("topi.add", topi::add); -TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); -TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply); -TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide); -TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide); -TOPI_REGISTER_BCAST_OP("topi.log_add_exp", topi::log_add_exp); -TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod); -TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod); -TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum); -TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum); -TOPI_REGISTER_BCAST_OP("topi.power", topi::power); -TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); -TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); -TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); -TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor); -TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and); -TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or); -TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor); -TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); -TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); -TOPI_REGISTER_BCAST_OP("topi.less", topi::less); -TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal); -TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); -TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); -TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); +#define TOPI_DEF_BCAST_OP(OpName, Op) \ + def_packed(OpName, [](ffi::PackedArgs args, ffi::Any* rv) { \ + bool lhs_is_tensor = args[0].as().has_value(); \ + bool rhs_is_tensor = args[1].as().has_value(); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } \ + }) TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("topi.broadcast_to", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = broadcast_to(args[0].cast(), args[1].cast>()); - }); + refl::GlobalDef() + .def_packed("topi.broadcast_to", + [](ffi::PackedArgs args, ffi::Any* rv) { + *rv = broadcast_to(args[0].cast(), args[1].cast>()); + }) + .TOPI_DEF_BCAST_OP("topi.add", topi::add) + .TOPI_DEF_BCAST_OP("topi.subtract", topi::subtract) + .TOPI_DEF_BCAST_OP("topi.multiply", topi::multiply) + .TOPI_DEF_BCAST_OP("topi.divide", topi::divide) + .TOPI_DEF_BCAST_OP("topi.floor_divide", topi::floor_divide) + .TOPI_DEF_BCAST_OP("topi.log_add_exp", topi::log_add_exp) + .TOPI_DEF_BCAST_OP("topi.mod", topi::mod) + .TOPI_DEF_BCAST_OP("topi.floor_mod", topi::floor_mod) + .TOPI_DEF_BCAST_OP("topi.maximum", topi::maximum) + .TOPI_DEF_BCAST_OP("topi.minimum", topi::minimum) + .TOPI_DEF_BCAST_OP("topi.power", topi::power) + .TOPI_DEF_BCAST_OP("topi.left_shift", topi::left_shift) + .TOPI_DEF_BCAST_OP("topi.logical_and", topi::logical_and) + .TOPI_DEF_BCAST_OP("topi.logical_or", topi::logical_or) + .TOPI_DEF_BCAST_OP("topi.logical_xor", topi::logical_xor) + .TOPI_DEF_BCAST_OP("topi.bitwise_and", topi::bitwise_and) + .TOPI_DEF_BCAST_OP("topi.bitwise_or", topi::bitwise_or) + .TOPI_DEF_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor) + .TOPI_DEF_BCAST_OP("topi.right_shift", topi::right_shift) + .TOPI_DEF_BCAST_OP("topi.greater", topi::greater) + .TOPI_DEF_BCAST_OP("topi.less", topi::less) + .TOPI_DEF_BCAST_OP("topi.equal", topi::equal) + .TOPI_DEF_BCAST_OP("topi.not_equal", topi::not_equal) + .TOPI_DEF_BCAST_OP("topi.greater_equal", topi::greater_equal) + .TOPI_DEF_BCAST_OP("topi.less_equal", topi::less_equal); }); } // namespace topi diff --git a/tests/python/contrib/test_hexagon/README_RPC.md b/tests/python/contrib/test_hexagon/README_RPC.md index 955cd58dc2ae..8d185fcbebeb 100644 --- a/tests/python/contrib/test_hexagon/README_RPC.md +++ b/tests/python/contrib/test_hexagon/README_RPC.md @@ -80,12 +80,15 @@ Which eventually jumps to the following line in C++, which creates a RPC client [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129) ```cpp -TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - auto url = args[0].cast(); - int port = args[1].cast(); - auto key = args[2].cast(); - *rv = RPCClientConnect(url, port, key, - ffi::PackedArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("rpc.Connect", [](ffi::PackedArgs args, ffi::Any* rv) { + auto url = args[0].cast(); + int port = args[1].cast(); + auto key = args[2].cast(); + *rv = RPCClientConnect(url, port, key, + ffi::PackedArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); + }); }); ``` @@ -94,8 +97,11 @@ TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, [https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106) ```cpp -TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed( + "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { auto session_name = args[0].cast(); int remote_stack_size_bytes = args[1].cast(); HexagonTransportChannel* hexagon_channel = @@ -105,6 +111,7 @@ TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") auto sess = CreateClientSession(ep); *rv = CreateRPCSessionModule(sess); }); +}); ``` `HexagonTransportChannel` is the one that actually knows how to talk to Hexagon. It uses functions such as `hexagon_rpc_send`, `hexagon_rpc_receive` defined in