From 533395694bc514366c424627b75c0bb8c33e09e3 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 8 Sep 2025 09:40:33 -0400 Subject: [PATCH] [FFI][REFACTOR] Introduce UnsafeInit and enhance ObjectRef nullptr safety This PR enhances the nullptr and general type-safe of ObjectRef types. Previously ObjectRef relies on constructor from ObjectPtr for casting and initialize from nullptr. We introduce a tag ffi::UnsafeInit, which explicitly states the intent that the initialization is unsafe and may initialize non-nullable Ref to null. Such tag should only be used in controlled scenarios. Now the general RefType(ObjectPtr) is removed. We still keep RefType(ObjectPtr) for nullable objects, but removes the default definition from non-nullable types, knowing that user can always explicitly add it to class impl (ensuring null checking). --- ffi/include/tvm/ffi/cast.h | 10 +-- ffi/include/tvm/ffi/container/array.h | 4 ++ ffi/include/tvm/ffi/container/map.h | 4 ++ ffi/include/tvm/ffi/container/shape.h | 17 ++++- ffi/include/tvm/ffi/container/tensor.h | 4 +- ffi/include/tvm/ffi/container/tuple.h | 14 ++-- ffi/include/tvm/ffi/container/variant.h | 2 +- ffi/include/tvm/ffi/extra/module.h | 17 ++++- ffi/include/tvm/ffi/function.h | 2 +- ffi/include/tvm/ffi/function_details.h | 2 +- ffi/include/tvm/ffi/object.h | 62 +++++++++++++++--- ffi/include/tvm/ffi/optional.h | 17 +++-- ffi/include/tvm/ffi/reflection/access_path.h | 4 ++ ffi/include/tvm/ffi/reflection/registry.h | 10 +++ ffi/include/tvm/ffi/rvalue_ref.h | 9 ++- ffi/include/tvm/ffi/type_traits.h | 17 +++-- ffi/src/ffi/tensor.cc | 2 +- ffi/tests/cpp/test_object.cc | 8 +++ ffi/tests/cpp/testing_object.h | 10 +-- include/tvm/ir/attrs.h | 6 +- include/tvm/ir/env_func.h | 8 +++ include/tvm/ir/expr.h | 6 +- include/tvm/ir/module.h | 6 +- include/tvm/ir/transform.h | 11 +++- include/tvm/meta_schedule/builder.h | 7 ++ include/tvm/meta_schedule/database.h | 10 ++- include/tvm/meta_schedule/runner.h | 6 +- include/tvm/meta_schedule/space_generator.h | 7 ++ include/tvm/meta_schedule/task_scheduler.h | 5 +- include/tvm/meta_schedule/tune_context.h | 7 ++ include/tvm/node/cast.h | 11 ++-- include/tvm/relax/dataflow_pattern.h | 5 ++ include/tvm/relax/expr.h | 3 +- include/tvm/relax/struct_info.h | 6 ++ include/tvm/runtime/disco/session.h | 1 + include/tvm/runtime/object.h | 2 +- include/tvm/runtime/tensor.h | 3 +- include/tvm/script/ir_builder/base.h | 1 + include/tvm/script/ir_builder/ir/frame.h | 3 + include/tvm/script/ir_builder/relax/frame.h | 24 +++++++ include/tvm/script/ir_builder/tir/frame.h | 65 +++++++++++++++++++ include/tvm/script/printer/doc.h | 39 ++++++----- include/tvm/script/printer/ir_docsifier.h | 2 +- include/tvm/target/target_kind.h | 3 + include/tvm/te/tensor.h | 1 + include/tvm/tir/block_scope.h | 7 ++ include/tvm/tir/schedule/state.h | 2 +- include/tvm/tir/var.h | 6 +- src/contrib/msc/core/printer/msc_doc.h | 8 +-- src/ir/source_map.cc | 4 +- src/meta_schedule/database/database.cc | 4 +- src/meta_schedule/database/json_database.cc | 4 +- .../disallow_async_strided_mem_copy.cc | 2 +- .../rewrite_parallel_vectorize_unroll.cc | 2 +- src/meta_schedule/postproc/verify_gpu_code.cc | 6 +- src/meta_schedule/schedule/cpu/winograd.cc | 2 +- .../schedule/cuda/thread_bind.cc | 4 +- src/meta_schedule/schedule/cuda/winograd.cc | 6 +- .../schedule_rule/cross_thread_reduction.cc | 18 ++--- .../multi_level_tiling_tensor_core.cc | 2 +- .../search_strategy/evolutionary_search.cc | 8 +-- src/meta_schedule/utils.h | 2 +- src/relax/ir/py_expr_functor.cc | 6 ++ src/relax/transform/few_shot_tuning.cc | 2 +- src/relax/transform/meta_schedule.cc | 2 +- src/runtime/rpc/rpc_session.h | 3 + src/script/printer/relax/call.cc | 2 +- src/script/printer/tir/block.cc | 2 +- src/script/printer/tir/expr.cc | 12 ++-- src/script/printer/tir/for_loop.cc | 2 +- src/script/printer/tir/ir.cc | 2 +- src/script/printer/tir/stmt.cc | 4 +- src/target/target.cc | 18 ++--- src/tir/ir/py_functor.cc | 6 ++ src/tir/schedule/analysis.h | 7 ++ src/tir/schedule/concrete_schedule.cc | 4 +- .../memhammer_tensorcore_rewrite.cc | 4 +- 77 files changed, 473 insertions(+), 153 deletions(-) diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index f70df9fe7ca2..398953ad6508 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -44,18 +44,20 @@ namespace ffi { */ template inline RefType GetRef(const ObjectType* ptr) { - static_assert(std::is_base_of_v, + using ContainerType = typename RefType::ContainerType; + static_assert(std::is_base_of_v, "Can only cast to the ref of same container type"); if constexpr (is_optional_type_v || RefType::_type_is_nullable) { if (ptr == nullptr) { - return RefType(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } else { TVM_FFI_ICHECK_NOTNULL(ptr); } - return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned( - const_cast(static_cast(ptr)))); + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned( + const_cast(static_cast(ptr)))); } /*! diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h index 7dbcc1f0189e..8fab30b8be56 100644 --- a/ffi/include/tvm/ffi/container/array.h +++ b/ffi/include/tvm/ffi/container/array.h @@ -362,6 +362,10 @@ class Array : public ObjectRef { /*! \brief The value type of the array */ using value_type = T; // constructors + /*! + * \brief Construct an Array with UnsafeInit + */ + explicit Array(UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief default constructor */ diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h index 27928d20c5cf..bea2688f7f20 100644 --- a/ffi/include/tvm/ffi/container/map.h +++ b/ffi/include/tvm/ffi/container/map.h @@ -1381,6 +1381,10 @@ class Map : public ObjectRef { using mapped_type = V; /*! \brief The iterator type of the map */ class iterator; + /*! + * \brief Construct an Map with UnsafeInit + */ + explicit Map(UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief default constructor */ diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h index 39c3ec273963..f5e88d6bb796 100644 --- a/ffi/include/tvm/ffi/container/shape.h +++ b/ffi/include/tvm/ffi/container/shape.h @@ -94,13 +94,13 @@ TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end return p; } -TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(int64_t ndim, int64_t* shape) { +TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(const int64_t* data, int64_t ndim) { int64_t* strides_data; ObjectPtr strides = details::MakeEmptyShape(ndim, &strides_data); int64_t stride = 1; for (int i = ndim - 1; i >= 0; --i) { strides_data[i] = stride; - stride *= shape[i]; + stride *= data[i]; } return strides; } @@ -150,6 +150,16 @@ class Shape : public ObjectRef { Shape(std::vector other) // NOLINT(*) : ObjectRef(make_object(std::move(other))) {} + /*! + * \brief Create a strides from a shape. + * \param data The shape data. + * \param ndim The number of dimensions. + * \return The strides. + */ + static Shape StridesFromShape(const int64_t* data, int64_t ndim) { + return Shape(details::MakeStridesFromShape(data, ndim)); + } + /*! * \brief Return the data pointer * @@ -204,6 +214,9 @@ class Shape : public ObjectRef { /// \cond Doxygen_Suppress TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj); /// \endcond + + private: + explicit Shape(ObjectPtr ptr) : ObjectRef(ptr) {} }; inline std::ostream& operator<<(std::ostream& os, const Shape& shape) { diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index 99fb29d10830..21c67decfcd5 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -203,7 +203,7 @@ class TensorObjFromNDAlloc : public TensorObj { this->ndim = static_cast(shape.size()); this->dtype = dtype; this->shape = const_cast(shape.data()); - Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape)); + Shape strides = Shape::StridesFromShape(this->shape, this->ndim); this->strides = const_cast(strides.data()); this->byte_offset = 0; this->shape_data_ = std::move(shape); @@ -224,7 +224,7 @@ class TensorObjFromDLPack : public TensorObj { explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { *static_cast(this) = tensor_->dl_tensor; if (tensor_->dl_tensor.strides == nullptr) { - Shape strides = Shape(details::MakeStridesFromShape(ndim, shape)); + Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim); this->strides = const_cast(strides.data()); this->strides_data_ = std::move(strides); } diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 0cb80b963e9e..75342409eabb 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -47,6 +47,10 @@ class Tuple : public ObjectRef { "All types used in Tuple<...> must be compatible with Any"); /*! \brief Default constructor */ Tuple() : ObjectRef(MakeDefaultTupleNode()) {} + /*! + * \brief Constructor with UnsafeInit + */ + explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {} /*! \brief Copy constructor */ Tuple(const Tuple& other) : ObjectRef(other) {} /*! \brief Move constructor */ @@ -128,13 +132,6 @@ class Tuple : public ObjectRef { return *this; } - /*! - * \brief Constructor ObjectPtr - * \param ptr The ObjectPtr - * \tparam The enable_if_t type - */ - explicit Tuple(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! * \brief Get I-th element of the tuple * @@ -283,7 +280,8 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); Any* ptr = arr.CopyOnWrite()->MutableBegin(); if (TryConvertElements<0, Types...>(ptr)) { - return Tuple(details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr>( + details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); } return std::nullopt; } diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index 5f66d73a1845..cae5a673b8ce 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -68,7 +68,7 @@ class VariantBase : public ObjectRef { explicit VariantBase(const T& other) : ObjectRef(other) {} template explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {} - explicit VariantBase(ObjectPtr ptr) : ObjectRef(ptr) {} + explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {} explicit VariantBase(Any other) : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h index 89e0c287a3fe..a1dc91eebc08 100644 --- a/ffi/include/tvm/ffi/extra/module.h +++ b/ffi/include/tvm/ffi/extra/module.h @@ -36,6 +36,7 @@ class Module; /*! * \brief A module that can dynamically load ffi::Functions or exportable source code. + * \sa Module */ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { public: @@ -168,6 +169,16 @@ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { /*! * \brief Reference to module object. + * + * When invoking a function on a ModuleObj, such as GetFunction, + * use operator-> to get the ModuleObj pointer and invoke the member functions. + * + * \code + * ffi::Module mod = ffi::Module::LoadFromFile("path/to/module.so"); + * ffi::Function func = mod->GetFunction(name); + * \endcode + * + * \sa ModuleObj which contains most of the function implementations. */ class Module : public ObjectRef { public: @@ -202,7 +213,11 @@ class Module : public ObjectRef { */ kCompilationExportable = 0b100 }; - + /*! + * \brief Constructor from ObjectPtr. + * \param ptr The object pointer. + */ + explicit Module(ObjectPtr ptr) : ObjectRef(ptr) { TVM_FFI_ICHECK(ptr != nullptr); } /*! * \brief Load a module from file. * \param file_name The name of the host function module. diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 884e46fa44cd..d27cfc0b6155 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -403,7 +403,7 @@ class Function : public ObjectRef { TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); if (handle != nullptr) { return Function( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); } else { return std::nullopt; } diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index d029c19dd107..20ca44cbcb72 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -193,7 +193,7 @@ TVM_FFI_INLINE static Error MoveFromSafeCallRaised() { TVMFFIObjectHandle handle; TVMFFIErrorMoveFromRaised(&handle); // handle is owned by caller - return Error( + return details::ObjectUnsafe::ObjectRefFromObjectPtr( details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); } diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index c1ab9d16d919..478bb27a8f20 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -44,6 +44,24 @@ using TypeIndex = TVMFFITypeIndex; */ using TypeInfo = TVMFFITypeInfo; +/*! + * \brief Helper tag to explicitly request unsafe initialization. + * + * Constructing an ObjectRefType with UnsafeInit{} will set the data_ member to nullptr. + * + * When initializing Object fields, ObjectRef fields can be set to UnsafeInit. + * This enables the "construct with UnsafeInit then set all fields" pattern + * when the object does not have a default constructor. + * + * Used for initialization in controlled scenarios where such unsafe + * initialization is known to be safe. + * + * Each ObjectRefType should have a constructor that takes an UnsafeInit tag. + * + * \note As the name suggests, do not use it in normal code paths. + */ +struct UnsafeInit {}; + /*! * \brief Known type keys for pre-defined types. */ @@ -702,6 +720,8 @@ class ObjectRef { ObjectRef& operator=(ObjectRef&& other) = default; /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! \brief Constructor from UnsafeInit */ + explicit ObjectRef(UnsafeInit) : data_(nullptr) {} /*! * \brief Comparator * \param other Another object ref. @@ -774,7 +794,9 @@ class ObjectRef { TVM_FFI_INLINE std::optional as() const { if (data_ != nullptr) { if (data_->IsInstance()) { - return ObjectRefType(data_); + ObjectRefType ref(UnsafeInit{}); + ref.data_ = data_; + return ref; } else { return std::nullopt; } @@ -782,6 +804,7 @@ class ObjectRef { return std::nullopt; } } + /*! * \brief Get the type index of the ObjectRef * \return The type index of the ObjectRef @@ -914,7 +937,8 @@ struct ObjectPtrEqual { */ #define TVM_FFI_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() = default; \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ @@ -928,7 +952,7 @@ struct ObjectPtrEqual { * \param ObjectName The type name of the object. */ #define TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ @@ -943,11 +967,12 @@ struct ObjectPtrEqual { * \note We recommend making objects immutable when possible. * This macro is only reserved for objects that stores runtime states. */ -#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ +#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName /*! @@ -958,7 +983,7 @@ struct ObjectPtrEqual { * \param ObjectName The type name of the object. */ #define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ ObjectName* operator->() const { return static_cast(data_.get()); } \ ObjectName* get() const { return operator->(); } \ @@ -1021,6 +1046,20 @@ struct ObjectUnsafe { reinterpret_cast(&(static_cast(nullptr)->header_))); } + template + TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr& ptr) { + T ref(UnsafeInit{}); + ref.data_ = ptr; + return ref; + } + + template + TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr&& ptr) { + T ref(UnsafeInit{}); + ref.data_ = std::move(ptr); + return ref; + } + template TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(const ObjectRef& ref) { if constexpr (std::is_same_v) { @@ -1035,7 +1074,10 @@ struct ObjectUnsafe { if constexpr (std::is_same_v) { return std::move(ref.data_); } else { - return tvm::ffi::ObjectPtr(std::move(ref.data_.data_)); + ObjectPtr result; + result.data_ = std::move(ref.data_.data_); + ref.data_.data_ = nullptr; + return result; } } diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index f93a0f0d555f..f370a178502e 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -262,7 +262,7 @@ class Optional>> : public Object Optional() = default; Optional(const Optional& other) : ObjectRef(other.data_) {} Optional(Optional&& other) : ObjectRef(std::move(other.data_)) {} - explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + explicit Optional(ffi::UnsafeInit tag) : ObjectRef(tag) {} // nullopt hanlding Optional(std::nullopt_t) {} // NOLINT(*) @@ -300,19 +300,20 @@ class Optional>> : public Object if (data_ == nullptr) { TVM_FFI_THROW(RuntimeError) << "Back optional access"; } - return T(data_); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); } TVM_FFI_INLINE T value() && { if (data_ == nullptr) { TVM_FFI_THROW(RuntimeError) << "Back optional access"; } - return T(std::move(data_)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); } template > TVM_FFI_INLINE T value_or(U&& default_value) const { - return data_ != nullptr ? T(data_) : T(std::forward(default_value)); + return data_ != nullptr ? details::ObjectUnsafe::ObjectRefFromObjectPtr(data_) + : T(std::forward(default_value)); } TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; } @@ -324,14 +325,18 @@ class Optional>> : public Object * \return the const reference to the stored value. * \note only use this function after checking has_value() */ - TVM_FFI_INLINE T operator*() const& noexcept { return T(data_); } + TVM_FFI_INLINE T operator*() const& noexcept { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); + } /*! * \brief Direct access to the value. * \return the const reference to the stored value. * \note only use this function after checking has_value() */ - TVM_FFI_INLINE T operator*() && noexcept { return T(std::move(data_)); } + TVM_FFI_INLINE T operator*() && noexcept { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); + } TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); } TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); } diff --git a/ffi/include/tvm/ffi/reflection/access_path.h b/ffi/include/tvm/ffi/reflection/access_path.h index c614d4ca28d8..e7aed0a8fcbf 100644 --- a/ffi/include/tvm/ffi/reflection/access_path.h +++ b/ffi/include/tvm/ffi/reflection/access_path.h @@ -360,6 +360,10 @@ class AccessPath : public ObjectRef { /// \cond Doxygen_Suppress TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef, AccessPathObj); /// \endcond + + private: + friend class AccessPathObj; + explicit AccessPath(ObjectPtr ptr) : ObjectRef(ptr) {} }; /*! diff --git a/ffi/include/tvm/ffi/reflection/registry.h b/ffi/include/tvm/ffi/reflection/registry.h index ba723fa394d7..6a1a9b55d2b0 100644 --- a/ffi/include/tvm/ffi/reflection/registry.h +++ b/ffi/include/tvm/ffi/reflection/registry.h @@ -148,6 +148,14 @@ class ReflectionDefBase { TVM_FFI_SAFE_CALL_END(); } + template + static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(UnsafeInit{}); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); + } + template TVM_FFI_INLINE static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { if constexpr (std::is_base_of_v>) { @@ -413,6 +421,8 @@ class ObjectDef : public ReflectionDefBase { info.doc = TVMFFIByteArray{nullptr, 0}; if constexpr (std::is_default_constructible_v) { info.creator = ObjectCreatorDefault; + } else if constexpr (std::is_constructible_v) { + info.creator = ObjectCreatorUnsafeInit; } // apply extra info traits ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h index 7c89038cc24e..ebbec582e62a 100644 --- a/ffi/include/tvm/ffi/rvalue_ref.h +++ b/ffi/include/tvm/ffi/rvalue_ref.h @@ -71,15 +71,17 @@ namespace ffi { template >> class RValueRef { public: + /*! \brief the container type of the rvalue ref */ + using ContainerType = typename TObjRef::ContainerType; /*! \brief only allow move constructor from rvalue of T */ explicit RValueRef(TObjRef&& data) - : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} + : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} /*! \brief return the data as rvalue */ TObjRef operator*() && { return TObjRef(std::move(data_)); } private: - mutable ObjectPtr data_; + mutable ObjectPtr data_; template friend struct TypeTraits; @@ -125,7 +127,8 @@ struct TypeTraits> : public TypeTraitsBase { tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); // fast path, storage type matches, direct move the rvalue ref if (TypeTraits::CheckAnyStrict(&tmp_any)) { - return RValueRef(TObjRef(std::move(*rvalue_ref))); + return RValueRef( + details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(*rvalue_ref))); } if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { // object type does not match up, we need to try to convert the object diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index 1812448ecc09..0f1971945a4b 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -551,34 +551,37 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { TVM_FFI_INLINE static TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } - return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); } TVM_FFI_INLINE static TObjRef MoveFromAnyAfterCheck(TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } // move out the object pointer - ObjectPtr obj_ptr = details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); + ObjectPtr obj_ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); // reset the src to nullptr TypeTraits::MoveToAny(nullptr, src); - return TObjRef(std::move(obj_ptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(obj_ptr)); } TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } } if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { if (details::IsObjectInstance(src->type_index)) { - return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); } } return std::nullopt; diff --git a/ffi/src/ffi/tensor.cc b/ffi/src/ffi/tensor.cc index 7b44e4586b4b..c166c296c8a4 100644 --- a/ffi/src/ffi/tensor.cc +++ b/ffi/src/ffi/tensor.cc @@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; } } - *ret = Shape(shape); + *ret = details::ObjectUnsafe::ObjectRefFromObjectPtr(shape); }); }); diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc index 1d7de990f01a..ec5c54c4d77a 100644 --- a/ffi/tests/cpp/test_object.cc +++ b/ffi/tests/cpp/test_object.cc @@ -97,6 +97,14 @@ TEST(ObjectRef, as) { EXPECT_EQ(b.as()->value, 20); } +TEST(ObjectRef, UnsafeInit) { + ObjectRef a(UnsafeInit{}); + EXPECT_TRUE(a.get() == nullptr); + + TInt b(UnsafeInit{}); + EXPECT_TRUE(b.get() == nullptr); +} + TEST(Object, CAPIAccessor) { ObjectRef a = TInt(10); TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(a); diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index fe3ba1b013c0..1f6e67822641 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -59,8 +59,8 @@ class TIntObj : public TNumberObj { public: int64_t value; - TIntObj() = default; TIntObj(int64_t value) : value(value) {} + explicit TIntObj(UnsafeInit) {} int64_t GetValue() const { return value; } @@ -165,9 +165,9 @@ class TVarObj : public Object { public: std::string name; - // need default constructor for json serialization - TVarObj() = default; TVarObj(std::string name) : name(name) {} + // need unsafe init constructor for json serialization + explicit TVarObj(UnsafeInit) {} static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -193,8 +193,8 @@ class TFuncObj : public Object { Array body; Optional comment; - // need default constructor for json serialization - TFuncObj() = default; + // need unsafe init constructor or default constructor for json serialization + explicit TFuncObj(UnsafeInit) {} TFuncObj(Array params, Array body, Optional comment) : params(params), body(body), comment(comment) {} diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 55576549169c..5c02db36f72e 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -54,7 +54,7 @@ namespace tvm { template inline TObjectRef NullValue() { static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); - return TObjectRef(ObjectPtr(nullptr)); + return TObjectRef(ObjectPtr(nullptr)); } template <> @@ -165,6 +165,10 @@ class DictAttrsNode : public BaseAttrsNode { */ class DictAttrs : public Attrs { public: + /*! + * \brief constructor with UnsafeInit + */ + explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {} /*! * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index e43575d486eb..e42cce527900 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -71,6 +71,10 @@ class EnvFunc : public ObjectRef { public: EnvFunc() {} explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit EnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { return static_cast(get()); } /*! @@ -117,6 +121,10 @@ class TypedEnvFunc : public ObjectRef { using TSelf = TypedEnvFunc; TypedEnvFunc() {} explicit TypedEnvFunc(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit TypedEnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 65954b83ac9d..d7e4e0f0d2ef 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -613,7 +613,11 @@ class Integer : public IntImm { /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : IntImm(node) {} + explicit Integer(ObjectPtr node) : IntImm(node) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit Integer(ffi::UnsafeInit tag) : IntImm(tag) {} /*! * \brief Construct integer from int value. */ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 5da00fb0b377..3deef6fed1f1 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -273,7 +273,11 @@ class IRModule : public ObjectRef { * \brief constructor * \param n The object pointer. */ - explicit IRModule(ObjectPtr n) : ObjectRef(n) {} + explicit IRModule(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit IRModule(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! \return mutable pointers to the node. */ IRModuleNode* operator->() const { auto* ptr = get_mutable(); diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index e501ace15997..e283234cb071 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -156,7 +156,14 @@ class PassContextNode : public Object { class PassContext : public ObjectRef { public: PassContext() {} - explicit PassContext(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit PassContext(ffi::UnsafeInit tag) : ObjectRef(tag) {} + /*! + * \brief constructor with ObjectPtr + */ + explicit PassContext(ObjectPtr n) : ObjectRef(n) {} /*! * \brief const accessor. * \return const access pointer. @@ -512,7 +519,7 @@ class Sequential : public Pass { TVM_DLL Sequential(ffi::Array passes, ffi::String name = "sequential"); Sequential() = default; - explicit Sequential(ObjectPtr n) : Pass(n) {} + explicit Sequential(ObjectPtr n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = SequentialNode; diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 6a6df2950271..0a527ad42585 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -136,6 +136,13 @@ class BuilderNode : public runtime::Object { */ class Builder : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Builder(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a builder with customized build method on the python-side. * \param f_build The packed function to the `Build` function.. diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index fbb09d7852c6..07686077311a 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -71,6 +71,7 @@ class WorkloadNode : public runtime::Object { class Workload : public runtime::ObjectRef { public: using THashCode = WorkloadNode::THashCode; + explicit Workload(ObjectPtr data) : ObjectRef(data) {} /*! * \brief Constructor of Workload. * \param mod The workload's IRModule. @@ -117,7 +118,7 @@ class TuningRecordNode : public runtime::Object { /*! \brief The trace tuned. */ tir::Trace trace; /*! \brief The workload. */ - Workload workload{nullptr}; + Workload workload{ffi::UnsafeInit()}; /*! \brief The profiling result in seconds. */ ffi::Optional> run_secs; /*! \brief The target for tuning. */ @@ -466,6 +467,13 @@ class PyDatabaseNode : public DatabaseNode { */ class Database : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Database(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief An in-memory database. * \param mod_eq_name A string to specify the module equality testing and hashing method. diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 2d42b5e590d4..f2753964ec63 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -207,7 +207,11 @@ class RunnerNode : public runtime::Object { class Runner : public runtime::ObjectRef { public: using FRun = RunnerNode::FRun; - + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Runner(ObjectPtr data) : ObjectRef(data) { TVM_FFI_ICHECK(data != nullptr); } /*! * \brief Create a runner with customized build method on the python-side. * \param f_run The packed function to run the built artifacts and get runner futures. diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index f013934e2342..a2bf7a394932 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -123,6 +123,13 @@ class SpaceGeneratorNode : public runtime::Object { */ class SpaceGenerator : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit SpaceGenerator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 0c88cb12c8cc..a6a53becad00 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -40,7 +40,7 @@ namespace meta_schedule { class TaskRecordNode : public runtime::Object { public: /*! \brief The tune context of the task. */ - TuneContext ctx{nullptr}; + TuneContext ctx{ffi::UnsafeInit()}; /*! \brief The weight of the task */ double task_weight{1.0}; /*! \brief The FLOP count of the task */ @@ -261,6 +261,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { */ class TaskScheduler : public runtime::ObjectRef { public: + explicit TaskScheduler(ObjectPtr data) : runtime::ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a task scheduler that fetches tasks in a round-robin fashion. * \param logger The tuning task's logging function. diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index cd9b8f1b5ad2..50bdb2586fc6 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -98,6 +98,13 @@ class TuneContextNode : public runtime::Object { class TuneContext : public runtime::ObjectRef { public: using TRandState = support::LinearCongruentialEngine::TRandState; + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit TuneContext(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Constructor. * \param mod The workload to be tuned. diff --git a/include/tvm/node/cast.h b/include/tvm/node/cast.h index 4ed5f4178c8b..32d4be721656 100644 --- a/include/tvm/node/cast.h +++ b/include/tvm/node/cast.h @@ -45,18 +45,19 @@ namespace tvm { template >> inline SubRef Downcast(BaseRef ref) { + using ContainerType = typename SubRef::ContainerType; if (ref.defined()) { - if (!ref->template IsInstance()) { + if (!ref->template IsInstance()) { TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key << " failed."; } - return SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); + return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr( + ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); } else { if constexpr (ffi::is_optional_type_v || SubRef::_type_is_nullable) { - return SubRef(ffi::ObjectPtr(nullptr)); + return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } - TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" - << SubRef::ContainerType::_type_key + TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" << ContainerType::_type_key << "` is not allowed. Use Downcast> instead."; TVM_FFI_UNREACHABLE(); } diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 4a7fd73c6ac0..7c4ee4e43e57 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -280,6 +280,7 @@ class PatternContextNode : public Object { */ class PatternContext : public ObjectRef { public: + explicit PatternContext(ffi::UnsafeInit tag) : ObjectRef(tag) {} TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} TVM_DLL explicit PatternContext(bool incremental = false); @@ -778,6 +779,10 @@ class WildcardPatternNode : public DFPatternNode { class WildcardPattern : public DFPattern { public: WildcardPattern(); + explicit WildcardPattern(ObjectPtr data) : DFPattern(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } // Declaring WildcardPattern declared as non-nullable avoids the // default zero-parameter constructor for ObjectRef with `data_ = diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index e0e2f4770fe9..80fe1e671091 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -607,7 +607,8 @@ class Binding : public ObjectRef { Binding() = default; public: - explicit Binding(ObjectPtr n) : ObjectRef(n) {} + explicit Binding(ObjectPtr n) : ObjectRef(n) {} + explicit Binding(ffi::UnsafeInit tag) : ObjectRef(tag) {} TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding); const BindingNode* operator->() const { return static_cast(data_.get()); } const BindingNode* get() const { return operator->(); } diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 8a97658330df..059292806de4 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -317,6 +319,10 @@ class FuncStructInfoNode : public StructInfoNode { */ class FuncStructInfo : public StructInfo { public: + explicit FuncStructInfo(ObjectPtr data) : StructInfo(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } /*! * \brief Constructor from parameter struct info and return value struct info. * \param params The struct info of function parameters. diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 1506d2548f1f..671e4bbd67f7 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -170,6 +170,7 @@ class DRefObj : public Object { */ class DRef : public ObjectRef { public: + explicit DRef(ObjectPtr data) : ObjectRef(data) { TVM_FFI_ICHECK(data != nullptr); } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj); }; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index e04a800400f1..cf5d93eae64e 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -128,7 +128,7 @@ static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= */ #define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index 71f8d27be008..97af218a1809 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -58,7 +58,8 @@ class Tensor : public tvm::ffi::Tensor { * \brief constructor. * \param data ObjectPtr to the data container. */ - explicit Tensor(ObjectPtr data) : tvm::ffi::Tensor(data) {} + explicit Tensor(ObjectPtr data) : tvm::ffi::Tensor(data) {} + explicit Tensor(ffi::UnsafeInit tag) : tvm::ffi::Tensor(tag) {} Tensor(ffi::Tensor&& other) : tvm::ffi::Tensor(std::move(other)) {} // NOLINT(*) Tensor(const ffi::Tensor& other) : tvm::ffi::Tensor(other) {} // NOLINT(*) diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index b2586e938719..75e6fd8061ea 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -107,6 +107,7 @@ class IRBuilderFrame : public runtime::ObjectRef { protected: /*! \brief Disallow direct construction of this object. */ IRBuilderFrame() = default; + explicit IRBuilderFrame(ObjectPtr data) : ObjectRef(data) {} public: /*! diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index e9f98d4a8ea6..767986fdf77f 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -75,6 +75,9 @@ class IRModuleFrameNode : public IRBuilderFrameNode { */ class IRModuleFrame : public IRBuilderFrame { public: + explicit IRModuleFrame(ObjectPtr data) : IRBuilderFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame, IRModuleFrameNode); }; diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 053f84285f6e..7ea8c439bf37 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -26,6 +26,8 @@ #include #include +#include + namespace tvm { namespace script { namespace ir_builder { @@ -45,6 +47,10 @@ class RelaxFrameNode : public IRBuilderFrameNode { class RelaxFrame : public IRBuilderFrame { public: + explicit RelaxFrame(ObjectPtr data) : IRBuilderFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); protected: @@ -78,6 +84,9 @@ class SeqExprFrameNode : public RelaxFrameNode { class SeqExprFrame : public RelaxFrame { public: + explicit SeqExprFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); }; @@ -134,6 +143,9 @@ class FunctionFrameNode : public SeqExprFrameNode { class FunctionFrame : public SeqExprFrame { public: + explicit FunctionFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); }; @@ -175,6 +187,9 @@ class BlockFrameNode : public RelaxFrameNode { class BlockFrame : public RelaxFrame { public: + explicit BlockFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); }; @@ -229,6 +244,9 @@ class IfFrameNode : public RelaxFrameNode { */ class IfFrame : public RelaxFrame { public: + explicit IfFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); }; @@ -267,6 +285,9 @@ class ThenFrameNode : public SeqExprFrameNode { */ class ThenFrame : public SeqExprFrame { public: + explicit ThenFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); }; @@ -305,6 +326,9 @@ class ElseFrameNode : public SeqExprFrameNode { */ class ElseFrame : public SeqExprFrame { public: + explicit ElseFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); }; diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 1c3e19959024..fa42ea9911c7 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -23,6 +23,8 @@ #include #include +#include + namespace tvm { namespace script { namespace ir_builder { @@ -58,6 +60,7 @@ class TIRFrame : public IRBuilderFrame { protected: TIRFrame() = default; + explicit TIRFrame(ObjectPtr data) : IRBuilderFrame(data) {} }; /*! @@ -115,6 +118,10 @@ class PrimFuncFrameNode : public TIRFrameNode { */ class PrimFuncFrame : public TIRFrame { public: + explicit PrimFuncFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; @@ -186,6 +193,10 @@ class BlockFrameNode : public TIRFrameNode { class BlockFrame : public TIRFrame { public: + explicit BlockFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); }; @@ -224,6 +235,10 @@ class BlockInitFrameNode : public TIRFrameNode { */ class BlockInitFrame : public TIRFrame { public: + explicit BlockInitFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode); }; @@ -277,6 +292,10 @@ class ForFrameNode : public TIRFrameNode { */ class ForFrame : public TIRFrame { public: + explicit ForFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); }; @@ -318,6 +337,10 @@ class AssertFrameNode : public TIRFrameNode { */ class AssertFrame : public TIRFrame { public: + explicit AssertFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode); }; @@ -358,6 +381,10 @@ class LetFrameNode : public TIRFrameNode { */ class LetFrame : public TIRFrame { public: + explicit LetFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); }; @@ -400,6 +427,10 @@ class LaunchThreadFrameNode : public TIRFrameNode { */ class LaunchThreadFrame : public TIRFrame { public: + explicit LaunchThreadFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame, LaunchThreadFrameNode); }; @@ -444,6 +475,10 @@ class RealizeFrameNode : public TIRFrameNode { */ class RealizeFrame : public TIRFrame { public: + explicit RealizeFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); }; @@ -496,6 +531,10 @@ class AllocateFrameNode : public TIRFrameNode { */ class AllocateFrame : public TIRFrame { public: + explicit AllocateFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode); }; @@ -545,6 +584,11 @@ class AllocateConstFrameNode : public TIRFrameNode { */ class AllocateConstFrame : public TIRFrame { public: + explicit AllocateConstFrame(ObjectPtr data) + : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame, AllocateConstFrameNode); }; @@ -588,6 +632,10 @@ class AttrFrameNode : public TIRFrameNode { */ class AttrFrame : public TIRFrame { public: + explicit AttrFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode); }; @@ -624,6 +672,10 @@ class WhileFrameNode : public TIRFrameNode { */ class WhileFrame : public TIRFrame { public: + explicit WhileFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode); }; @@ -667,6 +719,9 @@ class IfFrameNode : public TIRFrameNode { */ class IfFrame : public TIRFrame { public: + explicit IfFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode); }; @@ -705,6 +760,9 @@ class ThenFrameNode : public TIRFrameNode { */ class ThenFrame : public TIRFrame { public: + explicit ThenFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode); }; @@ -743,6 +801,10 @@ class ElseFrameNode : public TIRFrameNode { */ class ElseFrame : public TIRFrame { public: + explicit ElseFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode); }; @@ -769,6 +831,9 @@ class DeclBufferFrameNode : public TIRFrameNode { class DeclBufferFrame : public TIRFrame { public: + explicit DeclBufferFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode); }; diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 976e3183a16e..296df345246a 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -88,6 +88,7 @@ class DocNode : public Object { class Doc : public ObjectRef { protected: Doc() = default; + explicit Doc(ObjectPtr data) : ObjectRef(data) {} public: TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode); @@ -156,6 +157,8 @@ class ExprDoc : public Doc { */ ExprDoc operator[](ffi::Array indices) const; + explicit ExprDoc(ObjectPtr data) : Doc(data) { TVM_FFI_ICHECK(data != nullptr); } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); }; @@ -378,7 +381,7 @@ class IdDoc : public ExprDoc { class AttrAccessDocNode : public ExprDocNode { public: /*! \brief The target expression to be accessed */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; /*! \brief The attribute to be accessed */ ffi::String name; @@ -418,7 +421,7 @@ class AttrAccessDoc : public ExprDoc { class IndexDocNode : public ExprDocNode { public: /*! \brief The container value to be accessed */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; /*! * \brief The indices to access * @@ -464,7 +467,7 @@ class IndexDoc : public ExprDoc { class CallDocNode : public ExprDocNode { public: /*! \brief The callee of this function call */ - ExprDoc callee{nullptr}; + ExprDoc callee{ffi::UnsafeInit()}; /*! \brief The positional arguments */ ffi::Array args; /*! \brief The keys of keyword arguments */ @@ -604,7 +607,7 @@ class LambdaDocNode : public ExprDocNode { /*! \brief The arguments of this anonymous function */ ffi::Array args; /*! \brief The body of this anonymous function */ - ExprDoc body{nullptr}; + ExprDoc body{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -664,7 +667,7 @@ class TupleDoc : public ExprDoc { /*! * \brief Create an empty TupleDoc */ - TupleDoc() : TupleDoc(ffi::make_object()) {} + TupleDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of TupleDoc * \param elements Elements of tuple. @@ -703,7 +706,7 @@ class ListDoc : public ExprDoc { /*! * \brief Create an empty ListDoc */ - ListDoc() : ListDoc(ffi::make_object()) {} + ListDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of ListDoc * \param elements Elements of list. @@ -751,7 +754,7 @@ class DictDoc : public ExprDoc { /*! * \brief Create an empty dictionary */ - DictDoc() : DictDoc(ffi::make_object()) {} + DictDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of DictDoc * \param keys Keys of dictionary. @@ -816,7 +819,7 @@ class SliceDoc : public Doc { class AssignDocNode : public StmtDocNode { public: /*! \brief The left hand side of the assignment */ - ExprDoc lhs{nullptr}; + ExprDoc lhs{ffi::UnsafeInit()}; /*! * \brief The right hand side of the assignment. * @@ -864,7 +867,7 @@ class AssignDoc : public StmtDoc { class IfDocNode : public StmtDocNode { public: /*! \brief The predicate of the if-then-else statement. */ - ExprDoc predicate{nullptr}; + ExprDoc predicate{ffi::UnsafeInit()}; /*! \brief The then branch of the if-then-else statement. */ ffi::Array then_branch; /*! \brief The else branch of the if-then-else statement. */ @@ -909,7 +912,7 @@ class IfDoc : public StmtDoc { class WhileDocNode : public StmtDocNode { public: /*! \brief The predicate of the while statement. */ - ExprDoc predicate{nullptr}; + ExprDoc predicate{ffi::UnsafeInit()}; /*! \brief The body of the while statement. */ ffi::Array body; @@ -953,9 +956,9 @@ class WhileDoc : public StmtDoc { class ForDocNode : public StmtDocNode { public: /*! \brief The left hand side of the assignment of iterating variable. */ - ExprDoc lhs{nullptr}; + ExprDoc lhs{ffi::UnsafeInit()}; /*! \brief The right hand side of the assignment of iterating variable. */ - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; /*! \brief The body of the for statement. */ ffi::Array body; @@ -1004,7 +1007,7 @@ class ScopeDocNode : public StmtDocNode { /*! \brief The name of the scoped variable. */ ffi::Optional lhs{std::nullopt}; /*! \brief The value of the scoped variable. */ - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; /*! \brief The body of the scope doc. */ ffi::Array body; @@ -1054,7 +1057,7 @@ class ScopeDoc : public StmtDoc { class ExprStmtDocNode : public StmtDocNode { public: /*! \brief The expression represented by this doc. */ - ExprDoc expr{nullptr}; + ExprDoc expr{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1089,7 +1092,7 @@ class ExprStmtDoc : public StmtDoc { class AssertDocNode : public StmtDocNode { public: /*! \brief The expression to test. */ - ExprDoc test{nullptr}; + ExprDoc test{ffi::UnsafeInit()}; /*! \brief The optional error message when assertion failed. */ ffi::Optional msg{std::nullopt}; @@ -1129,7 +1132,7 @@ class AssertDoc : public StmtDoc { class ReturnDocNode : public StmtDocNode { public: /*! \brief The value to return. */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1164,7 +1167,7 @@ class ReturnDoc : public StmtDoc { class FunctionDocNode : public StmtDocNode { public: /*! \brief The name of function. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of function. * @@ -1223,7 +1226,7 @@ class FunctionDoc : public StmtDoc { class ClassDocNode : public StmtDocNode { public: /*! \brief The name of class. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! \brief Decorators of class. */ ffi::Array decorators; /*! \brief The body of class. */ diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 6e6be57f9ce5..a2fc1097ac36 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -132,7 +132,7 @@ class IRDocsifierNode : public Object { ffi::Optional name; }; /*! \brief The configuration of the printer */ - PrinterConfig cfg{nullptr}; + PrinterConfig cfg{ffi::UnsafeInit()}; /*! * \brief The stack of frames. * \sa FrameNode diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index ad167ce08bcc..f468f9cbac1b 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -127,6 +127,9 @@ class TargetKindNode : public Object { class TargetKind : public ObjectRef { public: TargetKind() = default; + explicit TargetKind(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! \brief Get the attribute map given the attribute name */ template static inline TargetKindAttrMap GetAttrMap(const ffi::String& attr_name); diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 8bcad6950f4d..68b2bbf71504 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -50,6 +50,7 @@ class Operation : public ObjectRef { /*! \brief default constructor */ Operation() {} explicit Operation(ObjectPtr n) : ObjectRef(n) {} + explicit Operation(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/tir/block_scope.h b/include/tvm/tir/block_scope.h index 3fc2515d0812..f79a45650045 100644 --- a/include/tvm/tir/block_scope.h +++ b/include/tvm/tir/block_scope.h @@ -297,6 +297,13 @@ class BlockScopeNode : public Object { */ class BlockScope : public ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit BlockScope(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! \brief The constructor creating an empty block scope with on dependency information */ TVM_DLL BlockScope(); /*! diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 8cb0053df79c..22c4c7d7bd78 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -43,7 +43,7 @@ namespace tir { */ struct BlockInfo { /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */ - BlockScope scope{nullptr}; + BlockScope scope{ffi::UnsafeInit()}; // The properties below are information about the current block realization under its parent scope /*! \brief Property of a block, indicating the block realization binding is quasi-affine */ bool affine_binding{false}; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 578b00fc08d4..51100c2292e2 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -77,7 +77,8 @@ class VarNode : public PrimExprNode { /*! \brief a named variable in TIR */ class Var : public PrimExpr { public: - explicit Var(ObjectPtr n) : PrimExpr(n) {} + explicit Var(ffi::UnsafeInit tag) : PrimExpr(tag) {} + explicit Var(ObjectPtr n) : PrimExpr(n) {} /*! * \brief Constructor * \param name_hint variable name @@ -143,7 +144,8 @@ class SizeVarNode : public VarNode { /*! \brief a named variable represents a tensor index size */ class SizeVar : public Var { public: - explicit SizeVar(ObjectPtr n) : Var(n) {} + explicit SizeVar(ObjectPtr n) : Var(n) {} + explicit SizeVar(ffi::UnsafeInit tag) : Var(tag) {} /*! * \brief constructor * \param name_hint variable name diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h index ea1cee396ba6..6433f3de9a2e 100644 --- a/src/contrib/msc/core/printer/msc_doc.h +++ b/src/contrib/msc/core/printer/msc_doc.h @@ -45,7 +45,7 @@ class DeclareDocNode : public ExprDocNode { /*! \brief The type of the variable */ ffi::Optional type; /*! \brief The variable */ - ExprDoc variable{nullptr}; + ExprDoc variable{ffi::UnsafeInit{}}; /*! \brief The init arguments for the variable. */ ffi::Array init_args; /*! \brief Whether to use constructor(otherwise initializer) */ @@ -164,7 +164,7 @@ class PointerDoc : public ExprDoc { class StructDocNode : public StmtDocNode { public: /*! \brief The name of class. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! \brief Decorators of class. */ ffi::Array decorators; /*! \brief The body of class. */ @@ -207,7 +207,7 @@ class StructDoc : public StmtDoc { class ConstructorDocNode : public StmtDocNode { public: /*! \brief The name of function. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of function. * @@ -300,7 +300,7 @@ class SwitchDoc : public StmtDoc { class LambdaDocNode : public StmtDocNode { public: /*! \brief The name of lambda. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of lambda. * diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 26fbe07cf6d3..47727d5297a0 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -46,7 +46,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("__data_from_json__", SourceName::Get); }); -ObjectPtr GetSourceNameNode(const ffi::String& name) { +ObjectPtr GetSourceNameNode(const ffi::String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr static std::unordered_map> source_map; @@ -62,7 +62,7 @@ ObjectPtr GetSourceNameNode(const ffi::String& name) { } } -ObjectPtr GetSourceNameNodeByStr(const std::string& name) { +ObjectPtr GetSourceNameNodeByStr(const std::string& name) { return GetSourceNameNode(name); } diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index b3c02607bddc..8094449bfb97 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -50,7 +50,7 @@ ObjectRef WorkloadNode::AsJSON() const { } Workload Workload::FromJSON(const ObjectRef& json_obj) { - IRModule mod{nullptr}; + IRModule mod{ffi::UnsafeInit()}; THashCode shash = 0; try { const ffi::ArrayObj* json_array = json_obj.as(); @@ -133,7 +133,7 @@ bool TuningRecordNode::IsValid() const { } TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { - tir::Trace trace{nullptr}; + tir::Trace trace{ffi::UnsafeInit()}; ffi::Optional> run_secs; ffi::Optional target; ffi::Optional> args_info; diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index cef4b6437ba2..56e179585e5e 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -185,11 +185,11 @@ Database Database::JSONDatabase(ffi::String path_workload, ffi::String path_tuni { std::vector json_objs = JSONFileReadLines(path_tuning_record, num_threads, allow_missing); std::vector records; - records.resize(json_objs.size(), TuningRecord{nullptr}); + records.resize(json_objs.size(), TuningRecord{ffi::UnsafeInit()}); support::parallel_for_dynamic( 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { auto json_obj = json_objs[task_id].cast(); - Workload workload{nullptr}; + Workload workload{ffi::UnsafeInit()}; try { const ffi::ArrayObj* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 88b6c2c649fb..a8ac2f05c41e 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -133,7 +133,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { - IRModule lowered{nullptr}; + IRModule lowered{ffi::UnsafeInit()}; try { auto pass_list = ffi::Array(); pass_list.push_back(tir::transform::BindTarget(this->target)); diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index f0047d688a80..5b250a6d2bdd 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -415,7 +415,7 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { bool Apply(const Schedule& sch) final { tir::ParsedAnnotation parsed_root; - tir::BlockRV root_rv{nullptr}; + tir::BlockRV root_rv{ffi::UnsafeInit()}; while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { ffi::Array loop_rvs = sch->GetLoops(block_rv); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 5aaf756d43bb..7e660dc7cf30 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -114,8 +114,8 @@ Integer Extract(const Target& target, const char* name) { /*! \brief Verify the correctness of the generated GPU code. */ class VerifyGPUCodeNode : public PostprocNode { public: - Target target_{nullptr}; - ffi::Map target_constraints_{nullptr}; + Target target_{ffi::UnsafeInit()}; + ffi::Map target_constraints_{ffi::UnsafeInit()}; int thread_warp_size_ = -1; void InitializeWithTuneContext(const TuneContext& context) final { @@ -150,7 +150,7 @@ class VerifyGPUCodeNode : public PostprocNode { if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { return false; } - IRModule lowered{nullptr}; + IRModule lowered{ffi::UnsafeInit()}; try { auto pass_list = ffi::Array(); // Phase 1 diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index e8afb71d6b7f..6a2b82aa426c 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -31,7 +31,7 @@ static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - ffi::Array factors{nullptr}; + ffi::Array factors{ffi::UnsafeInit()}; ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index b71ea9164ecf..2a042553d6b9 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -141,11 +141,11 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; throw; } - LoopRV loop_rv{nullptr}; + LoopRV loop_rv{ffi::UnsafeInit()}; { ffi::Array loop_rvs = sch->GetLoops(block_rv); if (i_spatial_loop == -1) { - LoopRV spatial_loop_rv{nullptr}; + LoopRV spatial_loop_rv{ffi::UnsafeInit()}; if (loop_rvs.empty()) { spatial_loop_rv = sch->AddUnitLoop(block_rv); } else { diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index 759ab9fc721c..2b9f4f78df0e 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -35,7 +35,7 @@ static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - ffi::Array factors{nullptr}; + ffi::Array factors{ffi::UnsafeInit()}; ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); @@ -109,7 +109,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int64_t max_threads_per_block = 1024; BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - LoopRV outer{nullptr}; + LoopRV outer{ffi::UnsafeInit()}; { ffi::Array loops = sch->GetLoops(data_pack); ICHECK_EQ(loops.size(), 6); @@ -139,7 +139,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] int64_t tile_size = Downcast(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; - LoopRV outer{nullptr}; + LoopRV outer{ffi::UnsafeInit()}; { BlockRV output = sch->GetConsumers(inverse)[0]; ffi::Array nchw = sch->GetLoops(output); diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 219e05254e2f..d39951779186 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -171,7 +171,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \return The extent of "threadIdx.x" in the input schedule */ tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { - tir::ExprRV extent{nullptr}; + tir::ExprRV extent{ffi::UnsafeInit()}; for (const tir::Instruction& inst : trace->insts) { if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { @@ -198,8 +198,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing // a tuple reduction, fusion is temporarily not supported. if (sch->Get(block_rv)->writes.size() != 1) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 1. Get all the consumers of the input block. @@ -208,8 +208,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is // not fusible. if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 3. Calculate the lowest common ancestor of all the consumers. @@ -221,8 +221,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { const tir::StmtSRef& lca_sref = tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 4. Get the outer loops of the target block, and get the compute-at position index. @@ -231,8 +231,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 5. A negative position index means not fusible, and vice-versa. if (pos < 0) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } else { return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 0bbccbdffe7a..741f0b6db444 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -77,7 +77,7 @@ class TensorCoreStateNode : public StateNode { /*! \brief The tensor core intrinsic group. */ TensorCoreIntrinGroup intrin_group; /*! \brief The auto tensorization maping info. */ - tir::AutoTensorizeMappingInfo mapping_info{nullptr}; + tir::AutoTensorizeMappingInfo mapping_info{ffi::UnsafeInit()}; /*! \brief The Tensor Core reindex block A for Tensor Core computation */ tir::BlockRV tensor_core_reindex_A; /*! \brief The Tensor Core reindex block B for Tensor Core computation */ diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 306a3634d9d1..456fbbf129af 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -112,7 +112,7 @@ class SizedHeap { }; struct PerThreadData { - IRModule mod{nullptr}; + IRModule mod{ffi::UnsafeInit()}; TRandState rand_state{-1}; std::function trace_sampler = nullptr; std::function()> mutator_sampler = nullptr; @@ -270,11 +270,11 @@ class EvolutionarySearchNode : public SearchStrategyNode { * */ IRModuleSet measured_workloads_; /*! \brief A Database for selecting useful candidates. */ - Database database_{nullptr}; + Database database_{ffi::UnsafeInit()}; /*! \brief A cost model helping to explore the search space */ - CostModel cost_model_{nullptr}; + CostModel cost_model_{ffi::UnsafeInit()}; /*! \brief The token registered for the given workload in database. */ - Workload token_{nullptr}; + Workload token_{ffi::UnsafeInit()}; explicit State(EvolutionarySearchNode* self, int max_trials, int num_trials_per_iter, ffi::Array design_space_schedules, Database database, diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 732a3a083d03..ee94b1d2ab5e 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -360,7 +360,7 @@ struct ThreadedTraceApply { /*! \brief A helper data structure that stores the fail count for each postprocessor. */ struct Item { /*! \brief The postprocessor. */ - Postproc postproc{nullptr}; + Postproc postproc{ffi::UnsafeInit()}; /*! \brief The thread-safe postprocessor failure counter. */ std::atomic fail_counter{0}; }; diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 11867dee6db4..a97c5f784dc9 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -177,6 +177,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor { */ class PyExprVisitor : public ObjectRef { public: + explicit PyExprVisitor(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyExprVisitor with customized methods on the python-side. * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. @@ -461,6 +464,9 @@ class PyExprMutatorNode : public Object, public ExprMutator { */ class PyExprMutator : public ObjectRef { public: + explicit PyExprMutator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyExprMutator with customized methods on the python-side. * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 091247272a64..7deffaa9f58e 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -34,7 +34,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& meta_schedule::Builder builder = f_get_local_builder().cast(); ICHECK(builder.defined()) << "ValueError: The local builder is not defined!"; // fetch a local runner - meta_schedule::Runner runner{nullptr}; + meta_schedule::Runner runner{ffi::UnsafeInit()}; if (benchmark) { static const auto f_get_local_runner = tvm::ffi::Function::GetGlobalRequired("meta_schedule.runner.get_local_runner"); diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 2d24f0785a15..295937084d86 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -81,7 +81,7 @@ Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_ ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; auto pass_func = [=](IRModule mod, PassContext ctx) { - Database database{nullptr}; + Database database{ffi::UnsafeInit()}; if (Database::Current().defined()) { database = Database::Current().value(); } else { diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 265c58f4af63..4c456b861e9d 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -333,6 +333,9 @@ class RPCObjectRefObj : public Object { */ class RPCObjectRef : public ObjectRef { public: + explicit RPCObjectRef(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, RPCObjectRefObj); }; diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 9b0d2b966a4d..666b3839ea0e 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -264,7 +264,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (ffi::Optional doc = PrintRelaxPrint(n, n_p, d)) { return doc.value(); } - ExprDoc prefix{nullptr}; + ExprDoc prefix{ffi::UnsafeInit()}; ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 587520d72fe5..1a33d760a9d5 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -83,7 +83,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: " << tir::IterVarType2String(iter_var->iter_type); } - ExprDoc dom{nullptr}; + ExprDoc dom{ffi::UnsafeInit()}; if (tir::is_zero(iter_var->dom->min)) { ExprDoc extent = d->AsDoc(iter_var->dom->extent, // iter_var_p->Attr("dom")->Attr("extent")); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index ddcf1b64f1a1..da525aa35fc2 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -27,7 +27,7 @@ namespace printer { ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { Type type = var->type_annotation; AccessPath type_p = var_p->Attr("type_annotation"); - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -169,7 +169,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { ICHECK_EQ(r->lhs.size(), r->rhs.size()); - LambdaDoc lambda{nullptr}; + ffi::Optional lambda; { With f(d, r); int n_vars = r->lhs.size(); @@ -194,7 +194,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } ExprDoc id = d->AsDoc(r->identity_element, p->Attr("identity_element")); - return TIR(d, "comm_reducer")->Call({lambda, id}); + return TIR(d, "comm_reducer")->Call({lambda.value(), id}); }); LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array& vs, @@ -244,7 +244,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) static const OpAttrMap dtype_locations = Op::GetAttrMap("TScriptDtypePrintLocation"); tir::ScriptDtypePrintLocation dtype_print_location = tir::ScriptDtypePrintLocation::kNone; - ExprDoc prefix{nullptr}; + ffi::Optional prefix; if (auto optional_op = call->op.as()) { auto op = optional_op.value(); ffi::String name = op_names.get(op, op->name); @@ -279,7 +279,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } - return prefix->Call(args); + return prefix.value()->Call(args); } } else if (call->op.as()) { prefix = d->AsDoc(call->op, call_p->Attr("op")); @@ -299,7 +299,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } - return prefix->Call(args); + return prefix.value()->Call(args); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 10bb6f756df2..742d23f69cdd 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -78,7 +78,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (!loop->annotations.empty()) { annotations = d->AsDoc(loop->annotations, loop_p->Attr("annotations")); } - ExprDoc prefix{nullptr}; + ExprDoc prefix{ffi::UnsafeInit()}; if (loop->kind == tir::ForKind::kSerial) { if (loop->annotations.empty()) { prefix = IdDoc("range"); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 0cd38d4c6a49..797c726c7c1a 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](PointerType ty, AccessPath ty_p, IRDocsifier d) -> Doc { - ExprDoc element_type{nullptr}; + ExprDoc element_type{ffi::UnsafeInit()}; if (const auto* prim_type = ty->element_type.as()) { element_type = LiteralDoc::DataType(prim_type->dtype, // ty_p->Attr("element_type")->Attr("dtype")); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 228fbbc78556..1b0774be3686 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -284,7 +284,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; - ExprDoc data_doc{nullptr}; + ExprDoc data_doc{ffi::UnsafeInit()}; if (stmt->dtype.is_int()) { if (stmt->dtype.bits() == 8) { data_doc = PrintTensor(stmt->data.value()); @@ -377,7 +377,7 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& at tir::IterVar iter_var = Downcast(attr_stmt->node); AccessPath iter_var_p = attr_stmt_p->Attr("node"); - ExprDoc var_doc{nullptr}; + ExprDoc var_doc{ffi::UnsafeInit()}; if (d->IsVarDefined(iter_var->var)) { var_doc = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); } else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) { diff --git a/src/target/target.cc b/src/target/target.cc index b2c3e8fe8c1b..e2013aba7218 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -56,10 +56,10 @@ class TargetInternal { const ffi::Map& attrs); static Any ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info); static Any ParseType(const Any& obj, const TargetKindNode::ValueTypeInfo& info); - static ObjectPtr FromString(const ffi::String& tag_or_config_or_target_str); - static ObjectPtr FromConfigString(const ffi::String& config_str); - static ObjectPtr FromRawString(const ffi::String& target_str); - static ObjectPtr FromConfig(ffi::Map config); + static ObjectPtr FromString(const ffi::String& tag_or_config_or_target_str); + static ObjectPtr FromConfigString(const ffi::String& config_str); + static ObjectPtr FromRawString(const ffi::String& target_str); + static ObjectPtr FromConfig(ffi::Map config); static void ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv); static Target WithHost(const Target& target, const Target& target_host) { ObjectPtr n = ffi::make_object(*target.get()); @@ -771,10 +771,10 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: " << args.size(); } -ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) { +ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) { if (ffi::Optional target = TargetTag::Get(tag_or_config_or_target_str)) { Target value = target.value(); - return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(value); + return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(value); } if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') { return TargetInternal::FromConfigString(tag_or_config_or_target_str); @@ -782,7 +782,7 @@ ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or return TargetInternal::FromRawString(tag_or_config_or_target_str); } -ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str) { +ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str) { const auto loader = tvm::ffi::Function::GetGlobal("target._load_config_dict"); ICHECK(loader.has_value()) << "AttributeError: \"target._load_config_dict\" is not registered. Please check " @@ -794,7 +794,7 @@ ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); } -ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { +ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string"; // Split the string by empty spaces std::vector options = SplitString(std::string(target_str), ' '); @@ -826,7 +826,7 @@ ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { return TargetInternal::FromConfig(config); } -ObjectPtr TargetInternal::FromConfig(ffi::Map config) { +ObjectPtr TargetInternal::FromConfig(ffi::Map config) { const ffi::String kKind = "kind"; const ffi::String kTag = "tag"; const ffi::String kKeys = "keys"; diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 871452aeb946..26b55d3bb922 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -342,6 +342,9 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { */ class PyStmtExprVisitor : public ObjectRef { public: + explicit PyStmtExprVisitor(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function f_visit_stmt, // ffi::Function f_visit_expr, // ffi::Function f_visit_let_stmt, // @@ -702,6 +705,9 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { /*! \brief Managed reference to PyStmtExprMutatorNode. */ class PyStmtExprMutator : public ObjectRef { public: + explicit PyStmtExprMutator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyStmtExprMutator with customized methods on the python-side. * \return The PyStmtExprMutator created. diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 8f3372b0ca17..910c22aae0b2 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -761,6 +761,9 @@ class TensorizeInfoNode : public Object { class TensorizeInfo : public ObjectRef { public: + explicit TensorizeInfo(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); }; @@ -810,6 +813,10 @@ class AutoTensorizeMappingInfoNode : public Object { class AutoTensorizeMappingInfo : public ObjectRef { public: + explicit AutoTensorizeMappingInfo(ObjectPtr data) + : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef, AutoTensorizeMappingInfoNode); }; diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index b33333177816..89ece537713d 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -604,7 +604,7 @@ void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, } LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { - LoopRV result{nullptr}; + LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::AddUnitLoop(state_, GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); @@ -613,7 +613,7 @@ LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { } LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { - LoopRV result{nullptr}; + LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::AddUnitLoop(state_, GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index c1b303e0731b..e16c51877188 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -334,7 +334,7 @@ class WmmaToGlobalRewriter : public StmtExprMutator { Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - ffi::Optional compute_location{nullptr}; + ffi::Optional compute_location; std::tie(body, compute_location) = TileWmmaBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer; @@ -543,7 +543,7 @@ class MmaToGlobalRewriter : public StmtExprMutator { Stmt MmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - ffi::Optional compute_location{nullptr}; + ffi::Optional compute_location; std::tie(body, compute_location) = TileMmaToGlobalBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer;