diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index f70df9fe7ca2..398953ad6508 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -44,18 +44,20 @@ namespace ffi { */ template inline RefType GetRef(const ObjectType* ptr) { - static_assert(std::is_base_of_v, + using ContainerType = typename RefType::ContainerType; + static_assert(std::is_base_of_v, "Can only cast to the ref of same container type"); if constexpr (is_optional_type_v || RefType::_type_is_nullable) { if (ptr == nullptr) { - return RefType(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } else { TVM_FFI_ICHECK_NOTNULL(ptr); } - return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned( - const_cast(static_cast(ptr)))); + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned( + const_cast(static_cast(ptr)))); } /*! diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h index 7dbcc1f0189e..8fab30b8be56 100644 --- a/ffi/include/tvm/ffi/container/array.h +++ b/ffi/include/tvm/ffi/container/array.h @@ -362,6 +362,10 @@ class Array : public ObjectRef { /*! \brief The value type of the array */ using value_type = T; // constructors + /*! + * \brief Construct an Array with UnsafeInit + */ + explicit Array(UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief default constructor */ diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h index 27928d20c5cf..bea2688f7f20 100644 --- a/ffi/include/tvm/ffi/container/map.h +++ b/ffi/include/tvm/ffi/container/map.h @@ -1381,6 +1381,10 @@ class Map : public ObjectRef { using mapped_type = V; /*! \brief The iterator type of the map */ class iterator; + /*! + * \brief Construct an Map with UnsafeInit + */ + explicit Map(UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief default constructor */ diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h index 39c3ec273963..f5e88d6bb796 100644 --- a/ffi/include/tvm/ffi/container/shape.h +++ b/ffi/include/tvm/ffi/container/shape.h @@ -94,13 +94,13 @@ TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end return p; } -TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(int64_t ndim, int64_t* shape) { +TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(const int64_t* data, int64_t ndim) { int64_t* strides_data; ObjectPtr strides = details::MakeEmptyShape(ndim, &strides_data); int64_t stride = 1; for (int i = ndim - 1; i >= 0; --i) { strides_data[i] = stride; - stride *= shape[i]; + stride *= data[i]; } return strides; } @@ -150,6 +150,16 @@ class Shape : public ObjectRef { Shape(std::vector other) // NOLINT(*) : ObjectRef(make_object(std::move(other))) {} + /*! + * \brief Create a strides from a shape. + * \param data The shape data. + * \param ndim The number of dimensions. + * \return The strides. + */ + static Shape StridesFromShape(const int64_t* data, int64_t ndim) { + return Shape(details::MakeStridesFromShape(data, ndim)); + } + /*! * \brief Return the data pointer * @@ -204,6 +214,9 @@ class Shape : public ObjectRef { /// \cond Doxygen_Suppress TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj); /// \endcond + + private: + explicit Shape(ObjectPtr ptr) : ObjectRef(ptr) {} }; inline std::ostream& operator<<(std::ostream& os, const Shape& shape) { diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index 99fb29d10830..21c67decfcd5 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -203,7 +203,7 @@ class TensorObjFromNDAlloc : public TensorObj { this->ndim = static_cast(shape.size()); this->dtype = dtype; this->shape = const_cast(shape.data()); - Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape)); + Shape strides = Shape::StridesFromShape(this->shape, this->ndim); this->strides = const_cast(strides.data()); this->byte_offset = 0; this->shape_data_ = std::move(shape); @@ -224,7 +224,7 @@ class TensorObjFromDLPack : public TensorObj { explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { *static_cast(this) = tensor_->dl_tensor; if (tensor_->dl_tensor.strides == nullptr) { - Shape strides = Shape(details::MakeStridesFromShape(ndim, shape)); + Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim); this->strides = const_cast(strides.data()); this->strides_data_ = std::move(strides); } diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 0cb80b963e9e..75342409eabb 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -47,6 +47,10 @@ class Tuple : public ObjectRef { "All types used in Tuple<...> must be compatible with Any"); /*! \brief Default constructor */ Tuple() : ObjectRef(MakeDefaultTupleNode()) {} + /*! + * \brief Constructor with UnsafeInit + */ + explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {} /*! \brief Copy constructor */ Tuple(const Tuple& other) : ObjectRef(other) {} /*! \brief Move constructor */ @@ -128,13 +132,6 @@ class Tuple : public ObjectRef { return *this; } - /*! - * \brief Constructor ObjectPtr - * \param ptr The ObjectPtr - * \tparam The enable_if_t type - */ - explicit Tuple(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! * \brief Get I-th element of the tuple * @@ -283,7 +280,8 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); Any* ptr = arr.CopyOnWrite()->MutableBegin(); if (TryConvertElements<0, Types...>(ptr)) { - return Tuple(details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr>( + details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); } return std::nullopt; } diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index 5f66d73a1845..cae5a673b8ce 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -68,7 +68,7 @@ class VariantBase : public ObjectRef { explicit VariantBase(const T& other) : ObjectRef(other) {} template explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {} - explicit VariantBase(ObjectPtr ptr) : ObjectRef(ptr) {} + explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {} explicit VariantBase(Any other) : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h index 89e0c287a3fe..a1dc91eebc08 100644 --- a/ffi/include/tvm/ffi/extra/module.h +++ b/ffi/include/tvm/ffi/extra/module.h @@ -36,6 +36,7 @@ class Module; /*! * \brief A module that can dynamically load ffi::Functions or exportable source code. + * \sa Module */ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { public: @@ -168,6 +169,16 @@ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { /*! * \brief Reference to module object. + * + * When invoking a function on a ModuleObj, such as GetFunction, + * use operator-> to get the ModuleObj pointer and invoke the member functions. + * + * \code + * ffi::Module mod = ffi::Module::LoadFromFile("path/to/module.so"); + * ffi::Function func = mod->GetFunction(name); + * \endcode + * + * \sa ModuleObj which contains most of the function implementations. */ class Module : public ObjectRef { public: @@ -202,7 +213,11 @@ class Module : public ObjectRef { */ kCompilationExportable = 0b100 }; - + /*! + * \brief Constructor from ObjectPtr. + * \param ptr The object pointer. + */ + explicit Module(ObjectPtr ptr) : ObjectRef(ptr) { TVM_FFI_ICHECK(ptr != nullptr); } /*! * \brief Load a module from file. * \param file_name The name of the host function module. diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 884e46fa44cd..d27cfc0b6155 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -403,7 +403,7 @@ class Function : public ObjectRef { TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); if (handle != nullptr) { return Function( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); } else { return std::nullopt; } diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index d029c19dd107..20ca44cbcb72 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -193,7 +193,7 @@ TVM_FFI_INLINE static Error MoveFromSafeCallRaised() { TVMFFIObjectHandle handle; TVMFFIErrorMoveFromRaised(&handle); // handle is owned by caller - return Error( + return details::ObjectUnsafe::ObjectRefFromObjectPtr( details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); } diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index c1ab9d16d919..478bb27a8f20 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -44,6 +44,24 @@ using TypeIndex = TVMFFITypeIndex; */ using TypeInfo = TVMFFITypeInfo; +/*! + * \brief Helper tag to explicitly request unsafe initialization. + * + * Constructing an ObjectRefType with UnsafeInit{} will set the data_ member to nullptr. + * + * When initializing Object fields, ObjectRef fields can be set to UnsafeInit. + * This enables the "construct with UnsafeInit then set all fields" pattern + * when the object does not have a default constructor. + * + * Used for initialization in controlled scenarios where such unsafe + * initialization is known to be safe. + * + * Each ObjectRefType should have a constructor that takes an UnsafeInit tag. + * + * \note As the name suggests, do not use it in normal code paths. + */ +struct UnsafeInit {}; + /*! * \brief Known type keys for pre-defined types. */ @@ -702,6 +720,8 @@ class ObjectRef { ObjectRef& operator=(ObjectRef&& other) = default; /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! \brief Constructor from UnsafeInit */ + explicit ObjectRef(UnsafeInit) : data_(nullptr) {} /*! * \brief Comparator * \param other Another object ref. @@ -774,7 +794,9 @@ class ObjectRef { TVM_FFI_INLINE std::optional as() const { if (data_ != nullptr) { if (data_->IsInstance()) { - return ObjectRefType(data_); + ObjectRefType ref(UnsafeInit{}); + ref.data_ = data_; + return ref; } else { return std::nullopt; } @@ -782,6 +804,7 @@ class ObjectRef { return std::nullopt; } } + /*! * \brief Get the type index of the ObjectRef * \return The type index of the ObjectRef @@ -914,7 +937,8 @@ struct ObjectPtrEqual { */ #define TVM_FFI_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() = default; \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ @@ -928,7 +952,7 @@ struct ObjectPtrEqual { * \param ObjectName The type name of the object. */ #define TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ @@ -943,11 +967,12 @@ struct ObjectPtrEqual { * \note We recommend making objects immutable when possible. * This macro is only reserved for objects that stores runtime states. */ -#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ +#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName /*! @@ -958,7 +983,7 @@ struct ObjectPtrEqual { * \param ObjectName The type name of the object. */ #define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ ObjectName* operator->() const { return static_cast(data_.get()); } \ ObjectName* get() const { return operator->(); } \ @@ -1021,6 +1046,20 @@ struct ObjectUnsafe { reinterpret_cast(&(static_cast(nullptr)->header_))); } + template + TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr& ptr) { + T ref(UnsafeInit{}); + ref.data_ = ptr; + return ref; + } + + template + TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr&& ptr) { + T ref(UnsafeInit{}); + ref.data_ = std::move(ptr); + return ref; + } + template TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(const ObjectRef& ref) { if constexpr (std::is_same_v) { @@ -1035,7 +1074,10 @@ struct ObjectUnsafe { if constexpr (std::is_same_v) { return std::move(ref.data_); } else { - return tvm::ffi::ObjectPtr(std::move(ref.data_.data_)); + ObjectPtr result; + result.data_ = std::move(ref.data_.data_); + ref.data_.data_ = nullptr; + return result; } } diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index f93a0f0d555f..f370a178502e 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -262,7 +262,7 @@ class Optional>> : public Object Optional() = default; Optional(const Optional& other) : ObjectRef(other.data_) {} Optional(Optional&& other) : ObjectRef(std::move(other.data_)) {} - explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + explicit Optional(ffi::UnsafeInit tag) : ObjectRef(tag) {} // nullopt hanlding Optional(std::nullopt_t) {} // NOLINT(*) @@ -300,19 +300,20 @@ class Optional>> : public Object if (data_ == nullptr) { TVM_FFI_THROW(RuntimeError) << "Back optional access"; } - return T(data_); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); } TVM_FFI_INLINE T value() && { if (data_ == nullptr) { TVM_FFI_THROW(RuntimeError) << "Back optional access"; } - return T(std::move(data_)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); } template > TVM_FFI_INLINE T value_or(U&& default_value) const { - return data_ != nullptr ? T(data_) : T(std::forward(default_value)); + return data_ != nullptr ? details::ObjectUnsafe::ObjectRefFromObjectPtr(data_) + : T(std::forward(default_value)); } TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; } @@ -324,14 +325,18 @@ class Optional>> : public Object * \return the const reference to the stored value. * \note only use this function after checking has_value() */ - TVM_FFI_INLINE T operator*() const& noexcept { return T(data_); } + TVM_FFI_INLINE T operator*() const& noexcept { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); + } /*! * \brief Direct access to the value. * \return the const reference to the stored value. * \note only use this function after checking has_value() */ - TVM_FFI_INLINE T operator*() && noexcept { return T(std::move(data_)); } + TVM_FFI_INLINE T operator*() && noexcept { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); + } TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); } TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); } diff --git a/ffi/include/tvm/ffi/reflection/access_path.h b/ffi/include/tvm/ffi/reflection/access_path.h index c614d4ca28d8..e7aed0a8fcbf 100644 --- a/ffi/include/tvm/ffi/reflection/access_path.h +++ b/ffi/include/tvm/ffi/reflection/access_path.h @@ -360,6 +360,10 @@ class AccessPath : public ObjectRef { /// \cond Doxygen_Suppress TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef, AccessPathObj); /// \endcond + + private: + friend class AccessPathObj; + explicit AccessPath(ObjectPtr ptr) : ObjectRef(ptr) {} }; /*! diff --git a/ffi/include/tvm/ffi/reflection/registry.h b/ffi/include/tvm/ffi/reflection/registry.h index ba723fa394d7..6a1a9b55d2b0 100644 --- a/ffi/include/tvm/ffi/reflection/registry.h +++ b/ffi/include/tvm/ffi/reflection/registry.h @@ -148,6 +148,14 @@ class ReflectionDefBase { TVM_FFI_SAFE_CALL_END(); } + template + static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(UnsafeInit{}); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); + } + template TVM_FFI_INLINE static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { if constexpr (std::is_base_of_v>) { @@ -413,6 +421,8 @@ class ObjectDef : public ReflectionDefBase { info.doc = TVMFFIByteArray{nullptr, 0}; if constexpr (std::is_default_constructible_v) { info.creator = ObjectCreatorDefault; + } else if constexpr (std::is_constructible_v) { + info.creator = ObjectCreatorUnsafeInit; } // apply extra info traits ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h index 7c89038cc24e..ebbec582e62a 100644 --- a/ffi/include/tvm/ffi/rvalue_ref.h +++ b/ffi/include/tvm/ffi/rvalue_ref.h @@ -71,15 +71,17 @@ namespace ffi { template >> class RValueRef { public: + /*! \brief the container type of the rvalue ref */ + using ContainerType = typename TObjRef::ContainerType; /*! \brief only allow move constructor from rvalue of T */ explicit RValueRef(TObjRef&& data) - : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} + : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} /*! \brief return the data as rvalue */ TObjRef operator*() && { return TObjRef(std::move(data_)); } private: - mutable ObjectPtr data_; + mutable ObjectPtr data_; template friend struct TypeTraits; @@ -125,7 +127,8 @@ struct TypeTraits> : public TypeTraitsBase { tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); // fast path, storage type matches, direct move the rvalue ref if (TypeTraits::CheckAnyStrict(&tmp_any)) { - return RValueRef(TObjRef(std::move(*rvalue_ref))); + return RValueRef( + details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(*rvalue_ref))); } if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { // object type does not match up, we need to try to convert the object diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index 1812448ecc09..0f1971945a4b 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -551,34 +551,37 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { TVM_FFI_INLINE static TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } - return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); } TVM_FFI_INLINE static TObjRef MoveFromAnyAfterCheck(TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } // move out the object pointer - ObjectPtr obj_ptr = details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); + ObjectPtr obj_ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); // reset the src to nullptr TypeTraits::MoveToAny(nullptr, src); - return TObjRef(std::move(obj_ptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(obj_ptr)); } TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { if (details::IsObjectInstance(src->type_index)) { - return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); } } return std::nullopt; diff --git a/ffi/src/ffi/tensor.cc b/ffi/src/ffi/tensor.cc index 7b44e4586b4b..c166c296c8a4 100644 --- a/ffi/src/ffi/tensor.cc +++ b/ffi/src/ffi/tensor.cc @@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; } } - *ret = Shape(shape); + *ret = details::ObjectUnsafe::ObjectRefFromObjectPtr(shape); }); }); diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc index 1d7de990f01a..ec5c54c4d77a 100644 --- a/ffi/tests/cpp/test_object.cc +++ b/ffi/tests/cpp/test_object.cc @@ -97,6 +97,14 @@ TEST(ObjectRef, as) { EXPECT_EQ(b.as()->value, 20); } +TEST(ObjectRef, UnsafeInit) { + ObjectRef a(UnsafeInit{}); + EXPECT_TRUE(a.get() == nullptr); + + TInt b(UnsafeInit{}); + EXPECT_TRUE(b.get() == nullptr); +} + TEST(Object, CAPIAccessor) { ObjectRef a = TInt(10); TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(a); diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index fe3ba1b013c0..1f6e67822641 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -59,8 +59,8 @@ class TIntObj : public TNumberObj { public: int64_t value; - TIntObj() = default; TIntObj(int64_t value) : value(value) {} + explicit TIntObj(UnsafeInit) {} int64_t GetValue() const { return value; } @@ -165,9 +165,9 @@ class TVarObj : public Object { public: std::string name; - // need default constructor for json serialization - TVarObj() = default; TVarObj(std::string name) : name(name) {} + // need unsafe init constructor for json serialization + explicit TVarObj(UnsafeInit) {} static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -193,8 +193,8 @@ class TFuncObj : public Object { Array body; Optional comment; - // need default constructor for json serialization - TFuncObj() = default; + // need unsafe init constructor or default constructor for json serialization + explicit TFuncObj(UnsafeInit) {} TFuncObj(Array params, Array body, Optional comment) : params(params), body(body), comment(comment) {} diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 55576549169c..5c02db36f72e 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -54,7 +54,7 @@ namespace tvm { template inline TObjectRef NullValue() { static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); - return TObjectRef(ObjectPtr(nullptr)); + return TObjectRef(ObjectPtr(nullptr)); } template <> @@ -165,6 +165,10 @@ class DictAttrsNode : public BaseAttrsNode { */ class DictAttrs : public Attrs { public: + /*! + * \brief constructor with UnsafeInit + */ + explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {} /*! * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index e43575d486eb..e42cce527900 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -71,6 +71,10 @@ class EnvFunc : public ObjectRef { public: EnvFunc() {} explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit EnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { return static_cast(get()); } /*! @@ -117,6 +121,10 @@ class TypedEnvFunc : public ObjectRef { using TSelf = TypedEnvFunc; TypedEnvFunc() {} explicit TypedEnvFunc(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit TypedEnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 65954b83ac9d..d7e4e0f0d2ef 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -613,7 +613,11 @@ class Integer : public IntImm { /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : IntImm(node) {} + explicit Integer(ObjectPtr node) : IntImm(node) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit Integer(ffi::UnsafeInit tag) : IntImm(tag) {} /*! * \brief Construct integer from int value. */ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 5da00fb0b377..3deef6fed1f1 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -273,7 +273,11 @@ class IRModule : public ObjectRef { * \brief constructor * \param n The object pointer. */ - explicit IRModule(ObjectPtr n) : ObjectRef(n) {} + explicit IRModule(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit IRModule(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! \return mutable pointers to the node. */ IRModuleNode* operator->() const { auto* ptr = get_mutable(); diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index e501ace15997..e283234cb071 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -156,7 +156,14 @@ class PassContextNode : public Object { class PassContext : public ObjectRef { public: PassContext() {} - explicit PassContext(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit PassContext(ffi::UnsafeInit tag) : ObjectRef(tag) {} + /*! + * \brief constructor with ObjectPtr + */ + explicit PassContext(ObjectPtr n) : ObjectRef(n) {} /*! * \brief const accessor. * \return const access pointer. @@ -512,7 +519,7 @@ class Sequential : public Pass { TVM_DLL Sequential(ffi::Array passes, ffi::String name = "sequential"); Sequential() = default; - explicit Sequential(ObjectPtr n) : Pass(n) {} + explicit Sequential(ObjectPtr n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = SequentialNode; diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 6a6df2950271..0a527ad42585 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -136,6 +136,13 @@ class BuilderNode : public runtime::Object { */ class Builder : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Builder(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a builder with customized build method on the python-side. * \param f_build The packed function to the `Build` function.. diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index fbb09d7852c6..07686077311a 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -71,6 +71,7 @@ class WorkloadNode : public runtime::Object { class Workload : public runtime::ObjectRef { public: using THashCode = WorkloadNode::THashCode; + explicit Workload(ObjectPtr data) : ObjectRef(data) {} /*! * \brief Constructor of Workload. * \param mod The workload's IRModule. @@ -117,7 +118,7 @@ class TuningRecordNode : public runtime::Object { /*! \brief The trace tuned. */ tir::Trace trace; /*! \brief The workload. */ - Workload workload{nullptr}; + Workload workload{ffi::UnsafeInit()}; /*! \brief The profiling result in seconds. */ ffi::Optional> run_secs; /*! \brief The target for tuning. */ @@ -466,6 +467,13 @@ class PyDatabaseNode : public DatabaseNode { */ class Database : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Database(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief An in-memory database. * \param mod_eq_name A string to specify the module equality testing and hashing method. diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 2d42b5e590d4..f2753964ec63 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -207,7 +207,11 @@ class RunnerNode : public runtime::Object { class Runner : public runtime::ObjectRef { public: using FRun = RunnerNode::FRun; - + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Runner(ObjectPtr data) : ObjectRef(data) { TVM_FFI_ICHECK(data != nullptr); } /*! * \brief Create a runner with customized build method on the python-side. * \param f_run The packed function to run the built artifacts and get runner futures. diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index f013934e2342..a2bf7a394932 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -123,6 +123,13 @@ class SpaceGeneratorNode : public runtime::Object { */ class SpaceGenerator : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit SpaceGenerator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 0c88cb12c8cc..a6a53becad00 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -40,7 +40,7 @@ namespace meta_schedule { class TaskRecordNode : public runtime::Object { public: /*! \brief The tune context of the task. */ - TuneContext ctx{nullptr}; + TuneContext ctx{ffi::UnsafeInit()}; /*! \brief The weight of the task */ double task_weight{1.0}; /*! \brief The FLOP count of the task */ @@ -261,6 +261,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { */ class TaskScheduler : public runtime::ObjectRef { public: + explicit TaskScheduler(ObjectPtr data) : runtime::ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a task scheduler that fetches tasks in a round-robin fashion. * \param logger The tuning task's logging function. diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index cd9b8f1b5ad2..50bdb2586fc6 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -98,6 +98,13 @@ class TuneContextNode : public runtime::Object { class TuneContext : public runtime::ObjectRef { public: using TRandState = support::LinearCongruentialEngine::TRandState; + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit TuneContext(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Constructor. * \param mod The workload to be tuned. diff --git a/include/tvm/node/cast.h b/include/tvm/node/cast.h index 4ed5f4178c8b..32d4be721656 100644 --- a/include/tvm/node/cast.h +++ b/include/tvm/node/cast.h @@ -45,18 +45,19 @@ namespace tvm { template >> inline SubRef Downcast(BaseRef ref) { + using ContainerType = typename SubRef::ContainerType; if (ref.defined()) { - if (!ref->template IsInstance()) { + if (!ref->template IsInstance()) { TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key << " failed."; } - return SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); + return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr( + ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); } else { if constexpr (ffi::is_optional_type_v || SubRef::_type_is_nullable) { - return SubRef(ffi::ObjectPtr(nullptr)); + return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } - TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" - << SubRef::ContainerType::_type_key + TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" << ContainerType::_type_key << "` is not allowed. Use Downcast> instead."; TVM_FFI_UNREACHABLE(); } diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 4a7fd73c6ac0..7c4ee4e43e57 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -280,6 +280,7 @@ class PatternContextNode : public Object { */ class PatternContext : public ObjectRef { public: + explicit PatternContext(ffi::UnsafeInit tag) : ObjectRef(tag) {} TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} TVM_DLL explicit PatternContext(bool incremental = false); @@ -778,6 +779,10 @@ class WildcardPatternNode : public DFPatternNode { class WildcardPattern : public DFPattern { public: WildcardPattern(); + explicit WildcardPattern(ObjectPtr data) : DFPattern(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } // Declaring WildcardPattern declared as non-nullable avoids the // default zero-parameter constructor for ObjectRef with `data_ = diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index e0e2f4770fe9..80fe1e671091 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -607,7 +607,8 @@ class Binding : public ObjectRef { Binding() = default; public: - explicit Binding(ObjectPtr n) : ObjectRef(n) {} + explicit Binding(ObjectPtr n) : ObjectRef(n) {} + explicit Binding(ffi::UnsafeInit tag) : ObjectRef(tag) {} TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding); const BindingNode* operator->() const { return static_cast(data_.get()); } const BindingNode* get() const { return operator->(); } diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 8a97658330df..059292806de4 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -317,6 +319,10 @@ class FuncStructInfoNode : public StructInfoNode { */ class FuncStructInfo : public StructInfo { public: + explicit FuncStructInfo(ObjectPtr data) : StructInfo(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } /*! * \brief Constructor from parameter struct info and return value struct info. * \param params The struct info of function parameters. diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 1506d2548f1f..671e4bbd67f7 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -170,6 +170,7 @@ class DRefObj : public Object { */ class DRef : public ObjectRef { public: + explicit DRef(ObjectPtr data) : ObjectRef(data) { TVM_FFI_ICHECK(data != nullptr); } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj); }; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index e04a800400f1..cf5d93eae64e 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -128,7 +128,7 @@ static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= */ #define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index 71f8d27be008..97af218a1809 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -58,7 +58,8 @@ class Tensor : public tvm::ffi::Tensor { * \brief constructor. * \param data ObjectPtr to the data container. */ - explicit Tensor(ObjectPtr data) : tvm::ffi::Tensor(data) {} + explicit Tensor(ObjectPtr data) : tvm::ffi::Tensor(data) {} + explicit Tensor(ffi::UnsafeInit tag) : tvm::ffi::Tensor(tag) {} Tensor(ffi::Tensor&& other) : tvm::ffi::Tensor(std::move(other)) {} // NOLINT(*) Tensor(const ffi::Tensor& other) : tvm::ffi::Tensor(other) {} // NOLINT(*) diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index b2586e938719..75e6fd8061ea 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -107,6 +107,7 @@ class IRBuilderFrame : public runtime::ObjectRef { protected: /*! \brief Disallow direct construction of this object. */ IRBuilderFrame() = default; + explicit IRBuilderFrame(ObjectPtr data) : ObjectRef(data) {} public: /*! diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index e9f98d4a8ea6..767986fdf77f 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -75,6 +75,9 @@ class IRModuleFrameNode : public IRBuilderFrameNode { */ class IRModuleFrame : public IRBuilderFrame { public: + explicit IRModuleFrame(ObjectPtr data) : IRBuilderFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame, IRModuleFrameNode); }; diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 053f84285f6e..7ea8c439bf37 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -26,6 +26,8 @@ #include #include +#include + namespace tvm { namespace script { namespace ir_builder { @@ -45,6 +47,10 @@ class RelaxFrameNode : public IRBuilderFrameNode { class RelaxFrame : public IRBuilderFrame { public: + explicit RelaxFrame(ObjectPtr data) : IRBuilderFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); protected: @@ -78,6 +84,9 @@ class SeqExprFrameNode : public RelaxFrameNode { class SeqExprFrame : public RelaxFrame { public: + explicit SeqExprFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); }; @@ -134,6 +143,9 @@ class FunctionFrameNode : public SeqExprFrameNode { class FunctionFrame : public SeqExprFrame { public: + explicit FunctionFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); }; @@ -175,6 +187,9 @@ class BlockFrameNode : public RelaxFrameNode { class BlockFrame : public RelaxFrame { public: + explicit BlockFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); }; @@ -229,6 +244,9 @@ class IfFrameNode : public RelaxFrameNode { */ class IfFrame : public RelaxFrame { public: + explicit IfFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); }; @@ -267,6 +285,9 @@ class ThenFrameNode : public SeqExprFrameNode { */ class ThenFrame : public SeqExprFrame { public: + explicit ThenFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); }; @@ -305,6 +326,9 @@ class ElseFrameNode : public SeqExprFrameNode { */ class ElseFrame : public SeqExprFrame { public: + explicit ElseFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); }; diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 1c3e19959024..fa42ea9911c7 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -23,6 +23,8 @@ #include #include +#include + namespace tvm { namespace script { namespace ir_builder { @@ -58,6 +60,7 @@ class TIRFrame : public IRBuilderFrame { protected: TIRFrame() = default; + explicit TIRFrame(ObjectPtr data) : IRBuilderFrame(data) {} }; /*! @@ -115,6 +118,10 @@ class PrimFuncFrameNode : public TIRFrameNode { */ class PrimFuncFrame : public TIRFrame { public: + explicit PrimFuncFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; @@ -186,6 +193,10 @@ class BlockFrameNode : public TIRFrameNode { class BlockFrame : public TIRFrame { public: + explicit BlockFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); }; @@ -224,6 +235,10 @@ class BlockInitFrameNode : public TIRFrameNode { */ class BlockInitFrame : public TIRFrame { public: + explicit BlockInitFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode); }; @@ -277,6 +292,10 @@ class ForFrameNode : public TIRFrameNode { */ class ForFrame : public TIRFrame { public: + explicit ForFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); }; @@ -318,6 +337,10 @@ class AssertFrameNode : public TIRFrameNode { */ class AssertFrame : public TIRFrame { public: + explicit AssertFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode); }; @@ -358,6 +381,10 @@ class LetFrameNode : public TIRFrameNode { */ class LetFrame : public TIRFrame { public: + explicit LetFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); }; @@ -400,6 +427,10 @@ class LaunchThreadFrameNode : public TIRFrameNode { */ class LaunchThreadFrame : public TIRFrame { public: + explicit LaunchThreadFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame, LaunchThreadFrameNode); }; @@ -444,6 +475,10 @@ class RealizeFrameNode : public TIRFrameNode { */ class RealizeFrame : public TIRFrame { public: + explicit RealizeFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); }; @@ -496,6 +531,10 @@ class AllocateFrameNode : public TIRFrameNode { */ class AllocateFrame : public TIRFrame { public: + explicit AllocateFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode); }; @@ -545,6 +584,11 @@ class AllocateConstFrameNode : public TIRFrameNode { */ class AllocateConstFrame : public TIRFrame { public: + explicit AllocateConstFrame(ObjectPtr data) + : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame, AllocateConstFrameNode); }; @@ -588,6 +632,10 @@ class AttrFrameNode : public TIRFrameNode { */ class AttrFrame : public TIRFrame { public: + explicit AttrFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode); }; @@ -624,6 +672,10 @@ class WhileFrameNode : public TIRFrameNode { */ class WhileFrame : public TIRFrame { public: + explicit WhileFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode); }; @@ -667,6 +719,9 @@ class IfFrameNode : public TIRFrameNode { */ class IfFrame : public TIRFrame { public: + explicit IfFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode); }; @@ -705,6 +760,9 @@ class ThenFrameNode : public TIRFrameNode { */ class ThenFrame : public TIRFrame { public: + explicit ThenFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode); }; @@ -743,6 +801,10 @@ class ElseFrameNode : public TIRFrameNode { */ class ElseFrame : public TIRFrame { public: + explicit ElseFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode); }; @@ -769,6 +831,9 @@ class DeclBufferFrameNode : public TIRFrameNode { class DeclBufferFrame : public TIRFrame { public: + explicit DeclBufferFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode); }; diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 976e3183a16e..296df345246a 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -88,6 +88,7 @@ class DocNode : public Object { class Doc : public ObjectRef { protected: Doc() = default; + explicit Doc(ObjectPtr data) : ObjectRef(data) {} public: TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode); @@ -156,6 +157,8 @@ class ExprDoc : public Doc { */ ExprDoc operator[](ffi::Array indices) const; + explicit ExprDoc(ObjectPtr data) : Doc(data) { TVM_FFI_ICHECK(data != nullptr); } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); }; @@ -378,7 +381,7 @@ class IdDoc : public ExprDoc { class AttrAccessDocNode : public ExprDocNode { public: /*! \brief The target expression to be accessed */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; /*! \brief The attribute to be accessed */ ffi::String name; @@ -418,7 +421,7 @@ class AttrAccessDoc : public ExprDoc { class IndexDocNode : public ExprDocNode { public: /*! \brief The container value to be accessed */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; /*! * \brief The indices to access * @@ -464,7 +467,7 @@ class IndexDoc : public ExprDoc { class CallDocNode : public ExprDocNode { public: /*! \brief The callee of this function call */ - ExprDoc callee{nullptr}; + ExprDoc callee{ffi::UnsafeInit()}; /*! \brief The positional arguments */ ffi::Array args; /*! \brief The keys of keyword arguments */ @@ -604,7 +607,7 @@ class LambdaDocNode : public ExprDocNode { /*! \brief The arguments of this anonymous function */ ffi::Array args; /*! \brief The body of this anonymous function */ - ExprDoc body{nullptr}; + ExprDoc body{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -664,7 +667,7 @@ class TupleDoc : public ExprDoc { /*! * \brief Create an empty TupleDoc */ - TupleDoc() : TupleDoc(ffi::make_object()) {} + TupleDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of TupleDoc * \param elements Elements of tuple. @@ -703,7 +706,7 @@ class ListDoc : public ExprDoc { /*! * \brief Create an empty ListDoc */ - ListDoc() : ListDoc(ffi::make_object()) {} + ListDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of ListDoc * \param elements Elements of list. @@ -751,7 +754,7 @@ class DictDoc : public ExprDoc { /*! * \brief Create an empty dictionary */ - DictDoc() : DictDoc(ffi::make_object()) {} + DictDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of DictDoc * \param keys Keys of dictionary. @@ -816,7 +819,7 @@ class SliceDoc : public Doc { class AssignDocNode : public StmtDocNode { public: /*! \brief The left hand side of the assignment */ - ExprDoc lhs{nullptr}; + ExprDoc lhs{ffi::UnsafeInit()}; /*! * \brief The right hand side of the assignment. * @@ -864,7 +867,7 @@ class AssignDoc : public StmtDoc { class IfDocNode : public StmtDocNode { public: /*! \brief The predicate of the if-then-else statement. */ - ExprDoc predicate{nullptr}; + ExprDoc predicate{ffi::UnsafeInit()}; /*! \brief The then branch of the if-then-else statement. */ ffi::Array then_branch; /*! \brief The else branch of the if-then-else statement. */ @@ -909,7 +912,7 @@ class IfDoc : public StmtDoc { class WhileDocNode : public StmtDocNode { public: /*! \brief The predicate of the while statement. */ - ExprDoc predicate{nullptr}; + ExprDoc predicate{ffi::UnsafeInit()}; /*! \brief The body of the while statement. */ ffi::Array body; @@ -953,9 +956,9 @@ class WhileDoc : public StmtDoc { class ForDocNode : public StmtDocNode { public: /*! \brief The left hand side of the assignment of iterating variable. */ - ExprDoc lhs{nullptr}; + ExprDoc lhs{ffi::UnsafeInit()}; /*! \brief The right hand side of the assignment of iterating variable. */ - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; /*! \brief The body of the for statement. */ ffi::Array body; @@ -1004,7 +1007,7 @@ class ScopeDocNode : public StmtDocNode { /*! \brief The name of the scoped variable. */ ffi::Optional lhs{std::nullopt}; /*! \brief The value of the scoped variable. */ - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; /*! \brief The body of the scope doc. */ ffi::Array body; @@ -1054,7 +1057,7 @@ class ScopeDoc : public StmtDoc { class ExprStmtDocNode : public StmtDocNode { public: /*! \brief The expression represented by this doc. */ - ExprDoc expr{nullptr}; + ExprDoc expr{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1089,7 +1092,7 @@ class ExprStmtDoc : public StmtDoc { class AssertDocNode : public StmtDocNode { public: /*! \brief The expression to test. */ - ExprDoc test{nullptr}; + ExprDoc test{ffi::UnsafeInit()}; /*! \brief The optional error message when assertion failed. */ ffi::Optional msg{std::nullopt}; @@ -1129,7 +1132,7 @@ class AssertDoc : public StmtDoc { class ReturnDocNode : public StmtDocNode { public: /*! \brief The value to return. */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1164,7 +1167,7 @@ class ReturnDoc : public StmtDoc { class FunctionDocNode : public StmtDocNode { public: /*! \brief The name of function. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of function. * @@ -1223,7 +1226,7 @@ class FunctionDoc : public StmtDoc { class ClassDocNode : public StmtDocNode { public: /*! \brief The name of class. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! \brief Decorators of class. */ ffi::Array decorators; /*! \brief The body of class. */ diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 6e6be57f9ce5..a2fc1097ac36 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -132,7 +132,7 @@ class IRDocsifierNode : public Object { ffi::Optional name; }; /*! \brief The configuration of the printer */ - PrinterConfig cfg{nullptr}; + PrinterConfig cfg{ffi::UnsafeInit()}; /*! * \brief The stack of frames. * \sa FrameNode diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index ad167ce08bcc..f468f9cbac1b 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -127,6 +127,9 @@ class TargetKindNode : public Object { class TargetKind : public ObjectRef { public: TargetKind() = default; + explicit TargetKind(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! \brief Get the attribute map given the attribute name */ template static inline TargetKindAttrMap GetAttrMap(const ffi::String& attr_name); diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 8bcad6950f4d..68b2bbf71504 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -50,6 +50,7 @@ class Operation : public ObjectRef { /*! \brief default constructor */ Operation() {} explicit Operation(ObjectPtr n) : ObjectRef(n) {} + explicit Operation(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/tir/block_scope.h b/include/tvm/tir/block_scope.h index 3fc2515d0812..f79a45650045 100644 --- a/include/tvm/tir/block_scope.h +++ b/include/tvm/tir/block_scope.h @@ -297,6 +297,13 @@ class BlockScopeNode : public Object { */ class BlockScope : public ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit BlockScope(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! \brief The constructor creating an empty block scope with on dependency information */ TVM_DLL BlockScope(); /*! diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 8cb0053df79c..22c4c7d7bd78 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -43,7 +43,7 @@ namespace tir { */ struct BlockInfo { /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */ - BlockScope scope{nullptr}; + BlockScope scope{ffi::UnsafeInit()}; // The properties below are information about the current block realization under its parent scope /*! \brief Property of a block, indicating the block realization binding is quasi-affine */ bool affine_binding{false}; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 578b00fc08d4..51100c2292e2 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -77,7 +77,8 @@ class VarNode : public PrimExprNode { /*! \brief a named variable in TIR */ class Var : public PrimExpr { public: - explicit Var(ObjectPtr n) : PrimExpr(n) {} + explicit Var(ffi::UnsafeInit tag) : PrimExpr(tag) {} + explicit Var(ObjectPtr n) : PrimExpr(n) {} /*! * \brief Constructor * \param name_hint variable name @@ -143,7 +144,8 @@ class SizeVarNode : public VarNode { /*! \brief a named variable represents a tensor index size */ class SizeVar : public Var { public: - explicit SizeVar(ObjectPtr n) : Var(n) {} + explicit SizeVar(ObjectPtr n) : Var(n) {} + explicit SizeVar(ffi::UnsafeInit tag) : Var(tag) {} /*! * \brief constructor * \param name_hint variable name diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h index ea1cee396ba6..6433f3de9a2e 100644 --- a/src/contrib/msc/core/printer/msc_doc.h +++ b/src/contrib/msc/core/printer/msc_doc.h @@ -45,7 +45,7 @@ class DeclareDocNode : public ExprDocNode { /*! \brief The type of the variable */ ffi::Optional type; /*! \brief The variable */ - ExprDoc variable{nullptr}; + ExprDoc variable{ffi::UnsafeInit{}}; /*! \brief The init arguments for the variable. */ ffi::Array init_args; /*! \brief Whether to use constructor(otherwise initializer) */ @@ -164,7 +164,7 @@ class PointerDoc : public ExprDoc { class StructDocNode : public StmtDocNode { public: /*! \brief The name of class. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! \brief Decorators of class. */ ffi::Array decorators; /*! \brief The body of class. */ @@ -207,7 +207,7 @@ class StructDoc : public StmtDoc { class ConstructorDocNode : public StmtDocNode { public: /*! \brief The name of function. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of function. * @@ -300,7 +300,7 @@ class SwitchDoc : public StmtDoc { class LambdaDocNode : public StmtDocNode { public: /*! \brief The name of lambda. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of lambda. * diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 26fbe07cf6d3..47727d5297a0 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -46,7 +46,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("__data_from_json__", SourceName::Get); }); -ObjectPtr GetSourceNameNode(const ffi::String& name) { +ObjectPtr GetSourceNameNode(const ffi::String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr static std::unordered_map> source_map; @@ -62,7 +62,7 @@ ObjectPtr GetSourceNameNode(const ffi::String& name) { } } -ObjectPtr GetSourceNameNodeByStr(const std::string& name) { +ObjectPtr GetSourceNameNodeByStr(const std::string& name) { return GetSourceNameNode(name); } diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index b3c02607bddc..8094449bfb97 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -50,7 +50,7 @@ ObjectRef WorkloadNode::AsJSON() const { } Workload Workload::FromJSON(const ObjectRef& json_obj) { - IRModule mod{nullptr}; + IRModule mod{ffi::UnsafeInit()}; THashCode shash = 0; try { const ffi::ArrayObj* json_array = json_obj.as(); @@ -133,7 +133,7 @@ bool TuningRecordNode::IsValid() const { } TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { - tir::Trace trace{nullptr}; + tir::Trace trace{ffi::UnsafeInit()}; ffi::Optional> run_secs; ffi::Optional target; ffi::Optional> args_info; diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index cef4b6437ba2..56e179585e5e 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -185,11 +185,11 @@ Database Database::JSONDatabase(ffi::String path_workload, ffi::String path_tuni { std::vector json_objs = JSONFileReadLines(path_tuning_record, num_threads, allow_missing); std::vector records; - records.resize(json_objs.size(), TuningRecord{nullptr}); + records.resize(json_objs.size(), TuningRecord{ffi::UnsafeInit()}); support::parallel_for_dynamic( 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { auto json_obj = json_objs[task_id].cast(); - Workload workload{nullptr}; + Workload workload{ffi::UnsafeInit()}; try { const ffi::ArrayObj* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 88b6c2c649fb..a8ac2f05c41e 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -133,7 +133,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { - IRModule lowered{nullptr}; + IRModule lowered{ffi::UnsafeInit()}; try { auto pass_list = ffi::Array(); pass_list.push_back(tir::transform::BindTarget(this->target)); diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index f0047d688a80..5b250a6d2bdd 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -415,7 +415,7 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { bool Apply(const Schedule& sch) final { tir::ParsedAnnotation parsed_root; - tir::BlockRV root_rv{nullptr}; + tir::BlockRV root_rv{ffi::UnsafeInit()}; while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { ffi::Array loop_rvs = sch->GetLoops(block_rv); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 5aaf756d43bb..7e660dc7cf30 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -114,8 +114,8 @@ Integer Extract(const Target& target, const char* name) { /*! \brief Verify the correctness of the generated GPU code. */ class VerifyGPUCodeNode : public PostprocNode { public: - Target target_{nullptr}; - ffi::Map target_constraints_{nullptr}; + Target target_{ffi::UnsafeInit()}; + ffi::Map target_constraints_{ffi::UnsafeInit()}; int thread_warp_size_ = -1; void InitializeWithTuneContext(const TuneContext& context) final { @@ -150,7 +150,7 @@ class VerifyGPUCodeNode : public PostprocNode { if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { return false; } - IRModule lowered{nullptr}; + IRModule lowered{ffi::UnsafeInit()}; try { auto pass_list = ffi::Array(); // Phase 1 diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index e8afb71d6b7f..6a2b82aa426c 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -31,7 +31,7 @@ static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - ffi::Array factors{nullptr}; + ffi::Array factors{ffi::UnsafeInit()}; ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index b71ea9164ecf..2a042553d6b9 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -141,11 +141,11 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; throw; } - LoopRV loop_rv{nullptr}; + LoopRV loop_rv{ffi::UnsafeInit()}; { ffi::Array loop_rvs = sch->GetLoops(block_rv); if (i_spatial_loop == -1) { - LoopRV spatial_loop_rv{nullptr}; + LoopRV spatial_loop_rv{ffi::UnsafeInit()}; if (loop_rvs.empty()) { spatial_loop_rv = sch->AddUnitLoop(block_rv); } else { diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index 759ab9fc721c..2b9f4f78df0e 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -35,7 +35,7 @@ static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - ffi::Array factors{nullptr}; + ffi::Array factors{ffi::UnsafeInit()}; ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); @@ -109,7 +109,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int64_t max_threads_per_block = 1024; BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - LoopRV outer{nullptr}; + LoopRV outer{ffi::UnsafeInit()}; { ffi::Array loops = sch->GetLoops(data_pack); ICHECK_EQ(loops.size(), 6); @@ -139,7 +139,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] int64_t tile_size = Downcast(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; - LoopRV outer{nullptr}; + LoopRV outer{ffi::UnsafeInit()}; { BlockRV output = sch->GetConsumers(inverse)[0]; ffi::Array nchw = sch->GetLoops(output); diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 219e05254e2f..d39951779186 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -171,7 +171,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \return The extent of "threadIdx.x" in the input schedule */ tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { - tir::ExprRV extent{nullptr}; + tir::ExprRV extent{ffi::UnsafeInit()}; for (const tir::Instruction& inst : trace->insts) { if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { @@ -198,8 +198,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing // a tuple reduction, fusion is temporarily not supported. if (sch->Get(block_rv)->writes.size() != 1) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 1. Get all the consumers of the input block. @@ -208,8 +208,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is // not fusible. if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 3. Calculate the lowest common ancestor of all the consumers. @@ -221,8 +221,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { const tir::StmtSRef& lca_sref = tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 4. Get the outer loops of the target block, and get the compute-at position index. @@ -231,8 +231,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 5. A negative position index means not fusible, and vice-versa. if (pos < 0) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } else { return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 0bbccbdffe7a..741f0b6db444 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -77,7 +77,7 @@ class TensorCoreStateNode : public StateNode { /*! \brief The tensor core intrinsic group. */ TensorCoreIntrinGroup intrin_group; /*! \brief The auto tensorization maping info. */ - tir::AutoTensorizeMappingInfo mapping_info{nullptr}; + tir::AutoTensorizeMappingInfo mapping_info{ffi::UnsafeInit()}; /*! \brief The Tensor Core reindex block A for Tensor Core computation */ tir::BlockRV tensor_core_reindex_A; /*! \brief The Tensor Core reindex block B for Tensor Core computation */ diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 306a3634d9d1..456fbbf129af 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -112,7 +112,7 @@ class SizedHeap { }; struct PerThreadData { - IRModule mod{nullptr}; + IRModule mod{ffi::UnsafeInit()}; TRandState rand_state{-1}; std::function trace_sampler = nullptr; std::function()> mutator_sampler = nullptr; @@ -270,11 +270,11 @@ class EvolutionarySearchNode : public SearchStrategyNode { * */ IRModuleSet measured_workloads_; /*! \brief A Database for selecting useful candidates. */ - Database database_{nullptr}; + Database database_{ffi::UnsafeInit()}; /*! \brief A cost model helping to explore the search space */ - CostModel cost_model_{nullptr}; + CostModel cost_model_{ffi::UnsafeInit()}; /*! \brief The token registered for the given workload in database. */ - Workload token_{nullptr}; + Workload token_{ffi::UnsafeInit()}; explicit State(EvolutionarySearchNode* self, int max_trials, int num_trials_per_iter, ffi::Array design_space_schedules, Database database, diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 732a3a083d03..ee94b1d2ab5e 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -360,7 +360,7 @@ struct ThreadedTraceApply { /*! \brief A helper data structure that stores the fail count for each postprocessor. */ struct Item { /*! \brief The postprocessor. */ - Postproc postproc{nullptr}; + Postproc postproc{ffi::UnsafeInit()}; /*! \brief The thread-safe postprocessor failure counter. */ std::atomic fail_counter{0}; }; diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 11867dee6db4..a97c5f784dc9 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -177,6 +177,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor { */ class PyExprVisitor : public ObjectRef { public: + explicit PyExprVisitor(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyExprVisitor with customized methods on the python-side. * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. @@ -461,6 +464,9 @@ class PyExprMutatorNode : public Object, public ExprMutator { */ class PyExprMutator : public ObjectRef { public: + explicit PyExprMutator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyExprMutator with customized methods on the python-side. * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 091247272a64..7deffaa9f58e 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -34,7 +34,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& meta_schedule::Builder builder = f_get_local_builder().cast(); ICHECK(builder.defined()) << "ValueError: The local builder is not defined!"; // fetch a local runner - meta_schedule::Runner runner{nullptr}; + meta_schedule::Runner runner{ffi::UnsafeInit()}; if (benchmark) { static const auto f_get_local_runner = tvm::ffi::Function::GetGlobalRequired("meta_schedule.runner.get_local_runner"); diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 2d24f0785a15..295937084d86 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -81,7 +81,7 @@ Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_ ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; auto pass_func = [=](IRModule mod, PassContext ctx) { - Database database{nullptr}; + Database database{ffi::UnsafeInit()}; if (Database::Current().defined()) { database = Database::Current().value(); } else { diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 265c58f4af63..4c456b861e9d 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -333,6 +333,9 @@ class RPCObjectRefObj : public Object { */ class RPCObjectRef : public ObjectRef { public: + explicit RPCObjectRef(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, RPCObjectRefObj); }; diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 9b0d2b966a4d..666b3839ea0e 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -264,7 +264,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (ffi::Optional doc = PrintRelaxPrint(n, n_p, d)) { return doc.value(); } - ExprDoc prefix{nullptr}; + ExprDoc prefix{ffi::UnsafeInit()}; ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 587520d72fe5..1a33d760a9d5 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -83,7 +83,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: " << tir::IterVarType2String(iter_var->iter_type); } - ExprDoc dom{nullptr}; + ExprDoc dom{ffi::UnsafeInit()}; if (tir::is_zero(iter_var->dom->min)) { ExprDoc extent = d->AsDoc(iter_var->dom->extent, // iter_var_p->Attr("dom")->Attr("extent")); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index ddcf1b64f1a1..da525aa35fc2 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -27,7 +27,7 @@ namespace printer { ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { Type type = var->type_annotation; AccessPath type_p = var_p->Attr("type_annotation"); - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -169,7 +169,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { ICHECK_EQ(r->lhs.size(), r->rhs.size()); - LambdaDoc lambda{nullptr}; + ffi::Optional lambda; { With f(d, r); int n_vars = r->lhs.size(); @@ -194,7 +194,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } ExprDoc id = d->AsDoc(r->identity_element, p->Attr("identity_element")); - return TIR(d, "comm_reducer")->Call({lambda, id}); + return TIR(d, "comm_reducer")->Call({lambda.value(), id}); }); LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array& vs, @@ -244,7 +244,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) static const OpAttrMap dtype_locations = Op::GetAttrMap("TScriptDtypePrintLocation"); tir::ScriptDtypePrintLocation dtype_print_location = tir::ScriptDtypePrintLocation::kNone; - ExprDoc prefix{nullptr}; + ffi::Optional prefix; if (auto optional_op = call->op.as()) { auto op = optional_op.value(); ffi::String name = op_names.get(op, op->name); @@ -279,7 +279,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } - return prefix->Call(args); + return prefix.value()->Call(args); } } else if (call->op.as()) { prefix = d->AsDoc(call->op, call_p->Attr("op")); @@ -299,7 +299,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } - return prefix->Call(args); + return prefix.value()->Call(args); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 10bb6f756df2..742d23f69cdd 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -78,7 +78,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (!loop->annotations.empty()) { annotations = d->AsDoc(loop->annotations, loop_p->Attr("annotations")); } - ExprDoc prefix{nullptr}; + ExprDoc prefix{ffi::UnsafeInit()}; if (loop->kind == tir::ForKind::kSerial) { if (loop->annotations.empty()) { prefix = IdDoc("range"); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 0cd38d4c6a49..797c726c7c1a 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](PointerType ty, AccessPath ty_p, IRDocsifier d) -> Doc { - ExprDoc element_type{nullptr}; + ExprDoc element_type{ffi::UnsafeInit()}; if (const auto* prim_type = ty->element_type.as()) { element_type = LiteralDoc::DataType(prim_type->dtype, // ty_p->Attr("element_type")->Attr("dtype")); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 228fbbc78556..1b0774be3686 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -284,7 +284,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; - ExprDoc data_doc{nullptr}; + ExprDoc data_doc{ffi::UnsafeInit()}; if (stmt->dtype.is_int()) { if (stmt->dtype.bits() == 8) { data_doc = PrintTensor(stmt->data.value()); @@ -377,7 +377,7 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& at tir::IterVar iter_var = Downcast(attr_stmt->node); AccessPath iter_var_p = attr_stmt_p->Attr("node"); - ExprDoc var_doc{nullptr}; + ExprDoc var_doc{ffi::UnsafeInit()}; if (d->IsVarDefined(iter_var->var)) { var_doc = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); } else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) { diff --git a/src/target/target.cc b/src/target/target.cc index b2c3e8fe8c1b..e2013aba7218 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -56,10 +56,10 @@ class TargetInternal { const ffi::Map& attrs); static Any ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info); static Any ParseType(const Any& obj, const TargetKindNode::ValueTypeInfo& info); - static ObjectPtr FromString(const ffi::String& tag_or_config_or_target_str); - static ObjectPtr FromConfigString(const ffi::String& config_str); - static ObjectPtr FromRawString(const ffi::String& target_str); - static ObjectPtr FromConfig(ffi::Map config); + static ObjectPtr FromString(const ffi::String& tag_or_config_or_target_str); + static ObjectPtr FromConfigString(const ffi::String& config_str); + static ObjectPtr FromRawString(const ffi::String& target_str); + static ObjectPtr FromConfig(ffi::Map config); static void ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv); static Target WithHost(const Target& target, const Target& target_host) { ObjectPtr n = ffi::make_object(*target.get()); @@ -771,10 +771,10 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: " << args.size(); } -ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) { +ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) { if (ffi::Optional target = TargetTag::Get(tag_or_config_or_target_str)) { Target value = target.value(); - return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(value); + return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(value); } if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') { return TargetInternal::FromConfigString(tag_or_config_or_target_str); @@ -782,7 +782,7 @@ ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or return TargetInternal::FromRawString(tag_or_config_or_target_str); } -ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str) { +ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str) { const auto loader = tvm::ffi::Function::GetGlobal("target._load_config_dict"); ICHECK(loader.has_value()) << "AttributeError: \"target._load_config_dict\" is not registered. Please check " @@ -794,7 +794,7 @@ ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); } -ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { +ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string"; // Split the string by empty spaces std::vector options = SplitString(std::string(target_str), ' '); @@ -826,7 +826,7 @@ ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { return TargetInternal::FromConfig(config); } -ObjectPtr TargetInternal::FromConfig(ffi::Map config) { +ObjectPtr TargetInternal::FromConfig(ffi::Map config) { const ffi::String kKind = "kind"; const ffi::String kTag = "tag"; const ffi::String kKeys = "keys"; diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 871452aeb946..26b55d3bb922 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -342,6 +342,9 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { */ class PyStmtExprVisitor : public ObjectRef { public: + explicit PyStmtExprVisitor(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function f_visit_stmt, // ffi::Function f_visit_expr, // ffi::Function f_visit_let_stmt, // @@ -702,6 +705,9 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { /*! \brief Managed reference to PyStmtExprMutatorNode. */ class PyStmtExprMutator : public ObjectRef { public: + explicit PyStmtExprMutator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyStmtExprMutator with customized methods on the python-side. * \return The PyStmtExprMutator created. diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 8f3372b0ca17..910c22aae0b2 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -761,6 +761,9 @@ class TensorizeInfoNode : public Object { class TensorizeInfo : public ObjectRef { public: + explicit TensorizeInfo(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); }; @@ -810,6 +813,10 @@ class AutoTensorizeMappingInfoNode : public Object { class AutoTensorizeMappingInfo : public ObjectRef { public: + explicit AutoTensorizeMappingInfo(ObjectPtr data) + : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef, AutoTensorizeMappingInfoNode); }; diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index b33333177816..89ece537713d 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -604,7 +604,7 @@ void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, } LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { - LoopRV result{nullptr}; + LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::AddUnitLoop(state_, GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); @@ -613,7 +613,7 @@ LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { } LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { - LoopRV result{nullptr}; + LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::AddUnitLoop(state_, GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index c1b303e0731b..e16c51877188 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -334,7 +334,7 @@ class WmmaToGlobalRewriter : public StmtExprMutator { Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - ffi::Optional compute_location{nullptr}; + ffi::Optional compute_location; std::tie(body, compute_location) = TileWmmaBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer; @@ -543,7 +543,7 @@ class MmaToGlobalRewriter : public StmtExprMutator { Stmt MmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - ffi::Optional compute_location{nullptr}; + ffi::Optional compute_location; std::tie(body, compute_location) = TileMmaToGlobalBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer;