Skip to content

Commit

Permalink
[RUNTIME][IR] Allow non-nullable ObjectRef, introduce Optional<T>. (#…
Browse files Browse the repository at this point in the history
…5314)

* [RUNTIME] Allow non-nullable ObjectRef, introduce Optional<T>.

We use ObjectRef and their sub-classes extensively throughout our codebase.
Each of ObjectRef's sub-classes are nullable, which means they can hold nullptr
as their values.

While in some places we need nullptr as an alternative value. The implicit support
for nullptr in all ObjectRef creates additional burdens for the developer
to explicitly check defined in many places of the codebase.

Moreover, it is unclear from the API's intentional point of view whether
we want a nullable object or not-null version(many cases we want the later).

Borrowing existing wisdoms from languages like Rust. We propose to
introduce non-nullable ObjectRef, and Optional<T> container that
represents a nullable variant.

To keep backward compatiblity, we will start by allowing most ObjectRef to be nullable.
However, we should start to use Optional<T> as the type in places where
we know nullable is a requirement. Gradually, we will move most of the ObjectRef
to be non-nullable and use Optional<T> in the nullable cases.

Such explicitness in typing can help reduce the potential problems
in our codebase overall.

Changes in this PR:
- Introduce _type_is_nullable attribute to ObjectRef
- Introduce Optional<T>
- Change String to be non-nullable.
- Change the API of function->GetAttr to return Optional<T>

* Address review comments

* Upgrade all compiler flags to c++14

* Update as per review comment
  • Loading branch information
tqchen authored Apr 13, 2020
1 parent 3df8d56 commit fc75de9
Show file tree
Hide file tree
Showing 56 changed files with 432 additions and 105 deletions.
2 changes: 1 addition & 1 deletion apps/android_camera/app/src/main/jni/Application.mk
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ include $(config)
APP_ABI ?= all
APP_STL := c++_shared

APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti
ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif
Expand Down
4 changes: 2 additions & 2 deletions apps/android_deploy/app/src/main/jni/Application.mk
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ include $(config)

APP_STL := c++_static

APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti
ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti
ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif
2 changes: 1 addition & 1 deletion apps/android_rpc/app/src/main/jni/Application.mk
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ include $(config)
APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips
APP_STL := c++_shared

APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti
ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_rpc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ else
LINK_PTHREAD=
endif

PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\
PKG_CFLAGS = -std=c++14 -O2 -fPIC -Wall\
-I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\
-I${TVM_ROOT}/3rdparty/dlpack/include
Expand Down
2 changes: 1 addition & 1 deletion apps/dso_plugin_module/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

TVM_ROOT=$(shell cd ../..; pwd)
PKG_CFLAGS = -std=c++11 -O2 -fPIC\
PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include
Expand Down
2 changes: 1 addition & 1 deletion apps/extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# Minimum Makefile for the extension package
TVM_ROOT=$(shell cd ../..; pwd)
PKG_CFLAGS = -std=c++11 -O2 -fPIC\
PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include
Expand Down
2 changes: 1 addition & 1 deletion apps/howto_deploy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TVM_ROOT=$(shell cd ../..; pwd)
DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core

PKG_CFLAGS = -std=c++11 -O2 -fPIC\
PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\
-I${TVM_ROOT}/3rdparty/dlpack/include\
Expand Down
2 changes: 1 addition & 1 deletion apps/howto_deploy/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* include in your project.
*
* - Copy this file into your project which depends on tvm runtime.
* - Compile with -std=c++11
* - Compile with -std=c++14
* - Add the following include path
* - /path/to/tvm/include/
* - /path/to/tvm/3rdparty/dmlc-core/include/
Expand Down
2 changes: 1 addition & 1 deletion apps/rocm_rpc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ROCM_PATH=/opt/rocm
TVM_ROOT=$(shell cd ../..; pwd)
DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core

PKG_CFLAGS = -std=c++11 -O2 -fPIC\
PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\
-I${TVM_ROOT}/3rdparty/dlpack/include\
Expand Down
2 changes: 1 addition & 1 deletion apps/tf_tvmdsoop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
cmake_minimum_required(VERSION 3.2)
project(tf_tvmdsoop C CXX)

