diff --git a/apps/android_camera/app/src/main/jni/Application.mk b/apps/android_camera/app/src/main/jni/Application.mk index 95a5a9697bcc..63a79458ef94 100644 --- a/apps/android_camera/app/src/main/jni/Application.mk +++ b/apps/android_camera/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= all APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_deploy/app/src/main/jni/Application.mk b/apps/android_deploy/app/src/main/jni/Application.mk index ee13eb8a1213..a50a40bf5cd1 100644 --- a/apps/android_deploy/app/src/main/jni/Application.mk +++ b/apps/android_deploy/app/src/main/jni/Application.mk @@ -27,7 +27,7 @@ include $(config) APP_STL := c++_static -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti -ifeq ($(USE_OPENCL), 1) +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti +ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index 56288bde9898..54abdf771e2a 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile index 9cd39b446acc..927331ad00ea 100644 --- a/apps/cpp_rpc/Makefile +++ b/apps/cpp_rpc/Makefile @@ -28,7 +28,7 @@ else LINK_PTHREAD= endif -PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC -Wall\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/dso_plugin_module/Makefile b/apps/dso_plugin_module/Makefile index 2ee6189e2876..c2ce3306870a 100644 --- a/apps/dso_plugin_module/Makefile +++ b/apps/dso_plugin_module/Makefile @@ -16,7 +16,7 @@ # under the License. TVM_ROOT=$(shell cd ../..; pwd) -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/extension/Makefile b/apps/extension/Makefile index e178b661f403..91d914aba63b 100644 --- a/apps/extension/Makefile +++ b/apps/extension/Makefile @@ -17,7 +17,7 @@ # Minimum Makefile for the extension package TVM_ROOT=$(shell cd ../..; pwd) -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/howto_deploy/Makefile b/apps/howto_deploy/Makefile index a260e89bc042..4ee243c2ce60 100644 --- a/apps/howto_deploy/Makefile +++ b/apps/howto_deploy/Makefile @@ -19,7 +19,7 @@ TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\ diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index d166eaf756a5..81bab497bebb 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -24,7 +24,7 @@ * include in your project. * * - Copy this file into your project which depends on tvm runtime. - * - Compile with -std=c++11 + * - Compile with -std=c++14 * - Add the following include path * - /path/to/tvm/include/ * - /path/to/tvm/3rdparty/dmlc-core/include/ diff --git a/apps/rocm_rpc/Makefile b/apps/rocm_rpc/Makefile index 36eb41596be8..971ca4603314 100644 --- a/apps/rocm_rpc/Makefile +++ b/apps/rocm_rpc/Makefile @@ -21,7 +21,7 @@ ROCM_PATH=/opt/rocm TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\ diff --git a/apps/tf_tvmdsoop/CMakeLists.txt b/apps/tf_tvmdsoop/CMakeLists.txt index cb601ef6d30d..f4e83c528701 100644 --- a/apps/tf_tvmdsoop/CMakeLists.txt +++ b/apps/tf_tvmdsoop/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.2) project(tf_tvmdsoop C CXX) -set(TFTVM_COMPILE_FLAGS -std=c++11) +set(TFTVM_COMPILE_FLAGS -std=c++14) set(BUILD_TVMDSOOP_ONLY ON) set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT}) set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build) diff --git a/golang/Makefile b/golang/Makefile index c54fd0e0992c..6fd77996e119 100644 --- a/golang/Makefile +++ b/golang/Makefile @@ -25,7 +25,7 @@ NATIVE_SRC = tvm_runtime_pack.cc GOPATH=$(CURDIR)/gopath GOPATHDIR=${GOPATH}/src/${TARGET}/ CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" -CGO_CXXFLAGS="-std=c++11" +CGO_CXXFLAGS="-std=c++14" CGO_CFLAGS="-I${TVM_BASE}" CGO_LDFLAGS="-ldl -lm" diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 0fc832e0fb7a..d12f1b85114c 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -85,6 +85,8 @@ namespace tvm { */ template inline TObjectRef NullValue() { + static_assert(TObjectRef::_type_is_nullable, + "Can only get NullValue for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4e0a301156a3..9d77136e14ff 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -317,6 +317,47 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); }; +/*! + * \brief Boolean constant. + * + * This reference type is useful to add additional compile-time + * type checks and helper functions for Integer equal comparisons. + */ +class Bool : public IntImm { + public: + explicit Bool(bool value) + : IntImm(DataType::Bool(), value) { + } + Bool operator!() const { + return Bool((*this)->value == 0); + } + operator bool() const { + return (*this)->value != 0; + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); +}; + +// Overload operators to make sure we have the most fine grained types. +inline Bool operator||(const Bool& a, bool b) { + return Bool(a.operator bool() || b); +} +inline Bool operator||(bool a, const Bool& b) { + return Bool(a || b.operator bool()); +} +inline Bool operator||(const Bool& a, const Bool& b) { + return Bool(a.operator bool() || b.operator bool()); +} +inline Bool operator&&(const Bool& a, bool b) { + return Bool(a.operator bool() && b); +} +inline Bool operator&&(bool a, const Bool& b) { + return Bool(a && b.operator bool()); +} +inline Bool operator&&(const Bool& a, const Bool& b) { + return Bool(a.operator bool() && b.operator bool()); +} + /*! * \brief Container of constant int that adds more constructors. * @@ -346,10 +387,10 @@ class Integer : public IntImm { * \tparam Enum The enum type. * \param value The enum value. */ - template::value>::type> - explicit Integer(ENum value) : Integer(static_cast(value)) { - static_assert(std::is_same::type>::value, + template::value>::type> + explicit Integer(Enum value) : Integer(static_cast(value)) { + static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); } /*! @@ -368,6 +409,24 @@ class Integer : public IntImm { << " Trying to reference a null Integer"; return (*this)->value; } + // comparators + Bool operator==(int other) const { + if (data_ == nullptr) return Bool(false); + return Bool((*this)->value == other); + } + Bool operator!=(int other) const { + return !(*this == other); + } + template::value>::type> + Bool operator==(Enum other) const { + return *this == static_cast(other); + } + template::value>::type> + Bool operator!=(Enum other) const { + return *this != static_cast(other); + } }; /*! \brief range over one dimension */ diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index dc7a2b218568..d55656f34b00 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode { * \code * * void GetAttrExample(const BaseFunc& f) { - * Integer value = f->GetAttr("AttrKey", 0); + * auto value = f->GetAttr("AttrKey", 0); * } * * \endcode */ template - TObjectRef GetAttr(const std::string& attr_key, - TObjectRef default_value = NullValue()) const { + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { static_assert(std::is_base_of::value, "Can only call GetAttr with ObjectRef types."); if (!attrs.defined()) return default_value; auto it = attrs->dict.find(attr_key); if (it != attrs->dict.end()) { - return Downcast((*it).second); + return Downcast>((*it).second); } else { return default_value; } } - + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr( + const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } /*! * \brief Check whether the function has an non-zero integer attr. * @@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0)->value != 0; + return GetAttr(attr_key, 0) != 0; } static constexpr const char* _type_key = "BaseFunc"; diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index b39e3b403421..471a0de361b7 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -63,7 +63,6 @@ using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::String; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8963f0921276..8f426415ffee 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -353,6 +353,10 @@ class StringObj : public Object { */ class String : public ObjectRef { public: + /*! + * \brief Construct an empty string. + */ + String() : String(std::string()) {} /*! * \brief Construct a new String object * @@ -467,9 +471,6 @@ class String : public ObjectRef { */ size_t size() const { const auto* ptr = get(); - if (ptr == nullptr) { - return 0; - } return ptr->size; } @@ -524,7 +525,7 @@ class String : public ObjectRef { /*! \return the internal StringObj pointer */ const StringObj* get() const { return operator->(); } - TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); private: /*! @@ -610,7 +611,146 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +/*! + * \brief Optional container that to represent to a Nullable variant of T. + * \tparam T The original ObjectRef. + * + * \code + * + * Optional opt0 = nullptr; + * Optional opt1 = String("xyz"); + * CHECK(opt0 == nullptr); + * CHECK(opt1 == "xyz"); + * + * \endcode + */ +template +class Optional : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + static_assert(std::is_base_of::value, + "Optional is only defined for ObjectRef."); + // default constructors. + Optional() = default; + Optional(const Optional&) = default; + Optional(Optional&&) = default; + Optional& operator=(const Optional&) = default; + Optional& operator=(Optional&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already matches the ContainerType. + * \param ptr + */ + explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) { + } + Optional& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // delete the int constructor + // since Optional(0) is ambiguious + // 0 can be implicitly casted to nullptr_t + explicit Optional(int val) = delete; + Optional& operator=(int val) = delete; + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + CHECK(data_ != nullptr); + return T(data_); + } + /*! + * \return The contained value if the Optional is not null + * otherwise return the default_value. + */ + T value_or(T default_value) const { + return data_ != nullptr ? T(data_) : default_value; + } + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { + return *this != nullptr; + } + // operator overloadings + bool operator==(std::nullptr_t) const { + return data_ == nullptr; + } + bool operator!=(std::nullptr_t) const { + return data_ != nullptr; + } + auto operator==(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (*this != nullptr && other != nullptr) { + return value() == other.value(); + } else { + // one of them is nullptr. + return RetType(false); + } + } + auto operator!=(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() != other.value()); + if (same_as(other)) return RetType(false); + if (*this != nullptr && other != nullptr) { + return value() != other.value(); + } else { + // one of them is nullptr. + return RetType(true); + } + } + auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (*this != nullptr) return value() == other; + return RetType(false); + } + auto operator!=(const T& other) const { + return !(*this == other); + } + template + auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (*this == nullptr) return RetType(false); + return value() == other; + } + template + auto operator!=(const U& other) const { + using RetType = decltype(value() != other); + if (*this == nullptr) return RetType(true); + return value() != other; + } + static constexpr bool _type_is_nullable = true; +}; + +template +struct PackedFuncValueConverter> { + static Optional From(const TVMArgValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } + static Optional From(const TVMRetValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } +}; + } // namespace runtime + +// expose the functions to the root namespace. +using runtime::String; +using runtime::Optional; } // namespace tvm namespace std { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index acbb9398b74c..edca925baeb0 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -546,7 +546,9 @@ class ObjectRef { bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } - /*! \return whether the expression is null */ + /*! + * \return whether the object is defined(not null). + */ bool defined() const { return data_ != nullptr; } @@ -582,6 +584,8 @@ class ObjectRef { /*! \brief type indicate the container type. */ using ContainerType = Object; + // Default type properties for the reference class. + static constexpr bool _type_is_nullable = true; protected: /*! \brief Internal pointer that backs the reference. */ @@ -720,6 +724,17 @@ struct ObjectEqual { TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() + +/* + * \brief Define the default copy/move constructor and assign opeator + * \param TypeName The class typename. + */ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; \ + /* * \brief Define object reference methods. * \param TypeName The object type name @@ -727,15 +742,34 @@ struct ObjectEqual { * \param ObjectName The type name of the object. */ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ using ContainerType = ObjectName; +/* + * \brief Define object reference methods that is not nullable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName( \ + ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ + : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + /* * \brief Define object reference methods of whose content is mutable. * \param TypeName The object type name @@ -745,7 +779,8 @@ struct ObjectEqual { * This macro is only reserved for objects that stores runtime states. */ #define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ @@ -869,11 +904,14 @@ inline const ObjectType* ObjectRef::as() const { } } -template -inline RelayRefType GetRef(const ObjType* ptr) { - static_assert(std::is_base_of::value, +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, "Can only cast to the ref of same container type"); - return RelayRefType(ObjectPtr(const_cast(static_cast(ptr)))); + if (!RefType::_type_is_nullable) { + CHECK(ptr != nullptr); + } + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); } template @@ -885,9 +923,15 @@ inline ObjectPtr GetObjectPtr(ObjType* ptr) { template inline SubRef Downcast(BaseRef ref) { - CHECK(!ref.defined() || ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; + if (ref.defined()) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } else { + CHECK(SubRef::_type_is_nullable) + << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; + } return SubRef(std::move(ref.data_)); } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index c5f0df57b10c..3d5a7e865303 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -352,7 +352,7 @@ template struct ObjectTypeChecker { static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - if (ptr == nullptr) return true; + if (ptr == nullptr) return T::_type_is_nullable; return ptr->IsInstance(); } static std::string TypeName() { @@ -1400,7 +1400,11 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; - if (type_code_ == kTVMNullptr) return TObjectRef(ObjectPtr(nullptr)); + if (type_code_ == kTVMNullptr) { + CHECK(TObjectRef::_type_is_nullable) + << "Expect a not null value of " << ContainerType::_type_key; + return TObjectRef(ObjectPtr(nullptr)); + } // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { // Casting to a sub-class of NDArray diff --git a/python/setup.py b/python/setup.py index 937d682e3c85..62f374923714 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,7 +96,7 @@ def config_cython(): "../3rdparty/dmlc-core/include", "../3rdparty/dlpack/include", ], - extra_compile_args=["-std=c++11"], + extra_compile_args=["-std=c++14"], library_dirs=library_dirs, libraries=libraries, language="c++")) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f939f0a1e7d6..d7955a2ca620 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -244,8 +244,9 @@ split_dev_host_funcs(IRModule mod_mixed, auto host_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value != static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), BindTarget(target_host), tir::transform::LowerTVMBuiltin(), @@ -259,8 +260,9 @@ split_dev_host_funcs(IRModule mod_mixed, // device pipeline auto device_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value == static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; }), BindTarget(target), tir::transform::LowerWarpMemory(), diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9cb6b2efe28c..4ed8fbc15abd 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -620,14 +620,14 @@ class CompileEngineImpl : public CompileEngineNode { if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen; + std::string code_gen_name = code_gen.value(); if (ext_mods.find(code_gen_name) == ext_mods.end()) { ext_mods[code_gen_name] = IRModule({}, {}); } auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(std::string(symbol_name)); + auto gv = GlobalVar(std::string(symbol_name.value())); ext_mods[code_gen_name]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } @@ -698,7 +698,7 @@ class CompileEngineImpl : public CompileEngineNode { key->source_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node); + cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); return value; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 1b953f3c4467..7dfa4bac06ed 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -72,7 +72,7 @@ class CSourceModuleCodegenBase { const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; - return std::string(name_node); + return std::string(name_node.value()); } }; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e2b0fffec8bd..8af6247fc810 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -446,7 +446,7 @@ class VMFunctionCompiler : ExprFunctor { const Expr& outputs) { std::vector argument_registers; - CHECK_NE(func->GetAttr(attr::kPrimitive, 0)->value, 0) + CHECK(func->GetAttr(attr::kPrimitive, 0) != 0) << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 59c549cabfee..bfbefd57a310 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -45,7 +45,7 @@ inline std::string GenerateName(const Function& func) { } bool IsClosure(const Function& func) { - return func->GetAttr(attr::kClosure, 0)->value != 0; + return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index a06eb5a4d347..06dd2b16661f 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,8 +145,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, } bool FunctionPassNode::SkipFunction(const Function& func) const { - return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - (func->GetAttr(attr::kCompiler).defined()); + return (func->GetAttr(attr::kCompiler).defined()) || + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 44d7b54e9637..2499982e321a 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -158,9 +158,9 @@ class AnnotateTargetWrapper : public ExprMutator { // if it is in the target list. Function func = Downcast(cn->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined()) { - std::string comp_name_str = comp_name; + + if (auto comp_name = func->GetAttr(attr::kComposite)) { + std::string comp_name_str = comp_name.value(); size_t i = comp_name_str.find('.'); if (i != std::string::npos) { std::string comp_target = comp_name_str.substr(0, i); diff --git a/src/target/build_common.h b/src/target/build_common.h index 5ba51da4ce67..93687c2578ac 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -51,14 +51,14 @@ ExtractFuncInfo(const IRModule& mod) { for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - if (thread_axis.defined()) { + if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { + auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - fmap[static_cast(global_symbol)] = info; + fmap[static_cast(global_symbol.value())] = info; } return fmap; } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index a863056e8226..ad09730e6db8 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -130,7 +130,7 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( - std::make_pair(global_symbol.operator std::string(), + std::make_pair(global_symbol.value().operator std::string(), builder_->CreatePointerCast(function_, t_void_p_))); } AddDebugInformation(function_); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 7112691de1bc..604533933b92 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -131,12 +131,12 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - CHECK(module_->getFunction(static_cast(global_symbol)) == nullptr) + CHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) << "Function " << global_symbol << " already exist in module"; function_ = llvm::Function::Create( ftype, llvm::Function::ExternalLinkage, - global_symbol.operator std::string(), module_.get()); + global_symbol.value().operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 52dccbaf5eb6..d1a244d01ff4 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -216,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()); - entry_func = global_symbol; + entry_func = global_symbol.value(); } funcs.push_back(f); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 634fb9a57f27..2d659e4487e3 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -138,8 +138,7 @@ runtime::Module BuildCUDA(IRModule mod) { << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index c6011cd4dc87..64674e3360dd 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -45,8 +45,7 @@ runtime::Module BuildAOCL(IRModule mod, << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a0e18a612055..444dc996b10f 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -84,7 +84,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol) << "("; + this->stream << " " << static_cast(global_symbol.value()) << "("; for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 715c0ae92ddc..ea49d33351a0 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -61,7 +61,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - this->stream << "kernel void " << static_cast(global_symbol) << "("; + this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; @@ -91,7 +91,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = static_cast(global_symbol) + "_args_t"; + std::string arg_buf_type = + static_cast(global_symbol.value()) + "_args_t"; stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct @@ -120,8 +121,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - CHECK(thread_axis.defined()); + auto thread_axis = f->GetAttr>( + tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -278,8 +279,7 @@ runtime::Module BuildMetal(IRModule mod) { << "CodeGenMetal: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 67761c17680a..d5b89609e514 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -249,8 +249,7 @@ runtime::Module BuildOpenCL(IRModule mod) { << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 13d87d282e6c..946b483a1dd9 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -160,7 +160,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; - shaders_[static_cast(global_symbol)] = runtime::OpenGLShader( + shaders_[static_cast(global_symbol.value())] = runtime::OpenGLShader( this->decl_stream.str() + this->stream.str(), std::move(arg_names), std::move(arg_kinds), this->thread_extent_var_); @@ -299,8 +299,7 @@ runtime::Module BuildOpenGL(IRModule mod) { << "CodeGenOpenGL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 7486164444c4..71c36264afa4 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -138,8 +138,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } @@ -164,7 +163,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - kernel_info.push_back({global_symbol, code}); + kernel_info.push_back({global_symbol.value(), code}); } std::string xclbin; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 58721414a665..161c1ca3bab1 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -87,14 +87,13 @@ runtime::Module BuildSPIRV(IRModule mod) { << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); f = PointerValueTypeRewrite(std::move(f)); VulkanShader shader; shader.data = cg.BuildFunction(f); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index db2a2f359aa4..bfe21b024426 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -82,7 +82,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - builder_->CommitKernelFunction(func_ptr, static_cast(global_symbol)); + builder_->CommitKernelFunction( + func_ptr, static_cast(global_symbol.value())); return builder_->Finalize(); } diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index da75a70e9123..661fdabd3c32 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -539,7 +539,7 @@ runtime::Module BuildStackVM(const IRModule& mod) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); StackVM vm = codegen::CodeGenStackVM().Compile(f); CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index d6a521f98487..9ff4f3d5b738 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -195,7 +195,7 @@ void VerifyMemory(const IRModule& mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - MemoryAccessVerifier v(func, target->device_type); + MemoryAccessVerifier v(func, target.value()->device_type); v.Run(); if (v.Failed()) { LOG(FATAL) diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc index 952d6635f582..a6db9f9c6da8 100644 --- a/src/tir/transforms/bind_device_type.cc +++ b/src/tir/transforms/bind_device_type.cc @@ -99,7 +99,7 @@ Pass BindDeviceType() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "BindDeviceType: Require the target attribute"; - n->body = DeviceTypeBinder(target->device_type)(std::move(n->body)); + n->body = DeviceTypeBinder(target.value()->device_type)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {}); diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 6026f8c67567..6cf9e3adce96 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -141,7 +141,7 @@ Pass LowerCustomDatatypes() { CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; - n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body)); + n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 41a94937d4ce..6ae638f33474 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -293,7 +293,7 @@ Pass LowerIntrin() { << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; n->body = - IntrinInjecter(&analyzer, target->target_name)(std::move(n->body)); + IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index c4df2dcdb868..655a0074c7fd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -348,7 +348,7 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; - n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); + n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 1921db53cb06..612a8f4d9eef 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -393,7 +393,7 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body)); + n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index b1dd235bce03..dd4bd6642676 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -48,9 +48,9 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) + CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; - std::string name_hint = global_symbol; + std::string name_hint = global_symbol.value(); auto* func_ptr = func.CopyOnWrite(); const Stmt nop = EvaluateNode::make(0); @@ -240,8 +240,9 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr(tvm::attr::kCallingConv, 0)->value - == static_cast(CallingConv::kDefault)) { + if (func->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); } diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index f3663532e56b..fdcfc4d4702e 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -82,7 +82,10 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) tmap[kv.first] = kv.second; } - auto thread_axis = f->GetAttr >(tir::attr::kDeviceThreadAxis); + auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); + CHECK(opt_thread_axis != nullptr) + << "Require attribute " << tir::attr::kDeviceThreadAxis; + auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); // replace the thread axis diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 792a06157c09..927536b5938e 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -277,7 +277,9 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; HostDeviceSplitter splitter( - device_mod, target, static_cast(global_symbol)); + device_mod, + target.value(), + static_cast(global_symbol.value())); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 063247db09b6..c67df63e6e7e 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -401,6 +402,74 @@ TEST(String, Cast) { String s2 = Downcast(r); } + +TEST(Optional, Composition) { + Optional opt0(nullptr); + Optional opt1 = String("xyz"); + Optional opt2 = String("xyz1"); + // operator bool + CHECK(!opt0); + CHECK(opt1); + // comparison op + CHECK(opt0 != "xyz"); + CHECK(opt1 == "xyz"); + CHECK(opt1 != nullptr); + CHECK(opt0 == nullptr); + CHECK(opt0.value_or("abc") == "abc"); + CHECK(opt1.value_or("abc") == "xyz"); + CHECK(opt0 != opt1); + CHECK(opt1 == Optional(String("xyz"))); + CHECK(opt0 == Optional(nullptr)); + opt0 = opt1; + CHECK(opt0 == opt1); + CHECK(opt0.value().same_as(opt1.value())); + opt0 = std::move(opt2); + CHECK(opt0 != opt2); +} + +TEST(Optional, IntCmp) { + Integer val(CallingConv::kDefault); + Optional opt = Integer(0); + CHECK(0 == static_cast(CallingConv::kDefault)); + CHECK(val == CallingConv::kDefault); + CHECK(opt == CallingConv::kDefault); + + // check we can handle implicit 0 to nullptr conversion. + Optional opt1(nullptr); + CHECK(opt1 != 0); + CHECK(opt1 != false); + CHECK(!(opt1 == 0)); +} + +TEST(Optional, PackedCall) { + auto tf = [](Optional s, bool isnull) { + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + return s; + }; + auto func = TypedPackedFunc(Optional, bool)>(tf); + CHECK(func(String("xyz"), false) == "xyz"); + CHECK(func(Optional(nullptr), true) == nullptr); + + auto pf = [](TVMArgs args, TVMRetValue* rv) { + Optional s = args[0]; + bool isnull = args[1]; + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + *rv = s; + }; + auto packedfunc = PackedFunc(pf); + CHECK(packedfunc("xyz", false).operator String() == "xyz"); + CHECK(packedfunc("xyz", false).operator Optional() == "xyz"); + CHECK(packedfunc(nullptr, true).operator Optional() == nullptr); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index b4496bb044ba..3797910080a1 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -39,7 +39,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 397c35db53d9..39209232f3d0 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -468,13 +468,13 @@ def run_extern(label, get_extern_src, **kwargs): def test_dso_extern(): - run_extern("lib", generate_csource_module, options=["-O2", "-std=c++11"]) + run_extern("lib", generate_csource_module, options=["-O2", "-std=c++14"]) def test_engine_extern(): run_extern("engine", generate_engine_module, - options=["-O2", "-std=c++11", "-I" + tmp_path.relpath("")]) + options=["-O2", "-std=c++14", "-I" + tmp_path.relpath("")]) def test_json_extern(): if not tvm.get_global_func("module.loadfile_examplejson", True): diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 705a2614674a..01ba9b619205 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -42,7 +42,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index c7d9626931d0..1d0cc5b79a44 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -182,7 +182,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 35bafb4ba3c7..fce7d2f350dc 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -191,7 +191,7 @@ def verify_multi_c_mod_export(): path_lib = temp.relpath(file_name) resnet18_cpu_lib.import_module(f) resnet18_cpu_lib.import_module(engine_module) - kwargs = {"options": ["-O2", "-std=c++11", "-I" + header_file_dir_path.relpath("")]} + kwargs = {"options": ["-O2", "-std=c++14", "-I" + header_file_dir_path.relpath("")]} resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs) loaded_lib = tvm.runtime.load_module(path_lib) assert loaded_lib.type_key == "library" diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index 65c1274214e6..9cfd50927041 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -107,7 +107,7 @@ def reconfig_runtime(cfg_json): if pkg.same_config(old_cfg): logging.info("Skip reconfig_runtime due to same config.") return - cflags = ["-O2", "-std=c++11"] + cflags = ["-O2", "-std=c++14"] cflags += pkg.cflags ldflags = pkg.ldflags lib_name = dll_path