Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/arch/device_target_interactions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ then be registered with the following steps.

#. Register the function to the tvm registry::

TVM_FFI_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("device_api.foo", FooDeviceAPI::Global);
});

.. _base.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h

Expand All @@ -164,7 +167,7 @@ then be registered with the following steps.

#. Add a case in ``DeviceName`` in `device_api.h`_ to convert from the
enum value to a string representation. This string representation
should match the name given to ``TVM_FFI_REGISTER_GLOBAL``.
should match the name given to ``GlobalDef().def``.

#. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of
:py:class:`tvm.runtime.Device` for the new enum value.
Expand Down Expand Up @@ -225,7 +228,10 @@ the same name as was used in the ``TVM_REGISTER_TARGET_KIND``
definition above. ::

tvm::runtime::Module GeneratorFooCode(IRModule mod, Target target);
TVM_FFI_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.foo", GeneratorFooCode);
});

The code generator takes two arguments. The first is the ``IRModule``
to compile, and the second is the ``Target`` that describes the device
Expand Down
6 changes: 4 additions & 2 deletions docs/arch/pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,10 @@ Python when needed.
return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}

TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant")
.set_body_typed(FoldConstant);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant);
});

} // namespace transform

Expand Down
16 changes: 10 additions & 6 deletions docs/arch/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ The following example registers PackedFunc in C++ and calls from python.
.. code:: c

// register a global packed function in c++
TVM_FFI_REGISTER_GLOBAL("myadd")
.set_body_packed(MyAdd);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("myadd", MyAdd);
});

.. code:: python

Expand Down Expand Up @@ -110,10 +112,12 @@ we can pass functions from python (as PackedFunc) to C++.

.. code:: c

TVM_FFI_REGISTER_GLOBAL("callhello")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
PackedFunc f = args[0];
f("hello world");
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("callhello", [](ffi::PackedArgs args, ffi::Any* rv) {
ffi::Function f = args[0].cast<ffi::Function>();
f("hello world");
});
});

.. code:: python
Expand Down
144 changes: 0 additions & 144 deletions ffi/include/tvm/ffi/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -756,135 +756,6 @@ struct TypeTraits<TypedFunction<FType>> : public TypeTraitsBase {
TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo<FType>::Sig(); }
};

/*! \brief Registry for global function */
class Function::Registry {
public:
/*! \brief constructor */
explicit Registry(const char* name) : name_(name) {}

/*!
* \brief Set body to be to use the packed convention.
*
* \tparam FLambda The signature of the function.
* \param f The body of the function.
*/
template <typename FLambda>
Registry& set_body_packed(FLambda f) {
return Register(ffi::Function::FromPacked(f));
}
/*!
* \brief set the body of the function to the given function.
* Note that this will ignore default arg values and always require all arguments to be
* provided.
*
* \code
*
* int multiply(int x, int y) {
* return x * y;
* }
*
* TVM_FFI_REGISTER_GLOBAL("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* // will have type int(int, int)
* TVM_FFI_REGISTER_GLOBAL("sub")
* .set_body_typed([](int a, int b) -> int { return a - b; });
*
* \endcode
*
* \param f The function to forward to.
* \tparam FLambda The signature of the function.
*/
template <typename FLambda>
Registry& set_body_typed(FLambda f) {
return Register(Function::FromTyped(f, name_));
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be
* provided.
*
* \code
*
* // objectRef subclass:
* struct Example : ObjectRef {
* int DoThing(int x);
* }
* TVM_FFI_REGISTER_GLOBAL("Example_DoThing")
* .set_body_method(&Example::DoThing); // will have type int(self, int)
*
* // Object subclass:
* struct Example : Object {
* int DoThing(int x);
* }
*
* TVM_FFI_REGISTER_GLOBAL("Example_DoThing")
* .set_body_method(&Example::DoThing); // will have type int(self, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template <typename T, typename R, typename... Args>
Registry& set_body_method(R (T::*f)(Args...)) {
static_assert(std::is_base_of_v<ObjectRef, T> || std::is_base_of_v<Object, T>,
"T must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, T>) {
auto fwrap = [f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
if constexpr (std::is_base_of_v<Object, T>) {
auto fwrap = [f](const T* target, Args... params) -> R {
// call method pointer
return (const_cast<T*>(target)->*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
return *this;
}

template <typename T, typename R, typename... Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
static_assert(std::is_base_of_v<ObjectRef, T> || std::is_base_of_v<Object, T>,
"T must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, T>) {
auto fwrap = [f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
if constexpr (std::is_base_of_v<Object, T>) {
auto fwrap = [f](const T* target, Args... params) -> R {
// call method pointer
return (target->*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
return *this;
}

protected:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
Registry& Register(Function f) {
Function::SetGlobal(name_, f);
return *this;
}

/*! \brief name of the function */
const char* name_;
};

/*!
* \brief helper function to get type index from key
*/
Expand All @@ -895,21 +766,6 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) {
return type_index;
}

#define TVM_FFI_FUNC_REG_VAR_DEF \
TVM_FFI_ATTRIBUTE_UNUSED static inline ::tvm::ffi::Function::Registry& __##TVMFFIFuncReg

/*!
* \brief Register a function globally.
* \code
* TVM_FFI_REGISTER_GLOBAL("MyAdd")
* .set_body_typed([](int a, int b) {
* return a + b;
* });
* \endcode
*/
#define TVM_FFI_REGISTER_GLOBAL(OpName) \
TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::ffi::Function::Registry(OpName)

/*!
* \brief Export typed function as a SafeCallType symbol.
*
Expand Down
7 changes: 5 additions & 2 deletions include/tvm/runtime/profiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ class Timer : public ObjectRef {
* };
* TVM_REGISTER_OBJECT_TYPE(CPUTimerNode);
*
* TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) {
* return Timer(make_object<CPUTimerNode>());
* TVM_FFI_STATIC_INIT_BLOCK({
* namespace refl = tvm::ffi::reflection;
* refl::GlobalDef().def("profiling.timer.cpu", [](Device dev) {
* return Timer(make_object<CPUTimerNode>());
* });
* });
* \endcode
*/
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/contrib/msc/plugin/codegen/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,17 @@ class TVMUtils {
return cuda_dev;
}
};

#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF(FuncName, Body) \
TVM_FFI_STATIC_INIT_BLOCK({ \
tvm::ffi::reflection::GlobalDef().def(FuncName, Body); \
})

#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED(FuncName, Body) \
TVM_FFI_STATIC_INIT_BLOCK({ \
tvm::ffi::reflection::GlobalDef().def_packed(FuncName, Body); \
})

#endif // PLUGIN_SUPPORT_TVM
"""

Expand Down Expand Up @@ -1101,6 +1112,7 @@ def get_plugin_utils_h_code() -> str:

#ifdef PLUGIN_SUPPORT_TVM
#include <tvm/relax/expr.h>
#include <tvm/ffi/reflection/registry.h>

#include "tvm/../../src/contrib/msc/core/transform/layout_utils.h"
#include "tvm/../../src/contrib/msc/core/utils.h"
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,7 +2087,7 @@ def extern(
out: OutType,
) -> OutType:
"""Invoke an extern function during runtime. The extern function must be registered with the "
TVM runtime using `TVM_FFI_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python).
TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_func` (Python).

Parameters
----------
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/runtime/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@
"""FFI APIs for tvm.runtime"""
import tvm.ffi

# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "runtime" prefix.
# e.g. TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile")
# Exports functions registered in runtime namespace.
tvm.ffi._init_api("runtime", __name__)
4 changes: 1 addition & 3 deletions python/tvm/runtime/_ffi_node_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
# The implementations below are default ones when the corresponding
# functions are not available in the runtime only mode.
# They will be overriden via _init_api to the ones registered
# via TVM_FFI_REGISTER_GLOBAL in the compiler mode.
def AsRepr(obj):
return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"

Expand All @@ -45,6 +44,5 @@ def LoadJSON(json_str):
raise RuntimeError("Do not support object serialization in runtime only mode")


# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "node" prefix.
# e.g. TVM_FFI_REGISTER_GLOBAL("node.AsRepr")
# Exports functions registered in node namespace.
tvm.ffi._init_api("node", __name__)
9 changes: 3 additions & 6 deletions src/contrib/msc/plugin/tvm_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,12 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) {
stack_.func_end("infer_output");

// register funcs
stack_.func_call("TVM_FFI_REGISTER_GLOBAL")
stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF")
.call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" + plugin->name))
.method_call("set_body_typed")
.call_arg("InferStructInfo" + plugin->name)
.line()
.func_call("TVM_FFI_REGISTER_GLOBAL")
.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF")
.call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name))
.method_call("set_body_typed")
.call_arg("InferLayout" + plugin->name)
.line();
}
Expand Down Expand Up @@ -262,9 +260,8 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) {
CodeGenCompute(plugin, "cpu");
stack_.cond_end().func_end();
// register the compute
stack_.func_call("TVM_FFI_REGISTER_GLOBAL")
stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED")
.call_arg(DocUtils::ToStr(plugin->name))
.method_call("set_body")
.call_arg(func_name)
.line();
}
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ std::tuple<ArgTypes...> GetArgStructInfo(const Call& call, const BlockBuilder& c
static const Op& op = Op::Get("relax." OpRegName); \
return Call(op, {std::move(x)}, Attrs(), {}); \
} \
TVM_FFI_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName)
TVM_FFI_STATIC_INIT_BLOCK( \
{ tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName); })

/************ Utilities ************/

Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/tensor/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ namespace relax {
static const Op& op = Op::Get("relax." #OpName); \
return Call(op, {x1, x2}, Attrs(), {}); \
} \
TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \
TVM_FFI_STATIC_INIT_BLOCK( \
{ tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \
TVM_REGISTER_OP("relax." #OpName) \
.set_num_inputs(2) \
.add_argument("x1", "Tensor", "The first input tensor.") \
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/tensor/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx
static const Op& op = Op::Get("relax." #OpName); \
return Call(op, {std::move(x)}, Attrs(attrs)); \
} \
TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \
TVM_FFI_STATIC_INIT_BLOCK( \
{ tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \
TVM_REGISTER_OP("relax." #OpName) \
.set_num_inputs(1) \
.add_argument("x", "Tensor", "The input data tensor") \
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/tensor/statistical.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ namespace relax {
static const Op& op = Op::Get("relax." #OpName); \
return Call(op, {std::move(x)}, Attrs{attrs}, {}); \
} \
TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \
TVM_FFI_STATIC_INIT_BLOCK( \
{ tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \
TVM_REGISTER_OP("relax." #OpName) \
.set_num_inputs(1) \
.add_argument("x", "Tensor", "The input data tensor") \
Expand Down
Loading
Loading