Skip to content

Commit

Permalink
[RUNTIME] Enable auto conversion String->DLDataType (apache#6214)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and Trevor Morris committed Aug 26, 2020
1 parent f15bc50 commit a7c7a9f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 51 deletions.
39 changes: 4 additions & 35 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

#include <algorithm>
#include <cstring>
Expand Down Expand Up @@ -74,6 +73,9 @@ class StringRef;
namespace tvm {
namespace runtime {

// Forward declare TVMArgValue
class TVMArgValue;

/*! \brief String-aware ObjectRef equal functor */
struct ObjectHash {
/*!
Expand Down Expand Up @@ -1289,9 +1291,7 @@ class String : public ObjectRef {
* \param val The value to be checked
* \return A boolean indicating if val can be converted to String
*/
static bool CanConvertFrom(const TVMArgValue& val) {
return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
}
inline static bool CanConvertFrom(const TVMArgValue& val);

/*!
* \brief Hash the binary bytes
Expand Down Expand Up @@ -1523,25 +1523,6 @@ inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) cons
return false;
}

template <>
struct PackedFuncValueConverter<::tvm::runtime::String> {
static String From(const TVMArgValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}

static String From(const TVMRetValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}
};

/*! \brief Helper to represent nullptr for optional. */
struct NullOptType {};

Expand Down Expand Up @@ -1659,18 +1640,6 @@ class Optional : public ObjectRef {
static constexpr bool _type_is_nullable = true;
};

template <typename T>
struct PackedFuncValueConverter<Optional<T>> {
static Optional<T> From(const TVMArgValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
static Optional<T> From(const TVMRetValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
};

/*!
* \brief An object representing a closure. This object is used by both the
* Relay VM and interpreter.
Expand Down
73 changes: 57 additions & 16 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
Expand Down Expand Up @@ -492,22 +493,6 @@ class TVMArgValue : public TVMPODValue_ {
return std::string(value_.v_str);
}
}
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
}
// None type
if (type_code_ == kTVMNullptr) {
DLDataType t;
t.code = kTVMOpaqueHandle;
t.bits = 0;
t.lanes = 0;
return t;
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
return value_.v_type;
}
operator DataType() const { return DataType(operator DLDataType()); }
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
Expand All @@ -521,6 +506,8 @@ class TVMArgValue : public TVMPODValue_ {

template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
inline operator T() const;
inline operator DLDataType() const;
inline operator DataType() const;
};

/*!
Expand Down Expand Up @@ -1473,6 +1460,60 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
return (*this)->GetFunction(name, query_imports);
}

// specializations of PackedFuncValueConverter
template <>
struct PackedFuncValueConverter<::tvm::runtime::String> {
static String From(const TVMArgValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}

static String From(const TVMRetValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}
};

template <typename T>
struct PackedFuncValueConverter<Optional<T>> {
static Optional<T> From(const TVMArgValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
static Optional<T> From(const TVMRetValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
};

inline bool String::CanConvertFrom(const TVMArgValue& val) {
return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
}

inline TVMArgValue::operator DLDataType() const {
if (String::CanConvertFrom(*this)) {
return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string());
}
// None type
if (type_code_ == kTVMNullptr) {
DLDataType t;
t.code = kTVMOpaqueHandle;
t.bits = 0;
t.lanes = 0;
return t;
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
return value_.v_type;
}

inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); }

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
2 changes: 2 additions & 0 deletions tests/python/unittest/test_node_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def test_make_node():
assert AA.op == A.op
assert AA.value_index == A.value_index

y = tvm.ir.make_node("IntImm", dtype=tvm.runtime.String("int32"), value=10)


def test_make_sum():
A = te.placeholder((2, 10), name='A')
Expand Down

0 comments on commit a7c7a9f

Please sign in to comment.