From 0f6171c5f68fa160795e6c5d787589528fb68f04 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 13 Apr 2020 10:49:48 -0700 Subject: [PATCH] [RUNTIME][IR] Allow non-nullable ObjectRef, introduce Optional. (#5314) * [RUNTIME] Allow non-nullable ObjectRef, introduce Optional. We use ObjectRef and their sub-classes extensively throughout our codebase. Each of ObjectRef's sub-classes are nullable, which means they can hold nullptr as their values. While in some places we need nullptr as an alternative value. The implicit support for nullptr in all ObjectRef creates additional burdens for the developer to explicitly check defined in many places of the codebase. Moreover, it is unclear from the API's intentional point of view whether we want a nullable object or not-null version(many cases we want the later). Borrowing existing wisdoms from languages like Rust. We propose to introduce non-nullable ObjectRef, and Optional container that represents a nullable variant. To keep backward compatiblity, we will start by allowing most ObjectRef to be nullable. However, we should start to use Optional as the type in places where we know nullable is a requirement. Gradually, we will move most of the ObjectRef to be non-nullable and use Optional in the nullable cases. Such explicitness in typing can help reduce the potential problems in our codebase overall. Changes in this PR: - Introduce _type_is_nullable attribute to ObjectRef - Introduce Optional - Change String to be non-nullable. - Change the API of function->GetAttr to return Optional * Address review comments * Upgrade all compiler flags to c++14 * Update as per review comment --- .../app/src/main/jni/Application.mk | 2 +- .../app/src/main/jni/Application.mk | 4 +- .../app/src/main/jni/Application.mk | 2 +- apps/cpp_rpc/Makefile | 2 +- apps/dso_plugin_module/Makefile | 2 +- apps/extension/Makefile | 2 +- apps/howto_deploy/Makefile | 2 +- apps/howto_deploy/tvm_runtime_pack.cc | 2 +- apps/rocm_rpc/Makefile | 2 +- apps/tf_tvmdsoop/CMakeLists.txt | 2 +- golang/Makefile | 2 +- include/tvm/ir/attrs.h | 2 + include/tvm/ir/expr.h | 67 +++++++- include/tvm/ir/function.h | 19 ++- include/tvm/node/node.h | 1 - include/tvm/runtime/container.h | 148 +++++++++++++++++- include/tvm/runtime/object.h | 64 ++++++-- include/tvm/runtime/packed_func.h | 8 +- python/setup.py | 2 +- src/driver/driver_api.cc | 10 +- src/relay/backend/compile_engine.cc | 6 +- .../backend/contrib/codegen_c/codegen_c.h | 2 +- src/relay/backend/vm/compiler.cc | 2 +- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/ir/transform.cc | 4 +- src/relay/transforms/annotate_target.cc | 6 +- src/target/build_common.h | 6 +- src/target/llvm/codegen_cpu.cc | 2 +- src/target/llvm/codegen_llvm.cc | 4 +- src/target/llvm/llvm_module.cc | 2 +- src/target/opt/build_cuda_on.cc | 3 +- src/target/source/codegen_aocl.cc | 3 +- src/target/source/codegen_c.cc | 2 +- src/target/source/codegen_metal.cc | 12 +- src/target/source/codegen_opencl.cc | 3 +- src/target/source/codegen_opengl.cc | 5 +- src/target/source/codegen_vhls.cc | 5 +- src/target/spirv/build_vulkan.cc | 5 +- src/target/spirv/codegen_spirv.cc | 3 +- src/target/stackvm/codegen_stackvm.cc | 2 +- src/tir/analysis/verify_memory.cc | 2 +- src/tir/transforms/bind_device_type.cc | 2 +- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/lower_intrin.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 2 +- src/tir/transforms/lower_warp_memory.cc | 2 +- src/tir/transforms/make_packed_api.cc | 9 +- src/tir/transforms/remap_thread_axis.cc | 5 +- src/tir/transforms/split_host_device.cc | 4 +- tests/cpp/container_test.cc | 69 ++++++++ tests/python/relay/test_external_codegen.py | 2 +- tests/python/relay/test_external_runtime.py | 4 +- .../python/relay/test_pass_annotate_target.py | 2 +- .../python/relay/test_pass_partition_graph.py | 2 +- .../unittest/test_runtime_module_export.py | 2 +- vta/python/vta/exec/rpc_server.py | 2 +- 56 files changed, 432 insertions(+), 105 deletions(-) diff --git a/apps/android_camera/app/src/main/jni/Application.mk b/apps/android_camera/app/src/main/jni/Application.mk index 95a5a9697bcc1..63a79458ef946 100644 --- a/apps/android_camera/app/src/main/jni/Application.mk +++ b/apps/android_camera/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= all APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_deploy/app/src/main/jni/Application.mk b/apps/android_deploy/app/src/main/jni/Application.mk index ee13eb8a12137..a50a40bf5cd19 100644 --- a/apps/android_deploy/app/src/main/jni/Application.mk +++ b/apps/android_deploy/app/src/main/jni/Application.mk @@ -27,7 +27,7 @@ include $(config) APP_STL := c++_static -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti -ifeq ($(USE_OPENCL), 1) +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti +ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index 56288bde98984..54abdf771e2a1 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile index 9cd39b446acc4..927331ad00ea3 100644 --- a/apps/cpp_rpc/Makefile +++ b/apps/cpp_rpc/Makefile @@ -28,7 +28,7 @@ else LINK_PTHREAD= endif -PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC -Wall\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/dso_plugin_module/Makefile b/apps/dso_plugin_module/Makefile index 2ee6189e2876b..c2ce3306870a6 100644 --- a/apps/dso_plugin_module/Makefile +++ b/apps/dso_plugin_module/Makefile @@ -16,7 +16,7 @@ # under the License. TVM_ROOT=$(shell cd ../..; pwd) -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/extension/Makefile b/apps/extension/Makefile index e178b661f4038..91d914aba63b7 100644 --- a/apps/extension/Makefile +++ b/apps/extension/Makefile @@ -17,7 +17,7 @@ # Minimum Makefile for the extension package TVM_ROOT=$(shell cd ../..; pwd) -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/howto_deploy/Makefile b/apps/howto_deploy/Makefile index a260e89bc042f..4ee243c2ce60c 100644 --- a/apps/howto_deploy/Makefile +++ b/apps/howto_deploy/Makefile @@ -19,7 +19,7 @@ TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\ diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index d166eaf756a56..81bab497bebb8 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -24,7 +24,7 @@ * include in your project. * * - Copy this file into your project which depends on tvm runtime. - * - Compile with -std=c++11 + * - Compile with -std=c++14 * - Add the following include path * - /path/to/tvm/include/ * - /path/to/tvm/3rdparty/dmlc-core/include/ diff --git a/apps/rocm_rpc/Makefile b/apps/rocm_rpc/Makefile index 36eb41596be87..971ca46033143 100644 --- a/apps/rocm_rpc/Makefile +++ b/apps/rocm_rpc/Makefile @@ -21,7 +21,7 @@ ROCM_PATH=/opt/rocm TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\ diff --git a/apps/tf_tvmdsoop/CMakeLists.txt b/apps/tf_tvmdsoop/CMakeLists.txt index cb601ef6d30d7..f4e83c5287014 100644 --- a/apps/tf_tvmdsoop/CMakeLists.txt +++ b/apps/tf_tvmdsoop/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.2) project(tf_tvmdsoop C CXX) -set(TFTVM_COMPILE_FLAGS -std=c++11) +set(TFTVM_COMPILE_FLAGS -std=c++14) set(BUILD_TVMDSOOP_ONLY ON) set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT}) set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build) diff --git a/golang/Makefile b/golang/Makefile index c54fd0e0992c5..6fd77996e119d 100644 --- a/golang/Makefile +++ b/golang/Makefile @@ -25,7 +25,7 @@ NATIVE_SRC = tvm_runtime_pack.cc GOPATH=$(CURDIR)/gopath GOPATHDIR=${GOPATH}/src/${TARGET}/ CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" -CGO_CXXFLAGS="-std=c++11" +CGO_CXXFLAGS="-std=c++14" CGO_CFLAGS="-I${TVM_BASE}" CGO_LDFLAGS="-ldl -lm" diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 0fc832e0fb7ae..d12f1b85114c7 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -85,6 +85,8 @@ namespace tvm { */ template inline TObjectRef NullValue() { + static_assert(TObjectRef::_type_is_nullable, + "Can only get NullValue for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 859a134cd5aa5..6630bf3ded20e 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -311,6 +311,47 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); }; +/*! + * \brief Boolean constant. + * + * This reference type is useful to add additional compile-time + * type checks and helper functions for Integer equal comparisons. + */ +class Bool : public IntImm { + public: + explicit Bool(bool value) + : IntImm(DataType::Bool(), value) { + } + Bool operator!() const { + return Bool((*this)->value == 0); + } + operator bool() const { + return (*this)->value != 0; + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); +}; + +// Overload operators to make sure we have the most fine grained types. +inline Bool operator||(const Bool& a, bool b) { + return Bool(a.operator bool() || b); +} +inline Bool operator||(bool a, const Bool& b) { + return Bool(a || b.operator bool()); +} +inline Bool operator||(const Bool& a, const Bool& b) { + return Bool(a.operator bool() || b.operator bool()); +} +inline Bool operator&&(const Bool& a, bool b) { + return Bool(a.operator bool() && b); +} +inline Bool operator&&(bool a, const Bool& b) { + return Bool(a && b.operator bool()); +} +inline Bool operator&&(const Bool& a, const Bool& b) { + return Bool(a.operator bool() && b.operator bool()); +} + /*! * \brief Container of constant int that adds more constructors. * @@ -340,10 +381,10 @@ class Integer : public IntImm { * \tparam Enum The enum type. * \param value The enum value. */ - template::value>::type> - explicit Integer(ENum value) : Integer(static_cast(value)) { - static_assert(std::is_same::type>::value, + template::value>::type> + explicit Integer(Enum value) : Integer(static_cast(value)) { + static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); } /*! @@ -362,6 +403,24 @@ class Integer : public IntImm { << " Trying to reference a null Integer"; return (*this)->value; } + // comparators + Bool operator==(int other) const { + if (data_ == nullptr) return Bool(false); + return Bool((*this)->value == other); + } + Bool operator!=(int other) const { + return !(*this == other); + } + template::value>::type> + Bool operator==(Enum other) const { + return *this == static_cast(other); + } + template::value>::type> + Bool operator!=(Enum other) const { + return *this != static_cast(other); + } }; /*! \brief range over one dimension */ diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index dc7a2b2185687..d55656f34b006 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode { * \code * * void GetAttrExample(const BaseFunc& f) { - * Integer value = f->GetAttr("AttrKey", 0); + * auto value = f->GetAttr("AttrKey", 0); * } * * \endcode */ template - TObjectRef GetAttr(const std::string& attr_key, - TObjectRef default_value = NullValue()) const { + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { static_assert(std::is_base_of::value, "Can only call GetAttr with ObjectRef types."); if (!attrs.defined()) return default_value; auto it = attrs->dict.find(attr_key); if (it != attrs->dict.end()) { - return Downcast((*it).second); + return Downcast>((*it).second); } else { return default_value; } } - + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr( + const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } /*! * \brief Check whether the function has an non-zero integer attr. * @@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0)->value != 0; + return GetAttr(attr_key, 0) != 0; } static constexpr const char* _type_key = "BaseFunc"; diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index b39e3b4034213..471a0de361b7c 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -63,7 +63,6 @@ using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::String; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8963f0921276e..8f426415ffeea 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -353,6 +353,10 @@ class StringObj : public Object { */ class String : public ObjectRef { public: + /*! + * \brief Construct an empty string. + */ + String() : String(std::string()) {} /*! * \brief Construct a new String object * @@ -467,9 +471,6 @@ class String : public ObjectRef { */ size_t size() const { const auto* ptr = get(); - if (ptr == nullptr) { - return 0; - } return ptr->size; } @@ -524,7 +525,7 @@ class String : public ObjectRef { /*! \return the internal StringObj pointer */ const StringObj* get() const { return operator->(); } - TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); private: /*! @@ -610,7 +611,146 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +/*! + * \brief Optional container that to represent to a Nullable variant of T. + * \tparam T The original ObjectRef. + * + * \code + * + * Optional opt0 = nullptr; + * Optional opt1 = String("xyz"); + * CHECK(opt0 == nullptr); + * CHECK(opt1 == "xyz"); + * + * \endcode + */ +template +class Optional : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + static_assert(std::is_base_of::value, + "Optional is only defined for ObjectRef."); + // default constructors. + Optional() = default; + Optional(const Optional&) = default; + Optional(Optional&&) = default; + Optional& operator=(const Optional&) = default; + Optional& operator=(Optional&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already matches the ContainerType. + * \param ptr + */ + explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) { + } + Optional& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // delete the int constructor + // since Optional(0) is ambiguious + // 0 can be implicitly casted to nullptr_t + explicit Optional(int val) = delete; + Optional& operator=(int val) = delete; + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + CHECK(data_ != nullptr); + return T(data_); + } + /*! + * \return The contained value if the Optional is not null + * otherwise return the default_value. + */ + T value_or(T default_value) const { + return data_ != nullptr ? T(data_) : default_value; + } + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { + return *this != nullptr; + } + // operator overloadings + bool operator==(std::nullptr_t) const { + return data_ == nullptr; + } + bool operator!=(std::nullptr_t) const { + return data_ != nullptr; + } + auto operator==(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (*this != nullptr && other != nullptr) { + return value() == other.value(); + } else { + // one of them is nullptr. + return RetType(false); + } + } + auto operator!=(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() != other.value()); + if (same_as(other)) return RetType(false); + if (*this != nullptr && other != nullptr) { + return value() != other.value(); + } else { + // one of them is nullptr. + return RetType(true); + } + } + auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (*this != nullptr) return value() == other; + return RetType(false); + } + auto operator!=(const T& other) const { + return !(*this == other); + } + template + auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (*this == nullptr) return RetType(false); + return value() == other; + } + template + auto operator!=(const U& other) const { + using RetType = decltype(value() != other); + if (*this == nullptr) return RetType(true); + return value() != other; + } + static constexpr bool _type_is_nullable = true; +}; + +template +struct PackedFuncValueConverter> { + static Optional From(const TVMArgValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } + static Optional From(const TVMRetValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } +}; + } // namespace runtime + +// expose the functions to the root namespace. +using runtime::String; +using runtime::Optional; } // namespace tvm namespace std { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index acbb9398b74c7..edca925baeb0d 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -546,7 +546,9 @@ class ObjectRef { bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } - /*! \return whether the expression is null */ + /*! + * \return whether the object is defined(not null). + */ bool defined() const { return data_ != nullptr; } @@ -582,6 +584,8 @@ class ObjectRef { /*! \brief type indicate the container type. */ using ContainerType = Object; + // Default type properties for the reference class. + static constexpr bool _type_is_nullable = true; protected: /*! \brief Internal pointer that backs the reference. */ @@ -720,6 +724,17 @@ struct ObjectEqual { TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() + +/* + * \brief Define the default copy/move constructor and assign opeator + * \param TypeName The class typename. + */ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; \ + /* * \brief Define object reference methods. * \param TypeName The object type name @@ -727,15 +742,34 @@ struct ObjectEqual { * \param ObjectName The type name of the object. */ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ using ContainerType = ObjectName; +/* + * \brief Define object reference methods that is not nullable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName( \ + ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ + : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + /* * \brief Define object reference methods of whose content is mutable. * \param TypeName The object type name @@ -745,7 +779,8 @@ struct ObjectEqual { * This macro is only reserved for objects that stores runtime states. */ #define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ @@ -869,11 +904,14 @@ inline const ObjectType* ObjectRef::as() const { } } -template -inline RelayRefType GetRef(const ObjType* ptr) { - static_assert(std::is_base_of::value, +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, "Can only cast to the ref of same container type"); - return RelayRefType(ObjectPtr(const_cast(static_cast(ptr)))); + if (!RefType::_type_is_nullable) { + CHECK(ptr != nullptr); + } + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); } template @@ -885,9 +923,15 @@ inline ObjectPtr GetObjectPtr(ObjType* ptr) { template inline SubRef Downcast(BaseRef ref) { - CHECK(!ref.defined() || ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; + if (ref.defined()) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } else { + CHECK(SubRef::_type_is_nullable) + << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; + } return SubRef(std::move(ref.data_)); } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index c5f0df57b10c5..3d5a7e865303a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -352,7 +352,7 @@ template struct ObjectTypeChecker { static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - if (ptr == nullptr) return true; + if (ptr == nullptr) return T::_type_is_nullable; return ptr->IsInstance(); } static std::string TypeName() { @@ -1400,7 +1400,11 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; - if (type_code_ == kTVMNullptr) return TObjectRef(ObjectPtr(nullptr)); + if (type_code_ == kTVMNullptr) { + CHECK(TObjectRef::_type_is_nullable) + << "Expect a not null value of " << ContainerType::_type_key; + return TObjectRef(ObjectPtr(nullptr)); + } // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { // Casting to a sub-class of NDArray diff --git a/python/setup.py b/python/setup.py index 937d682e3c852..62f374923714a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,7 +96,7 @@ def config_cython(): "../3rdparty/dmlc-core/include", "../3rdparty/dlpack/include", ], - extra_compile_args=["-std=c++11"], + extra_compile_args=["-std=c++14"], library_dirs=library_dirs, libraries=libraries, language="c++")) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f939f0a1e7d60..d7955a2ca6208 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -244,8 +244,9 @@ split_dev_host_funcs(IRModule mod_mixed, auto host_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value != static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), BindTarget(target_host), tir::transform::LowerTVMBuiltin(), @@ -259,8 +260,9 @@ split_dev_host_funcs(IRModule mod_mixed, // device pipeline auto device_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value == static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; }), BindTarget(target), tir::transform::LowerWarpMemory(), diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9cb6b2efe28c0..4ed8fbc15abd7 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -620,14 +620,14 @@ class CompileEngineImpl : public CompileEngineNode { if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen; + std::string code_gen_name = code_gen.value(); if (ext_mods.find(code_gen_name) == ext_mods.end()) { ext_mods[code_gen_name] = IRModule({}, {}); } auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(std::string(symbol_name)); + auto gv = GlobalVar(std::string(symbol_name.value())); ext_mods[code_gen_name]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } @@ -698,7 +698,7 @@ class CompileEngineImpl : public CompileEngineNode { key->source_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node); + cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); return value; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 1b953f3c44671..7dfa4bac06ed8 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -72,7 +72,7 @@ class CSourceModuleCodegenBase { const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; - return std::string(name_node); + return std::string(name_node.value()); } }; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e2b0fffec8bde..8af6247fc810f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -446,7 +446,7 @@ class VMFunctionCompiler : ExprFunctor { const Expr& outputs) { std::vector argument_registers; - CHECK_NE(func->GetAttr(attr::kPrimitive, 0)->value, 0) + CHECK(func->GetAttr(attr::kPrimitive, 0) != 0) << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 59c549cabfee8..bfbefd57a3105 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -45,7 +45,7 @@ inline std::string GenerateName(const Function& func) { } bool IsClosure(const Function& func) { - return func->GetAttr(attr::kClosure, 0)->value != 0; + return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index a06eb5a4d3472..06dd2b16661f1 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,8 +145,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, } bool FunctionPassNode::SkipFunction(const Function& func) const { - return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - (func->GetAttr(attr::kCompiler).defined()); + return (func->GetAttr(attr::kCompiler).defined()) || + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 44d7b54e96374..2499982e321a1 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -158,9 +158,9 @@ class AnnotateTargetWrapper : public ExprMutator { // if it is in the target list. Function func = Downcast(cn->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined()) { - std::string comp_name_str = comp_name; + + if (auto comp_name = func->GetAttr(attr::kComposite)) { + std::string comp_name_str = comp_name.value(); size_t i = comp_name_str.find('.'); if (i != std::string::npos) { std::string comp_target = comp_name_str.substr(0, i); diff --git a/src/target/build_common.h b/src/target/build_common.h index 5ba51da4ce672..93687c2578acc 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -51,14 +51,14 @@ ExtractFuncInfo(const IRModule& mod) { for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - if (thread_axis.defined()) { + if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { + auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - fmap[static_cast(global_symbol)] = info; + fmap[static_cast(global_symbol.value())] = info; } return fmap; } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index a863056e82267..ad09730e6db83 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -130,7 +130,7 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( - std::make_pair(global_symbol.operator std::string(), + std::make_pair(global_symbol.value().operator std::string(), builder_->CreatePointerCast(function_, t_void_p_))); } AddDebugInformation(function_); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 7112691de1bcb..604533933b922 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -131,12 +131,12 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - CHECK(module_->getFunction(static_cast(global_symbol)) == nullptr) + CHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) << "Function " << global_symbol << " already exist in module"; function_ = llvm::Function::Create( ftype, llvm::Function::ExternalLinkage, - global_symbol.operator std::string(), module_.get()); + global_symbol.value().operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 52dccbaf5eb61..d1a244d01ff40 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -216,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()); - entry_func = global_symbol; + entry_func = global_symbol.value(); } funcs.push_back(f); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 634fb9a57f27a..2d659e4487e31 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -138,8 +138,7 @@ runtime::Module BuildCUDA(IRModule mod) { << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index c6011cd4dc87f..64674e3360dd2 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -45,8 +45,7 @@ runtime::Module BuildAOCL(IRModule mod, << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a0e18a6120554..444dc996b10fa 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -84,7 +84,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol) << "("; + this->stream << " " << static_cast(global_symbol.value()) << "("; for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 715c0ae92ddca..ea49d33351a06 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -61,7 +61,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - this->stream << "kernel void " << static_cast(global_symbol) << "("; + this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; @@ -91,7 +91,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = static_cast(global_symbol) + "_args_t"; + std::string arg_buf_type = + static_cast(global_symbol.value()) + "_args_t"; stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct @@ -120,8 +121,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - CHECK(thread_axis.defined()); + auto thread_axis = f->GetAttr>( + tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -278,8 +279,7 @@ runtime::Module BuildMetal(IRModule mod) { << "CodeGenMetal: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 67761c17680a8..d5b89609e514c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -249,8 +249,7 @@ runtime::Module BuildOpenCL(IRModule mod) { << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 13d87d282e6cb..946b483a1dd9c 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -160,7 +160,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; - shaders_[static_cast(global_symbol)] = runtime::OpenGLShader( + shaders_[static_cast(global_symbol.value())] = runtime::OpenGLShader( this->decl_stream.str() + this->stream.str(), std::move(arg_names), std::move(arg_kinds), this->thread_extent_var_); @@ -299,8 +299,7 @@ runtime::Module BuildOpenGL(IRModule mod) { << "CodeGenOpenGL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 7486164444c4e..71c36264afa46 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -138,8 +138,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } @@ -164,7 +163,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - kernel_info.push_back({global_symbol, code}); + kernel_info.push_back({global_symbol.value(), code}); } std::string xclbin; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 58721414a6651..161c1ca3bab10 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -87,14 +87,13 @@ runtime::Module BuildSPIRV(IRModule mod) { << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); f = PointerValueTypeRewrite(std::move(f)); VulkanShader shader; shader.data = cg.BuildFunction(f); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index db2a2f359aa48..bfe21b024426c 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -82,7 +82,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - builder_->CommitKernelFunction(func_ptr, static_cast(global_symbol)); + builder_->CommitKernelFunction( + func_ptr, static_cast(global_symbol.value())); return builder_->Finalize(); } diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index da75a70e91232..661fdabd3c32e 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -539,7 +539,7 @@ runtime::Module BuildStackVM(const IRModule& mod) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); StackVM vm = codegen::CodeGenStackVM().Compile(f); CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index d6a521f984870..9ff4f3d5b7384 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -195,7 +195,7 @@ void VerifyMemory(const IRModule& mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - MemoryAccessVerifier v(func, target->device_type); + MemoryAccessVerifier v(func, target.value()->device_type); v.Run(); if (v.Failed()) { LOG(FATAL) diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc index 952d6635f582a..a6db9f9c6da84 100644 --- a/src/tir/transforms/bind_device_type.cc +++ b/src/tir/transforms/bind_device_type.cc @@ -99,7 +99,7 @@ Pass BindDeviceType() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "BindDeviceType: Require the target attribute"; - n->body = DeviceTypeBinder(target->device_type)(std::move(n->body)); + n->body = DeviceTypeBinder(target.value()->device_type)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {}); diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 6026f8c67567f..6cf9e3adce967 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -141,7 +141,7 @@ Pass LowerCustomDatatypes() { CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; - n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body)); + n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 41a94937d4ce5..6ae638f334742 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -293,7 +293,7 @@ Pass LowerIntrin() { << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; n->body = - IntrinInjecter(&analyzer, target->target_name)(std::move(n->body)); + IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index c4df2dcdb868a..655a0074c7fde 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -348,7 +348,7 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; - n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); + n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 1921db53cb060..612a8f4d9eef4 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -393,7 +393,7 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body)); + n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index b1dd235bce03f..dd4bd66426767 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -48,9 +48,9 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) + CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; - std::string name_hint = global_symbol; + std::string name_hint = global_symbol.value(); auto* func_ptr = func.CopyOnWrite(); const Stmt nop = EvaluateNode::make(0); @@ -240,8 +240,9 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr(tvm::attr::kCallingConv, 0)->value - == static_cast(CallingConv::kDefault)) { + if (func->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); } diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index f3663532e56ba..fdcfc4d4702e6 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -82,7 +82,10 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) tmap[kv.first] = kv.second; } - auto thread_axis = f->GetAttr >(tir::attr::kDeviceThreadAxis); + auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); + CHECK(opt_thread_axis != nullptr) + << "Require attribute " << tir::attr::kDeviceThreadAxis; + auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); // replace the thread axis diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 792a06157c09c..927536b5938e7 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -277,7 +277,9 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; HostDeviceSplitter splitter( - device_mod, target, static_cast(global_symbol)); + device_mod, + target.value(), + static_cast(global_symbol.value())); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 063247db09b66..c67df63e6e7e5 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -401,6 +402,74 @@ TEST(String, Cast) { String s2 = Downcast(r); } + +TEST(Optional, Composition) { + Optional opt0(nullptr); + Optional opt1 = String("xyz"); + Optional opt2 = String("xyz1"); + // operator bool + CHECK(!opt0); + CHECK(opt1); + // comparison op + CHECK(opt0 != "xyz"); + CHECK(opt1 == "xyz"); + CHECK(opt1 != nullptr); + CHECK(opt0 == nullptr); + CHECK(opt0.value_or("abc") == "abc"); + CHECK(opt1.value_or("abc") == "xyz"); + CHECK(opt0 != opt1); + CHECK(opt1 == Optional(String("xyz"))); + CHECK(opt0 == Optional(nullptr)); + opt0 = opt1; + CHECK(opt0 == opt1); + CHECK(opt0.value().same_as(opt1.value())); + opt0 = std::move(opt2); + CHECK(opt0 != opt2); +} + +TEST(Optional, IntCmp) { + Integer val(CallingConv::kDefault); + Optional opt = Integer(0); + CHECK(0 == static_cast(CallingConv::kDefault)); + CHECK(val == CallingConv::kDefault); + CHECK(opt == CallingConv::kDefault); + + // check we can handle implicit 0 to nullptr conversion. + Optional opt1(nullptr); + CHECK(opt1 != 0); + CHECK(opt1 != false); + CHECK(!(opt1 == 0)); +} + +TEST(Optional, PackedCall) { + auto tf = [](Optional s, bool isnull) { + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + return s; + }; + auto func = TypedPackedFunc(Optional, bool)>(tf); + CHECK(func(String("xyz"), false) == "xyz"); + CHECK(func(Optional(nullptr), true) == nullptr); + + auto pf = [](TVMArgs args, TVMRetValue* rv) { + Optional s = args[0]; + bool isnull = args[1]; + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + *rv = s; + }; + auto packedfunc = PackedFunc(pf); + CHECK(packedfunc("xyz", false).operator String() == "xyz"); + CHECK(packedfunc("xyz", false).operator Optional() == "xyz"); + CHECK(packedfunc(nullptr, true).operator Optional() == nullptr); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index b4496bb044bad..3797910080a1d 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -39,7 +39,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 397c35db53d94..39209232f3d05 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -468,13 +468,13 @@ def run_extern(label, get_extern_src, **kwargs): def test_dso_extern(): - run_extern("lib", generate_csource_module, options=["-O2", "-std=c++11"]) + run_extern("lib", generate_csource_module, options=["-O2", "-std=c++14"]) def test_engine_extern(): run_extern("engine", generate_engine_module, - options=["-O2", "-std=c++11", "-I" + tmp_path.relpath("")]) + options=["-O2", "-std=c++14", "-I" + tmp_path.relpath("")]) def test_json_extern(): if not tvm.get_global_func("module.loadfile_examplejson", True): diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 705a2614674af..01ba9b619205b 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -42,7 +42,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index c7d9626931d0e..1d0cc5b79a44c 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -182,7 +182,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 35bafb4ba3c7b..fce7d2f350dcd 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -191,7 +191,7 @@ def verify_multi_c_mod_export(): path_lib = temp.relpath(file_name) resnet18_cpu_lib.import_module(f) resnet18_cpu_lib.import_module(engine_module) - kwargs = {"options": ["-O2", "-std=c++11", "-I" + header_file_dir_path.relpath("")]} + kwargs = {"options": ["-O2", "-std=c++14", "-I" + header_file_dir_path.relpath("")]} resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs) loaded_lib = tvm.runtime.load_module(path_lib) assert loaded_lib.type_key == "library" diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index 65c1274214e6d..9cfd50927041f 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -107,7 +107,7 @@ def reconfig_runtime(cfg_json): if pkg.same_config(old_cfg): logging.info("Skip reconfig_runtime due to same config.") return - cflags = ["-O2", "-std=c++11"] + cflags = ["-O2", "-std=c++14"] cflags += pkg.cflags ldflags = pkg.ldflags lib_name = dll_path