Skip to content

Commit

Permalink
[REFACTOR][DTYPE] Isolate dtype to runtime (#4560)
Browse files Browse the repository at this point in the history
dtype.h -> runtime/data_type.h

Changes:
- Rename all old reference of tvm::Type to DataType
- ExprNode.type -> ExprNode.dtype
- Expr.type() -> Expr.dtype()
- Change Expr related functions to expr_operator.
  - DataType::min() -> min_value(DataType)
  - DataType::max() -> max_value(DataType)
- Move type constructor Int, UInt, Float, Handle, Bool into DataType.
  - Int(bits) -> DataType::Int(bits)
  - UInt(bits) -> DataType::UInt(bits)
  • Loading branch information
tqchen authored Dec 22, 2019
1 parent ad81796 commit 7fa8aab
Show file tree
Hide file tree
Showing 203 changed files with 2,003 additions and 1,947 deletions.
8 changes: 4 additions & 4 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class AttrsEqual {
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
// node comparator
Expand Down Expand Up @@ -506,8 +506,8 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
}
}
template<>
inline void SetValue(Type* ptr, const TVMArgValue& val) {
*ptr = val.operator Type();
inline void SetValue(DataType* ptr, const TVMArgValue& val) {
*ptr = val.operator DataType();
}
template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
Expand Down Expand Up @@ -611,7 +611,7 @@ struct TypeName<uint64_t> {
};

template<>
struct TypeName<Type> {
struct TypeName<DataType> {
static constexpr const char* value = "Type";
};

Expand Down
18 changes: 10 additions & 8 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,16 @@ class Buffer : public NodeRef {
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
TVM_DLL Expr access_ptr(int access_mask,
DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
Expr offset = make_const(DataType::Int(32), 0)) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
TVM_DLL Expr vload(Array<Expr> begin, Type dtype) const;
TVM_DLL Expr vload(Array<Expr> begin, DataType dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
Expand All @@ -108,7 +110,7 @@ class BufferNode : public Node {
*/
Var data;
/*! \brief data type in the content of the tensor */
Type dtype;
DataType dtype;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
Expand Down Expand Up @@ -149,14 +151,14 @@ class BufferNode : public Node {
}

/*! \return preferred index type for this buffer node */
Type DefaultIndexType() const {
return shape.size() != 0 ? shape[0].type() : Int(32);
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}

// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL static Buffer make(Var ptr,
Type dtype,
DataType dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr elem_offset,
Expand All @@ -183,7 +185,7 @@ inline const BufferNode* Buffer::operator->() const {
* \sa BufferNode::make for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<Expr> shape,
Type dtype = Float(32),
DataType dtype = DataType::Float(32),
std::string name = "buffer");
} // namespace tvm
#endif // TVM_BUFFER_H_
4 changes: 2 additions & 2 deletions include/tvm/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ struct ChannelNode : public Node {
/*! \brief Variable to channel handle */
Var handle_var;
/*! \brief default data type in read/write */
Type dtype;
DataType dtype;
// visit all attributes
void VisitAttrs(AttrVisitor* v) {
v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype);
}

static Channel make(Var handle_var, Type dtype);
static Channel make(Var handle_var, DataType dtype);
static constexpr const char* _type_key = "Channel";

TVM_DECLARE_NODE_TYPE_INFO(ChannelNode, Node);
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@
#include <unordered_map>
#include <iostream>
#include "base.h"
#include "dtype.h"
#include "node/node.h"
#include "node/container.h"
#include "node/functor.h"
#include "runtime/c_runtime_api.h"
#include "runtime/data_type.h"

namespace tvm {

/*! \brief Base node of all expressions. */
class ExprNode : public Node {
public:
/*! \brief The data type of the expression. */
DataType type;
DataType dtype;

static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
Expand Down Expand Up @@ -69,8 +69,8 @@ class Expr : public NodeRef {
TVM_DLL Expr(std::string str); // NOLINT(*)

/*! \return the data type of this expression. */
DataType type() const {
return static_cast<const ExprNode*>(get())->type;
DataType dtype() const {
return static_cast<const ExprNode*>(get())->dtype;
}

/*! \brief type indicate the container type */
Expand Down Expand Up @@ -113,7 +113,7 @@ class Variable : public ExprNode {
static Var make(DataType dtype, std::string name_hint);

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
}

Expand All @@ -126,14 +126,14 @@ class Var : public Expr {
public:
explicit Var(ObjectPtr<Object> n) : Expr(n) {}
TVM_DLL explicit Var(std::string name_hint = "v",
Type t = Int(32));
DataType t = DataType::Int(32));
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
return Var((*this)->name_hint + suffix, (*this)->dtype);
}
/*!
* \brief Get pointer to the internal value.
Expand Down Expand Up @@ -167,7 +167,7 @@ class IntImm : public ExprNode {
int64_t value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

Expand Down Expand Up @@ -452,7 +452,7 @@ inline const char* IterVarType2String(IterVarType t) {
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(std::string name_hint, Type t = Int(32));
TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));

/*
* \brief Template function to convert Map to unordered_map
Expand Down
46 changes: 30 additions & 16 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,28 @@ namespace tvm {
*/
template<typename ValueType,
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
inline Expr make_const(Type t, ValueType value);
inline Expr make_const(DataType t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
inline Expr make_zero(Type t);
inline Expr make_zero(DataType t);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_true(int lanes = 1) {
return make_const(UInt(1, lanes), 1);
return make_const(DataType::UInt(1, lanes), 1);
}
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_false(int lanes = 1) {
return make_const(UInt(1, lanes), 0);
return make_const(DataType::UInt(1, lanes), 0);
}
/*!
* \brief Get x as constant int expression.
Expand Down Expand Up @@ -139,6 +139,20 @@ inline bool is_zero(const Expr& x) {
*/
inline bool is_const(const Expr& x);

/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
* \return the maximum possible value in this format.
*/
TVM_DLL Expr max_value(const DataType& dtype);

/*!
* Query the minimum possible value of dtype.
* \param dtype The data type.
* \return the minimum possible value in this format.
*/
TVM_DLL Expr min_value(const DataType& dtype);

/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
Expand All @@ -157,7 +171,7 @@ TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr cast(const Type& t, Expr value);
TVM_DLL Expr cast(const DataType& t, Expr value);
/*!
* \brief perform reinterpret cast value to type.
*
Expand All @@ -166,7 +180,7 @@ TVM_DLL Expr cast(const Type& t, Expr value);
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr reinterpret(const Type& t, Expr value);
TVM_DLL Expr reinterpret(const DataType& t, Expr value);
/*!
* \brief add operator
*
Expand Down Expand Up @@ -586,7 +600,7 @@ TVM_DLL Expr trunc(Expr x);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \

TVM_DECLARE_INTRIN_UNARY(exp);
Expand Down Expand Up @@ -657,7 +671,7 @@ inline bool is_no_op(const Stmt& stmt) {
}

template<typename ValueType>
inline Expr MakeConstScalar(Type t, ValueType value) {
inline Expr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
Expand All @@ -672,7 +686,7 @@ inline Expr MakeConstScalar(Type t, ValueType value) {
}

template<typename ValueType, typename>
inline Expr make_const(Type t, ValueType value) {
inline Expr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
Expand All @@ -681,9 +695,9 @@ inline Expr make_const(Type t, ValueType value) {
}
}

inline Expr make_zero(Type t) {
inline Expr make_zero(DataType t) {
if (t.is_handle()) {
return reinterpret(t, make_const(UInt(64), 0));
return reinterpret(t, make_const(DataType::UInt(64), 0));
}
return make_const(t, 0);
}
Expand All @@ -703,13 +717,13 @@ inline Expr make_zero(Type t) {
return Name(Expr(a), b); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
return Name(make_const(b.dtype(), a), b); \
} \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
return Name(a, make_const(a.dtype(), b)); \
} \
inline Expr Name(const Expr& a, double b) { \
return Name(a, make_const(Float(64), b)); \
return Name(a, make_const(DataType::Float(64), b)); \
}

#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
Expand All @@ -722,10 +736,10 @@ inline Expr make_zero(Type t) {

#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
return Name(a, make_const(a.dtype(), b)); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
return Name(make_const(b.dtype(), a), b); \
}


Expand Down
Loading

0 comments on commit 7fa8aab

Please sign in to comment.