Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ffi/include/tvm/ffi/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,20 @@ namespace ffi {
*/
template <typename RefType, typename ObjectType>
inline RefType GetRef(const ObjectType* ptr) {
static_assert(std::is_base_of_v<typename RefType::ContainerType, ObjectType>,
using ContainerType = typename RefType::ContainerType;
static_assert(std::is_base_of_v<ContainerType, ObjectType>,
"Can only cast to the ref of same container type");

if constexpr (is_optional_type_v<RefType> || RefType::_type_is_nullable) {
if (ptr == nullptr) {
return RefType(ObjectPtr<Object>(nullptr));
return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(nullptr);
}
} else {
TVM_FFI_ICHECK_NOTNULL(ptr);
}
return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
const_cast<Object*>(static_cast<const Object*>(ptr))));
return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(
details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
const_cast<Object*>(static_cast<const Object*>(ptr))));
}

/*!
Expand Down
4 changes: 4 additions & 0 deletions ffi/include/tvm/ffi/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ class Array : public ObjectRef {
/*! \brief The value type of the array */
using value_type = T;
// constructors
/*!
* \brief Construct an Array with UnsafeInit
*/
explicit Array(UnsafeInit tag) : ObjectRef(tag) {}
/*!
* \brief default constructor
*/
Expand Down
4 changes: 4 additions & 0 deletions ffi/include/tvm/ffi/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,10 @@ class Map : public ObjectRef {
using mapped_type = V;
/*! \brief The iterator type of the map */
class iterator;
/*!
* \brief Construct an Map with UnsafeInit
*/
explicit Map(UnsafeInit tag) : ObjectRef(tag) {}
/*!
* \brief default constructor
*/
Expand Down
17 changes: 15 additions & 2 deletions ffi/include/tvm/ffi/container/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType begin, IterType end
return p;
}

TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(int64_t ndim, int64_t* shape) {
TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(const int64_t* data, int64_t ndim) {
int64_t* strides_data;
ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
int64_t stride = 1;
for (int i = ndim - 1; i >= 0; --i) {
strides_data[i] = stride;
stride *= shape[i];
stride *= data[i];
}
return strides;
}
Expand Down Expand Up @@ -150,6 +150,16 @@ class Shape : public ObjectRef {
Shape(std::vector<int64_t> other) // NOLINT(*)
: ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {}

/*!
* \brief Create a strides from a shape.
* \param data The shape data.
* \param ndim The number of dimensions.
* \return The strides.
*/
static Shape StridesFromShape(const int64_t* data, int64_t ndim) {
return Shape(details::MakeStridesFromShape(data, ndim));
}

/*!
* \brief Return the data pointer
*
Expand Down Expand Up @@ -204,6 +214,9 @@ class Shape : public ObjectRef {
/// \cond Doxygen_Suppress
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj);
/// \endcond

private:
explicit Shape(ObjectPtr<ShapeObj> ptr) : ObjectRef(ptr) {}
};

inline std::ostream& operator<<(std::ostream& os, const Shape& shape) {
Expand Down
4 changes: 2 additions & 2 deletions ffi/include/tvm/ffi/container/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class TensorObjFromNDAlloc : public TensorObj {
this->ndim = static_cast<int>(shape.size());
this->dtype = dtype;
this->shape = const_cast<int64_t*>(shape.data());
Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape));
Shape strides = Shape::StridesFromShape(this->shape, this->ndim);
this->strides = const_cast<int64_t*>(strides.data());
this->byte_offset = 0;
this->shape_data_ = std::move(shape);
Expand All @@ -224,7 +224,7 @@ class TensorObjFromDLPack : public TensorObj {
explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) {
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
if (tensor_->dl_tensor.strides == nullptr) {
Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim);
this->strides = const_cast<int64_t*>(strides.data());
this->strides_data_ = std::move(strides);
}
Expand Down
14 changes: 6 additions & 8 deletions ffi/include/tvm/ffi/container/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class Tuple : public ObjectRef {
"All types used in Tuple<...> must be compatible with Any");
/*! \brief Default constructor */
Tuple() : ObjectRef(MakeDefaultTupleNode()) {}
/*!
* \brief Constructor with UnsafeInit
*/
explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {}
/*! \brief Copy constructor */
Tuple(const Tuple<Types...>& other) : ObjectRef(other) {}
/*! \brief Move constructor */
Expand Down Expand Up @@ -128,13 +132,6 @@ class Tuple : public ObjectRef {
return *this;
}

/*!
* \brief Constructor ObjectPtr
* \param ptr The ObjectPtr
* \tparam The enable_if_t type
*/
explicit Tuple(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}

/*!
* \brief Get I-th element of the tuple
*
Expand Down Expand Up @@ -283,7 +280,8 @@ struct TypeTraits<Tuple<Types...>> : public ObjectRefTypeTraitsBase<Tuple<Types.
Array<Any> arr = TypeTraits<Array<Any>>::CopyFromAnyViewAfterCheck(src);
Any* ptr = arr.CopyOnWrite()->MutableBegin();
if (TryConvertElements<0, Types...>(ptr)) {
return Tuple<Types...>(details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(arr));
return details::ObjectUnsafe::ObjectRefFromObjectPtr<Tuple<Types...>>(
details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(arr));
}
return std::nullopt;
}
Expand Down
2 changes: 1 addition & 1 deletion ffi/include/tvm/ffi/container/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class VariantBase<true> : public ObjectRef {
explicit VariantBase(const T& other) : ObjectRef(other) {}
template <typename T>
explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
explicit VariantBase(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {}
explicit VariantBase(Any other)
: ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(other))) {}

Expand Down
17 changes: 16 additions & 1 deletion ffi/include/tvm/ffi/extra/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Module;

/*!
* \brief A module that can dynamically load ffi::Functions or exportable source code.
* \sa Module
*/
class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object {
public:
Expand Down Expand Up @@ -168,6 +169,16 @@ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object {

/*!
* \brief Reference to module object.
*
* When invoking a function on a ModuleObj, such as GetFunction,
* use operator-> to get the ModuleObj pointer and invoke the member functions.
*
* \code
* ffi::Module mod = ffi::Module::LoadFromFile("path/to/module.so");
* ffi::Function func = mod->GetFunction(name);
* \endcode
*
* \sa ModuleObj which contains most of the function implementations.
*/
class Module : public ObjectRef {
public:
Expand Down Expand Up @@ -202,7 +213,11 @@ class Module : public ObjectRef {
*/
kCompilationExportable = 0b100
};

/*!
* \brief Constructor from ObjectPtr<ModuleObj>.
* \param ptr The object pointer.
*/
explicit Module(ObjectPtr<ModuleObj> ptr) : ObjectRef(ptr) { TVM_FFI_ICHECK(ptr != nullptr); }
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
Expand Down
2 changes: 1 addition & 1 deletion ffi/include/tvm/ffi/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class Function : public ObjectRef {
TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle));
if (handle != nullptr) {
return Function(
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<Object*>(handle)));
details::ObjectUnsafe::ObjectPtrFromOwned<FunctionObj>(static_cast<Object*>(handle)));
} else {
return std::nullopt;
}
Expand Down
2 changes: 1 addition & 1 deletion ffi/include/tvm/ffi/function_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ TVM_FFI_INLINE static Error MoveFromSafeCallRaised() {
TVMFFIObjectHandle handle;
TVMFFIErrorMoveFromRaised(&handle);
// handle is owned by caller
return Error(
return details::ObjectUnsafe::ObjectRefFromObjectPtr<Error>(
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle)));
}

Expand Down
62 changes: 52 additions & 10 deletions ffi/include/tvm/ffi/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ using TypeIndex = TVMFFITypeIndex;
*/
using TypeInfo = TVMFFITypeInfo;

/*!
* \brief Helper tag to explicitly request unsafe initialization.
*
* Constructing an ObjectRefType with UnsafeInit{} will set the data_ member to nullptr.
*
* When initializing Object fields, ObjectRef fields can be set to UnsafeInit.
* This enables the "construct with UnsafeInit then set all fields" pattern
* when the object does not have a default constructor.
*
* Used for initialization in controlled scenarios where such unsafe
* initialization is known to be safe.
*
* Each ObjectRefType should have a constructor that takes an UnsafeInit tag.
*
* \note As the name suggests, do not use it in normal code paths.
*/
struct UnsafeInit {};

/*!
* \brief Known type keys for pre-defined types.
*/
Expand Down Expand Up @@ -702,6 +720,8 @@ class ObjectRef {
ObjectRef& operator=(ObjectRef&& other) = default;
/*! \brief Constructor from existing object ptr */
explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
/*! \brief Constructor from UnsafeInit */
explicit ObjectRef(UnsafeInit) : data_(nullptr) {}
/*!
* \brief Comparator
* \param other Another object ref.
Expand Down Expand Up @@ -774,14 +794,17 @@ class ObjectRef {
TVM_FFI_INLINE std::optional<ObjectRefType> as() const {
if (data_ != nullptr) {
if (data_->IsInstance<typename ObjectRefType::ContainerType>()) {
return ObjectRefType(data_);
ObjectRefType ref(UnsafeInit{});
ref.data_ = data_;
return ref;
} else {
return std::nullopt;
}
} else {
return std::nullopt;
}
}

/*!
* \brief Get the type index of the ObjectRef
* \return The type index of the ObjectRef
Expand Down Expand Up @@ -914,7 +937,8 @@ struct ObjectPtrEqual {
*/
#define TVM_FFI_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \
explicit TypeName(::tvm::ffi::ObjectPtr<ObjectName> n) : ParentType(n) {} \
explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
Expand All @@ -928,7 +952,7 @@ struct ObjectPtrEqual {
* \param ObjectName The type name of the object.
*/
#define TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \
explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
Expand All @@ -943,11 +967,12 @@ struct ObjectPtrEqual {
* \note We recommend making objects immutable when possible.
* This macro is only reserved for objects that stores runtime states.
*/
#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
explicit TypeName(::tvm::ffi::ObjectPtr<ObjectName> n) : ParentType(n) {} \
ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
using ContainerType = ObjectName

/*!
Expand All @@ -958,7 +983,7 @@ struct ObjectPtrEqual {
* \param ObjectName The type name of the object.
*/
#define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \
explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
ObjectName* get() const { return operator->(); } \
Expand Down Expand Up @@ -1021,6 +1046,20 @@ struct ObjectUnsafe {
reinterpret_cast<int64_t>(&(static_cast<Object*>(nullptr)->header_)));
}

template <typename T>
TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr<Object>& ptr) {
T ref(UnsafeInit{});
ref.data_ = ptr;
return ref;
}

template <typename T>
TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr<Object>&& ptr) {
T ref(UnsafeInit{});
ref.data_ = std::move(ptr);
return ref;
}

template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromObjectRef(const ObjectRef& ref) {
if constexpr (std::is_same_v<T, Object>) {
Expand All @@ -1035,7 +1074,10 @@ struct ObjectUnsafe {
if constexpr (std::is_same_v<T, Object>) {
return std::move(ref.data_);
} else {
return tvm::ffi::ObjectPtr<T>(std::move(ref.data_.data_));
ObjectPtr<T> result;
result.data_ = std::move(ref.data_.data_);
ref.data_.data_ = nullptr;
return result;
}
}

Expand Down
17 changes: 11 additions & 6 deletions ffi/include/tvm/ffi/optional.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class Optional<T, std::enable_if_t<use_ptr_based_optional_v<T>>> : public Object
Optional() = default;
Optional(const Optional<T>& other) : ObjectRef(other.data_) {}
Optional(Optional<T>&& other) : ObjectRef(std::move(other.data_)) {}
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
explicit Optional(ffi::UnsafeInit tag) : ObjectRef(tag) {}
// nullopt hanlding
Optional(std::nullopt_t) {} // NOLINT(*)

Expand Down Expand Up @@ -300,19 +300,20 @@ class Optional<T, std::enable_if_t<use_ptr_based_optional_v<T>>> : public Object
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Back optional access";
}
return T(data_);
return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(data_);
}

TVM_FFI_INLINE T value() && {
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Back optional access";
}
return T(std::move(data_));
return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(std::move(data_));
}

template <typename U = std::remove_cv_t<T>>
TVM_FFI_INLINE T value_or(U&& default_value) const {
return data_ != nullptr ? T(data_) : T(std::forward<U>(default_value));
return data_ != nullptr ? details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(data_)
: T(std::forward<U>(default_value));
}

TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; }
Expand All @@ -324,14 +325,18 @@ class Optional<T, std::enable_if_t<use_ptr_based_optional_v<T>>> : public Object
* \return the const reference to the stored value.
* \note only use this function after checking has_value()
*/
TVM_FFI_INLINE T operator*() const& noexcept { return T(data_); }
TVM_FFI_INLINE T operator*() const& noexcept {
return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(data_);
}

/*!
* \brief Direct access to the value.
* \return the const reference to the stored value.
* \note only use this function after checking has_value()
*/
TVM_FFI_INLINE T operator*() && noexcept { return T(std::move(data_)); }
TVM_FFI_INLINE T operator*() && noexcept {
return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(std::move(data_));
}

TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); }
TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); }
Expand Down
Loading
Loading