From 8547ee9f7b52dad3113c43c91ae4ad97a71a1649 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 5 Aug 2020 18:57:25 -0700 Subject: [PATCH] [RUNTIME] Enable auto conversion String->DLDataType (#6214) --- include/tvm/runtime/container.h | 39 +--------- include/tvm/runtime/packed_func.h | 73 +++++++++++++++---- tests/python/unittest/test_node_reflection.py | 2 + 3 files changed, 63 insertions(+), 51 deletions(-) diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index f8fa09dd2108e..423ea896aae8e 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include @@ -74,6 +73,9 @@ class StringRef; namespace tvm { namespace runtime { +// Forward declare TVMArgValue +class TVMArgValue; + /*! \brief String-aware ObjectRef equal functor */ struct ObjectHash { /*! @@ -1289,9 +1291,7 @@ class String : public ObjectRef { * \param val The value to be checked * \return A boolean indicating if val can be converted to String */ - static bool CanConvertFrom(const TVMArgValue& val) { - return val.type_code() == kTVMStr || val.IsObjectRef(); - } + inline static bool CanConvertFrom(const TVMArgValue& val); /*! * \brief Hash the binary bytes @@ -1523,25 +1523,6 @@ inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) cons return false; } -template <> -struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMArgValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); - } - } - - static String From(const TVMRetValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); - } - } -}; - /*! \brief Helper to represent nullptr for optional. */ struct NullOptType {}; @@ -1659,18 +1640,6 @@ class Optional : public ObjectRef { static constexpr bool _type_is_nullable = true; }; -template -struct PackedFuncValueConverter> { - static Optional From(const TVMArgValue& val) { - if (val.type_code() == kTVMNullptr) return Optional(nullptr); - return PackedFuncValueConverter::From(val); - } - static Optional From(const TVMRetValue& val) { - if (val.type_code() == kTVMNullptr) return Optional(nullptr); - return PackedFuncValueConverter::From(val); - } -}; - /*! * \brief An object representing a closure. This object is used by both the * Relay VM and interpreter. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 32312174c1ea1..d2450c4420484 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -492,22 +493,6 @@ class TVMArgValue : public TVMPODValue_ { return std::string(value_.v_str); } } - operator DLDataType() const { - if (type_code_ == kTVMStr) { - return String2DLDataType(operator std::string()); - } - // None type - if (type_code_ == kTVMNullptr) { - DLDataType t; - t.code = kTVMOpaqueHandle; - t.bits = 0; - t.lanes = 0; - return t; - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); - return value_.v_type; - } - operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); @@ -521,6 +506,8 @@ class TVMArgValue : public TVMPODValue_ { template ::value>::type> inline operator T() const; + inline operator DLDataType() const; + inline operator DataType() const; }; /*! @@ -1473,6 +1460,60 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import return (*this)->GetFunction(name, query_imports); } +// specializations of PackedFuncValueConverter +template <> +struct PackedFuncValueConverter<::tvm::runtime::String> { + static String From(const TVMArgValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } + + static String From(const TVMRetValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } +}; + +template +struct PackedFuncValueConverter> { + static Optional From(const TVMArgValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } + static Optional From(const TVMRetValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } +}; + +inline bool String::CanConvertFrom(const TVMArgValue& val) { + return val.type_code() == kTVMStr || val.IsObjectRef(); +} + +inline TVMArgValue::operator DLDataType() const { + if (String::CanConvertFrom(*this)) { + return String2DLDataType(PackedFuncValueConverter::From(*this).operator std::string()); + } + // None type + if (type_code_ == kTVMNullptr) { + DLDataType t; + t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; + return t; + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); + return value_.v_type; +} + +inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); } + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index d375fa0f75c68..edf8b426f7373 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -65,6 +65,8 @@ def test_make_node(): assert AA.op == A.op assert AA.value_index == A.value_index + y = tvm.ir.make_node("IntImm", dtype=tvm.runtime.String("int32"), value=10) + def test_make_sum(): A = te.placeholder((2, 10), name='A')