Skip to content

Commit

Permalink
[RUNTIME] Enable auto conversion from str to runtime::String in Packe…
Browse files Browse the repository at this point in the history
…dFunc, move dtype related handling to data_type.h
  • Loading branch information
tqchen committed Apr 6, 2020
1 parent f31df01 commit c899fbb
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 133 deletions.
136 changes: 136 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <dmlc/logging.h>
#include <type_traits>
#include <string>

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -263,6 +264,141 @@ inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}

/*!
* \brief Runtime utility for getting custom type name from code
* \param type_code Custom type code
* \return Custom type name
*/
TVM_DLL std::string GetCustomTypeName(uint8_t type_code);

/*!
* \brief Runtime utility for checking whether custom type is registered
* \param type_code Custom type code
* \return Bool representing whether type is registered
*/
TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);

/*!
* \brief Runtime utility for parsing string of the form "custom[<typename>]"
* \param s String to parse
* \param scan pointer to parsing pointer, which is scanning across s
* \return type code of custom type parsed
*/
TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);

/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);

/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline DLDataType String2DLDataType(std::string s);

/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string DLDataType2String(DLDataType t);

// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kTVMStr: return "str";
case kTVMBytes: return "bytes";
case kTVMOpaqueHandle: return "handle";
case kTVMNullptr: return "NULL";
case kTVMDLTensorHandle: return "ArrayHandle";
case kTVMDataType: return "DLDataType";
case kTVMContext: return "TVMContext";
case kTVMPackedFuncHandle: return "FunctionHandle";
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}

inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}

inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}

inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return "";
std::ostringstream os;
os << t;
return os.str();
}

inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}

} // namespace runtime

using DataType = runtime::DataType;
Expand Down
142 changes: 9 additions & 133 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <functional>
#include <tuple>
#include <vector>
Expand All @@ -52,28 +53,6 @@ class PrimExpr;

namespace runtime {

/*!
* \brief Runtime utility for getting custom type name from code
* \param type_code Custom type code
* \return Custom type name
*/
TVM_DLL std::string GetCustomTypeName(uint8_t type_code);

/*!
* \brief Runtime utility for checking whether custom type is registered
* \param type_code Custom type code
* \return Bool representing whether type is registered
*/
TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);

/*!
* \brief Runtime utility for parsing string of the form "custom[<typename>]"
* \param s String to parse
* \param scan pointer to parsing pointer, which is scanning across s
* \return type code of custom type parsed
*/
TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);

// forward declarations
class TVMArgs;
class TVMArgValue;
Expand Down Expand Up @@ -359,27 +338,6 @@ class TVMArgs {
inline TVMArgValue operator[](int i) const;
};

/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);

/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline DLDataType String2DLDataType(std::string s);

/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string DLDataType2String(DLDataType t);

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
Expand Down Expand Up @@ -554,6 +512,10 @@ class TVMArgValue : public TVMPODValue_ {
return std::string(value_.v_str);
}
}
operator tvm::runtime::String() const {
// directly use the std::string constructor for now.
return tvm::runtime::String(operator std::string());
}
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
Expand Down Expand Up @@ -642,6 +604,10 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
return *ptr<std::string>();
}
operator tvm::runtime::String() const {
// directly use the std::string constructor for now.
return tvm::runtime::String(operator std::string());
}
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
Expand Down Expand Up @@ -994,96 +960,6 @@ class TVMRetValue : public TVMPODValue_ {
} \
}

// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kTVMStr: return "str";
case kTVMBytes: return "bytes";
case kTVMOpaqueHandle: return "handle";
case kTVMNullptr: return "NULL";
case kTVMDLTensorHandle: return "ArrayHandle";
case kTVMDataType: return "DLDataType";
case kTVMContext: return "TVMContext";
case kTVMPackedFuncHandle: return "FunctionHandle";
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}

inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}

inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}

inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return "";
std::ostringstream os;
os << t;
return os.str();
}

inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}

inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/packed_func_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ TEST(PackedFunc, str) {
CHECK(args.num_args == 1);
std::string x = args[0];
CHECK(x == "hello");
String y = args[0];
CHECK(y == "hello");
*rv = x;
})("hello");
}
Expand Down

0 comments on commit c899fbb

Please sign in to comment.