Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ffi/include/tvm/ffi/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ObjAllocatorBase {
* \param args The arguments.
*/
template <typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
ObjectPtr<T> make_object(Args&&... args) {
using Handler = typename Derived::template Handler<T>;
static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...);
Expand All @@ -89,7 +89,7 @@ class ObjAllocatorBase {
* \param args The arguments.
*/
template <typename ArrayType, typename ElemType, typename... Args>
inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
using Handler = typename Derived::template ArrayHandler<ArrayType, ElemType>;
static_assert(std::is_base_of<Object, ArrayType>::value,
"make_inplace_array can only be used to create Object");
Expand All @@ -109,7 +109,9 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> {
template <typename T>
class Handler {
public:
using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
struct alignas(T) StorageType {
char data[sizeof(T)];
};

template <typename... Args>
static T* New(SimpleObjAllocator*, Args&&... args) {
Expand Down
159 changes: 98 additions & 61 deletions ffi/include/tvm/ffi/reflection/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DefaultValue : public FieldInfoTrait {
public:
explicit DefaultValue(Any value) : value_(value) {}

void Apply(TVMFFIFieldInfo* info) const {
TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
info->default_value = AnyView(value_).CopyToTVMFFIAny();
info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
}
Expand All @@ -65,16 +65,89 @@ class DefaultValue : public FieldInfoTrait {
* \returns The byteoffset
*/
template <typename Class, typename T>
inline int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
int64_t field_offset_to_class =
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
}

class ReflectionDefBase {
protected:
template <typename T>
static int FieldGetter(void* field, TVMFFIAny* result) {
TVM_FFI_SAFE_CALL_BEGIN();
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
TVM_FFI_SAFE_CALL_END();
}

template <typename T>
static int FieldSetter(void* field, const TVMFFIAny* value) {
TVM_FFI_SAFE_CALL_BEGIN();
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
TVM_FFI_SAFE_CALL_END();
}

template <typename T>
static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
TVM_FFI_SAFE_CALL_BEGIN();
ObjectPtr<T> obj = make_object<T>();
*result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
TVM_FFI_SAFE_CALL_END();
}

template <typename T>
static TVM_FFI_INLINE void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
value.Apply(info);
}
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}

template <typename T>
static TVM_FFI_INLINE void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}

template <typename T>
static TVM_FFI_INLINE void ApplyExtraInfoTrait(TVMFFITypeExtraInfo* info, const T& value) {
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}
template <typename Class, typename R, typename... Args>
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...)) {
auto fwrap = [func](const Class* target, Args... params) -> R {
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}

template <typename Class, typename R, typename... Args>
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
auto fwrap = [func](const Class* target, Args... params) -> R {
return (target->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}

template <typename Class, typename Func>
static TVM_FFI_INLINE Function GetMethod(std::string name, Func&& func) {
return ffi::Function::FromTyped(std::forward<Func>(func), name);
}
};

template <typename Class>
class ObjectDef {
class ObjectDef : public ReflectionDefBase {
public:
ObjectDef() : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {}
template <typename... ExtraArgs>
explicit ObjectDef(ExtraArgs&&... extra_args)
: type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {
RegisterExtraInfo(std::forward<ExtraArgs>(extra_args)...);
}

/*!
* \brief Define a readonly field.
Expand All @@ -90,7 +163,7 @@ class ObjectDef {
* \return The reflection definition.
*/
template <typename T, typename... Extra>
ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) {
TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) {
RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
return *this;
}
Expand All @@ -109,7 +182,8 @@ class ObjectDef {
* \return The reflection definition.
*/
template <typename T, typename... Extra>
ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) {
TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) {
static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields");
RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
return *this;
}
Expand All @@ -127,7 +201,7 @@ class ObjectDef {
* \return The reflection definition.
*/
template <typename Func, typename... Extra>
ObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
RegisterMethod(name, false, std::forward<Func>(func), std::forward<Extra>(extra)...);
return *this;
}
Expand All @@ -145,12 +219,26 @@ class ObjectDef {
* \return The reflection definition.
*/
template <typename Func, typename... Extra>
ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
RegisterMethod(name, true, std::forward<Func>(func), std::forward<Extra>(extra)...);
return *this;
}

private:
template <typename... ExtraArgs>
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
TVMFFITypeExtraInfo info;
info.total_size = sizeof(Class);
info.creator = nullptr;
info.doc = TVMFFIByteArray{nullptr, 0};
if constexpr (std::is_default_constructible_v<Class>) {
info.creator = ObjectCreatorDefault<Class>;
}
// apply extra info traits
((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info));
}

template <typename T, typename... ExtraArgs>
void RegisterField(const char* name, T Class::*field_ptr, bool writable,
ExtraArgs&&... extra_args) {
Expand Down Expand Up @@ -178,30 +266,6 @@ class ObjectDef {
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
}

template <typename T>
static int FieldGetter(void* field, TVMFFIAny* result) {
TVM_FFI_SAFE_CALL_BEGIN();
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
TVM_FFI_SAFE_CALL_END();
}

template <typename T>
static int FieldSetter(void* field, const TVMFFIAny* value) {
TVM_FFI_SAFE_CALL_BEGIN();
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
TVM_FFI_SAFE_CALL_END();
}

template <typename T>
static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
value.Apply(info);
}
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}

// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
Expand All @@ -214,41 +278,14 @@ class ObjectDef {
info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
}
// obtain the method function
Function method = GetMethod(std::string(type_key_) + "." + name, std::forward<Func>(func));
Function method =
GetMethod<Class>(std::string(type_key_) + "." + name, std::forward<Func>(func));
info.method = AnyView(method).CopyToTVMFFIAny();
// apply method info traits
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
}

