Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 125 additions & 8 deletions ffi/include/tvm/ffi/reflection/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,47 @@ class ReflectionDefBase {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}

template <typename Class, typename R, typename... Args>
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...)) {
auto fwrap = [func](const Class* target, Args... params) -> R {
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
"Class must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
auto fwrap = [func](Class target, Args... params) -> R {
// call method pointer
return (target.*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}

if constexpr (std::is_base_of_v<Object, Class>) {
auto fwrap = [func](const Class* target, Args... params) -> R {
// call method pointer
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}
}

template <typename Class, typename R, typename... Args>
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
auto fwrap = [func](const Class* target, Args... params) -> R {
return (target->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
static_assert(std::is_base_of_v<ObjectRef, Class> || std::is_base_of_v<Object, Class>,
"Class must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
auto fwrap = [func](const Class target, Args... params) -> R {
// call method pointer
return (target.*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}

if constexpr (std::is_base_of_v<Object, Class>) {
auto fwrap = [func](const Class* target, Args... params) -> R {
// call method pointer
return (target->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}
}

template <typename Class, typename Func>
Expand All @@ -140,6 +167,96 @@ class ReflectionDefBase {
}
};

class GlobalDef : public ReflectionDefBase {
public:
/*
* \brief Define a global function.
*
* \tparam Func The function type.
* \tparam Extra The extra arguments.
*
* \param name The name of the function.
* \param func The function to be registered.
* \param extra The extra arguments that can be docstring.
*
* \return The reflection definition.
*/
template <typename Func, typename... Extra>
GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func), std::string(name)),
std::forward<Extra>(extra)...);
return *this;
}

/*
* \brief Define a global function in ffi::PackedArgs format.
*
* \tparam Func The function type.
* \tparam Extra The extra arguments.
*
* \param name The name of the function.
* \param func The function to be registered.
* \param extra The extra arguments that can be docstring.
*
* \return The reflection definition.
*/
template <typename Func, typename... Extra>
GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
RegisterFunc(name, ffi::Function::FromPacked(func), std::forward<Extra>(extra)...);
return *this;
}

/*
* \brief Expose a class method as a global function.
*
* An argument will be added to the first position if the function is not static.
*
* \tparam Class The class type.
* \tparam Func The function type.
*
* \param name The name of the method.
* \param func The function to be registered.
*
* \return The reflection definition.
*/
template <typename Func, typename... Extra>
GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
RegisterFunc(name, GetMethod_(std::string(name), std::forward<Func>(func)),
std::forward<Extra>(extra)...);
return *this;
}

private:
template <typename Func>
static TVM_FFI_INLINE Function GetMethod_(std::string name, Func&& func) {
return ffi::Function::FromTyped(std::forward<Func>(func), name);
}

template <typename Class, typename R, typename... Args>
static TVM_FFI_INLINE Function GetMethod_(std::string name, R (Class::*func)(Args...) const) {
return GetMethod<Class>(std::string(name), func);
}

template <typename Class, typename R, typename... Args>
static TVM_FFI_INLINE Function GetMethod_(std::string name, R (Class::*func)(Args...)) {
return GetMethod<Class>(std::string(name), func);
}

template <typename... Extra>
void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) {
TVMFFIMethodInfo info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
info.type_schema = TVMFFIByteArray{nullptr, 0};
info.flags = 0;
// obtain the method function
info.method = AnyView(func).CopyToTVMFFIAny();
// apply method info traits
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0));
}
};

template <typename Class>
class ObjectDef : public ReflectionDefBase {
public:
Expand Down
64 changes: 29 additions & 35 deletions ffi/src/ffi/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,11 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>

namespace tvm {
namespace ffi {

TVM_FFI_REGISTER_GLOBAL("ffi.Array").set_body_packed([](ffi::PackedArgs args, Any* ret) {
*ret = Array<Any>(args.data(), args.data() + args.size());
});

TVM_FFI_REGISTER_GLOBAL("ffi.ArrayGetItem")
.set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); });

TVM_FFI_REGISTER_GLOBAL("ffi.ArraySize").set_body_typed([](const ffi::ArrayObj* n) -> int64_t {
return static_cast<int64_t>(n->size());
});
// Map
TVM_FFI_REGISTER_GLOBAL("ffi.Map").set_body_packed([](ffi::PackedArgs args, Any* ret) {
TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
Map<Any, Any> data;
for (int i = 0; i < args.size(); i += 2) {
data.Set(args[i], args[i + 1]);
}
*ret = data;
});

TVM_FFI_REGISTER_GLOBAL("ffi.MapSize").set_body_typed([](const ffi::MapObj* n) -> int64_t {
return static_cast<int64_t>(n->size());
});

TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem")
.set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); });

TVM_FFI_REGISTER_GLOBAL("ffi.MapCount")
.set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); });

// Favor struct outside function scope as MSVC may have bug for in fn scope struct.
class MapForwardIterFunctor {
public:
Expand Down Expand Up @@ -86,10 +57,33 @@ class MapForwardIterFunctor {
ffi::MapObj::iterator end_;
};

TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor")
.set_body_typed([](const ffi::MapObj* n) -> ffi::Function {
return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end()));
});

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("ffi.Array",
[](ffi::PackedArgs args, Any* ret) {
*ret = Array<Any>(args.data(), args.data() + args.size());
})
.def("ffi.ArrayGetItem", [](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); })
.def("ffi.ArraySize",
[](const ffi::ArrayObj* n) -> int64_t { return static_cast<int64_t>(n->size()); })
.def_packed("ffi.Map",
[](ffi::PackedArgs args, Any* ret) {
TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
Map<Any, Any> data;
for (int i = 0; i < args.size(); i += 2) {
data.Set(args[i], args[i + 1]);
}
*ret = data;
})
.def("ffi.MapSize",
[](const ffi::MapObj* n) -> int64_t { return static_cast<int64_t>(n->size()); })
.def("ffi.MapGetItem", [](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); })
.def("ffi.MapCount",
[](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); })
.def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> ffi::Function {
return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end()));
});
});
} // namespace ffi
} // namespace tvm
54 changes: 27 additions & 27 deletions ffi/src/ffi/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ffi/string.h>

namespace tvm {
Expand Down Expand Up @@ -307,31 +308,30 @@ int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol) {
TVM_FFI_SAFE_CALL_END();
}

TVM_FFI_REGISTER_GLOBAL("ffi.FunctionRemoveGlobal")
.set_body_typed([](const tvm::ffi::String& name) -> bool {
return tvm::ffi::GlobalFunctionTable::Global()->Remove(name);
});

TVM_FFI_REGISTER_GLOBAL("ffi.FunctionListGlobalNamesFunctor").set_body_typed([]() {
// NOTE: we return functor instead of array
// so list global function names do not need to depend on array
// this is because list global function names usually is a core api that happens
// before array ffi functions are available.
tvm::ffi::Array<tvm::ffi::String> names = tvm::ffi::GlobalFunctionTable::Global()->ListNames();
auto return_functor = [names](int64_t i) -> tvm::ffi::Any {
if (i < 0) {
return names.size();
} else {
return names[i];
}
};
return tvm::ffi::Function::FromTyped(return_functor);
});

TVM_FFI_REGISTER_GLOBAL("ffi.String").set_body_typed([](tvm::ffi::String val) -> tvm::ffi::String {
return val;
});

TVM_FFI_REGISTER_GLOBAL("ffi.Bytes").set_body_typed([](tvm::ffi::Bytes val) -> tvm::ffi::Bytes {
return val;
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("ffi.FunctionRemoveGlobal",
[](const tvm::ffi::String& name) -> bool {
return tvm::ffi::GlobalFunctionTable::Global()->Remove(name);
})
.def("ffi.FunctionListGlobalNamesFunctor",
[]() {
// NOTE: we return functor instead of array
// so list global function names do not need to depend on array
// this is because list global function names usually is a core api that happens
// before array ffi functions are available.
tvm::ffi::Array<tvm::ffi::String> names =
tvm::ffi::GlobalFunctionTable::Global()->ListNames();
auto return_functor = [names](int64_t i) -> tvm::ffi::Any {
if (i < 0) {
return names.size();
} else {
return names[i];
}
};
return tvm::ffi::Function::FromTyped(return_functor);
})
.def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return val; })
.def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return val; });
});
26 changes: 15 additions & 11 deletions ffi/src/ffi/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,27 @@
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/ndarray.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>

namespace tvm {
namespace ffi {

// Shape
TVM_FFI_REGISTER_GLOBAL("ffi.Shape").set_body_packed([](ffi::PackedArgs args, Any* ret) {
int64_t* mutable_data;
ObjectPtr<ShapeObj> shape = details::MakeEmptyShape(args.size(), &mutable_data);
for (int i = 0; i < args.size(); ++i) {
if (auto opt_int = args[i].try_cast<int64_t>()) {
mutable_data[i] = *opt_int;
} else {
TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments";
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("ffi.Shape", [](ffi::PackedArgs args, Any* ret) {
int64_t* mutable_data;
ObjectPtr<ShapeObj> shape = details::MakeEmptyShape(args.size(), &mutable_data);
for (int i = 0; i < args.size(); ++i) {
if (auto opt_int = args[i].try_cast<int64_t>()) {
mutable_data[i] = *opt_int;
} else {
TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments";
}
}
}
*ret = Shape(shape);
*ret = Shape(shape);
});
});

} // namespace ffi
} // namespace tvm

Expand Down
6 changes: 5 additions & 1 deletion ffi/src/ffi/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ffi/string.h>

#include <memory>
Expand Down Expand Up @@ -404,7 +405,10 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
*ret = ObjectRef(ptr);
}

TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", MakeObjectFromPackedArgs);
});

} // namespace ffi
} // namespace tvm
Expand Down
Loading
Loading