set(TFTVM_COMPILE_FLAGS -std=c++11)
set(TFTVM_COMPILE_FLAGS -std=c++14)
set(BUILD_TVMDSOOP_ONLY ON)
set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT})
set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build)
Expand Down
2 changes: 1 addition & 1 deletion golang/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ NATIVE_SRC = tvm_runtime_pack.cc
GOPATH=$(CURDIR)/gopath
GOPATHDIR=${GOPATH}/src/${TARGET}/
CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/"
CGO_CXXFLAGS="-std=c++11"
CGO_CXXFLAGS="-std=c++14"
CGO_CFLAGS="-I${TVM_BASE}"
CGO_LDFLAGS="-ldl -lm"

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ namespace tvm {
*/
template<typename TObjectRef>
inline TObjectRef NullValue() {
static_assert(TObjectRef::_type_is_nullable,
"Can only get NullValue for nullable types");
return TObjectRef(ObjectPtr<Object>(nullptr));
}

Expand Down
67 changes: 63 additions & 4 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,47 @@ class FloatImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};

/*!
* \brief Boolean constant.
*
* This reference type is useful to add additional compile-time
* type checks and helper functions for Integer equal comparisons.
*/
class Bool : public IntImm {
public:
explicit Bool(bool value)
: IntImm(DataType::Bool(), value) {
}
Bool operator!() const {
return Bool((*this)->value == 0);
}
operator bool() const {
return (*this)->value != 0;
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode);
};

// Overload operators to make sure we have the most fine grained types.
inline Bool operator||(const Bool& a, bool b) {
return Bool(a.operator bool() || b);
}
inline Bool operator||(bool a, const Bool& b) {
return Bool(a || b.operator bool());
}
inline Bool operator||(const Bool& a, const Bool& b) {
return Bool(a.operator bool() || b.operator bool());
}
inline Bool operator&&(const Bool& a, bool b) {
return Bool(a.operator bool() && b);
}
inline Bool operator&&(bool a, const Bool& b) {
return Bool(a && b.operator bool());
}
inline Bool operator&&(const Bool& a, const Bool& b) {
return Bool(a.operator bool() && b.operator bool());
}

/*!
* \brief Container of constant int that adds more constructors.
*
Expand Down Expand Up @@ -340,10 +381,10 @@ class Integer : public IntImm {
* \tparam Enum The enum type.
* \param value The enum value.
*/
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
explicit Integer(ENum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
"declare enum to be enum int to use visitor");
}
/*!
Expand All @@ -362,6 +403,24 @@ class Integer : public IntImm {
<< " Trying to reference a null Integer";
return (*this)->value;
}
// comparators
Bool operator==(int other) const {
if (data_ == nullptr) return Bool(false);
return Bool((*this)->value == other);
}
Bool operator!=(int other) const {
return !(*this == other);
}
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator==(Enum other) const {
return *this == static_cast<int>(other);
}
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator!=(Enum other) const {
return *this != static_cast<int>(other);
}
};

/*! \brief range over one dimension */
Expand Down
19 changes: 13 additions & 6 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/container.h>
#include <type_traits>
#include <string>

Expand Down Expand Up @@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode {
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* Integer value = f->GetAttr<Integer>("AttrKey", 0);
* auto value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template<typename TObjectRef>
TObjectRef GetAttr(const std::string& attr_key,
TObjectRef default_value = NullValue<TObjectRef>()) const {
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<TObjectRef>((*it).second);
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}

// variant that uses TObjectRef to enable implicit conversion to default value.
template<typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*!
* \brief Check whether the function has an non-zero integer attr.
*
Expand All @@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode {
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0)->value != 0;
return GetAttr<Integer>(attr_key, 0) != 0;
}

static constexpr const char* _type_key = "BaseFunc";
Expand Down
1 change: 0 additions & 1 deletion include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::String;

} // namespace tvm
#endif // TVM_NODE_NODE_H_
Loading

0 comments on commit fc75de9

Please sign in to comment.