template <typename T>
static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
}

template <typename R, typename... Args>
static Function GetMethod(std::string name, R (Class::*func)(Args...)) {
auto fwrap = [func](const Class* target, Args... params) -> R {
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}

template <typename R, typename... Args>
static Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
auto fwrap = [func](const Class* target, Args... params) -> R {
return (target->*func)(std::forward<Args>(params)...);
};
return ffi::Function::FromTyped(fwrap, name);
}

template <typename Func>
static Function GetMethod(std::string name, Func&& func) {
return ffi::Function::FromTyped(std::forward<Func>(func), name);
}

int32_t type_index_;
const char* type_key_;
};
Expand Down
12 changes: 12 additions & 0 deletions ffi/include/tvm/ffi/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,18 @@ class String : public ObjectRef {
return Bytes::memncmp(data(), other, size(), std::strlen(other));
}

/*!
* \brief Compares this to other
*
* \param other The TVMFFIByteArray to compare with.
*
* \return zero if both char sequences compare equal. negative if this appear
* before other, positive otherwise.
*/
int compare(const TVMFFIByteArray& other) const {
return Bytes::memncmp(data(), other.data, size(), other.size);
}

/*!
* \brief Returns a pointer to the char array in the string.
*
Expand Down
77 changes: 77 additions & 0 deletions ffi/src/ffi/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,83 @@ class TypeTable {
Map<String, int64_t> type_key2index_;
std::vector<Any> any_pool_;
};

void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
String type_key = args[0].cast<String>();
TVM_FFI_ICHECK(args.size() % 2 == 1);

int32_t type_index;
TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
if (type_info == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Cannot find type `" << type_key << "`";
}

if (type_info->extra_info == nullptr || type_info->extra_info->creator == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << type_key << "` does not support reflection creation";
}
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info->extra_info->creator(&handle));
ObjectPtr<Object> ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));

std::vector<String> keys;
std::vector<bool> keys_found;

for (int i = 1; i < args.size(); i += 2) {
keys.push_back(args[i].cast<String>());
}
keys_found.resize(keys.size(), false);

auto search_field = [&](const TVMFFIByteArray& field_name) {
for (size_t i = 0; i < keys.size(); ++i) {
if (keys_found[i]) continue;
if (keys[i].compare(field_name) == 0) {
return i;
}
}
return keys.size();
};

auto update_fields = [&](const TVMFFITypeInfo* tinfo) {
for (int i = 0; i < tinfo->num_fields; ++i) {
const TVMFFIFieldInfo* field_info = tinfo->fields + i;
size_t arg_index = search_field(field_info->name);
void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset;
if (arg_index < keys.size()) {
AnyView field_value = args[arg_index * 2 + 2];
field_info->setter(field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
keys_found[arg_index] = true;
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
field_info->setter(field_addr, &(field_info->default_value));
} else {
TVM_FFI_THROW(TypeError) << "Required field `"
<< String(field_info->name.data, field_info->name.size)
<< "` not set in type `" << type_key << "`";
}
}
};

// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject);
for (int i = 1; i < type_info->type_depth; ++i) {
update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i]));
}
update_fields(type_info);

for (size_t i = 0; i < keys.size(); ++i) {
if (!keys_found[i]) {
TVM_FFI_THROW(TypeError) << "Type `" << type_key << "` does not have field `" << keys[i]
<< "`";
}
}
*ret = ObjectRef(ptr);
}

TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs);

} // namespace ffi
} // namespace tvm

Expand Down
Loading
Loading