diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 2fbb9e6a866e..8810c4e4a0df 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -159,7 +159,7 @@ class AttrsEqual { bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; } - bool operator()(const Type& lhs, const Type& rhs) const { + bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; } // node comparator @@ -506,8 +506,8 @@ inline void SetValue(std::string* ptr, const TVMArgValue& val) { } } template<> -inline void SetValue(Type* ptr, const TVMArgValue& val) { - *ptr = val.operator Type(); +inline void SetValue(DataType* ptr, const TVMArgValue& val) { + *ptr = val.operator DataType(); } template<> inline void SetValue(double* ptr, const TVMArgValue& val) { @@ -611,7 +611,7 @@ struct TypeName { }; template<> -struct TypeName { +struct TypeName { static constexpr const char* value = "Type"; }; diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index d2c2b40661e2..fac18a9b1753 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -74,14 +74,16 @@ class Buffer : public NodeRef { * \param content_lanes The number of lanes for the (data) type. * \param offset The offset of ptr. */ - TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(), - int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const; + TVM_DLL Expr access_ptr(int access_mask, + DataType ptr_type = DataType::Handle(), + int content_lanes = 1, + Expr offset = make_const(DataType::Int(32), 0)) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index * \param dtype The data type to be loaded. */ - TVM_DLL Expr vload(Array begin, Type dtype) const; + TVM_DLL Expr vload(Array begin, DataType dtype) const; /*! * \brief Create a Stmt that does a vector store at begin index. * \param begin The beginning index @@ -108,7 +110,7 @@ class BufferNode : public Node { */ Var data; /*! \brief data type in the content of the tensor */ - Type dtype; + DataType dtype; /*! \brief The shape of the buffer */ Array shape; /*! @@ -149,14 +151,14 @@ class BufferNode : public Node { } /*! \return preferred index type for this buffer node */ - Type DefaultIndexType() const { - return shape.size() != 0 ? shape[0].type() : Int(32); + DataType DefaultIndexType() const { + return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); } // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. TVM_DLL static Buffer make(Var ptr, - Type dtype, + DataType dtype, Array shape, Array strides, Expr elem_offset, @@ -183,7 +185,7 @@ inline const BufferNode* Buffer::operator->() const { * \sa BufferNode::make for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, - Type dtype = Float(32), + DataType dtype = DataType::Float(32), std::string name = "buffer"); } // namespace tvm #endif // TVM_BUFFER_H_ diff --git a/include/tvm/channel.h b/include/tvm/channel.h index 3a40a787d891..25ee7a9e531c 100644 --- a/include/tvm/channel.h +++ b/include/tvm/channel.h @@ -52,14 +52,14 @@ struct ChannelNode : public Node { /*! \brief Variable to channel handle */ Var handle_var; /*! \brief default data type in read/write */ - Type dtype; + DataType dtype; // visit all attributes void VisitAttrs(AttrVisitor* v) { v->Visit("handle_var", &handle_var); v->Visit("dtype", &dtype); } - static Channel make(Var handle_var, Type dtype); + static Channel make(Var handle_var, DataType dtype); static constexpr const char* _type_key = "Channel"; TVM_DECLARE_NODE_TYPE_INFO(ChannelNode, Node); diff --git a/include/tvm/expr.h b/include/tvm/expr.h index fc52421d903b..f27cb9879fb7 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -29,11 +29,11 @@ #include #include #include "base.h" -#include "dtype.h" #include "node/node.h" #include "node/container.h" #include "node/functor.h" #include "runtime/c_runtime_api.h" +#include "runtime/data_type.h" namespace tvm { @@ -41,7 +41,7 @@ namespace tvm { class ExprNode : public Node { public: /*! \brief The data type of the expression. */ - DataType type; + DataType dtype; static constexpr const char* _type_key = "Expr"; TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node); @@ -69,8 +69,8 @@ class Expr : public NodeRef { TVM_DLL Expr(std::string str); // NOLINT(*) /*! \return the data type of this expression. */ - DataType type() const { - return static_cast(get())->type; + DataType dtype() const { + return static_cast(get())->dtype; } /*! \brief type indicate the container type */ @@ -113,7 +113,7 @@ class Variable : public ExprNode { static Var make(DataType dtype, std::string name_hint); void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("name", &name_hint); } @@ -126,14 +126,14 @@ class Var : public Expr { public: explicit Var(ObjectPtr n) : Expr(n) {} TVM_DLL explicit Var(std::string name_hint = "v", - Type t = Int(32)); + DataType t = DataType::Int(32)); /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ Var copy_with_suffix(const std::string& suffix) const { - return Var((*this)->name_hint + suffix, (*this)->type); + return Var((*this)->name_hint + suffix, (*this)->dtype); } /*! * \brief Get pointer to the internal value. @@ -167,7 +167,7 @@ class IntImm : public ExprNode { int64_t value; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("value", &value); } @@ -452,7 +452,7 @@ inline const char* IterVarType2String(IterVarType t) { * \param name_hint The name hint for the expression * \param t The type of the expression */ -TVM_DLL Var var(std::string name_hint, Type t = Int(32)); +TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32)); /* * \brief Template function to convert Map to unordered_map diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 625ee8e49286..41e7aa5b7796 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -44,20 +44,20 @@ namespace tvm { */ template::value>::type> -inline Expr make_const(Type t, ValueType value); +inline Expr make_const(DataType t, ValueType value); /*! * \brief Make a const zero expr. * \param t The target type. * \return the result expression. */ -inline Expr make_zero(Type t); +inline Expr make_zero(DataType t); /*! * \brief Make a constant true expression. * \param lanes The number of lanes in the bool * \return The result expression. */ inline Expr const_true(int lanes = 1) { - return make_const(UInt(1, lanes), 1); + return make_const(DataType::UInt(1, lanes), 1); } /*! * \brief Make a constant false expression. @@ -65,7 +65,7 @@ inline Expr const_true(int lanes = 1) { * \return The result expression. */ inline Expr const_false(int lanes = 1) { - return make_const(UInt(1, lanes), 0); + return make_const(DataType::UInt(1, lanes), 0); } /*! * \brief Get x as constant int expression. @@ -139,6 +139,20 @@ inline bool is_zero(const Expr& x) { */ inline bool is_const(const Expr& x); +/*! + * Query the maximum possible value of dtype. + * \param dtype The data type. + * \return the maximum possible value in this format. + */ +TVM_DLL Expr max_value(const DataType& dtype); + +/*! + * Query the minimum possible value of dtype. + * \param dtype The data type. + * \return the minimum possible value in this format. + */ +TVM_DLL Expr min_value(const DataType& dtype); + /*! * \brief Check whether x is a constant power of two * If x is power of two, write the power to the shift. @@ -157,7 +171,7 @@ TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL Expr cast(const Type& t, Expr value); +TVM_DLL Expr cast(const DataType& t, Expr value); /*! * \brief perform reinterpret cast value to type. * @@ -166,7 +180,7 @@ TVM_DLL Expr cast(const Type& t, Expr value); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL Expr reinterpret(const Type& t, Expr value); +TVM_DLL Expr reinterpret(const DataType& t, Expr value); /*! * \brief add operator * @@ -586,7 +600,7 @@ TVM_DLL Expr trunc(Expr x); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ - return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \ + return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \ } \ TVM_DECLARE_INTRIN_UNARY(exp); @@ -657,7 +671,7 @@ inline bool is_no_op(const Stmt& stmt) { } template -inline Expr MakeConstScalar(Type t, ValueType value) { +inline Expr MakeConstScalar(DataType t, ValueType value) { if (t.is_int()) return ir::IntImm::make(t, static_cast(value)); if (t.is_uint()) return ir::UIntImm::make(t, static_cast(value)); if (t.is_float()) return ir::FloatImm::make(t, static_cast(value)); @@ -672,7 +686,7 @@ inline Expr MakeConstScalar(Type t, ValueType value) { } template -inline Expr make_const(Type t, ValueType value) { +inline Expr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { @@ -681,9 +695,9 @@ inline Expr make_const(Type t, ValueType value) { } } -inline Expr make_zero(Type t) { +inline Expr make_zero(DataType t) { if (t.is_handle()) { - return reinterpret(t, make_const(UInt(64), 0)); + return reinterpret(t, make_const(DataType::UInt(64), 0)); } return make_const(t, 0); } @@ -703,13 +717,13 @@ inline Expr make_zero(Type t) { return Name(Expr(a), b); \ } \ inline Expr Name(int a, const Expr& b) { \ - return Name(make_const(b.type(), a), b); \ + return Name(make_const(b.dtype(), a), b); \ } \ inline Expr Name(const Expr& a, int b) { \ - return Name(a, make_const(a.type(), b)); \ + return Name(a, make_const(a.dtype(), b)); \ } \ inline Expr Name(const Expr& a, double b) { \ - return Name(a, make_const(Float(64), b)); \ + return Name(a, make_const(DataType::Float(64), b)); \ } #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ @@ -722,10 +736,10 @@ inline Expr make_zero(Type t) { #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ inline Expr Name(const Expr& a, int b) { \ - return Name(a, make_const(a.type(), b)); \ + return Name(a, make_const(a.dtype(), b)); \ } \ inline Expr Name(int a, const Expr& b) { \ - return Name(make_const(b.type(), a), b); \ + return Name(make_const(b.dtype(), a), b); \ } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 53eb94ec4ff5..33aa72b50805 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -46,7 +46,7 @@ class UIntImm : public ExprNode { uint64_t value; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("value", &value); } @@ -63,7 +63,7 @@ class FloatImm : public ExprNode { double value; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("value", &value); } @@ -80,7 +80,7 @@ class StringImm : public ExprNode { std::string value; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("value", &value); } @@ -100,7 +100,7 @@ class Cast : public ExprNode { Expr value; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("value", &value); } @@ -123,7 +123,7 @@ class BinaryOpNode : public ExprNode { Expr b; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->type)); + v->Visit("dtype", &(this->dtype)); v->Visit("a", &a); v->Visit("b", &b); } @@ -131,9 +131,9 @@ class BinaryOpNode : public ExprNode { static Expr make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; - CHECK(a.type() == b.type()) << "TypeError: mismatched types\n"; + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; NodePtr node = make_node(); - node->type = a.type(); + node->dtype = a.dtype(); node->a = std::move(a); node->b = std::move(b); return Expr(node); @@ -215,7 +215,7 @@ class CmpOpNode : public ExprNode { Expr b; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->type)); + v->Visit("dtype", &(this->dtype)); v->Visit("a", &a); v->Visit("b", &b); } @@ -223,9 +223,9 @@ class CmpOpNode : public ExprNode { static Expr make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; - CHECK(a.type() == b.type()) << "TypeError: mismatched types\n"; + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; NodePtr node = make_node(); - node->type = Bool(a.type().lanes()); + node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); @@ -279,7 +279,7 @@ class And : public ExprNode { Expr b; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->type)); + v->Visit("dtype", &(this->dtype)); v->Visit("a", &a); v->Visit("b", &b); } @@ -299,7 +299,7 @@ class Or : public ExprNode { Expr b; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("a", &a); v->Visit("b", &b); } @@ -317,7 +317,7 @@ class Not : public ExprNode { Expr a; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("a", &a); } @@ -344,7 +344,7 @@ class Select : public ExprNode { Expr false_value; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("condition", &condition); v->Visit("true_value", &true_value); v->Visit("false_value", &false_value); @@ -381,13 +381,13 @@ class Load : public ExprNode { Expr predicate; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("buffer_var", &buffer_var); v->Visit("index", &index); v->Visit("predicate", &predicate); } - TVM_DLL static Expr make(DataType type, Var buffer_var, Expr index, Expr predicate); + TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate); static constexpr const char* _type_key = "Load"; TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode); @@ -412,7 +412,7 @@ class Ramp : public ExprNode { int lanes; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("base", &base); v->Visit("stride", &stride); v->Visit("lanes", &lanes); @@ -433,7 +433,7 @@ class Broadcast : public ExprNode { int lanes; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("value", &value); v->Visit("lanes", &lanes); } @@ -457,7 +457,7 @@ class Let : public ExprNode { Expr body; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); @@ -523,7 +523,7 @@ class Call : public ExprNode { int value_index{0}; void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("name", &name); v->Visit("args", &args); v->Visit("call_type", &call_type); @@ -531,7 +531,7 @@ class Call : public ExprNode { v->Visit("value_index", &value_index); } - TVM_DLL static Expr make(DataType type, + TVM_DLL static Expr make(DataType dtype, std::string name, Array args, CallType call_type, @@ -695,7 +695,7 @@ class Reduce : public ExprNode { int value_index); void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("combiner", &combiner); v->Visit("source", &source); v->Visit("axis", &axis); @@ -713,7 +713,7 @@ class Any : public ExprNode { void VisitAttrs(AttrVisitor* v) {} /*! \brief Convert to var. */ Var ToVar() const { - return Variable::make(Int(32), "any_dim"); + return Variable::make(DataType::Int(32), "any_dim"); } TVM_DLL static Expr make(); @@ -917,7 +917,7 @@ class Allocate : public StmtNode { /*! \brief The buffer variable. */ Var buffer_var; /*! \brief The type of the buffer. */ - DataType type; + DataType dtype; /*! \brief The extents of the buffer. */ Array extents; /*! \brief Only allocate buffer when condition is satisfied. */ @@ -931,14 +931,14 @@ class Allocate : public StmtNode { void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("extents", &extents); v->Visit("condition", &condition); v->Visit("body", &body); } TVM_DLL static Stmt make(Var buffer_var, - DataType type, + DataType dtype, Array extents, Expr condition, Stmt body, @@ -993,7 +993,7 @@ class Realize : public StmtNode { /*! \brief The output value index if func's value is a tuple. */ int value_index; /*! \brief The data type of the array. */ - DataType type; + DataType dtype; /*! \brief Bounds to be realized. */ Region bounds; /*! \brief Only realize if condition holds. */ @@ -1004,7 +1004,7 @@ class Realize : public StmtNode { void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); - v->Visit("dtype", &type); + v->Visit("dtype", &dtype); v->Visit("bounds", &bounds); v->Visit("condition", &condition); v->Visit("body", &body); @@ -1012,7 +1012,7 @@ class Realize : public StmtNode { TVM_DLL static Stmt make(FunctionRef func, int value_index, - DataType type, + DataType dtype, Region bounds, Expr condition, Stmt body); @@ -1165,20 +1165,20 @@ class Prefetch : public StmtNode { /*! \brief The output value index if func's value is a tuple. */ int value_index; /*! \brief The data type of the array. */ - DataType type; + DataType dtype; /*! \brief Bounds to be prefetched. */ Region bounds; void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); - v->Visit("type", &type); + v->Visit("dtype", &dtype); v->Visit("bounds", &bounds); } TVM_DLL static Stmt make(FunctionRef func, int value_index, - DataType type, + DataType dtype, Region bounds); static constexpr const char* _type_key = "Prefetch"; @@ -1620,7 +1620,7 @@ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; * \param dtype The data type * \return Expr a expression with dtype. */ -inline Expr TypeAnnotation(Type dtype) { +inline Expr TypeAnnotation(DataType dtype) { return ir::Call::make(dtype, "type_annotation", {}, ir::Call::PureIntrinsic); diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 35a8e1d4a657..daffeb859668 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -35,8 +36,6 @@ namespace tvm { // forward declaration -class DataType; - using runtime::Object; using runtime::ObjectPtr; using runtime::ObjectRef; diff --git a/include/tvm/operation.h b/include/tvm/operation.h index f53c1ce56a93..34f584b63261 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -75,7 +75,7 @@ class OperationNode : public ir::FunctionBaseNode { * \param i The output index. * \return type of i-th output. */ - virtual Type output_dtype(size_t i) const = 0; + virtual DataType output_dtype(size_t i) const = 0; /*! * \brief Get shape of i-th output tensor. * \param i The output index. @@ -160,11 +160,11 @@ class PlaceholderOpNode : public OperationNode { /*! \brief The shape of the input */ Array shape; /*! \brief The data type of the input. */ - Type dtype; + DataType dtype; // override behavior. int num_outputs() const final; Array root_iter_vars() const final; - Type output_dtype(size_t i) const final; + DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( @@ -197,7 +197,7 @@ class PlaceholderOpNode : public OperationNode { } static Operation make(std::string name, Array shape, - Type dtype); + DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); @@ -243,7 +243,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { ComputeOpNode() {} // override functions int num_outputs() const final; - Type output_dtype(size_t i) const final; + DataType output_dtype(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( const Operation& self, @@ -296,7 +296,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { TensorComputeOpNode() {} // override functions int num_outputs() const final; - Type output_dtype(size_t i) const final; + DataType output_dtype(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( const Operation& self, @@ -370,7 +370,7 @@ class ScanOpNode : public OperationNode { // override behavior. int num_outputs() const final; Array root_iter_vars() const final; - Type output_dtype(size_t i) const final; + DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( @@ -437,7 +437,7 @@ class ExternOpNode : public OperationNode { // override functions int num_outputs() const final; Array root_iter_vars() const final; - Type output_dtype(size_t i) const final; + DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( @@ -505,7 +505,7 @@ class HybridOpNode : public OperationNode { // override functions int num_outputs() const final; Array root_iter_vars() const final; - Type output_dtype(size_t i) const final; + DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( @@ -562,7 +562,7 @@ using FBatchCompute = std::function (const Array& i)>; * \param name The name of the Tensor. */ TVM_DLL Tensor placeholder(Array shape, - Type dtype = Float(32), + DataType dtype = DataType::Float(32), std::string name = "placeholder"); /*! diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 71f8f55b2655..93b7ac33f155 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -208,23 +208,6 @@ inline TObjectRef TVMRetValue::AsObjectRef() const { return TObjectRef(ObjectPtr(ptr)); } -// type related stuffs -inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { - return this->operator=(t.operator DLDataType()); -} - -inline TVMRetValue::operator tvm::DataType() const { - return DataType(operator DLDataType()); -} - -inline TVMArgValue::operator tvm::DataType() const { - return DataType(operator DLDataType()); -} - -inline void TVMArgsSetter::operator()( - size_t i, const DataType& t) const { - this->operator()(i, t.operator DLDataType()); -} } // namespace runtime } // namespace tvm #endif // TVM_PACKED_FUNC_EXT_H_ diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index 2e279a56bbde..c74b6487de54 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -43,7 +43,7 @@ struct AllocTensorAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(dtype) .describe( "The dtype of the tensor to allocate.") - .set_default(Float(32, 1)); + .set_default(DataType::Float(32, 1)); TVM_ATTR_FIELD(const_shape) .describe( "The shape of constant used to aid in type inference."); diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 42a01f009b10..32f9c32f468a 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -63,7 +63,7 @@ using NodeRef = tvm::NodeRef; /*! * \brief Content data type. */ -using DataType = ::tvm::Type; +using DataType = ::tvm::DataType; /*! * \brief Symbolic expression for tensor shape. diff --git a/include/tvm/dtype.h b/include/tvm/runtime/data_type.h similarity index 57% rename from include/tvm/dtype.h rename to include/tvm/runtime/data_type.h index 9f7902deb960..5b222ac6b442 100644 --- a/include/tvm/dtype.h +++ b/include/tvm/runtime/data_type.h @@ -17,23 +17,35 @@ * under the License. */ /* - * \file tvm/dtype.h - * \brief Data type used in IR. + * \file tvm/runtime/data_type.h + * \brief Primitive runtime data type. */ // Acknowledgement: DataType structure design originates from Halide. -#ifndef TVM_DTYPE_H_ -#define TVM_DTYPE_H_ +#ifndef TVM_RUNTIME_DATA_TYPE_H_ +#define TVM_RUNTIME_DATA_TYPE_H_ -#include "runtime/packed_func.h" +#include +#include +#include -namespace tvm { -class Expr; +namespace tvm { +namespace runtime { /*! - * \brief Primitive data types in tvm. + * \brief Runtime primitive data type. + * + * This class is a thin wrapper of DLDataType. + * We also make use of DataType in compiler to store quick hint */ class DataType { public: + /*! \brief Type code for the DataType. */ + enum TypeCode { + kInt = kDLInt, + kUInt = kDLUInt, + kFloat = kDLFloat, + kHandle = TVMTypeCode::kHandle, + }; /*! \brief default constructor */ DataType() {} /*! @@ -75,23 +87,23 @@ class DataType { } /*! \return whether type is a scalar type. */ bool is_bool() const { - return code() == kDLUInt && bits() == 1; + return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ bool is_float() const { - return code() == kDLFloat; + return code() == DataType::kFloat; } /*! \return whether type is an int type. */ bool is_int() const { - return code() == kDLInt; + return code() == DataType::kInt; } /*! \return whether type is an uint type. */ bool is_uint() const { - return code() == kDLUInt; + return code() == DataType::kUInt; } /*! \return whether type is a handle type. */ bool is_handle() const { - return code() == kHandle; + return code() == DataType::kHandle; } /*! \return whether type is a vector type. */ bool is_vector() const { @@ -120,107 +132,93 @@ class DataType { DataType element_of() const { return with_lanes(1); } - // operator overloadings + /*! + * \brief Equal comparator. + * \param other The data type to compre against. + * \return The comparison resilt. + */ bool operator==(const DataType& other) const { return data_.code == other.data_.code && data_.bits == other.data_.bits && data_.lanes == other.data_.lanes; } + /*! + * \brief NotEqual comparator. + * \param other The data type to compre against. + * \return The comparison resilt. + */ bool operator!=(const DataType& other) const { return !operator==(other); } + /*! + * \brief Converter to DLDataType + * \return the result. + */ operator DLDataType () const { return data_; } - /*! \return the maximum possible value in this format. */ - TVM_DLL Expr max() const; - /*! \return the minimum possible value in this format. */ - TVM_DLL Expr min() const; + + /*! + * \brief Construct an int type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes. + * \return The constructed data type. + */ + static DataType Int(int bits, int lanes = 1) { + return DataType(kDLInt, bits, lanes); + } + /*! + * \brief Construct an uint type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType UInt(int bits, int lanes = 1) { + return DataType(kDLUInt, bits, lanes); + } + /*! + * \brief Construct an uint type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType Float(int bits, int lanes = 1) { + return DataType(kDLFloat, bits, lanes); + } + /*! + * \brief Construct a bool type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType Bool(int lanes = 1) { + return DataType::UInt(1, lanes); + } + /*! + * \brief Construct a handle type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType Handle(int bits = 64, int lanes = 1) { + return DataType(kHandle, bits, lanes); + } + /*! + * \brief Get the corresponding type of TVMShapeIndex. + * \return The type of TVM shape index. + */ + static DataType ShapeIndex() { + if (std::is_signed::value) { + return DataType::Int(sizeof(tvm_index_t) * 8); + } else { + return DataType::UInt(sizeof(tvm_index_t) * 8); + } + } private: DLDataType data_; }; -/*! - * \brief Construct an int type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes. - * \return The constructed data type. - */ -inline DataType Int(int bits, int lanes = 1) { - return DataType(kDLInt, bits, lanes); -} - -/*! - * \brief Construct an uint type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ -inline DataType UInt(int bits, int lanes = 1) { - return DataType(kDLUInt, bits, lanes); -} - -/*! - * \brief Construct a bool type. - * \param lanes The number of lanes - * \return The constructed data type. - */ -inline DataType Bool(int lanes = 1) { - return UInt(1, lanes); -} - -/*! - * \brief Construct an uint type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ -inline DataType Float(int bits, int lanes = 1) { - return DataType(kDLFloat, bits, lanes); -} - -/*! - * \brief Construct a handle type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ -inline DataType Handle(int bits = 64, int lanes = 1) { - return DataType(kHandle, bits, lanes); -} - -/*! - * \brief Get the corresponding type of TVMShapeIndex. - * \return The type of TVM shape index. - */ -inline DataType TVMShapeIndexType() { - if (std::is_signed::value) { - return Int(sizeof(tvm_index_t) * 8); - } else { - return UInt(sizeof(tvm_index_t) * 8); - } -} - -/*! - * \brief Convert DLDataType to DataType. - * \param t The original type. - * \return The conversion result. - */ -inline DataType TVMType2Type(DLDataType t) { - return DataType(t.code, t.bits, t.lanes); -} - -/*! - * \brief Convert DataType to DataType. - * \param t The original type. - * \return The conversion result. - */ -inline DLDataType Type2TVMType(DataType t) { - return t.operator DLDataType(); -} - /*! * \brief Get the number of bytes needed in a vector. * \param dtype The data type. @@ -229,19 +227,15 @@ inline DLDataType Type2TVMType(DataType t) { inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist - if (dtype == Bool()) return 1; + if (dtype == DataType::Bool()) return 1; CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; } -// Overload print function. -inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*) - using namespace tvm::runtime; - return os << dtype.operator DLDataType(); -} +} // namespace runtime + +using DataType = runtime::DataType; -// Backward compatibility -using Type = DataType; } // namespace tvm -#endif // TVM_DTYPE_H_ +#endif // TVM_RUNTIME_DATA_TYPE_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 57c4291907c0..1d7db66ec570 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -28,6 +28,11 @@ #include #endif #include +#include +#include +#include +#include +#include #include #include #include @@ -36,10 +41,7 @@ #include #include #include -#include "c_runtime_api.h" -#include "module.h" -#include "ndarray.h" -#include "object.h" + // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY @@ -49,7 +51,6 @@ namespace tvm { // forward declarations class Integer; -class DataType; class Expr; namespace runtime { @@ -629,7 +630,7 @@ class TVMArgValue : public TVMPODValue_ { typename = typename std::enable_if< std::is_class::value>::type> inline operator T() const; - inline operator tvm::DataType() const; + inline operator DataType() const; inline operator tvm::Expr() const; inline operator tvm::Integer() const; }; @@ -834,8 +835,8 @@ class TVMRetValue : public TVMPODValue_ { template inline TObjectRef AsObjectRef() const; // type related - inline operator tvm::DataType() const; - inline TVMRetValue& operator=(const tvm::DataType& other); + inline operator DataType() const; + inline TVMRetValue& operator=(const DataType& other); private: template @@ -1048,6 +1049,10 @@ inline TVMType String2TVMType(std::string s) { return t; } +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) + return os << dtype.operator DLDataType(); +} + inline TVMArgValue TVMArgs::operator[](int i) const { CHECK_LT(i, num_args) << "not enough argument passed, " @@ -1198,7 +1203,7 @@ class TVMArgsSetter { typename = typename std::enable_if< extension_type_info::code != 0>::type> inline void operator()(size_t i, const T& value) const; - inline void operator()(size_t i, const tvm::DataType& t) const; + inline void operator()(size_t i, const DataType& t) const; private: /*! \brief The values fields */ @@ -1362,6 +1367,24 @@ inline void TVMArgsSetter::operator()(size_t i, const T& value) const { values_[i].v_handle = const_cast(&value); } +// PackedFunc support +inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { + return this->operator=(t.operator DLDataType()); +} + +inline TVMRetValue::operator DataType() const { + return DataType(operator DLDataType()); +} + +inline TVMArgValue::operator DataType() const { + return DataType(operator DLDataType()); +} + +inline void TVMArgsSetter::operator()( + size_t i, const DataType& t) const { + this->operator()(i, t.operator DLDataType()); +} + // extension type handling template struct ExtTypeInfo { diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 599d6ff657d1..f44498a0aa7a 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -163,7 +163,7 @@ class TensorNode : public Node { /*! \brief The shape of the tensor */ Array shape; /*! \brief data type in the content of the tensor */ - Type dtype; + DataType dtype; /*! \brief the source operation, can be None */ Operation op; /*! \brief the output index from source operation */ @@ -178,7 +178,7 @@ class TensorNode : public Node { v->Visit("value_index", &value_index); } TVM_DLL static Tensor make(Array shape, - Type dtype, + DataType dtype, Operation op, int value_index); diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h index f108ff131d66..63d065576213 100644 --- a/nnvm/include/nnvm/compiler/util.h +++ b/nnvm/include/nnvm/compiler/util.h @@ -41,7 +41,7 @@ namespace compiler { inline tvm::Array ShapeToArray(TShape shape) { tvm::Array result; for (auto i : shape) { - result.push_back(tvm::make_const(tvm::Int(32), i)); + result.push_back(tvm::make_const(tvm::DataType::Int(32), i)); } return result; } diff --git a/nnvm/src/compiler/alter_op_layout.cc b/nnvm/src/compiler/alter_op_layout.cc index abc0022c2a79..8a6694f166d4 100644 --- a/nnvm/src/compiler/alter_op_layout.cc +++ b/nnvm/src/compiler/alter_op_layout.cc @@ -46,7 +46,7 @@ tvm::Array GetTensorInfo(const IndexedGraph& idx_graph, tvm::Array shape; for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) { CHECK_LE(x, static_cast(std::numeric_limits::max())); - shape.push_back(tvm::make_const(tvm::Int(32), x)); + shape.push_back(tvm::make_const(tvm::DataType::Int(32), x)); } vec.push_back(tvm::placeholder( shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)]))); diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 67af852ab393..82d8ff31612e 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -47,52 +47,52 @@ using namespace tvm; * \param type the tvm type. * \return corresponding DLDataType */ -int GetTypeFlag(tvm::Type type) { - if (type == tvm::Float(32)) return 0; - if (type == tvm::Float(64)) return 1; - if (type == tvm::Float(16)) return 2; - if (type == tvm::UInt(8)) return 3; - if (type == tvm::Int(32)) return 4; - if (type == tvm::Int(8)) return 5; - if (type == tvm::Int(64)) return 6; - if (type == tvm::Int(16)) return 7; - if (type == tvm::UInt(16)) return 8; - if (type == tvm::UInt(32)) return 9; - if (type == tvm::UInt(64)) return 10; - if (type == tvm::UInt(1)) return 11; +int GetTypeFlag(tvm::DataType type) { + if (type == tvm::DataType::Float(32)) return 0; + if (type == tvm::DataType::Float(64)) return 1; + if (type == tvm::DataType::Float(16)) return 2; + if (type == tvm::DataType::UInt(8)) return 3; + if (type == tvm::DataType::Int(32)) return 4; + if (type == tvm::DataType::Int(8)) return 5; + if (type == tvm::DataType::Int(64)) return 6; + if (type == tvm::DataType::Int(16)) return 7; + if (type == tvm::DataType::UInt(16)) return 8; + if (type == tvm::DataType::UInt(32)) return 9; + if (type == tvm::DataType::UInt(64)) return 10; + if (type == tvm::DataType::UInt(1)) return 11; LOG(FATAL) << "cannot convert " << type; return 0; } // convert from type flag to tvm type. -Type GetTVMType(int type_flag) { +DataType GetTVMType(int type_flag) { switch (type_flag) { case 0: - return tvm::Float(32); + return tvm::DataType::Float(32); case 1: - return tvm::Float(64); + return tvm::DataType::Float(64); case 2: - return tvm::Float(16); + return tvm::DataType::Float(16); case 3: - return tvm::UInt(8); + return tvm::DataType::UInt(8); case 4: - return tvm::Int(32); + return tvm::DataType::Int(32); case 5: - return tvm::Int(8); + return tvm::DataType::Int(8); case 6: - return tvm::Int(64); + return tvm::DataType::Int(64); case 7: - return tvm::Int(16); + return tvm::DataType::Int(16); case 8: - return tvm::UInt(16); + return tvm::DataType::UInt(16); case 9: - return tvm::UInt(32); + return tvm::DataType::UInt(32); case 10: - return tvm::UInt(64); + return tvm::DataType::UInt(64); case 11: - return tvm::UInt(1); + return tvm::DataType::UInt(1); default: LOG(FATAL) << "unknown type_flag=" << type_flag; - return Float(32); + return DataType::Float(32); } } @@ -216,7 +216,7 @@ class CompileEngine { Array shape; for (int64_t x : shape_vec[idx.entry_id(nid, i)]) { CHECK_LE(x, static_cast(std::numeric_limits::max())); - shape.push_back(make_const(Int(32), x)); + shape.push_back(make_const(DataType::Int(32), x)); } out_info.push_back( placeholder(shape, diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h index 8151f6ced478..b4fec104bbcb 100644 --- a/nnvm/src/compiler/compile_engine.h +++ b/nnvm/src/compiler/compile_engine.h @@ -117,7 +117,7 @@ GraphFunc GraphLower(Graph graph, * \param type the tvm type * \return corresponding DLDataType */ -int GetTypeFlag(tvm::Type type); +int GetTypeFlag(tvm::DataType type); /*! * \brief Get TVM Type from type flag @@ -125,7 +125,7 @@ int GetTypeFlag(tvm::Type type); * \param type_flag the type flag * \return corresponding TVM type */ -tvm::Type GetTVMType(int type_flag); +tvm::DataType GetTVMType(int type_flag); } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index 1b4a8e117555..f6c1332dd79c 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -352,17 +352,17 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { prod *= x; } CHECK_LE(prod, static_cast(std::numeric_limits::max())); - shape.push_back(make_const(Int(32), prod)); + shape.push_back(make_const(DataType::Int(32), prod)); } else { for (int64_t x : shape_vec[idx.entry_id(e)]) { CHECK_LE(x, static_cast(std::numeric_limits::max())); - shape.push_back(make_const(Int(32), x)); + shape.push_back(make_const(DataType::Int(32), x)); } } std::ostringstream os_name; os_name << "input" << fe.imap.size(); Tensor data = placeholder( - shape, TVMType2Type(GetDLType(dtype_vec[idx.entry_id(e)])), + shape, DataType(GetDLType(dtype_vec[idx.entry_id(e)])), os_name.str()); NodeEntry garg = Symbol::CreateVariable(os_name.str()).outputs[0]; fe.imap[e] = garg; diff --git a/nnvm/src/compiler/graph_fuse.h b/nnvm/src/compiler/graph_fuse.h index ce7da828e301..dd8d5d57f66a 100644 --- a/nnvm/src/compiler/graph_fuse.h +++ b/nnvm/src/compiler/graph_fuse.h @@ -47,7 +47,7 @@ enum class FuseRule { * \return corresponding DLDataType */ inline DLDataType GetDLType(int type_flag) { - return tvm::Type2TVMType(GetTVMType(type_flag)); + return GetTVMType(type_flag); } struct INodeEntryHash { diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index 63b2d45b18dc..1864ccd3506f 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -631,11 +631,11 @@ NNVM_REGISTER_OP(pad) << "Illegal pad_width"; Array pad_before; for (size_t i = 0; i < pad_width.ndim(); ++i) { - pad_before.push_back(tvm::make_const(tvm::Int(32), pad_width[i][0])); + pad_before.push_back(tvm::make_const(tvm::DataType::Int(32), pad_width[i][0])); } Array pad_after; for (size_t i = 0; i < pad_width.ndim(); ++i) { - pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1])); + pad_after.push_back(tvm::make_const(tvm::DataType::Int(32), pad_width[i][1])); } return Array{ topi::pad(inputs[0], pad_before, pad_after, tvm::make_const(inputs[0]->dtype, param.pad_value)) }; diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index 7a79db041755..5ac6d91dc141 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -482,7 +482,7 @@ NNVM_REGISTER_INIT_OP(full) const Array& out_info) { const InitOpWithScalarParam& param = nnvm::get(attrs.parsed); Array shape = ShapeToArray(param.shape); - Type dtype = GetTVMType(param.dtype); + DataType dtype = GetTVMType(param.dtype); Expr fill_value = tvm::make_const(dtype, param.fill_value); return Array{ topi::full(shape, dtype, fill_value) }; }) @@ -505,7 +505,7 @@ NNVM_REGISTER_INIT_OP(zeros) const Array& out_info) { const InitOpParam& param = nnvm::get(attrs.parsed); Array shape = ShapeToArray(param.shape); - Type dtype = GetTVMType(param.dtype); + DataType dtype = GetTVMType(param.dtype); Expr fill_value = tvm::make_const(dtype, 0); return Array{ topi::full(shape, dtype, fill_value) }; }) @@ -528,7 +528,7 @@ NNVM_REGISTER_INIT_OP(ones) const Array& out_info) { const InitOpParam& param = nnvm::get(attrs.parsed); Array shape = ShapeToArray(param.shape); - Type dtype = GetTVMType(param.dtype); + DataType dtype = GetTVMType(param.dtype); Expr fill_value = tvm::make_const(dtype, 1); return Array{ topi::full(shape, dtype, fill_value) }; }) @@ -950,8 +950,8 @@ Example:: const Array& out_info) { const ClipParam params = get(attrs.parsed); return Array{ - topi::clip(inputs[0], tvm::make_const(tvm::Float(32), params.a_min), - tvm::make_const(tvm::Float(32), params.a_max)) }; + topi::clip(inputs[0], tvm::make_const(tvm::DataType::Float(32), params.a_min), + tvm::make_const(tvm::DataType::Float(32), params.a_max)) }; }) .add_argument("data", "NDArray-or-Symbol", "Input array.") .add_arguments(ClipParam::__FIELDS__()) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 8b85c4e31ad2..a83f447a60f4 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -477,7 +477,7 @@ NNVM_REGISTER_OP(cast) const Array& inputs, const Array& out_info) { const CastParam& param = nnvm::get(attrs.parsed); - Type dtype = GetTVMType(param.dtype); + DataType dtype = GetTVMType(param.dtype); return Array{ topi::cast(inputs[0], dtype) }; }) .set_support_level(1); @@ -1261,8 +1261,8 @@ NNVM_REGISTER_OP(slice_like) Array target_shape = inputs[1]->shape; Array begin_idx, end_idx, strides; for (size_t i = 0; i < src_shape.size(); ++i) { - begin_idx.push_back(make_const(tvm::Int(32), 0)); - strides.push_back(make_const(tvm::Int(32), 1)); + begin_idx.push_back(make_const(tvm::DataType::Int(32), 0)); + strides.push_back(make_const(tvm::DataType::Int(32), 1)); } end_idx = Array(src_shape); if (param.axis.ndim() == 0) { diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 9312c5532302..03f37b171782 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -30,7 +30,7 @@ namespace tvm { namespace ir { TVM_REGISTER_API("_Var") -.set_body_typed([](std::string s, Type t) { +.set_body_typed([](std::string s, DataType t) { return Variable::make(t, s); }); @@ -75,7 +75,7 @@ TVM_REGISTER_API("make.For") TVM_REGISTER_API("make.Load") .set_body([](TVMArgs args, TVMRetValue *ret) { - Type t = args[0]; + DataType t = args[0]; if (args.size() == 3) { *ret = Load::make(t, args[1], args[2], const_true(t.lanes())); } else { @@ -87,7 +87,7 @@ TVM_REGISTER_API("make.Store") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr value = args[1]; if (args.size() == 3) { - *ret = Store::make(args[0], value, args[2], const_true(value.type().lanes())); + *ret = Store::make(args[0], value, args[2], const_true(value.dtype().lanes())); } else { *ret = Store::make(args[0], value, args[2], args[3]); } @@ -97,8 +97,8 @@ TVM_REGISTER_API("make.Realize") .set_body_typed(Realize::make); TVM_REGISTER_API("make.Call") -.set_body_typed, int, FunctionRef, int)>([]( - Type type, std::string name, +.set_body_typed, int, FunctionRef, int)>([]( + DataType type, std::string name, Array args, int call_type, FunctionRef func, int value_index ) { @@ -166,8 +166,8 @@ TVM_REGISTER_API("make.Block") // has default args TVM_REGISTER_API("make.Allocate") - .set_body_typed, Expr, Stmt)>([]( - VarExpr buffer_var, Type type, Array extents, Expr condition, Stmt body + .set_body_typed, Expr, Stmt)>([]( + VarExpr buffer_var, DataType type, Array extents, Expr condition, Stmt body ){ return Allocate::make(buffer_var, type, extents, condition, body); }); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index f3d6c5f6ab62..9cb797fa45e4 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -35,10 +35,10 @@ namespace tvm { TVM_REGISTER_API("_min_value") -.set_body_method(&DataType::min); +.set_body_typed(min_value); TVM_REGISTER_API("_max_value") -.set_body_method(&DataType::max); +.set_body_typed(max_value); TVM_REGISTER_API("_const") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -287,8 +287,8 @@ TVM_REGISTER_API("_TensorHash") }); TVM_REGISTER_API("_Placeholder") -.set_body_typed, Type, std::string)>([]( - Array shape, Type dtype, std::string name +.set_body_typed, DataType, std::string)>([]( + Array shape, DataType dtype, std::string name ) { return placeholder(shape, dtype, name); }); diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 31fedcc72cde..19f045241915 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -132,7 +132,7 @@ class BoundDeducer: public IRVisitor { Expr target_var = left ? op->a : op->b; SignType sign_operand; - if (operand.type().is_uint()) { + if (operand.dtype().is_uint()) { sign_operand = kPositive; } else { sign_operand = expr_map_[operand].sign_type(); diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 1b576a645824..022dd8e94dbb 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -115,7 +115,7 @@ class SplitExprNode : public CanonicalExprNode { Expr NormalizeWithScale(int64_t sscale) const { Expr res = this->index; - Type dtype = this->type; + DataType dtype = this->dtype; if (this->scale == 0) { return make_const(dtype, 0); } @@ -190,9 +190,9 @@ class SumExprNode : public CanonicalExprNode { Expr Normalize() const final { // quick path 1. if (this->args.size() == 0) { - return make_const(this->type, this->base); + return make_const(this->dtype, this->base); } - return Normalize_(this->type, + return Normalize_(this->dtype, SimplifySplitExprs(args), base); } @@ -379,7 +379,7 @@ class SumExprNode : public CanonicalExprNode { std::stable_sort(args.begin(), args.end(), fcompare); return args; } - static Expr Normalize_(Type dtype, + static Expr Normalize_(DataType dtype, const std::vector& args, int64_t base) { // Positive scales first @@ -508,7 +508,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { expr = op->Normalize(); } NodePtr n = make_node(); - n->type = expr.type(); + n->dtype = expr.dtype(); n->index = std::move(expr); n->div_mode = kTruncDiv; return SplitExpr(n); @@ -545,7 +545,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { return GetRef(op); } NodePtr n = make_node(); - n->type = expr.type(); + n->dtype = expr.dtype(); if (const auto* op = expr.as()) { n->base = op->value; return SumExpr(n); @@ -560,7 +560,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { Expr CanonicalSimplifier::Impl:: Mutate_(const Add* op, const Expr& self) { - if (!IsIndexType(op->type)) { + if (!IsIndexType(op->dtype)) { return Rewriter::Mutate_(op, self); } // normalize @@ -586,7 +586,7 @@ Mutate_(const Add* op, const Expr& self) { Expr CanonicalSimplifier::Impl:: Mutate_(const Sub* op, const Expr& self) { - if (!IsIndexType(op->type)) { + if (!IsIndexType(op->dtype)) { return Rewriter::Mutate_(op, self); } // normalize @@ -613,7 +613,7 @@ Mutate_(const Sub* op, const Expr& self) { Expr CanonicalSimplifier::Impl:: Mutate_(const Mul* op, const Expr& self) { - if (!IsIndexType(op->type)) { + if (!IsIndexType(op->dtype)) { return Rewriter::Mutate_(op, self); } // normalize @@ -657,8 +657,8 @@ SeparateDivisibleParts(const SumExprNode* psum, SumExpr* out_non_divisible) { auto divisible = make_node(); auto non_divisible = make_node(); - divisible->type = psum->type; - non_divisible->type = psum->type; + divisible->dtype = psum->dtype; + non_divisible->dtype = psum->dtype; if (psum->base % coeff == 0) { divisible->base = psum->base; @@ -698,11 +698,11 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { // (x % c1) / c2 => 0 when c2 >= c1 - return ToSplitExpr(make_zero(lhs.type())); + return ToSplitExpr(make_zero(lhs.dtype())); } else { // move the upper_factor modular into index. lhs.CopyOnWrite()->index = - ModImpl(lhs->index, make_const(lhs.type(), lhs->upper_factor), div_mode); + ModImpl(lhs->index, make_const(lhs.dtype(), lhs->upper_factor), div_mode); lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; lhs.CopyOnWrite()->scale = 1; lhs.CopyOnWrite()->lower_factor *= scaled_cval; @@ -720,7 +720,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { Expr CanonicalSimplifier::Impl:: Mutate_(const Div* op, const Expr& self) { - if (!IsIndexType(op->type)) { + if (!IsIndexType(op->dtype)) { return Rewriter::Mutate_(op, self); } @@ -764,7 +764,7 @@ Mutate_(const Div* op, const Expr& self) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return make_zero(a.type()); + return make_zero(a.dtype()); } } return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); @@ -781,7 +781,7 @@ Mutate_(const Div* op, const Expr& self) { Expr CanonicalSimplifier::Impl:: Mutate_(const FloorDiv* op, const Expr& self) { - if (!IsIndexType(op->type)) { + if (!IsIndexType(op->dtype)) { return Rewriter::Mutate_(op, self); } Expr a = this->CanonicalMutate(op->a); @@ -820,7 +820,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return make_zero(a.type()); + return make_zero(a.dtype()); } } return SplitDivConst(ToSplitExpr(std::move(a)), cval, kFloorDiv); @@ -859,7 +859,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { auto updated = ToSplitExpr(Mutate(ModImpl( - lhs->index, make_const(lhs.type(), new_upper_factor), div_mode))); + lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); // re-apply the lower_factor if (lhs->lower_factor != 1) { return SplitDivConst(updated, lhs->lower_factor, div_mode); @@ -887,7 +887,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { Expr CanonicalSimplifier::Impl:: Mutate_(const Mod* op, const Expr& self) { - if (!IsIndexType(op->type)) { + if (!IsIndexType(op->dtype)) { return Rewriter::Mutate_(op, self); } // normalize @@ -906,7 +906,7 @@ Mutate_(const Mod* op, const Expr& self) { SumExpr lhs, extra; SeparateDivisibleParts(psum, cval, &lhs, &extra); if (extra->IsZero()) { - return make_zero(a.type()); + return make_zero(a.dtype()); } // both lhs and extra are non-negative if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && @@ -957,7 +957,7 @@ Mutate_(const Mod* op, const Expr& self) { Expr CanonicalSimplifier::Impl:: Mutate_(const FloorMod* op, const Expr& self) { - if (!IsIndexType(op->type)) { + if (!IsIndexType(op->dtype)) { return Rewriter::Mutate_(op, self); } // normalize diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 4b001cfb8610..806587ab75aa 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -56,7 +56,7 @@ inline Expr ComputeReduce( const Array& values, Expr empty_value); inline bool GetConst(Expr e, int64_t* out) { - if (e.type().is_vector()) return false; + if (e.dtype().is_vector()) return false; const int64_t* v = as_const_int(e); if (v) { *out = *v; return true; diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 86f1927f2abe..93bf708a113f 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -70,7 +70,7 @@ inline Expr TryConstFold(Expr a); * \param type The type to represent index. * \return the checked result. */ -inline bool IsIndexType(const Type& type) { +inline bool IsIndexType(const DataType& type) { return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); } @@ -92,8 +92,8 @@ inline bool IsIndexType(const Type& type) { using ir::UIntImm; \ const IntImm* pa = a.as(); \ const IntImm* pb = b.as(); \ - const Type& ta = a.type(); \ - const Type& tb = b.type(); \ + const DataType& ta = a.dtype(); \ + const DataType& tb = b.dtype(); \ if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ BODY; \ } \ @@ -103,7 +103,7 @@ inline bool IsIndexType(const Type& type) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; @@ -117,7 +117,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value); @@ -129,7 +129,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); if (pa) { if (pa->value == 1) return b; @@ -155,7 +155,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) { // due to division and mod can have different modes // NOTE: this will assumes truc div. @@ -184,7 +184,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) { return IntImm::make(rtype, pa->value % pb->value); } @@ -202,7 +202,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) { CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm::make(rtype, arith::floordiv(pa->value, pb->value)); @@ -229,7 +229,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) { return IntImm::make(rtype, arith::floormod(pa->value, pb->value)); } @@ -247,7 +247,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); }); @@ -258,7 +258,7 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); }); @@ -269,8 +269,8 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value); + if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value > fb->value); }); return Expr(); } @@ -278,8 +278,8 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value); + if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value >= fb->value); }); return Expr(); } @@ -287,8 +287,8 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value); + if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value < fb->value); }); return Expr(); } @@ -296,8 +296,8 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value); + if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value <= fb->value); }); return Expr(); } @@ -305,8 +305,8 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value); + if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value == fb->value); }); return Expr(); } @@ -314,8 +314,8 @@ inline Expr TryConstFold(Expr a, Expr b) { template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value); + if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value != fb->value); }); return Expr(); } @@ -349,7 +349,7 @@ inline Expr TryConstFold(Expr a) { using ir::UIntImm; const UIntImm* pa = a.as(); if (pa) { - return UIntImm::make(UInt(1), !(pa->value)); + return UIntImm::make(DataType::UInt(1), !(pa->value)); } return Expr(); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 6e119695a8c8..c0519107d5b8 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -125,7 +125,7 @@ class ConstIntBoundAnalyzer::Impl : // Override visitor behaviors Entry VisitExprDefault_(const Node* op) final { return Everything( - static_cast(op)->type); + static_cast(op)->dtype); } Entry VisitExpr(const Expr& expr) final { @@ -142,7 +142,7 @@ class ConstIntBoundAnalyzer::Impl : Entry VisitExpr_(const Cast* op) final { Entry a = VisitExpr(op->value); - Entry b = Everything(op->type); + Entry b = Everything(op->dtype); return Intersect(a, b); } @@ -154,7 +154,7 @@ class ConstIntBoundAnalyzer::Impl : if (op->value <= static_cast(kPosInf)) { return MakeBound(op->value, op->value); } else { - return Everything(op->type); + return Everything(op->dtype); } } @@ -211,7 +211,7 @@ class ConstIntBoundAnalyzer::Impl : CHECK(!b.is_const(0)) << "mod by zero"; // mod by negative value is rare, // and we just use the simpliest rule. - return Everything(op->type); + return Everything(op->dtype); } } @@ -242,7 +242,7 @@ class ConstIntBoundAnalyzer::Impl : CHECK(!b.is_const(0)) << "floormod by zero"; // mod by negative value is rare, // and we just use the simpliest rule. - return Everything(op->type); + return Everything(op->dtype); } } @@ -278,7 +278,7 @@ class ConstIntBoundAnalyzer::Impl : } else if (op->is_intrinsic(Call::bitwise_and)) { return VisitBitwiseAnd(op); } else { - return Everything(op->type); + return Everything(op->dtype); } } @@ -288,7 +288,7 @@ class ConstIntBoundAnalyzer::Impl : if (it != var_map_.end()) { return it->second; } else { - return Everything(op->type); + return Everything(op->dtype); } } @@ -311,7 +311,7 @@ class ConstIntBoundAnalyzer::Impl : if (a.min_value >= 0) { return MakeBound(0, a.max_value); } - return Everything(op->type); + return Everything(op->dtype); } } @@ -466,7 +466,7 @@ class ConstIntBoundAnalyzer::Impl : * \param dtype The data type. * \return Bound that represent everything dtype can represent. */ - static Entry Everything(Type dtype) { + static Entry Everything(DataType dtype) { if (!dtype.is_int() && !dtype.is_uint()) { return MakeBound(kNegInf, kPosInf); } diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 8c7f4f2bb738..cf37545502ba 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -53,10 +53,10 @@ class LinearEqDetector *ret = VisitExpr(e, e); if (fail_) return false; if (!ret->base.defined()) { - ret->base = make_zero(var_.type()); + ret->base = make_zero(var_.dtype()); } if (!ret->coeff.defined()) { - ret->coeff = make_zero(var_.type()); + ret->coeff = make_zero(var_.dtype()); } return true; } @@ -100,7 +100,7 @@ class LinearEqDetector LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final { LinearEqEntry ret; if (op == var_.get()) { - ret.coeff = make_const(op->type, 1); + ret.coeff = make_const(op->dtype, 1); } else { ret.base = e; } @@ -190,16 +190,16 @@ bool DetectClipBound( // canonical form: exp >= 0 Expr canonical; if (const LT* op = cond.as()) { - if (!op->a.type().is_int()) return false; - canonical = op->b - op->a - make_const(op->a.type(), 1); + if (!op->a.dtype().is_int()) return false; + canonical = op->b - op->a - make_const(op->a.dtype(), 1); } else if (const LE* op = cond.as()) { - if (!op->a.type().is_int()) return false; + if (!op->a.dtype().is_int()) return false; canonical = op->b - op->a; } else if (const GT* op = cond.as()) { - if (!op->a.type().is_int()) return false; - canonical = op->a - op->b - make_const(op->a.type(), 1); + if (!op->a.dtype().is_int()) return false; + canonical = op->a - op->b - make_const(op->a.dtype(), 1); } else if (const GE* op = cond.as()) { - if (!op->a.type().is_int()) return false; + if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; } else { return false; diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc index c28346ed2e33..947f0050c6cb 100644 --- a/src/arithmetic/domain_touched.cc +++ b/src/arithmetic/domain_touched.cc @@ -72,7 +72,7 @@ class FuncTouchedDomain final : public IRVisitor { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); const Variable* var = thread_axis->var.get(); - dom_map_[var] = IntSet::range(Range(make_zero(op->value.type()), op->value)); + dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); IRVisitor::Visit_(op); dom_map_.erase(var); } else { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 9f8effb6c612..e4f2042a19d7 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -33,8 +33,8 @@ namespace tvm { namespace arith { -Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle()); -Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle()); +Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); +Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); IntervalSet::IntervalSet(Expr min_value, Expr max_value) { auto node = make_node(); @@ -54,8 +54,8 @@ TVM_REGISTER_API("arith._make_IntervalSet") IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { Expr max_value = min(a->max_value, b->max_value); Expr min_value = max(a->min_value, b->min_value); - if ((max_value.type().is_int() || max_value.type().is_uint()) && - (min_value.type().is_int() || min_value.type().is_uint()) && + if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && + (min_value.dtype().is_int() || min_value.dtype().is_uint()) && analyzer->CanProveGreaterEqual(min_value - max_value, 1)) { return IntervalSet::Empty(); } else { @@ -105,8 +105,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::SinglePoint(res); } if (is_logical_op::value) { - return IntervalSet(make_const(a->min_value.type(), 0), - make_const(a->min_value.type(), 1)); + return IntervalSet(make_const(a->min_value.dtype(), 0), + make_const(a->min_value.dtype(), 1)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; @@ -177,7 +177,7 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = a->min_value * b->min_value; Expr e2 = a->max_value * b->min_value; return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); @@ -212,7 +212,7 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = a->min_value / b->min_value; Expr e2 = a->max_value / b->min_value; return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); @@ -242,7 +242,7 @@ inline IntervalSet Combine(Analyzer* analyzer, // is the case of our application. // TODO(tqchen): add bound constraints for a. if (analyzer->CanProveGreaterEqual(divisor, 0)) { - return IntervalSet(make_zero(divisor.type()), divisor - 1); + return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { Expr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -278,7 +278,7 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = floordiv(a->min_value, b->min_value); Expr e2 = floordiv(a->max_value, b->min_value); return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); @@ -304,7 +304,7 @@ inline IntervalSet Combine(Analyzer* analyzer, LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { - return IntervalSet(make_zero(divisor.type()), divisor - 1); + return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { Expr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -476,7 +476,7 @@ class IntervalSetEvaluator : IntervalSet base = Eval(op->base); PVar stride; if (stride.Match(op->stride)) { - Type t = op->base.type(); + DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; if (vstride> 0) { return Combine( diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index cda9d585ace1..0d4b8f26b18b 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -140,7 +140,7 @@ Mutate_(const Call* op, const Expr& self) { false_value.same_as(op->args[2])) { return self; } else { - return Call::make(op->type, op->name, + return Call::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index f7d5483cf6de..fd07a377e955 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -291,7 +291,7 @@ class PConstWithTypeLike : } Expr Eval() const { - return make_const(ref_.Eval().type(), value_); + return make_const(ref_.Eval().dtype(), value_); } private: @@ -474,7 +474,7 @@ class PCastExpr : bool Match_(const NodeRef& node) const { if (const ir::Cast* ptr = node.as()) { - if (!dtype_.Match_(ptr->type)) return false; + if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; return true; } else { @@ -730,7 +730,7 @@ class PCallExpr : #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ struct OpName { \ static Expr Eval(Array args) { \ - return ir::Call::make(args[0].type(), kName, args, \ + return ir::Call::make(args[0].dtype(), kName, args, \ ir::Call::PureIntrinsic); \ } \ static constexpr const char* kName = IntrinStr; \ @@ -751,7 +751,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ struct OpName { \ static Expr Eval(Array args) { \ - return ir::Call::make(args[0].type(), kName, args, \ + return ir::Call::make(args[0].dtype(), kName, args, \ ir::Call::PureIntrinsic); \ } \ static constexpr const char* kName = IntrinStr; \ @@ -768,7 +768,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); struct PIfThenElseOp { static Expr Eval(Array args) { return ir::Call::make( - args[1].type(), kName, args, + args[1].dtype(), kName, args, ir::Call::PureIntrinsic); } static constexpr const char* kName = "tvm_if_then_else"; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index b26f8335055a..235306cc7bf8 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -129,7 +129,7 @@ Mutate_(const Add* op, const Expr& self) { // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), @@ -140,7 +140,7 @@ Mutate_(const Add* op, const Expr& self) { broadcast(x + y, lanes)); } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { // Index rules // cancelation rules TVM_TRY_REWRITE((x - y) + y, x); @@ -244,7 +244,7 @@ Mutate_(const Sub* op, const Expr& self) { // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), @@ -255,7 +255,7 @@ Mutate_(const Sub* op, const Expr& self) { broadcast(x - y, lanes)); } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { // Index rules // cancelation rules TVM_TRY_REWRITE((x + y) - y, x); @@ -443,7 +443,7 @@ Mutate_(const Mul* op, const Expr& self) { // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), @@ -452,7 +452,7 @@ Mutate_(const Mul* op, const Expr& self) { ramp(b1 * x, s1 * x, lanes)); } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { // constant simplification rule TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2); TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2)); @@ -484,12 +484,12 @@ Mutate_(const Div* op, const Expr& self) { // x / 2.0 = x * 0.5 if (const FloatImm* ptr = op->b.as()) { - CHECK(op->type.is_float()); - return op->a * make_const(op->b.type(), 1.0 / ptr->value); + CHECK(op->dtype.is_float()); + return op->a * make_const(op->b.dtype(), 1.0 / ptr->value); } // Vector rules - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { // NOTE: use div as the pattern also works for float. TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes)); @@ -512,7 +512,7 @@ Mutate_(const Div* op, const Expr& self) { } } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { // Be-aware of the division rules: // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. @@ -524,7 +524,7 @@ Mutate_(const Div* op, const Expr& self) { if (truncdiv(c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - return make_const(op->type, truncdiv(c1val, c2val)); + return make_const(op->dtype, truncdiv(c1val, c2val)); } // while it is always true for trunc div @@ -706,7 +706,7 @@ Mutate_(const Mod* op, const Expr& self) { PVar lanes; // Vector rules - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)), broadcast(truncmod(x, y), lanes)); @@ -734,7 +734,7 @@ Mutate_(const Mod* op, const Expr& self) { } } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { // Be-aware of the division rules: // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. @@ -762,9 +762,10 @@ Mutate_(const Mod* op, const Expr& self) { // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required - TVM_TRY_RECURSIVE_REWRITE_IF(truncmod(x, c1), - truncmod(x, PConst(make_const(op->type, -c1.Eval()->value))), - c1.Eval()->value < 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + truncmod(x, c1), + truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), + c1.Eval()->value < 0); // try modular analysis if (truncmod(x, c1).Match(ret)) { @@ -794,7 +795,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { PVar lanes; // Vector rules - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)), broadcast(floordiv(x, y), lanes)); // ramp // bcast @@ -814,7 +815,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { } } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { // Be-aware of the division rules: this is floor division. TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2), c1.Eval()->value > 0 && c2.Eval()->value > 0); @@ -939,7 +940,7 @@ Mutate_(const FloorMod* op, const Expr& self) { PVar lanes; // Vector rules - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)), broadcast(floormod(x, y), lanes)); @@ -964,7 +965,7 @@ Mutate_(const FloorMod* op, const Expr& self) { } } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x), c2.Eval()->value != 0 && @@ -1008,13 +1009,13 @@ Mutate_(const Min* op, const Expr& self) { PVar lanes; // vector rule - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes)); TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), min(x, broadcast(min(y, z), lanes))); } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { TVM_TRY_REWRITE(min(x, x), x); // constant int bound @@ -1193,13 +1194,13 @@ Mutate_(const Max* op, const Expr& self) { PVar lanes; // vector rule - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes)); TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), max(x, broadcast(max(y, z), lanes))); } - if (IsIndexType(op->type)) { + if (IsIndexType(op->dtype)) { TVM_TRY_REWRITE(max(x, x), x); // constant int bound @@ -1366,17 +1367,17 @@ Mutate_(const EQ* op, const Expr& self) { PVar lanes; // vector rule - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); } - if (IsIndexType(op->a.type())) { + if (IsIndexType(op->a.dtype())) { CompareResult result = TryCompare(op->a - op->b, 0); if (result == kEQ) { - return make_const(op->type, true); + return make_const(op->dtype, true); } else if (result == kNE || result == kGT || result == kLT) { - return make_const(op->type, false); + return make_const(op->dtype, false); } TVM_TRY_REWRITE(x - c1 == 0, x == c1); TVM_TRY_REWRITE(c1 - x == 0, x == c1); @@ -1420,20 +1421,20 @@ Mutate_(const LT* op, const Expr& self) { PVar lanes; // vector rule - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); } - if (IsIndexType(op->a.type())) { + if (IsIndexType(op->a.dtype())) { CompareResult result = TryCompare(op->a - op->b, 0); if (result == kLT) { - return make_const(op->type, true); + return make_const(op->dtype, true); } if (result == kEQ || result == kGT || result == kGE) { - return make_const(op->type, false); + return make_const(op->dtype, false); } TVM_TRY_REWRITE(x + y < x + z, y < z); @@ -1571,7 +1572,7 @@ Mutate_(const Not* op, const Expr& self) { // Pattern var to match any expression PVar x, y; PVar lanes; - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); } @@ -1600,12 +1601,12 @@ Mutate_(const And* op, const Expr& self) { PVar c1, c2; PVar lanes; - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } - auto cfalse = PConst(make_const(op->type, false)); + auto cfalse = PConst(make_const(op->dtype, false)); TVM_TRY_REWRITE(x == y && x != y, cfalse); TVM_TRY_REWRITE(x != y && x == y, cfalse); TVM_TRY_REWRITE(x && !x, cfalse); @@ -1649,12 +1650,12 @@ Mutate_(const Or* op, const Expr& self) { PVar c1, c2; PVar lanes; - if (op->type.lanes() != 1) { + if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } - auto ctrue = PConst(make_const(op->type, true)); + auto ctrue = PConst(make_const(op->dtype, true)); TVM_TRY_REWRITE(x == y || x != y, ctrue); TVM_TRY_REWRITE(x != y || x == y, ctrue); @@ -1720,7 +1721,7 @@ Mutate_(const Call* op, const Expr& self) { for (const auto& constraint : literal_constraints_) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (Equal(constraint, op->args[0])) { - return make_const(op->type, true); + return make_const(op->dtype, true); } } } @@ -1741,7 +1742,7 @@ Expr RewriteSimplifier::Impl:: Mutate_(const Cast* op, const Expr& self) { Expr ret = IRMutator::Mutate_(op, self); op = ret.as(); - return cast(op->type, op->value); + return cast(op->dtype, op->value); } Expr RewriteSimplifier::Impl:: diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index 101d8f1aa57f..f66a724595c6 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > *re feature_row.push_back(Array{std::string("_itervar_"), var}); Array attr{std::string("_attr_"), - FloatImm::make(Float(32), trans(fea.length)), - IntImm::make(Int(32), fea.nest_level), - FloatImm::make(Float(32), trans(fea.topdown_product)), - FloatImm::make(Float(32), trans(fea.bottomup_product)), + FloatImm::make(DataType::Float(32), trans(fea.length)), + IntImm::make(DataType::Int(32), fea.nest_level), + FloatImm::make(DataType::Float(32), trans(fea.topdown_product)), + FloatImm::make(DataType::Float(32), trans(fea.bottomup_product)), }; // one hot annotation for (int i = 0; i < kNum; i++) { @@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > *re // arithmetic feature_row.push_back(Array{std::string("_arith_"), - FloatImm::make(Float(32), trans(fea.add_ct)), - FloatImm::make(Float(32), trans(fea.mul_ct)), - FloatImm::make(Float(32), trans(fea.div_ct)), + FloatImm::make(DataType::Float(32), trans(fea.add_ct)), + FloatImm::make(DataType::Float(32), trans(fea.mul_ct)), + FloatImm::make(DataType::Float(32), trans(fea.div_ct)), }); // touch map @@ -282,12 +282,12 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > *re for (auto k : bufs) { TouchPattern &v = fea.touch_feature[k]; feature_row.push_back(Array{k, - FloatImm::make(Float(32), trans(v.stride)), - FloatImm::make(Float(32), trans(v.mod)), - FloatImm::make(Float(32), trans(v.count)), - FloatImm::make(Float(32), trans(v.reuse)), - FloatImm::make(Float(32), trans(v.thread_count)), - FloatImm::make(Float(32), trans(v.thread_reuse)), + FloatImm::make(DataType::Float(32), trans(v.stride)), + FloatImm::make(DataType::Float(32), trans(v.mod)), + FloatImm::make(DataType::Float(32), trans(v.count)), + FloatImm::make(DataType::Float(32), trans(v.reuse)), + FloatImm::make(DataType::Float(32), trans(v.thread_count)), + FloatImm::make(DataType::Float(32), trans(v.thread_reuse)), }); } diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index e6690641edc6..1028b0144e12 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -91,31 +91,31 @@ class TouchExtractor : public FeatureVisitor { // arithmetic stats void Visit_(const Add *op) { - if (op->type.is_float()) + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; IRVisitor::Visit_(op); } void Visit_(const Sub *op) { - if (op->type.is_float()) + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; IRVisitor::Visit_(op); } void Visit_(const Mul *op) { - if (op->type.is_float()) + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; IRVisitor::Visit_(op); } void Visit_(const Div *op) { - if (op->type.is_float()) + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; IRVisitor::Visit_(op); } void Visit_(const Mod *op) { - if (op->type.is_float()) + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; IRVisitor::Visit_(op); } diff --git a/src/codegen/build_common.h b/src/codegen/build_common.h index 8a21aeea7eee..b2c895348a46 100644 --- a/src/codegen/build_common.h +++ b/src/codegen/build_common.h @@ -39,7 +39,7 @@ ExtractFuncInfo(const Array& funcs) { for (LoweredFunc f : funcs) { runtime::FunctionInfo info; for (size_t i = 0; i < f->args.size(); ++i) { - info.arg_types.push_back(Type2TVMType(f->args[i].type())); + info.arg_types.push_back(f->args[i].dtype()); } for (size_t i = 0; i < f->thread_axis.size(); ++i) { info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag); diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index a7325a92f50a..ca25731cafef 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -334,12 +334,12 @@ Target DefaultTargetHost(Target target) { } Buffer BufferWithOffsetAlignment(Array shape, - Type dtype, + DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact) { - auto data = Var(name, Handle()); + auto data = Var(name, DataType::Handle()); bool has_any = false; if (!compact) { for (const auto& it : shape) { @@ -353,7 +353,7 @@ Buffer BufferWithOffsetAlignment(Array shape, Expr elem_offset; if (offset_factor != 0) { - elem_offset = Var(name + "_elem_offset", shape[0].type()); + elem_offset = Var(name + "_elem_offset", shape[0].dtype()); } else { elem_offset = Expr(); } diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index eab542dd3e08..4b95e2caf1aa 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -79,7 +79,7 @@ void CodeGenC::AddFunction(LoweredFunc f) { ReserveKeywordsAsUnique(); // add to alloc buffer type. for (const auto & kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.type()); + RegisterHandleType(kv.first.get(), kv.second.dtype()); } this->stream << "void " << f->name << "("; @@ -87,7 +87,7 @@ void CodeGenC::AddFunction(LoweredFunc f) { Var v = f->args[i]; std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; - if (v.type().is_handle()) { + if (v.dtype().is_handle()) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) PrintStorageScope(it->second, stream); @@ -104,7 +104,7 @@ void CodeGenC::AddFunction(LoweredFunc f) { stream << ' ' << restrict_keyword_; } } else { - PrintType(v.type(), stream); + PrintType(v.dtype(), stream); } stream << ' ' << vid; } @@ -125,14 +125,14 @@ void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { std::ostringstream temp; VisitExpr(n, temp); - os << SSAGetID(temp.str(), n.type()); + os << SSAGetID(temp.str(), n.dtype()); } else { VisitExpr(n, os); } } void CodeGenC::PrintSSAAssign( - const std::string& target, const std::string& src, Type t) { + const std::string& target, const std::string& src, DataType t) { PrintType(t, stream); stream << ' ' << target << " = "; if (src.length() > 3 && @@ -146,7 +146,7 @@ void CodeGenC::PrintSSAAssign( // Print a reference expression to a buffer. std::string CodeGenC::GetBufferRef( - Type t, const Variable* buffer, Expr index) { + DataType t, const Variable* buffer, Expr index) { std::ostringstream os; std::string vid = GetVarID(buffer); std::string scope; @@ -213,7 +213,7 @@ std::string CodeGenC::GetBufferRef( // Print a reference expression to a buffer. std::string CodeGenC::GetStructRef( - Type t, const Expr& buffer, const Expr& index, int kind) { + DataType t, const Expr& buffer, const Expr& index, int kind) { if (kind < intrinsic::kArrKindBound_) { std::ostringstream os; os << "(((TVMArray*)"; @@ -265,13 +265,13 @@ std::string CodeGenC::GetStructRef( } -bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const { +bool CodeGenC::HandleTypeMatch(const Variable* buf_var, DataType t) const { auto it = handle_data_type_.find(buf_var); if (it == handle_data_type_.end()) return false; return it->second == t; } -void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) { +void CodeGenC::RegisterHandleType(const Variable* buf_var, DataType t) { auto it = handle_data_type_.find(buf_var); if (it == handle_data_type_.end()) { handle_data_type_[buf_var] = t; @@ -282,13 +282,13 @@ void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) { } void CodeGenC::PrintVecElemLoad(const std::string& vec, - Type t, int i, + DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << ".s" << std::hex << i << std::dec; } void CodeGenC::PrintVecElemStore(const std::string& vec, - Type t, int i, + DataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << ".s" << std::hex << i @@ -296,19 +296,19 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, } std::string CodeGenC::GetVecLoad( - Type t, const Variable* buffer, Expr base) { + DataType t, const Variable* buffer, Expr base) { return GetBufferRef(t, buffer, base); } void CodeGenC::PrintVecStore(const Variable* buffer, - Type t, Expr base, + DataType t, Expr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); this->PrintIndent(); stream << ref << " = " << value << ";\n"; } -std::string CodeGenC::CastFromTo(std::string value, Type from, Type target) { +std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) { if (from == target) return value; std::ostringstream os; os << "(("; @@ -328,7 +328,7 @@ void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { / CHECK_EQ(scope, "global"); } -void CodeGenC::PrintType(Type t, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*) CHECK_EQ(t.lanes(), 1) << "do not yet support vector types"; if (t.is_handle()) { @@ -360,48 +360,48 @@ void CodeGenC::PrintType(Type t, std::ostream& os) { // NOLINT(*) inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - if (op->type == Int(32)) { + if (op->dtype == DataType::Int(32)) { std::ostringstream temp; temp << op->value; p->MarkConst(temp.str()); os << temp.str(); } else { os << "("; - p->PrintType(op->type, os); + p->PrintType(op->dtype, os); os << ")" << op->value; } } inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - if (op->type == UInt(32)) { + if (op->dtype == DataType::UInt(32)) { std::ostringstream temp; temp << op->value << "U"; p->MarkConst(temp.str()); os << temp.str(); } else { os << "("; - p->PrintType(op->type, os); + p->PrintType(op->dtype, os); os << ")" << op->value; } } inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - switch (op->type.bits()) { + switch (op->dtype.bits()) { case 64: case 32: { std::ostringstream temp; temp << std::scientific << op->value; - if (op->type.bits() == 32) temp << 'f'; + if (op->dtype.bits() == 32) temp << 'f'; p->MarkConst(temp.str()); os << temp.str(); break; } case 16: { os << '('; - p->PrintType(op->type, os); + p->PrintType(op->dtype, os); os << ')' << std::scientific <value << 'f'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n"; + default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } @@ -423,7 +423,7 @@ inline void PrintBinaryExpr(const T* op, const char *opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { - if (op->type.lanes() == 1) { + if (op->dtype.lanes() == 1) { if (isalpha(opstr[0])) { os << opstr << '('; p->PrintExpr(op->a, os); @@ -438,7 +438,7 @@ inline void PrintBinaryExpr(const T* op, os << ')'; } } else { - p->PrintVecBinaryOp(opstr, op->type, op->a, op->b, os); + p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); } } @@ -446,7 +446,7 @@ inline void PrintBinaryIntrinsic(const Call* op, const char *opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { - if (op->type.lanes() == 1) { + if (op->dtype.lanes() == 1) { CHECK_EQ(op->args.size(), 2U); os << '('; p->PrintExpr(op->args[0], os); @@ -454,13 +454,13 @@ inline void PrintBinaryIntrinsic(const Call* op, p->PrintExpr(op->args[1], os); os << ')'; } else { - p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os); + p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os); } } void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) std::stringstream value; this->PrintExpr(op->value, value); - os << CastFromTo(value.str(), op->value.type(), op->type); + os << CastFromTo(value.str(), op->value.dtype(), op->dtype); } void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); @@ -553,7 +553,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) const Load *l = op->args[0].as(); CHECK(op->args.size() == 1 && l); os << "(("; - this->PrintType(l->type.element_of(), os); + this->PrintType(l->dtype.element_of(), os); os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; this->PrintExpr(l->index, os); @@ -561,7 +561,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); os << GetStructRef( - op->type, op->args[0], op->args[1], + op->dtype, op->args[0], op->args[1], op->args[2].as()->value); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); @@ -571,7 +571,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) } else if (op->is_intrinsic(Call::reinterpret)) { // generate (*( TYPE *)(&(ARG))) os << "(*("; - this->PrintType(op->type, os); + this->PrintType(op->dtype, os); os << " *)(&("; this->PrintExpr(op->args[0], os); os << ")))"; @@ -585,7 +585,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) if (op->call_type == Call::Intrinsic || op->call_type == Call::PureIntrinsic) { LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->type; + << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } @@ -593,7 +593,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) } void CodeGenC::PrintVecBinaryOp( - const std::string& op, Type t, + const std::string& op, DataType t, Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*) if (isalpha(op[0])) { os << op << "("; @@ -611,17 +611,17 @@ void CodeGenC::PrintVecBinaryOp( } void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) - int lanes = op->type.lanes(); + int lanes = op->dtype.lanes(); // delcare type. - if (op->type.lanes() == 1) { - std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index); + if (op->dtype.lanes() == 1) { + std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index); os << ref; } else { CHECK(is_one(op->predicate)) << "predicated load is not supported"; Expr base; - if (GetRamp1Base(op->index, op->type.lanes(), &base)) { - std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base); + if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) { + std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base); os << ref; } else { // The assignment below introduces side-effect, and the resulting value cannot @@ -631,16 +631,16 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) // load seperately. std::string svalue = GetUniqueName("_"); this->PrintIndent(); - this->PrintType(op->type, stream); + this->PrintType(op->dtype, stream); stream << ' ' << svalue << ";\n"; - std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type()); + std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype()); std::string vid = GetVarID(op->buffer_var.get()); - Type elem_type = op->type.element_of(); + DataType elem_type = op->dtype.element_of(); for (int i = 0; i < lanes; ++i) { std::ostringstream value_temp; if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { value_temp << "(("; - if (op->buffer_var.get()->type.is_handle()) { + if (op->buffer_var.get()->dtype.is_handle()) { auto it = alloc_storage_scope_.find(op->buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, value_temp); @@ -653,9 +653,9 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) value_temp << vid; } value_temp << '['; - PrintVecElemLoad(sindex, op->index.type(), i, value_temp); + PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp); value_temp << ']'; - PrintVecElemStore(svalue, op->type, i, value_temp.str()); + PrintVecElemStore(svalue, op->dtype, i, value_temp.str()); } os << svalue; EndScope(vec_scope); @@ -664,7 +664,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitStmt_(const Store* op) { - Type t = op->value.type(); + DataType t = op->value.dtype(); if (t.lanes() == 1) { std::string value = this->PrintExpr(op->value); std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); @@ -683,15 +683,15 @@ void CodeGenC::VisitStmt_(const Store* op) { int vec_scope = BeginScope(); // store elements seperately - std::string index = SSAGetID(PrintExpr(op->index), op->index.type()); - std::string value = SSAGetID(PrintExpr(op->value), op->value.type()); + std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype()); + std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype()); std::string vid = GetVarID(op->buffer_var.get()); for (int i = 0; i < t.lanes(); ++i) { this->PrintIndent(); - Type elem_type = t.element_of(); + DataType elem_type = t.element_of(); if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { stream << "(("; - if (op->buffer_var.get()->type.is_handle()) { + if (op->buffer_var.get()->dtype.is_handle()) { auto it = alloc_storage_scope_.find(op->buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); @@ -704,9 +704,9 @@ void CodeGenC::VisitStmt_(const Store* op) { stream << vid; } stream << '['; - PrintVecElemLoad(index, op->index.type(), i, stream); + PrintVecElemLoad(index, op->index.dtype(), i, stream); stream << "] = "; - PrintVecElemLoad(value, op->value.type(), i, stream); + PrintVecElemLoad(value, op->value.dtype(), i, stream); stream << ";\n"; } EndScope(vec_scope); @@ -723,7 +723,7 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*) // constraint of current logic - CHECK_EQ(op->base.type(), Int(32)); + CHECK_EQ(op->base.dtype(), DataType::Int(32)); os << "((int" << op->lanes << ")("; for (int i = 0; i < op->lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; @@ -758,7 +758,7 @@ void CodeGenC::VisitStmt_(const LetStmt* op) { var_idmap_[op->var.get()] = value; } else { PrintIndent(); - if (op->var.type() == Handle() && + if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) { PrintType(handle_data_type_.at(op->var.get()), stream); stream << "* " @@ -767,7 +767,7 @@ void CodeGenC::VisitStmt_(const LetStmt* op) { PrintType(handle_data_type_.at(op->var.get()), stream); stream << "*)" << value << ";\n"; } else { - PrintType(op->var.type(), this->stream); + PrintType(op->var.dtype(), this->stream); this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n"; @@ -784,7 +784,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { CHECK_EQ(op->free_function, "nop"); std::string new_data = PrintExpr(op->new_expr); this->PrintIndent(); - PrintType(op->type, stream); + PrintType(op->dtype, stream); stream << "* "<< vid << '=' << new_data << ";\n"; } else { this->PrintIndent(); @@ -795,11 +795,11 @@ void CodeGenC::VisitStmt_(const Allocate* op) { std::string scope = alloc_storage_scope_.at(buffer); PrintStorageScope(scope, stream); stream << ' '; - PrintType(op->type, stream); + PrintType(op->dtype, stream); stream << ' '<< vid << '[' << constant_size << "];\n"; } - RegisterHandleType(op->buffer_var.get(), op->type); + RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } @@ -841,7 +841,7 @@ void CodeGenC::VisitStmt_(const For* op) { std::string vid = AllocVarID(op->loop_var.get()); CHECK(is_zero(op->min)); stream << "for ("; - PrintType(op->loop_var.type(), stream); + PrintType(op->loop_var.dtype(), stream); stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; @@ -890,7 +890,7 @@ void CodeGenC::VisitStmt_(const Evaluate *op) { CHECK_EQ(call->args.size(), 4); std::string value = PrintExpr(call->args[3]); std::string ref = GetStructRef( - call->args[3].type(), + call->args[3].dtype(), call->args[0], call->args[1], call->args[2].as()->value); diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 8701cda1e14c..b8d357051998 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -147,7 +147,7 @@ class CodeGenC : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(Type t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) /*! * \brief Print expr representing the thread tag * \param IterVar iv The thread index to be binded; @@ -157,51 +157,51 @@ class CodeGenC : virtual void PrintStorageSync(const Call* op); // NOLINT(*) // Binary vector op. virtual void PrintVecBinaryOp( - const std::string&op, Type op_type, + const std::string&op, DataType op_type, Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) // print vector load - virtual std::string GetVecLoad(Type t, const Variable* buffer, Expr base); + virtual std::string GetVecLoad(DataType t, const Variable* buffer, Expr base); // print vector store virtual void PrintVecStore(const Variable* buffer, - Type t, Expr base, + DataType t, Expr base, const std::string& value); // NOLINT(*) // print load of single element virtual void PrintVecElemLoad( - const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*) + const std::string& vec, DataType t, int i, std::ostream& os); // NOLINT(*) // print store of single element. virtual void PrintVecElemStore( - const std::string& vec, Type t, int i, const std::string& value); + const std::string& vec, DataType t, int i, const std::string& value); // Get a cast type from to - virtual std::string CastFromTo(std::string value, Type from, Type target); + virtual std::string CastFromTo(std::string value, DataType from, DataType target); protected: // Print reference to struct location std::string GetStructRef( - Type t, const Expr& buffer, const Expr& index, int kind); + DataType t, const Expr& buffer, const Expr& index, int kind); // print reference to a buffer as type t in index. virtual std::string GetBufferRef( - Type t, const Variable* buffer, Expr index); + DataType t, const Variable* buffer, Expr index); /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. * \param t The type to be checked. */ - bool HandleTypeMatch(const Variable* buf_var, Type t) const; + bool HandleTypeMatch(const Variable* buf_var, DataType t) const; /*! * \brief Register the data type of buf_var * \param buf_var The buffer variable. * \param t The type to be checked. */ - void RegisterHandleType(const Variable* buf_var, Type t); + void RegisterHandleType(const Variable* buf_var, DataType t); // override void PrintSSAAssign( - const std::string& target, const std::string& src, Type t) final; + const std::string& target, const std::string& src, DataType t) final; /*! \brief restrict keyword */ std::string restrict_keyword_{""}; /*! \brief the storage scope of allocation */ std::unordered_map alloc_storage_scope_; /*! \brief the data type of allocated buffers */ - std::unordered_map handle_data_type_; + std::unordered_map handle_data_type_; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); diff --git a/src/codegen/codegen_c_host.cc b/src/codegen/codegen_c_host.cc index 9c099a425fd6..f2c54c2700c9 100644 --- a/src/codegen/codegen_c_host.cc +++ b/src/codegen/codegen_c_host.cc @@ -48,7 +48,7 @@ void CodeGenCHost::AddFunction(LoweredFunc f) { ReserveKeywordsAsUnique(); // add to alloc buffer type. for (const auto & kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.type()); + RegisterHandleType(kv.first.get(), kv.second.dtype()); } this->stream << "#ifdef __cplusplus\n"; @@ -59,7 +59,7 @@ void CodeGenCHost::AddFunction(LoweredFunc f) { Var v = f->args[i]; std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; - if (v.type().is_handle()) { + if (v.dtype().is_handle()) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); @@ -77,7 +77,7 @@ void CodeGenCHost::AddFunction(LoweredFunc f) { stream << ' ' << restrict_keyword_; } } else { - PrintType(v.type(), stream); + PrintType(v.dtype(), stream); } stream << ' ' << vid; } @@ -96,14 +96,14 @@ std::string CodeGenCHost::Finish() { return CodeGenC::Finish(); } -void CodeGenCHost::PrintType(Type t, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { CHECK_EQ(lanes, 1) << "does not support vector types"; os << "void*"; return; } - if (t == Bool()) { + if (t == DataType::Bool()) { os << "bool"; return; } bool fail = false; @@ -145,7 +145,7 @@ void CodeGenCHost::PrintType(Type t, std::ostream& os) { // NOLINT(*) void CodeGenCHost::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; - PrintType(op->type, os); + PrintType(op->dtype, os); os << ")("; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; @@ -268,10 +268,10 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, std::ostream& os) { // NOLINT(*) std::ostringstream temp_a; VisitExpr(op->a, temp_a); - std::string a_id = SSAGetID(temp_a.str(), op->a.type()); + std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); std::ostringstream temp_b; VisitExpr(op->b, temp_b); - std::string b_id = SSAGetID(temp_b.str(), op->b.type()); + std::string b_id = SSAGetID(temp_b.str(), op->b.dtype()); os << "((" << a_id << ") " << compare << " (" << b_id << ") " << "? (" << a_id << ") : (" << b_id << "))"; diff --git a/src/codegen/codegen_c_host.h b/src/codegen/codegen_c_host.h index 80e359c33ce0..44f838536627 100644 --- a/src/codegen/codegen_c_host.h +++ b/src/codegen/codegen_c_host.h @@ -39,7 +39,7 @@ class CodeGenCHost final : public CodeGenC { void AddFunction(LoweredFunc f); std::string Finish(); - void PrintType(Type t, std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) // overload visitor functions void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 6656fa07740d..06b542a66323 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -105,10 +105,10 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) { void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); var_idmap_[iv->var.get()] = - CastFromTo(iv->thread_tag, UInt(32), iv->var.type()); + CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); } -void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { CHECK_EQ(lanes, 1) @@ -137,7 +137,7 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } - } else if (t == Bool()) { + } else if (t == DataType::Bool()) { os << "bool"; return; } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -199,7 +199,7 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) } void CodeGenCUDA::PrintVecBinaryOp( - const std::string&op, Type t, + const std::string&op, DataType t, Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*) // unpacking operations. int lanes = t.lanes(); @@ -210,8 +210,8 @@ void CodeGenCUDA::PrintVecBinaryOp( int vec_scope = BeginScope(); // default: unpack into individual ops. - std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.type()); - std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.type()); + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); std::string sret = GetUniqueName("_"); { // delcare type. @@ -223,15 +223,15 @@ void CodeGenCUDA::PrintVecBinaryOp( std::ostringstream value_temp; if (isalpha(op[0])) { value_temp << op << "("; - PrintVecElemLoad(vlhs, lhs.type(), i, value_temp); + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); value_temp << ", "; - PrintVecElemLoad(vrhs, rhs.type(), i, value_temp); + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); value_temp << ")"; } else { value_temp << "("; - PrintVecElemLoad(vlhs, lhs.type(), i, value_temp); + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); value_temp << op; - PrintVecElemLoad(vrhs, rhs.type(), i, value_temp); + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); value_temp << ")"; } PrintVecElemStore(sret, t, i, value_temp.str()); @@ -242,7 +242,7 @@ void CodeGenCUDA::PrintVecBinaryOp( } void CodeGenCUDA::PrintVecElemLoad( - const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*) + const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < 4); if (t.is_int() && t.bits() == 8) { @@ -253,7 +253,7 @@ void CodeGenCUDA::PrintVecElemLoad( } void CodeGenCUDA::PrintVecElemStore( - const std::string& vec, Type t, int i, const std::string& value) { + const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < 4); @@ -390,7 +390,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { CHECK_EQ(op->free_function, "nop"); std::string new_data = PrintExpr(op->new_expr); this->PrintIndent(); - PrintType(op->type, stream); + PrintType(op->dtype, stream); stream << "* "<< vid << '=' << new_data << ";\n"; } else { this->PrintIndent(); @@ -401,23 +401,27 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { std::string scope = alloc_storage_scope_.at(buffer); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8)) + CHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8)) << "Matrix_a and matrix_b only support half or char or unsigned char type for now"; } else { - CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32)) + CHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Float(32) || + op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); - PrintWmmaScope(scope, op->type, buffer, stream); + PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); stream << ' '; - PrintType(op->type, stream); + PrintType(op->dtype, stream); } stream << ' '<< vid << '[' << constant_size << "];\n"; } - RegisterHandleType(op->buffer_var.get(), op->type); + RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } @@ -449,7 +453,7 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) { } void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) - if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) { + if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) { // make_int8x4 const int64_t *p = as_const_int(op->value); CHECK(p); @@ -461,7 +465,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN std::string v = PrintExpr(op->value); os << "make_"; - PrintType(op->type, os); + PrintType(op->dtype, os); os << '('; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; @@ -473,11 +477,11 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) { std::vector to_shuffle(op->vectors.size()); for (int i = 0, e = op->vectors.size(); i < e; ++i) { - CHECK(op->vectors[i].type().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; + CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; to_shuffle[i] = PrintExpr(op->vectors[i]); } os << "make_"; - PrintType(op->type, os); + PrintType(op->dtype, os); os << '('; for (int i = 0, e = op->indices.size(); i < e; ++i) { const int64_t *val = as_const_int(op->indices[i]); @@ -489,21 +493,21 @@ void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) { } inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) - switch (op->type.bits()) { + switch (op->dtype.bits()) { case 64: case 32: { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { temp << "-"; } - temp << ((op->type.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); + temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); p->need_math_constants_h_ = true; } else if (std::isnan(op->value)) { - temp << ((op->type.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); + temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); p->need_math_constants_h_ = true; } else { temp << std::scientific << op->value; - if (op->type.bits() == 32) temp << 'f'; + if (op->dtype.bits() == 32) temp << 'f'; } p->MarkConst(temp.str()); os << temp.str(); @@ -514,7 +518,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { / os << '(' << std::scientific << op->value << 'f' << ')'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n"; + default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } @@ -523,7 +527,7 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, +void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, const Variable* variable, std::ostream &os) { std::stringstream type; PrintType(t, type); diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index efb300415b56..74d6fba35fc7 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -47,13 +47,13 @@ class CodeGenCUDA final : public CodeGenC { void PrintStorageSync(const Call* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp( - const std::string&op, Type t, + const std::string&op, DataType t, Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*) - void PrintType(Type t, std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintVecElemLoad( - const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*) + const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) void PrintVecElemStore( - const std::string& vec, Type t, int i, const std::string& value) final; + const std::string& vec, DataType t, int i, const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // overload visitor void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*) @@ -84,8 +84,10 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os); - int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size); + void PrintWmmaScope( + const std::string& scope, DataType t, const Variable* variable, std::ostream& os); + int32_t GetWmmaFragmentSize( + const std::string &scope, const Variable* variable, int32_t size); }; } // namespace codegen diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc index 311bdcbfa8d4..f4ff014c7297 100644 --- a/src/codegen/codegen_metal.cc +++ b/src/codegen/codegen_metal.cc @@ -36,7 +36,7 @@ void CodeGenMetal::InitFuncState(LoweredFunc f) { CodeGenC::InitFuncState(f); // analyze the data; for (Var arg : f->args) { - if (arg.type().is_handle()) { + if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -57,7 +57,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { GetUniqueName("_"); // add to alloc buffer type. for (const auto & kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.type()); + RegisterHandleType(kv.first.get(), kv.second.dtype()); } // Function header. this->stream << "kernel void " << f->name << "(\n"; @@ -65,7 +65,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { size_t num_buffer = 0; for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) { Var v = f->args[i]; - if (!v.type().is_handle()) break; + if (!v.dtype().is_handle()) break; stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); @@ -76,7 +76,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { PrintType(handle_data_type_.at(v.get()), stream); stream << "*"; } else { - PrintType(v.type(), stream); + PrintType(v.dtype(), stream); } stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; @@ -92,19 +92,19 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < f->args.size(); ++i) { Var v = f->args[i]; - CHECK(!v.type().is_handle()); + CHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; - if (v.type().bits() == 32) { + if (v.dtype().bits() == 32) { decl_stream << " "; - PrintType(v.type(), decl_stream); + PrintType(v.dtype(), decl_stream); decl_stream << " " << vid << ";\n"; vref << varg << "." << vid; } else { // For non 32bit type, ref through arg union. decl_stream << " __TVMArgUnion " << vid << ";\n"; vref << varg << "." << vid << ".v_"; - PrintType(v.type(), vref); + PrintType(v.dtype(), vref); } var_idmap_[v.get()] = vref.str(); } @@ -121,10 +121,10 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { if (work_dim != 0) { // use ushort by default for now stream << " "; - PrintType(UInt(thread_index_bits_, work_dim), stream); + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); stream << " blockIdx [[threadgroup_position_in_grid]],\n"; stream << " "; - PrintType(UInt(thread_index_bits_, work_dim), stream); + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } // bind thread axis @@ -135,7 +135,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { vname = vname.substr(0, iv->thread_tag.length() - 2); } var_idmap_[iv->var.get()] = - CastFromTo(vname, UInt(thread_index_bits_), iv->var.type()); + CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); } // the function scope. stream << ") {\n"; @@ -149,17 +149,17 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { void CodeGenMetal::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); var_idmap_[iv->var.get()] = - CastFromTo(iv->thread_tag, UInt(thread_index_bits_), iv->var.type()); + CastFromTo(iv->thread_tag, DataType::UInt(thread_index_bits_), iv->var.dtype()); } -void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { CHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } - if (t == Bool()) { + if (t == DataType::Bool()) { os << "bool"; return; } bool fail = false; @@ -210,13 +210,13 @@ void CodeGenMetal::PrintStorageSync(const Call* op) { } void CodeGenMetal::PrintVecElemLoad(const std::string& vec, - Type t, int i, + DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } void CodeGenMetal::PrintVecElemStore(const std::string& vec, - Type t, int i, + DataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "]" @@ -236,7 +236,7 @@ void CodeGenMetal::PrintStorageScope( void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - PrintType(op->type, os); + PrintType(op->dtype, os); os << "("; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; @@ -249,7 +249,7 @@ void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*) if (op->is_intrinsic(Call::reinterpret)) { // generate as_type(ARG) os << "(as_type<"; - this->PrintType(op->type, os); + this->PrintType(op->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << "))"; diff --git a/src/codegen/codegen_metal.h b/src/codegen/codegen_metal.h index c009cd1e9169..728e3e07a916 100644 --- a/src/codegen/codegen_metal.h +++ b/src/codegen/codegen_metal.h @@ -41,14 +41,14 @@ class CodeGenMetal final : public CodeGenC { void InitFuncState(LoweredFunc f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*) - void PrintType(Type t, std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // print load of single element void PrintVecElemLoad( - const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*) + const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) // print store of single element. void PrintVecElemStore( - const std::string& vec, Type t, int i, const std::string& value) final; + const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 49dccb173ed3..ae434197400f 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -39,7 +39,7 @@ CodeGenOpenCL::CodeGenOpenCL() { void CodeGenOpenCL::InitFuncState(LoweredFunc f) { CodeGenC::InitFuncState(f); for (Var arg : f->args) { - if (arg.type().is_handle()) { + if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -89,17 +89,17 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { os << "get_group_id(" << ts.dim_index << ")"; } var_idmap_[iv->var.get()] = - CastFromTo(os.str(), UInt(64), iv->var.type()); + CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); } -void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { CHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } - if (t == Bool()) { + if (t == DataType::Bool()) { os << "bool"; return; } bool fail = false; @@ -144,7 +144,7 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } -void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t, +void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t, Expr base, std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { os << '('; @@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t, PrintExpr(base, os); } std::string CodeGenOpenCL::GetVecLoad( - Type t, const Variable* buffer, Expr base) { + DataType t, const Variable* buffer, Expr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -169,7 +169,7 @@ std::string CodeGenOpenCL::GetVecLoad( } void CodeGenOpenCL::PrintVecStore(const Variable* buffer, - Type t, Expr base, + DataType t, Expr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; @@ -199,7 +199,7 @@ void CodeGenOpenCL::PrintStorageScope( } } -std::string CodeGenOpenCL::CastFromTo(std::string value, Type from, Type target) { +std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) { if (from == target) return value; std::ostringstream os; if (target.lanes() == 1) { @@ -218,7 +218,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, Type from, Type target) void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; - PrintType(op->type, os); + PrintType(op->dtype, os); os << ")("; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; @@ -232,7 +232,7 @@ void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) * add a cast */ if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { os << "("; - PrintType(op->args[2].type(), os); + PrintType(op->args[2].dtype(), os); os << ")"; } CodeGenC::VisitExpr_(op, os); @@ -242,7 +242,7 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT( /* Return type of ternary expression is not always same as its sub-expressions, * add a cast */ os << "("; - PrintType(op->true_value.type(), os); + PrintType(op->true_value.dtype(), os); os << ")"; CodeGenC::VisitExpr_(op, os); } diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 32f4501276e7..36324eb431ae 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -43,16 +43,16 @@ class CodeGenOpenCL final : public CodeGenC { void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*) - void PrintType(Type t, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(Type t, const Variable* buffer, + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + std::string GetVecLoad(DataType t, const Variable* buffer, Expr base) final; void PrintVecStore(const Variable* buffer, - Type t, Expr base, + DataType t, Expr base, const std::string& value) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const Variable* buffer, Type t, + void PrintVecAddr(const Variable* buffer, DataType t, Expr base, std::ostream& os); // NOLINT(*) - std::string CastFromTo(std::string value, Type from, Type target); // NOLINT(*) + std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) // overload visitor void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 52e04db12480..db14be3b395e 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -59,7 +59,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { GetUniqueName("_"); // add to alloc buffer type. for (const auto& kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.type()); + RegisterHandleType(kv.first.get(), kv.second.dtype()); } // Allocate argument names. Store in `var_idmap_`. @@ -93,7 +93,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { auto type_it = this->handle_data_type_.find(arg.get()); CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type."; - auto type = Type2TVMType(type_it->second); + DLDataType type = type_it->second; CHECK_EQ(type.lanes, 1) << "Vector type not supported."; switch (type.code) { @@ -129,7 +129,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { // Format: "uniform {type} {name};" auto arg_name = GetVarID(arg.get()); - auto type = arg.get()->type; + auto type = arg.get()->dtype; this->decl_stream << "uniform "; PrintType(type, this->decl_stream); @@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) { // Print a reference expression to a buffer. // Format: texelFetch(buffer, index, 0).r std::string CodeGenOpenGL::GetBufferRef( - Type t, const Variable* buffer, Expr index) { + DataType t, const Variable* buffer, Expr index) { CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; @@ -221,7 +221,7 @@ std::string CodeGenOpenGL::GetBufferRef( } } -void CodeGenOpenGL::PrintType(Type t, std::ostream& os) { +void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) { switch (t.code()) { case kDLInt: CHECK_EQ(t.bits(), 32) << "Only support 32-bit int."; @@ -243,17 +243,17 @@ void CodeGenOpenGL::PrintType(Type t, std::ostream& os) { // Codegen for immediate values void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) { - CHECK_EQ(op->type, Int(32)) << "GLSL 3.0 only supports 32-bit ints."; + CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints."; CodeGenC::VisitExpr_(op, os); } void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) { - CHECK_EQ(op->type, UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; + CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; CodeGenC::VisitExpr_(op, os); } void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) { - CHECK_EQ(op->type, Float(32)) << "GLSL 3.0 only supports 32-bit floats."; + CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CodeGenC::VisitExpr_(op, os); } @@ -273,7 +273,7 @@ void CodeGenOpenGL::VisitStmt_(const Evaluate* op) { auto value = call->args[1]; // Doesn't support store to vector. - auto type = value.type(); + auto type = value.dtype(); CHECK_EQ(type.lanes(), 1) << "Vectorized store not implemented, type = " << type; diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index d18052f5f46c..46e87a8165c1 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -45,8 +45,8 @@ class CodeGenOpenGL final : public CodeGenC { void BindThreadIndex(const IterVar& iv) final; void VisitStmt_(const Store* op) final; std::string TexelFetch(const Variable* buffer, Expr index); - std::string GetBufferRef(Type t, const Variable* buffer, Expr index) final; - void PrintType(Type t, std::ostream& os) final; // NOLINT(*) + std::string GetBufferRef(DataType t, const Variable* buffer, Expr index) final; + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) // Codegen for immediate values void VisitExpr_(const IntImm* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/codegen_source_base.cc b/src/codegen/codegen_source_base.cc index 9a9f525d40f1..7c4ed5b91c8b 100644 --- a/src/codegen/codegen_source_base.cc +++ b/src/codegen/codegen_source_base.cc @@ -52,7 +52,7 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) { return prefix; } -std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) { +std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { if (name_alloc_map_.count(src)) return src; auto it = ssa_assign_map_.find(src); if (it != ssa_assign_map_.end()) { diff --git a/src/codegen/codegen_source_base.h b/src/codegen/codegen_source_base.h index e0608c6afbde..7fd0eef98a90 100644 --- a/src/codegen/codegen_source_base.h +++ b/src/codegen/codegen_source_base.h @@ -79,7 +79,7 @@ class CodeGenSourceBase { * \param src The source expression * \param t The type of the expression. */ - std::string SSAGetID(std::string src, Type t); + std::string SSAGetID(std::string src, DataType t); /*! * \brief get a unique name with the corresponding prefix * \param prefix The prefix of the name @@ -103,7 +103,7 @@ class CodeGenSourceBase { * \param t The type of target. */ virtual void PrintSSAAssign( - const std::string& target, const std::string& src, Type t) = 0; + const std::string& target, const std::string& src, DataType t) = 0; /*! \brief the declaration stream */ std::ostringstream decl_stream; diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index 84329f90ddfc..40550d9f9916 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -37,7 +37,7 @@ void CodeGenVivadoHLS::Init(bool output_ssa) { this->stream << "#include \n\n"; } -void CodeGenVivadoHLS::PrintType(Type t, std::ostream& os) { +void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { if (t.is_uint()) { switch (t.bits()) { case 8: @@ -78,7 +78,7 @@ void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) { for (size_t i = 0; i < f->args.size(); ++i) { Var v = f->args[i]; std::string vid = GetVarID(v.get()); - if (v.type().is_handle()) { + if (v.dtype().is_handle()) { this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n"; } this->stream << "#pragma HLS INTERFACE s_axilite port=" << vid << " bundle=control\n"; @@ -100,8 +100,8 @@ inline void PrintBinaryExpr(const T* op, void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) const char *opstr = "std::min"; - if (op->type.is_float()) { - switch (op->type.bits()) { + if (op->dtype.is_float()) { + switch (op->dtype.bits()) { case 32: opstr = "fminf"; break; case 64: @@ -114,8 +114,8 @@ void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT( void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) const char *opstr = "std::max"; - if (op->type.is_float()) { - switch (op->type.bits()) { + if (op->dtype.is_float()) { + switch (op->dtype.bits()) { case 32: opstr = "fmaxf"; break; case 64: diff --git a/src/codegen/codegen_vhls.h b/src/codegen/codegen_vhls.h index 4ec7b105385d..e678edb05198 100644 --- a/src/codegen/codegen_vhls.h +++ b/src/codegen/codegen_vhls.h @@ -35,7 +35,7 @@ namespace codegen { class CodeGenVivadoHLS final : public CodeGenC { public: void Init(bool output_ssa); - void PrintType(Type t, std::ostream& os); + void PrintType(DataType t, std::ostream& os); void AddFunction(LoweredFunc f); void PreFunctionBody(LoweredFunc f); void VisitExpr_(const Min *op, std::ostream& os); diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc index f765c0095ce1..219b485387d5 100644 --- a/src/codegen/intrin_rule.cc +++ b/src/codegen/intrin_rule.cc @@ -57,7 +57,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") const Call* call = e.as(); CHECK(call != nullptr); - auto one = make_const(call->args[0].type(), 1); + auto one = make_const(call->args[0].dtype(), 1); *rv = one / sqrt(call->args[0]); }); @@ -70,7 +70,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") const Call* call = e.as(); CHECK(call != nullptr); - auto one = make_const(call->args[0].type(), 1); + auto one = make_const(call->args[0].dtype(), 1); *rv = one / (one + exp(-call->args[0])); }); diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h index 9f3bd793dd39..581387da69cf 100644 --- a/src/codegen/intrin_rule.h +++ b/src/codegen/intrin_rule.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -37,10 +37,10 @@ using namespace ir; // Add float suffix to the intrinsics struct FloatSuffix { - std::string operator()(Type t, std::string name) const { - if (t == Float(32)) { + std::string operator()(DataType t, std::string name) const { + if (t == DataType::Float(32)) { return name + 'f'; - } else if (t == Float(64)) { + } else if (t == DataType::Float(64)) { return name; } else { return ""; @@ -50,7 +50,7 @@ struct FloatSuffix { // Return the intrinsic name struct Direct { - std::string operator()(Type t, std::string name) const { + std::string operator()(DataType t, std::string name) const { return name; } }; @@ -61,10 +61,10 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { Expr e = args[0]; const Call* call = e.as(); CHECK(call != nullptr); - std::string name = T()(call->type, call->name); + std::string name = T()(call->dtype, call->name); if (name.length() != 0) { *rv = Call::make( - call->type, name, call->args, Call::PureExtern); + call->dtype, name, call->args, Call::PureExtern); } else { *rv = e; } diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index 4fed20fce51d..3f6bc7ba1d06 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -28,7 +28,7 @@ namespace codegen { namespace intrin { // Add float suffix to the intrinsics, CUDA fast math. struct CUDAMath { - std::string operator()(Type t, std::string name) const { + std::string operator()(DataType t, std::string name) const { if (t.lanes() == 1) { if (t.is_float()) { switch (t.bits()) { @@ -44,7 +44,7 @@ struct CUDAMath { }; struct CUDAFastMath : public CUDAMath { - std::string operator()(Type t, std::string name) const { + std::string operator()(DataType t, std::string name) const { if (t.lanes() == 1 && t.is_float() && t.bits() == 32) { return "__" + name + 'f'; } else { @@ -55,7 +55,7 @@ struct CUDAFastMath : public CUDAMath { }; struct CUDAPopcount { - std::string operator()(Type t, std::string name) const { + std::string operator()(DataType t, std::string name) const { if (t.lanes() == 1 && t.is_uint()) { switch (t.bits()) { case 32: return "__popc"; @@ -68,7 +68,7 @@ struct CUDAPopcount { }; struct CUDAShuffle { - std::string operator()(Type t, std::string name) const { + std::string operator()(DataType t, std::string name) const { return "__shfl"; } }; diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index 246747cc361d..4b1d4033c16f 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension struct IntelShuffle { - std::string operator()(Type t, std::string name) const { + std::string operator()(DataType t, std::string name) const { return "intel_sub_group_shuffle"; } }; diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 491a304983c6..f57a3ca869ef 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -82,7 +82,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->type, constant_size); + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the AMD devices if (info.alignment > 16) { @@ -93,7 +93,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( - LLVMType(op->type), ConstInt32(constant_size)); + LLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 @@ -108,7 +108,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size); + llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable *global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", @@ -122,7 +122,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } } buf = builder_->CreatePointerCast( - buf, LLVMType(op->type)->getPointerTo( + buf, LLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 9b21455605c3..4c092dfe377a 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -61,14 +61,14 @@ Expr CodeGenARM::ARMPopcount(const Call *call) { ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu; // Fallback to default llvm lowering rule if input type not a full vector or half vector length - int total_size = call->type.bits() * call->type.lanes(); - if (!call->type.is_vector() || call->type.bits() == 8 || + int total_size = call->dtype.bits() * call->dtype.lanes(); + if (!call->dtype.is_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { Array vcnt_args; - vcnt_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id)); - vcnt_args.push_back(ir::UIntImm::make(UInt(32), 1)); + vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id)); + vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return ir::Call::make(call->type, "llvm_intrin", vcnt_args, Call::PureIntrinsic); + return ir::Call::make(call->dtype, "llvm_intrin", vcnt_args, Call::PureIntrinsic); } // Popcount lowering rule: @@ -77,9 +77,12 @@ Expr CodeGenARM::ARMPopcount(const Call *call) { // to return back to original input type // Dvisions are always divisible (number of bits = 64 or 128) - Type uint8_type = Type(e.type().code(), 8, e.type().bits() * e.type().lanes() / 8); - Type uint16_type = Type(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); - Type uint32_type = Type(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); + DataType uint8_type = DataType( + e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); + DataType uint16_type = DataType( + uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); + DataType uint32_type = DataType( + uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); // Interpret input as vector of 8bit values Expr input8 = reinterpret(uint8_type, e); @@ -87,37 +90,37 @@ Expr CodeGenARM::ARMPopcount(const Call *call) { const Call* c0 = input8.as(); CHECK(c0 != nullptr); Array vcnt8_args; - vcnt8_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id)); - vcnt8_args.push_back(ir::UIntImm::make(UInt(32), 1)); + vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id)); + vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; - vcnt16_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id)); - vcnt16_args.push_back(ir::UIntImm::make(UInt(32), 1)); + vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id)); + vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic); - if (call->type.bits() == 16) { + if (call->dtype.bits() == 16) { return vcnt16; } // Accumulation 16->32bit Array vcnt32_args; - vcnt32_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id)); - vcnt32_args.push_back(ir::UIntImm::make(UInt(32), 1)); + vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id)); + vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic); - if (call->type.bits() == 32) { + if (call->dtype.bits() == 32) { return vcnt32; } // Accumulation 32->64bit Array vcnt64_args; - vcnt64_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id)); - vcnt64_args.push_back(ir::UIntImm::make(UInt(32), 1)); + vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id)); + vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return ir::Call::make(call->type, "llvm_intrin", vcnt64_args, Call::PureIntrinsic); + return ir::Call::make(call->dtype, "llvm_intrin", vcnt64_args, Call::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc index 0ba0c584a590..9f1a2926f002 100644 --- a/src/codegen/llvm/codegen_cpu.cc +++ b/src/codegen/llvm/codegen_cpu.cc @@ -43,7 +43,7 @@ void CodeGenCPU::Init(const std::string& module_name, func_handle_map_.clear(); export_system_symbols_.clear(); // TVM runtime types - t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, TVMShapeIndexType().bits()); + t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, DataType::ShapeIndex().bits()); t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); t_tvm_func_handle_ = t_void_p_; @@ -252,7 +252,7 @@ std::unique_ptr CodeGenCPU::Finish() { return CodeGenLLVM::Finish(); } llvm::Value* CodeGenCPU::CreateStructRefPtr( - Type t, llvm::Value* buf, llvm::Value* index, int kind) { + DataType t, llvm::Value* buf, llvm::Value* index, int kind) { if (kind < intrinsic::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); @@ -329,7 +329,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const Call* op) { arg_types.push_back(v->getType()); } llvm::FunctionType* ftype = llvm::FunctionType::get( - LLVMType(op->type), arg_types, false); + LLVMType(op->dtype), arg_types, false); // Check if it is available in global function table as injected function. auto it = gv_func_map_.find(op->name); if (it != gv_func_map_.end()) { @@ -448,7 +448,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) { llvm::Argument* v = &(*it); const Var& var = vargs[idx]; new_vmap[var.get()] = v; - if (var.type().is_handle() && !alias_var_set_.count(var.get())) { + if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) { // set non alias. #if TVM_LLVM_VERSION >= 50 fcompute->addParamAttr(idx, llvm::Attribute::NoAlias); @@ -532,8 +532,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { UnpackClosureData(cdata, vfields, &new_vmap); // setup parallel env ParallelEnv par_env; - par_env.task_id = Var("task_id", Int(32)); - par_env.num_task = Var("num_task", Int(32)); + par_env.task_id = Var("task_id", DataType::Int(32)); + par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; new_vmap[par_env.num_task.get()] = builder_->CreateLoad( builder_->CreateInBoundsGEP( @@ -670,7 +670,7 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::BasicBlock * CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, - llvm::Value **ret_tcode, const Type &r_type, + llvm::Value **ret_tcode, const DataType &r_type, const int64_t begin, const int64_t end) { using llvm::BasicBlock; std::string func_name = args[0].as()->value; @@ -684,15 +684,15 @@ CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); llvm::Value *arg_tcode = - CreateBufferPtr(Int(32), stack_tcode, ConstInt32(begin)); + CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value *ret_value = builder_->CreateInBoundsGEP( builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - *ret_tcode = CreateBufferPtr(Int(32), stack_tcode, ConstInt32(end)); + *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall( RuntimeTVMFuncCall(), {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); - Type r_api_type = ir::APIType(r_type); + DataType r_api_type = ir::APIType(r_type); *rvalue = builder_->CreateAlignedLoad( builder_->CreatePointerCast(ret_value, LLVMType(r_api_type)->getPointerTo()), @@ -705,7 +705,7 @@ llvm::Value *CodeGenCPU::CreateCallPacked(const Call *op) { CHECK_EQ(op->args.size(), 5U); llvm::Value *rvalue = nullptr; llvm::Value *ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->type, + MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, op->args[4].as()->value); return rvalue; @@ -717,7 +717,7 @@ llvm::Value *CodeGenCPU::CreateCallTracePacked(const Call *op) { llvm::Value *rvalue = nullptr; llvm::Value *ret_tcode = nullptr; BasicBlock *end_block = MakeCallPacked( - op->args, &rvalue, &ret_tcode, op->type, op->args[3].as()->value, + op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, op->args[4].as()->value); // Get traced value. llvm::Value *traced_value = MakeValue(op->args[5]); @@ -800,7 +800,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; llvm::Value* ref = this->CreateStructRefPtr( - op->type, MakeValue(op->args[0]), + op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == intrinsic::kArrAddr) { return builder_->CreatePointerCast(ref, t_void_p_); @@ -812,7 +812,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); llvm::Value* ref = this->CreateStructRefPtr( - op->args[3].type(), MakeValue(op->args[0]), + op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); CHECK(kind != intrinsic::kArrAddr); if (value->getType()->isPointerTy()) { @@ -922,7 +922,7 @@ void CodeGenCPU::VisitStmt_(const For* op) { CHECK(parallel_env_.task_id.defined()); CHECK(parallel_env_.num_task.defined()); CHECK(parallel_env_.penv != nullptr); - Type t = op->extent.type(); + DataType t = op->extent.dtype(); Expr num_task = cast(t, parallel_env_.num_task); Expr task_id = cast(t, parallel_env_.task_id); CHECK(!parallel_env_.in_parallel_loop) diff --git a/src/codegen/llvm/codegen_cpu.h b/src/codegen/llvm/codegen_cpu.h index 52e6f6c6ef90..b9e127557e1a 100644 --- a/src/codegen/llvm/codegen_cpu.h +++ b/src/codegen/llvm/codegen_cpu.h @@ -96,14 +96,14 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); llvm::Value* PackClosureData(const Array& fields, uint64_t *num_bytes); - llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind); + llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); void UnpackClosureData(llvm::Value*cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. llvm::BasicBlock *MakeCallPacked(const Array &args, llvm::Value **rvalue, - llvm::Value **ret_tcode, const Type &r_type, + llvm::Value **ret_tcode, const DataType &r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const Call* op); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 2cff88b0bbf4..94ad8b76c9c9 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -115,11 +115,11 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { std::vector arg_types; is_restricted_ = f->is_restricted; for (Var arg : f->args) { - Type t = arg.type(); + DataType t = arg.dtype(); if (t.is_handle()) { auto it = f->handle_data_type.find(arg); if (it != f->handle_data_type.end()) { - arg_types.push_back(LLVMType((*it).second.type()) + arg_types.push_back(LLVMType((*it).second.dtype()) ->getPointerTo(GetGlobalAddressSpace())); } else { arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace())); @@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { alias_var_set_.insert(arg.get()); } } else { - arg_types.push_back(LLVMType(arg.type())); + arg_types.push_back(LLVMType(arg.dtype())); } } llvm::FunctionType* ftype = llvm::FunctionType::get( @@ -147,7 +147,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { const Var& var = f->args[i]; var_map_[var.get()] = v; if (is_restricted_) { - if (var.type().is_handle() && !alias_var_set_.count(var.get())) { + if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) { // set non alias. #if TVM_LLVM_VERSION >= 50 function_->addParamAttr(i, llvm::Attribute::NoAlias); @@ -302,7 +302,7 @@ unsigned CodeGenLLVM::GetGlobalAddressSpace() { return 0; } -llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { +llvm::Type* CodeGenLLVM::LLVMType(const DataType& t) const { if (t.is_handle()) { CHECK_EQ(t.lanes(), 1); return t_void_p_; @@ -335,7 +335,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const Variable* buffer, Expr index, - Type type) { + DataType type) { if (alias_var_set_.count(buffer) != 0) { // Mark all possibly aliased pointer as same type. llvm::MDNode* meta = md_tbaa_alias_set_; @@ -387,7 +387,7 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, md_builder_->createTBAAStructTagNode(meta, meta, 0)); } -void CodeGenLLVM::GetAlignment(Type t, +void CodeGenLLVM::GetAlignment(DataType t, const Variable* buf_var, const Expr& index, int* p_alignment, @@ -474,7 +474,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { } llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { - llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes))); + llvm::Value* mask = llvm::UndefValue::get(LLVMType(DataType::Int(32, target_lanes))); int num_elems = static_cast(vec->getType()->getVectorNumElements()); if (num_elems == target_lanes) return vec; CHECK_LT(num_elems, target_lanes); @@ -542,19 +542,19 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, loop_value->addIncoming(begin, pre_block); CHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end), + builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, md_very_likely_branch_); builder_->SetInsertPoint(for_body); this->VisitStmt(body); var_map_.erase(loop_var.get()); - llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride); + llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride); loop_value->addIncoming(loop_next, builder_->GetInsertBlock()); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_end); } // cast operatpr -llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { +llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { llvm::Type * target = LLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { @@ -609,7 +609,7 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { } llvm::Value* CodeGenLLVM::CreateBufferPtr( - Type t, llvm::Value* buffer, llvm::Value* index) { + DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_EQ(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); @@ -622,7 +622,7 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr( } llvm::Value* CodeGenLLVM::CreateBufferVecPtr( - Type t, llvm::Value* buffer, llvm::Value* index) { + DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_GT(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); @@ -647,7 +647,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { arg_type.push_back(arg_value.back()->getType()); } llvm::FunctionType* ftype = llvm::FunctionType::get( - LLVMType(op->type), arg_type, false); + LLVMType(op->dtype), arg_type, false); llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { f = llvm::Function::Create( @@ -674,7 +674,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { sig_type.push_back(arg_value.back()->getType()); } } - llvm::Type *return_type = LLVMType(op->type); + llvm::Type *return_type = LLVMType(op->dtype); if (sig_type.size() > 0 && return_type != sig_type[0]) { sig_type.insert(sig_type.begin(), return_type); } @@ -692,7 +692,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { } else if (op->is_intrinsic(Call::shift_left)) { return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1])); } else if (op->is_intrinsic(Call::shift_right)) { - if (op->args[0].type().is_int()) { + if (op->args[0].dtype().is_int()) { return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1])); } else { return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1])); @@ -707,13 +707,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { unsigned addrspace; if (!r) { ptr = CreateBufferPtr( - l->type, MakeValue(l->buffer_var), MakeValue(l->index)); + l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); addrspace = llvm::dyn_cast( ptr->getType())->getAddressSpace(); } else { - Expr index = r->base / make_const(Int(32), r->lanes); + Expr index = r->base / make_const(DataType::Int(32), r->lanes); ptr = CreateBufferVecPtr( - l->type, MakeValue(l->buffer_var), MakeValue(index)); + l->dtype, MakeValue(l->buffer_var), MakeValue(index)); addrspace = llvm::dyn_cast( ptr->getType())->getAddressSpace(); } @@ -723,7 +723,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { - CHECK_EQ(op->args[0].type().lanes(), 1) + CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; using llvm::BasicBlock; BasicBlock* then_block = BasicBlock::Create( @@ -747,7 +747,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { value->addIncoming(else_value, else_value_block); return value; } else if (op->is_intrinsic(Call::reinterpret)) { - llvm::Type * target = LLVMType(op->type); + llvm::Type * target = LLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); } else if (op->is_intrinsic(Call::isnan)) { // TODO(hgt312): set fast math flag @@ -779,13 +779,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { void CodeGenLLVM::Scalarize(const Expr& e, std::function f) { if (const Ramp* ramp = e.as()) { - for (int i = 0; i < ramp->type.lanes(); ++i) { + for (int i = 0; i < ramp->dtype.lanes(); ++i) { Expr offset = ramp->base + (ramp->stride * i); f(i, MakeValue(offset)); } } else { llvm::Value* value = MakeValue(e); - for (int i = 0; i < e.type().lanes(); ++i) { + for (int i = 0; i < e.dtype().lanes(); ++i) { f(i, builder_->CreateExtractElement(value, i)); } } @@ -798,18 +798,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) { - return CreateCast(op->value.type(), op->type, MakeValue(op->value)); + return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) { - return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value); + return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value); } llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) { - return llvm::ConstantInt::get(LLVMType(op->type), op->value); + return llvm::ConstantInt::get(LLVMType(op->dtype), op->value); } llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) { - return llvm::ConstantFP::get(LLVMType(op->type), op->value); + return llvm::ConstantFP::get(LLVMType(op->dtype), op->value); } llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) { @@ -818,7 +818,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) { #define DEFINE_CODEGEN_BINARY_OP(Op) \ llvm::Value* CodeGenLLVM::Create ## Op( \ - Type t, llvm::Value* a, llvm::Value *b) { \ + DataType t, llvm::Value* a, llvm::Value *b) { \ if (t.is_int()) { \ if (t.bits() >= 32) { \ return builder_->CreateNSW ## Op (a, b); \ @@ -837,7 +837,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) { } \ } \ llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \ - return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b)); \ + return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_BINARY_OP(Add); @@ -846,7 +846,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); #define DEFINE_CODEGEN_CMP_OP(Op) \ llvm::Value* CodeGenLLVM::Create ## Op( \ - Type t, llvm::Value* a, llvm::Value* b) { \ + DataType t, llvm::Value* a, llvm::Value* b) { \ if (t.is_int()) { \ return builder_->CreateICmpS ## Op (a, b); \ } else if (t.is_uint()) { \ @@ -857,7 +857,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); } \ } \ llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \ - return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \ + return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_CMP_OP(LT); @@ -868,12 +868,12 @@ DEFINE_CODEGEN_CMP_OP(GE); llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->type.is_int()) { + if (op->dtype.is_int()) { return builder_->CreateSDiv(a, b); - } else if (op->type.is_uint()) { + } else if (op->dtype.is_uint()) { return builder_->CreateUDiv(a, b); } else { - CHECK(op->type.is_float()); + CHECK(op->dtype.is_float()); return builder_->CreateFDiv(a, b); } } @@ -881,12 +881,12 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->type.is_int()) { + if (op->dtype.is_int()) { return builder_->CreateSRem(a, b); - } else if (op->type.is_uint()) { + } else if (op->dtype.is_uint()) { return builder_->CreateURem(a, b); } else { - CHECK(op->type.is_float()); + CHECK(op->dtype.is_float()); return builder_->CreateFRem(a, b); } } @@ -894,19 +894,19 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b); + return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b); + return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->a.type().is_int() || op->a.type().is_uint()) { + if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { return builder_->CreateICmpEQ(a, b); } else { return builder_->CreateFCmpOEQ(a, b); @@ -916,7 +916,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->a.type().is_int() || op->a.type().is_uint()) { + if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { return builder_->CreateICmpNE(a, b); } else { return builder_->CreateFCmpONE(a, b); @@ -950,7 +950,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { - Type t = op->type; + DataType t = op->dtype; bool is_volatile = volatile_buf_.count(op->buffer_var.get()); llvm::Value* buffer = MakeValue(op->buffer_var); llvm::Value* index = MakeValue(op->index); @@ -1010,10 +1010,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) { - llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type)); + llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->dtype)); for (int i = 0; i < op->lanes; ++i) { vec = builder_->CreateInsertElement( - vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)), + vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i)); } return vec; @@ -1024,7 +1024,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) { int total_lanes = 0; for (int i = 0, e = op->vectors.size(); i < e; ++i) { vecs[i] = VisitExpr(op->vectors[i]); - total_lanes += op->vectors[i].type().lanes(); + total_lanes += op->vectors[i].dtype().lanes(); } llvm::Value* v0 = CreateVecConcat(vecs); std::vector idx(op->indices.size()); @@ -1045,7 +1045,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) { void CodeGenLLVM::VisitStmt_(const Store* op) { CHECK(is_one(op->predicate)); - Type t = op->value.type(); + DataType t = op->value.dtype(); bool is_volatile = volatile_buf_.count(op->buffer_var.get()); llvm::Value* buffer = MakeValue(op->buffer_var); llvm::Value* index = MakeValue(op->index); @@ -1056,7 +1056,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); llvm::Value* ptr = CreateBufferPtr(t, buffer, index); llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); - AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type()); + AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype()); return; } else { // vector store @@ -1071,7 +1071,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { t.element_of(), buffer, MakeValue(ramp->base)); ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace)); llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); - AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type()); + AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype()); return; } } @@ -1084,7 +1084,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { llvm::StoreInst* store = builder_->CreateAlignedStore( builder_->CreateExtractElement(value, i), ptr, basic_align, is_volatile); - AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type()); + AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.dtype()); }; this->Scalarize(op->index, f); } @@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { << "Can only handle constant size stack allocation"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->type, constant_size); + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices if (info.alignment > 16) { @@ -1150,7 +1150,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { } llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( - LLVMType(op->type), ConstInt32(constant_size)); + LLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 @@ -1163,7 +1163,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { buf = alloca; } buf = builder_->CreatePointerCast( - buf, LLVMType(op->type)->getPointerTo( + buf, LLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; @@ -1204,7 +1204,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { void CodeGenLLVM::VisitStmt_(const LetStmt* op) { CHECK(!var_map_.count(op->var.get())); - if (op->var.type().is_handle()) { + if (op->var.dtype().is_handle()) { if (!is_restricted_) { alias_var_set_.insert(op->var.get()); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index b7d091b3921b..08c836adf9d0 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -206,12 +206,12 @@ class CodeGenLLVM : * \param t The original type. * \return LLVM type of t */ - llvm::Type* LLVMType(const Type& t) const; + llvm::Type* LLVMType(const DataType& t) const; // initialize the function state. void InitFuncState(); // Get alignment given index. void GetAlignment( - Type t, const Variable* buf_var, const Expr& index, + DataType t, const Variable* buf_var, const Expr& index, int* p_alignment, int* p_native_bits); // Get constant string llvm::Value* GetConstString(const std::string& str); @@ -221,19 +221,19 @@ class CodeGenLLVM : // handle module import void HandleImport(const std::string& code); // cast operatpr - llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); + llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value); // comparison op llvm::Value* GetVarValue(const Variable* v) const; - llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateGE(DataType t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); - llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); - llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index); + llvm::Value* CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); + llvm::Value* CreateBufferVecPtr(DataType t, llvm::Value* buffer, llvm::Value* index); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec); @@ -245,7 +245,7 @@ class CodeGenLLVM : llvm::Value* stride, const VarExpr& loop_var, const Stmt& body); // add alias information. - void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type); + void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, DataType type); // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc index b6bc6ef952fd..372408c5e666 100644 --- a/src/codegen/llvm/codegen_nvptx.cc +++ b/src/codegen/llvm/codegen_nvptx.cc @@ -58,7 +58,7 @@ class CodeGenNVPTX : public CodeGenLLVM { << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->type, constant_size); + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices if (info.alignment > 16) { @@ -69,7 +69,7 @@ class CodeGenNVPTX : public CodeGenLLVM { // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( - LLVMType(op->type), ConstInt32(constant_size)); + LLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 @@ -84,7 +84,7 @@ class CodeGenNVPTX : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size); + llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable *global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", @@ -98,7 +98,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } } buf = builder_->CreatePointerCast( - buf, LLVMType(op->type)->getPointerTo( + buf, LLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 804d9b2f1b37..5d72b56df376 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -74,8 +74,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) { // LLVM does not automatically generate the correct instruction sequences for // half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of // vcvtph2ps), so we explicitly generate them ourselves. - const auto from = op->value.type(); - const auto to = op->type; + const auto from = op->value.dtype(); + const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { CHECK_EQ(from.lanes(), to.lanes()); CHECK_NOTNULL(target_machine_); @@ -85,21 +85,25 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) { if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( - ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, LLVMType(Float(32, from.lanes())), + ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, + LLVMType(DataType::Float(32, from.lanes())), { - MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, - ir::Call::PureIntrinsic)), - MakeValue(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), from.lanes())), - /*mask=*/MakeValue(ir::IntImm::make(Int(16), -1)), - /*rounding-mode=*/MakeValue(ir::IntImm::make(Int(32), 4)), + MakeValue(ir::Call::make( + DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, + ir::Call::PureIntrinsic)), + MakeValue( + ir::Broadcast::make(ir::FloatImm::make(DataType::Float(32), 0), from.lanes())), + /*mask=*/MakeValue(ir::IntImm::make(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(ir::IntImm::make(DataType::Int(32), 4)), }); } if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin( - ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(Float(32, from.lanes())), - {MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, - ir::Call::PureIntrinsic))}); + ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())), + {MakeValue(ir::Call::make( + DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, + ir::Call::PureIntrinsic))}); } } diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index fd28d7e4594a..da07ff324b20 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -67,19 +67,19 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") const ir::Call* call = e.as(); CHECK(call != nullptr); const Expr& x = call->args[0]; - Expr one = make_const(x.type(), 1); - Expr two = make_const(x.type(), 2); - Expr neg_two = make_const(x.type(), -2); + Expr one = make_const(x.dtype(), 1); + Expr two = make_const(x.dtype(), 2); + Expr neg_two = make_const(x.dtype(), -2); Expr exp_neg2x = ir::Call::make( - x.type(), "exp", {neg_two * x}, ir::Call::PureIntrinsic); + x.dtype(), "exp", {neg_two * x}, ir::Call::PureIntrinsic); Expr exp_pos2x = ir::Call::make( - x.type(), "exp", {two * x}, ir::Call::PureIntrinsic); + x.dtype(), "exp", {two * x}, ir::Call::PureIntrinsic); Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); *rv = ir::Select::make( - x >= make_zero(x.type()), tanh_pos, tanh_neg); + x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index c0b5241e8876..7863a3dd7a96 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -41,14 +41,14 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImm::make(UInt(32), id)); - cargs.push_back(ir::UIntImm::make(UInt(32), num_signature)); + cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id)); + cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature)); for (Expr arg : call->args) { cargs.push_back(arg); } *rv = ir::Call::make( - call->type, "llvm_intrin", cargs, ir::Call::PureIntrinsic); + call->dtype, "llvm_intrin", cargs, ir::Call::PureIntrinsic); } template @@ -58,13 +58,13 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImm::make(UInt(32), id)); - cargs.push_back(ir::UIntImm::make(UInt(32), num_signature)); + cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id)); + cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature)); for (Expr arg : call->args) { cargs.push_back(arg); } *rv = ir::Call::make( - call->type, "llvm_intrin", cargs, ir::Call::Intrinsic); + call->dtype, "llvm_intrin", cargs, ir::Call::Intrinsic); } } // namespace codegen diff --git a/src/codegen/llvm/intrin_rule_nvptx.cc b/src/codegen/llvm/intrin_rule_nvptx.cc index 4718cf78062e..862d06b73a5f 100644 --- a/src/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/codegen/llvm/intrin_rule_nvptx.cc @@ -35,11 +35,11 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { using namespace ir; const Call* call = e.as(); CHECK(call != nullptr); - CHECK(call->type.bits() == 32 || call->type.bits() == 64) << "Only support float32 or float64."; + CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << call->name; - if (call->type.bits() == 32) intrinsic_name << "f"; - *rv = Call::make(call->type, intrinsic_name.str(), call->args, + if (call->dtype.bits() == 32) intrinsic_name << "f"; + *rv = Call::make(call->dtype, intrinsic_name.str(), call->args, Call::PureExtern); } diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index 5ad5261c81bf..22b324545825 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -36,8 +36,8 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { const Call* call = e.as(); CHECK(call != nullptr); std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << call->name << "_f" << call->type.bits(); - *rv = Call::make(call->type, intrinsic_name.str(), call->args, + intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); + *rv = Call::make(call->dtype, intrinsic_name.str(), call->args, Call::PureExtern); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index be2b6cc668eb..7800e47319e0 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -37,11 +37,11 @@ std::vector CodeGenSPIRV::BuildFunction(const LoweredFunc& f) { std::vector pod_args; uint32_t num_buffer = 0; for (Var arg : f->args) { - Type t = arg.type(); + DataType t = arg.dtype(); if (t.is_handle()) { auto it = f->handle_data_type.find(arg); if (it != f->handle_data_type.end()) { - Type value_type = (*it).second.type(); + DataType value_type = (*it).second.dtype(); spirv::Value arg_value = builder_->BufferArgument( builder_->GetSType(value_type), 0, num_buffer); storage_info_[arg.get()].UpdateContentType(value_type); @@ -61,7 +61,7 @@ std::vector CodeGenSPIRV::BuildFunction(const LoweredFunc& f) { if (pod_args.size() != 0) { std::vector value_types; for (size_t i = 0; i < pod_args.size(); ++i) { - value_types.push_back(builder_->GetSType(pod_args[i].type())); + value_types.push_back(builder_->GetSType(pod_args[i].dtype())); } spirv::Value ptr = builder_->DeclarePushConstant(value_types); for (size_t i = 0; i < pod_args.size(); ++i) { @@ -103,7 +103,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex( } else { v = builder_->GetWorkgroupID(ts.dim_index); } - return builder_->Cast(builder_->GetSType(iv->var.type()), v); + return builder_->Cast(builder_->GetSType(iv->var.dtype()), v); } spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) { @@ -112,7 +112,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) { if (sync == "warp") { return value; } else if (sync == "shared") { - auto type_int = builder_->GetSType(Int(32)); + auto type_int = builder_->GetSType(DataType::Int(32)); builder_->MakeInst( spv::OpControlBarrier, builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), @@ -133,15 +133,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Variable* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const IntImm* op) { - return builder_->IntImm(builder_->GetSType(op->type), op->value); + return builder_->IntImm(builder_->GetSType(op->dtype), op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImm* op) { - return builder_->UIntImm(builder_->GetSType(op->type), op->value); + return builder_->UIntImm(builder_->GetSType(op->dtype), op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImm* op) { - return builder_->FloatImm(builder_->GetSType(op->type), op->value); + return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) { @@ -150,7 +150,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const Cast* op) { - return builder_->Cast(builder_->GetSType(op->type), MakeValue(op->value)); + return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value)); } spirv::Value CodeGenSPIRV::VisitExpr_(const Add* op) { @@ -248,7 +248,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) { values.push_back(MakeValue(op->args[i])); } return builder_->CallGLSL450( - builder_->GetSType(op->type), inst_id, values); + builder_->GetSType(op->dtype), inst_id, values); } else if (op->is_intrinsic(Call::bitwise_and)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); @@ -277,13 +277,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); - if (op->args[0].type().is_int()) { + if (op->args[0].dtype().is_int()) { return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b); } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } } else if (op->is_intrinsic(Call::reinterpret)) { - return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->type), + return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return this->CreateStorageSync(op); @@ -316,17 +316,17 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) { } else if (op->is_intrinsic("popcount")) { return builder_->MakeValue( spv::OpBitCount, - builder_->GetSType(op->type), + builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else { if (op->call_type == Call::Intrinsic || op->call_type == Call::PureIntrinsic) { LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->type; + << " with return type " << op->dtype; } else if (op->call_type == Call::Extern || op->call_type == Call::PureExtern) { LOG(FATAL) << "Unresolved extern " << op->name - << " with return type " << op->type; + << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } @@ -341,7 +341,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) { spirv::Value v = base; if (i != 0) { spirv::Value offset = MakeValue( - make_const(op->stride.type(), i) * op->stride); + make_const(op->stride.dtype(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); @@ -364,7 +364,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) { CHECK(it != storage_info_.end()); StorageInfo& info = it->second; if (!info.content_fixed) { - info.UpdateContentType(op->type); + info.UpdateContentType(op->dtype); } spirv::SType content_type = builder_->GetSType(info.content_type); @@ -376,15 +376,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) { if (info.is_volatile) { mask |= spv::MemoryAccessVolatileMask; } - if (op->type.lanes() == 1) { - CHECK_EQ(info.content_type, op->type) + if (op->dtype.lanes() == 1) { + CHECK_EQ(info.content_type, op->dtype) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, index); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); } else { - if (op->type.element_of() == info.content_type) { + if (op->dtype.element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. std::vector values; auto f = [&](int i, spirv::Value index) { @@ -398,13 +398,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) { } else { if (const Ramp* ramp = op->index.as()) { if (is_one(ramp->stride)) { - CHECK_EQ(ramp->lanes, op->type.lanes()); + CHECK_EQ(ramp->lanes, op->dtype.lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; Expr vec_index = ir::Simplify( - ramp->base / make_const(ramp->base.type(), ramp->lanes)); + ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); @@ -420,14 +420,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) { void CodeGenSPIRV::Scalarize(const Expr& e, std::function f) { if (const Ramp* ramp = e.as()) { - for (int i = 0; i < ramp->type.lanes(); ++i) { + for (int i = 0; i < ramp->dtype.lanes(); ++i) { Expr offset = ramp->base + ramp->stride * i; f(i, MakeValue(offset)); } } else { - spirv::SType etype = builder_->GetSType(e.type().element_of()); + spirv::SType etype = builder_->GetSType(e.dtype().element_of()); spirv::Value value = MakeValue(e); - for (int i = 0; i < e.type().lanes(); ++i) { + for (int i = 0; i < e.dtype().lanes(); ++i) { f(i, builder_->MakeValue( spv::OpCompositeExtract, etype, value, i)); } @@ -441,7 +441,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { StorageInfo& info = it->second; if (!info.content_fixed) { - info.UpdateContentType(op->value.type()); + info.UpdateContentType(op->value.dtype()); } spirv::SType content_type = builder_->GetSType(info.content_type); @@ -455,15 +455,15 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { mask |= spv::MemoryAccessVolatileMask; } - if (op->value.type().lanes() == 1) { - CHECK_EQ(info.content_type, op->value.type()) + if (op->value.dtype().lanes() == 1) { + CHECK_EQ(info.content_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); } else { - if (op->value.type().element_of() == info.content_type) { + if (op->value.dtype().element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. auto f = [&](int i, spirv::Value index) { spirv::Value elem = builder_->MakeValue( @@ -476,13 +476,13 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { } else { if (const Ramp* ramp = op->index.as()) { if (is_one(ramp->stride)) { - CHECK_EQ(ramp->lanes, op->value.type().lanes()); + CHECK_EQ(ramp->lanes, op->value.dtype().lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; Expr vec_index = ir::Simplify( - ramp->base / make_const(ramp->base.type(), ramp->lanes)); + ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); builder_->MakeInst(spv::OpStore, ptr, value, mask); @@ -530,7 +530,7 @@ void CodeGenSPIRV::VisitStmt_(const For* op) { // loop continue builder_->StartLabel(continue_label); spirv::Value one = - op->loop_var.type().is_int() ? + op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) : builder_->UIntImm(loop_var.stype, 1); spirv::Value next_value = builder_->Add(loop_var, one); @@ -576,13 +576,13 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElse* op) { void CodeGenSPIRV::VisitStmt_(const Allocate* op) { CHECK(!is_zero(op->condition)); CHECK(!op->new_expr.defined()); - CHECK(!op->type.is_handle()); + CHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; StorageInfo& info = storage_info_[op->buffer_var.get()]; - spirv::SType etype = builder_->GetSType(op->type); + spirv::SType etype = builder_->GetSType(op->dtype); if (info.scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate( etype, static_cast(constant_size), @@ -597,7 +597,7 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) { spv::StorageClassWorkgroup); } CHECK(!info.content_fixed); - info.UpdateContentType(op->type); + info.UpdateContentType(op->dtype); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -632,7 +632,7 @@ void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { void CodeGenSPIRV::VisitStmt_(const LetStmt* op) { CHECK(!var_map_.count(op->var.get())); - CHECK(!op->var.type().is_handle()); + CHECK(!op->var.dtype().is_handle()); var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); this->VisitStmt(op->body); diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index eca361493e80..3d16377271c4 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -112,10 +112,10 @@ class CodeGenSPIRV: /*! \brief Whether it is volatile */ bool content_fixed{false}; /*! \brief Current content type */ - Type content_type{Handle()}; + DataType content_type{DataType::Handle()}; // Update content type if it hasn't beenupdated. - void UpdateContentType(Type type) { + void UpdateContentType(DataType type) { if (content_fixed) { CHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model"; diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index fca9aa203f80..7a347e5e8dbc 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -39,13 +39,13 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImm::make(UInt(32), id)); + cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id)); for (Expr arg : call->args) { cargs.push_back(arg); } *rv = ir::Call::make( - call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic); + call->dtype, "spirv_glsl450", cargs, ir::Call::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 35d57d7cc3f8..6f8d96e148c1 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -53,10 +53,10 @@ void IRBuilder::InitHeader() { void IRBuilder::InitPreDefs() { ext_glsl450_ = ExtInstImport("GLSL.std.450"); - t_int32_ = DeclareType(Int(32)); - t_uint32_ = DeclareType(UInt(32)); - t_bool_ = DeclareType(UInt(1)); - t_fp32_ = DeclareType(Float(32)); + t_int32_ = DeclareType(DataType::Int(32)); + t_uint32_ = DeclareType(DataType::UInt(32)); + t_bool_ = DeclareType(DataType::UInt(1)); + t_fp32_ = DeclareType(DataType::Float(32)); const_i32_zero_ = IntImm(t_int32_, 0); // declare void, and void functions t_void_.id = id_counter_++; @@ -66,14 +66,14 @@ void IRBuilder::InitPreDefs() { .AddSeq(t_void_func_, t_void_).Commit(&global_); } -SType IRBuilder::GetSType(const Type& dtype) { - if (dtype == Int(32)) { +SType IRBuilder::GetSType(const DataType& dtype) { + if (dtype == DataType::Int(32)) { return t_int32_; - } else if (dtype == UInt(1)) { + } else if (dtype == DataType::UInt(1)) { return t_bool_; - } else if (dtype == Float(32)) { + } else if (dtype == DataType::Float(32)) { return t_fp32_; - } else if (dtype == UInt(32)) { + } else if (dtype == DataType::UInt(32)) { return t_uint32_; } uint32_t type_key; @@ -99,7 +99,7 @@ SType IRBuilder::GetPointerType(const SType& value_type, } SType t; t.id = id_counter_++; - t.type = Handle(); + t.type = DataType::Handle(); t.element_type_id = value_type.id; t.storage_class = storage_class; ib_.Begin(spv::OpTypePointer) @@ -118,11 +118,11 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, SType arr_type; arr_type.id = id_counter_++; - arr_type.type = Handle(); + arr_type.type = DataType::Handle(); arr_type.element_type_id = value_type.id; if (num_elems != 0) { - Value length = UIntImm(GetSType(UInt(32)), num_elems); + Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems); ib_.Begin(spv::OpTypeArray) .AddSeq(arr_type, value_type, length).Commit(&global_); } else { @@ -138,7 +138,7 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, // declare struct of array SType struct_type; struct_type.id = id_counter_++; - struct_type.type = Handle(); + struct_type.type = DataType::Handle(); struct_type.element_type_id = value_type.id; ib_.Begin(spv::OpTypeStruct) .AddSeq(struct_type, arr_type).Commit(&global_); @@ -183,7 +183,7 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { } else { CHECK_EQ(dtype.type.bits(), 16); return Cast(dtype, - FloatImm(GetSType(Float(32)), value)); + FloatImm(GetSType(DataType::Float(32)), value)); } } @@ -206,7 +206,7 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { CHECK_EQ(push_const_.id, 0); SType struct_type; struct_type.id = id_counter_++; - struct_type.type = Handle(); + struct_type.type = DataType::Handle(); ib_.Begin(spv::OpTypeStruct).Add(struct_type); for (const SType& vtype : value_types) { ib_.Add(vtype); @@ -218,7 +218,7 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { ib_.Begin(spv::OpMemberDecorate) .AddSeq(struct_type, i, spv::DecorationOffset, offset) .Commit(&decorate_); - Type t = value_types[i].type; + DataType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); CHECK_EQ(nbits % 8 , 0); offset += nbits / 8; @@ -296,7 +296,7 @@ Value IRBuilder::Allocate(const SType& value_type, Value IRBuilder::GetWorkgroupID(uint32_t dim_index) { if (workgroup_id_.id == 0) { - SType vec3_type = this->GetSType(Int(32).with_lanes(3)); + SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); SType ptr_type = this->GetPointerType( vec3_type, spv::StorageClassInput); workgroup_id_ = NewValue(ptr_type, kVectorPtr); @@ -315,7 +315,7 @@ Value IRBuilder::GetWorkgroupID(uint32_t dim_index) { Value IRBuilder::GetLocalID(uint32_t dim_index) { if (local_id_.id == 0) { - SType vec3_type = this->GetSType(Int(32).with_lanes(3)); + SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); local_id_ = NewValue(ptr_type, kVectorPtr); ib_.Begin(spv::OpVariable) @@ -339,7 +339,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { } CHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); - if (dtype.type == UInt(1)) { + if (dtype.type == DataType::UInt(1)) { // bool types. if (*pvalue) { ib_.Begin(spv::OpConstantTrue).AddSeq(ret); @@ -367,7 +367,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { return ret; } -SType IRBuilder::DeclareType(const Type& dtype) { +SType IRBuilder::DeclareType(const DataType& dtype) { if (dtype.lanes() == 1) { SType t; t.id = id_counter_++; @@ -426,7 +426,7 @@ Value IRBuilder::CallGLSL450(const SType& ret_type, Value IRBuilder::Concat(const std::vector& vec) { bool is_const = vec[0].flag == kConstant; - Type etype = vec[0].stype.type; + DataType etype = vec[0].stype.type; int lanes = etype.lanes(); for (size_t i = 1; i < vec.size(); ++i) { CHECK_EQ(etype, vec[i].stype.type.element_of()) @@ -456,10 +456,10 @@ Value IRBuilder::Concat(const std::vector& vec) { Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { CHECK_NE(value.stype.id, 0U); if (value.stype.id == dst_type.id) return value; - const tvm::Type& from = value.stype.type; - const tvm::Type& to = dst_type.type; + const tvm::DataType& from = value.stype.type; + const tvm::DataType& to = dst_type.type; CHECK_EQ(from.lanes(), to.lanes()); - if (from == Bool()) { + if (from == DataType::Bool()) { if (to.is_int()) { return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); } else if (to.is_uint()) { @@ -471,7 +471,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { LOG(FATAL) << "cannot cast from " << from << " to " << to; return Value(); } - } else if (to == Bool()) { + } else if (to == DataType::Bool()) { if (from.is_int()) { return NE(value, IntImm(value.stype, 0)); } else if (to.is_uint()) { @@ -558,7 +558,7 @@ Value IRBuilder::Mod(Value a, Value b) { Value IRBuilder::_OpName(Value a, Value b) { \ CHECK_EQ(a.stype.id, b.stype.id); \ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \ + const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ if (a.stype.type.is_int()) { \ return MakeValue(spv::OpS##_Op, bool_type, a, b); \ } else if (a.stype.type.is_uint()) { \ @@ -578,7 +578,7 @@ DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); Value IRBuilder::_OpName(Value a, Value b) { \ CHECK_EQ(a.stype.id, b.stype.id); \ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \ + const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ return MakeValue(spv::OpI##_Op, bool_type, a, b); \ } else { \ @@ -592,7 +592,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { CHECK_EQ(a.stype.id, b.stype.id); - CHECK_EQ(cond.stype.type.element_of(), UInt(1)); + CHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1)); return MakeValue(spv::OpSelect, a.stype, cond, a, b); } diff --git a/src/codegen/spirv/ir_builder.h b/src/codegen/spirv/ir_builder.h index c04af743fbb8..3843cbb3c6a9 100644 --- a/src/codegen/spirv/ir_builder.h +++ b/src/codegen/spirv/ir_builder.h @@ -45,7 +45,7 @@ struct SType { /*! \brief The Id to represent type */ uint32_t id{0}; /*! \brief corresponding TVM type */ - tvm::Type type; + tvm::DataType type; /*! \brief content type id if it is a pointer/struct-array class */ uint32_t element_type_id{0}; /*! \brief The storage class, if it is a pointer */ @@ -424,7 +424,7 @@ class IRBuilder { * \param dtype The data type. * \return The corresponding spirv type. */ - SType GetSType(const tvm::Type& dtype); + SType GetSType(const tvm::DataType& dtype); /*! * \brief Get the pointer type that points to value_type * \param value_type. @@ -575,7 +575,7 @@ class IRBuilder { // get constant given value encoded in uint64_t Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type - SType DeclareType(const Type& dtype); + SType DeclareType(const DataType& dtype); /*! \brief internal instruction builder */ InstrBuilder ib_; /*! \brief Current label */ diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index fd2a5f764ff6..52cabaf0b6eb 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -100,12 +100,12 @@ int CodeGenStackVM::GetVarID(const Variable* v) const { void CodeGenStackVM::VisitExpr_(const Load* op) { this->Push(op->buffer_var); - StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type)); + StackVM::OpCode code = StackVM::GetLoad(op->dtype); if (const IntImm* index = op->index.as()) { this->PushOp(code, index->value); } else { this->Push(op->index); - this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes()); + this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); this->PushOp(code, 0); @@ -114,13 +114,13 @@ void CodeGenStackVM::VisitExpr_(const Load* op) { void CodeGenStackVM::VisitStmt_(const Store* op) { this->Push(op->buffer_var); - StackVM::OpCode code = StackVM::GetStore(Type2TVMType(op->value.type())); + StackVM::OpCode code = StackVM::GetStore(op->value.dtype()); if (const IntImm* index = op->index.as()) { this->Push(op->value); this->PushOp(code, index->value); } else { this->Push(op->index); - this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes()); + this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); this->Push(op->value); @@ -147,7 +147,7 @@ void CodeGenStackVM::VisitExpr_(const Call* op) { CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); this->Push(l->index); - this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes()); + this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); } else if (op->is_intrinsic(Call::reinterpret)) { @@ -248,7 +248,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const Expr& b) { this->Push(a); this->Push(b); - Type t = a.type(); + DataType t = a.dtype(); if (t.is_int()) { this->PushOp(op_int64); } else if (t.is_uint()) { @@ -258,7 +258,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, } } -void CodeGenStackVM::PushCast(Type dst, Type src) { +void CodeGenStackVM::PushCast(DataType dst, DataType src) { if (dst.is_int()) { if (src.is_int() || src.is_uint()) return; } else if (dst.is_uint()) { @@ -297,7 +297,7 @@ void CodeGenStackVM::VisitExpr_(const Variable *op) { void CodeGenStackVM::VisitExpr_(const Cast *op) { this->Push(op->value); - PushCast(op->type, op->value.type()); + PushCast(op->dtype, op->value.dtype()); } void CodeGenStackVM::VisitExpr_(const Add *op) { diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 1e6dd64476aa..dcae072c102d 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -108,7 +108,7 @@ class CodeGenStackVM const Expr& a, const Expr& b); // push cast; - void PushCast(Type dst, Type src); + void PushCast(DataType dst, DataType src); // overloadable functions // expression void VisitExpr_(const Variable* op) final; diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 9e55d9be13d5..2bb86093e2f8 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -57,7 +57,7 @@ std::string CodeGenHybrid::Finish() { return stream.str(); } -void CodeGenHybrid::PrintType(Type t, std::ostream &os) { +void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { if (t.is_float()) { os << "float"; CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64); @@ -76,11 +76,11 @@ void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT( os << op->value; } void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*) - PrintType(op->type, os); + PrintType(op->dtype, os); os << "(" << op->value << ")"; } void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) - PrintType(op->type, os); + PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; } void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*) @@ -92,7 +92,7 @@ inline void PrintBinaryExpr(const T* op, const char *opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->type.lanes() == 1) << "vec bin op not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; if (isalpha(opstr[0])) { os << opstr << '('; p->PrintExpr(op->a, os); @@ -114,7 +114,7 @@ inline void PrintBinaryIntrinsitc(const Call* op, const char *opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->type.lanes() == 1) << "vec bin intrin not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; CHECK_EQ(op->args.size(), 2U); os << '('; p->PrintExpr(op->args[0], os); @@ -124,10 +124,10 @@ inline void PrintBinaryIntrinsitc(const Call* op, } void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) - if (op->type == op->value.type()) { + if (op->dtype == op->value.dtype()) { PrintExpr(op->value, stream); } else { - PrintType(op->type, os); + PrintType(op->dtype, os); os << "("; PrintExpr(op->value, os); os << ")"; @@ -148,14 +148,14 @@ void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) } void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) - if (op->type.is_int()) + if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*) - if (op->type.is_int()) + if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); @@ -320,7 +320,7 @@ void CodeGenHybrid::VisitStmt_(const Realize *op) { } if (op->bounds.size() == 1) stream << ", "; stream << "), '"; - PrintType(op->type, stream); + PrintType(op->dtype, stream); stream << "', '"; stream << alloc_storage_scope_[op->func] << "')\n"; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 866756996f8d..2c719b0b3ecf 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -138,7 +138,7 @@ class CodeGenHybrid : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(Type t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) private: /*! \brief The current indent of the code dump. */ diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 007a68b1e629..b83734beacb3 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -177,7 +177,7 @@ bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { - if (lhs->type != rhs->type) return false; + if (lhs->dtype != rhs->dtype) return false; return Equal(lhs->value, rhs->value); } else { return false; @@ -188,7 +188,7 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->name == rhs->name && - lhs->type == rhs->type && + lhs->dtype == rhs->dtype && lhs->call_type == rhs->call_type && Equal(lhs->args, rhs->args); } else { @@ -290,7 +290,7 @@ size_t AttrsHashHandler::VisitAttr_(const Cast* op) { static size_t key = std::hash()(Cast::_type_key); AttrsHash hasher; size_t res = key; - res = Combine(res, hasher(op->type)); + res = Combine(res, hasher(op->dtype)); res = Combine(res, Hash(op->value)); return res; } @@ -300,7 +300,7 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) { AttrsHash hasher; size_t res = key; res = Combine(res, hasher(op->name)); - res = Combine(res, hasher(op->type)); + res = Combine(res, hasher(op->dtype)); res = Combine(res, Hash(op->args)); return res; } diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 77e741086a59..eb5d87efbbfa 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -42,10 +42,10 @@ Array SimplifyArray(Array array) { } Buffer decl_buffer(Array shape, - Type dtype, + DataType dtype, std::string name) { return BufferNode::make( - Var(name, Handle()), + Var(name, DataType::Handle()), dtype, shape, Array(), @@ -279,30 +279,30 @@ inline Expr ElemOffset(const BufferNode* n, Array index) { return base; } -inline Expr BufferOffset(const BufferNode* n, Array index, Type dtype) { +inline Expr BufferOffset(const BufferNode* n, Array index, DataType dtype) { Expr offset = ElemOffset(n, index); if (n->dtype.lanes() != 1) { - offset = offset * make_const(offset.type(), dtype.lanes()); + offset = offset * make_const(offset.dtype(), dtype.lanes()); } if (dtype.lanes() != 1) { - return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes()); + return ir::Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); } else { return offset; } } -Expr Buffer::vload(Array begin, Type dtype) const { - // specially handle bool, stored as Int(8) +Expr Buffer::vload(Array begin, DataType dtype) const { + // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; - if (dtype == Bool()) { + if (dtype == DataType::Bool()) { return ir::Cast::make( - Bool(), + DataType::Bool(), ir::Load::make( - Int(8), n->data, BufferOffset(n, begin, Int(8)), + DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), const_true())); } else { return ir::Load::make( @@ -312,17 +312,17 @@ Expr Buffer::vload(Array begin, Type dtype) const { } Stmt Buffer::vstore(Array begin, Expr value) const { - // specially handle bool, stored as Int(8) + // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); - Type dtype = value.type(); + DataType dtype = value.dtype(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; - if (value.type() == Bool()) { + if (value.dtype() == DataType::Bool()) { return ir::Store::make(n->data, - ir::Cast::make(Int(8), value), - BufferOffset(n, begin, Int(8)), + ir::Cast::make(DataType::Int(8), value), + BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), @@ -381,7 +381,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const { n->buffer_type); } -Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const { +Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, Expr offset) const { const BufferNode* self = operator->(); Expr e_dtype; Expr extent; @@ -396,21 +396,21 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); - extent = extent / make_const(self->elem_offset.type(), content_lanes); - elem_offset = self->elem_offset / make_const(self->elem_offset.type(), + extent = extent / make_const(self->elem_offset.dtype(), content_lanes); + elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); } else { e_dtype = ir::TypeAnnotation(self->dtype); } Array acc_args{ e_dtype, self->data, elem_offset, - extent, make_const(Int(32), access_mask)}; + extent, make_const(DataType::Int(32), access_mask)}; return ir::Call::make( ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic); } Buffer BufferNode::make(Var data, - Type dtype, + DataType dtype, Array shape, Array strides, Expr elem_offset, diff --git a/src/lang/channel.cc b/src/lang/channel.cc index cb3e2f566c77..555562aef065 100644 --- a/src/lang/channel.cc +++ b/src/lang/channel.cc @@ -24,7 +24,7 @@ namespace tvm { -Channel ChannelNode::make(Var handle_var, Type dtype) { +Channel ChannelNode::make(Var handle_var, DataType dtype) { auto n = make_node(); n->handle_var = handle_var; n->dtype = dtype; diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 6a69fdaa20c4..997c15177546 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -29,70 +29,11 @@ namespace tvm { -// maximum and min values -Expr DataType::max() const { - using namespace ir; - CHECK_EQ(lanes(), 1); - if (is_int()) { - if (bits() == 64) { - return IntImm::make(*this, std::numeric_limits::max()); - } else if (bits() < 64) { - int64_t val = 1; - val = (val << (bits() - 1)) - 1; - return IntImm::make(*this, val); - } - } else if (is_uint()) { - if (bits() == 64) { - return UIntImm::make(*this, std::numeric_limits::max()); - } else if (bits() < 64) { - uint64_t val = 1; - val = (val << static_cast(bits())) - 1; - return UIntImm::make(*this, val); - } - } else if (is_float()) { - if (bits() == 64) { - return FloatImm::make(*this, std::numeric_limits::max()); - } else if (bits() == 32) { - return FloatImm::make(*this, std::numeric_limits::max()); - } else if (bits() == 16) { - return FloatImm::make(*this, 65504.0); - } - } - LOG(FATAL) << "Cannot decide max_value for type" << *this; - return Expr(); -} - -Expr DataType::min() const { - using namespace ir; - CHECK_EQ(lanes(), 1); - if (is_int()) { - if (bits() == 64) { - return IntImm::make(*this, std::numeric_limits::lowest()); - } else if (bits() < 64) { - int64_t val = 1; - val = -(val << (bits() - 1)); - return IntImm::make(*this, val); - } - } else if (is_uint()) { - return UIntImm::make(*this, 0); - } else if (is_float()) { - if (bits() == 64) { - return FloatImm::make(*this, std::numeric_limits::lowest()); - } else if (bits() == 32) { - return FloatImm::make(*this, std::numeric_limits::lowest()); - } else if (bits() == 16) { - return FloatImm::make(*this, -65504.0); - } - } - LOG(FATAL) << "Cannot decide min_value for type" << *this; - return Expr(); -} - Expr::Expr(int32_t value) - : Expr(IntImm::make(Int(32), value)) {} + : Expr(IntImm::make(DataType::Int(32), value)) {} Expr::Expr(float value) - : Expr(ir::FloatImm::make(Float(32), value)) {} + : Expr(ir::FloatImm::make(DataType::Float(32), value)) {} Expr::Expr(std::string str) : Expr(ir::StringImm::make(str)) {} @@ -102,7 +43,7 @@ Var::Var(std::string name_hint, DataType t) Var Variable::make(DataType t, std::string name_hint) { NodePtr node = make_node(); - node->type = t; + node->dtype = t; node->name_hint = std::move(name_hint); return Var(node); } @@ -113,11 +54,11 @@ Range::Range(Expr begin, Expr end) is_zero(begin) ? end : (end - begin))) { } -Integer IntImm::make(Type t, int64_t value) { +Integer IntImm::make(DataType t, int64_t value) { CHECK(t.is_int() && t.is_scalar()) << "ValueError: IntImm can only take scalar."; NodePtr node = make_node(); - node->type = t; + node->dtype = t; node->value = value; return Integer(node); } @@ -152,7 +93,7 @@ void Dump(const NodeRef& n) { std::cerr << n << "\n"; } -Var var(std::string name_hint, Type t) { +Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } @@ -184,10 +125,10 @@ IRPrinter::FType& IRPrinter::vtable() { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter* p) { auto* op = static_cast(node.get()); - if (op->type == Int(32)) { + if (op->dtype == DataType::Int(32)) { p->stream << op->value; } else { - p->stream << "(" << op->type << ")" << op->value; + p->stream << "(" << op->dtype << ")" << op->value; } }); diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 220d4378cc97..1166e7eef976 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -30,16 +30,16 @@ namespace tvm { // simple cast that only checks if type matches and cast -inline Expr SimpleCast(const Type& t, Expr value) { - if (value.type() == t) return value; +inline Expr SimpleCast(const DataType& t, Expr value) { + if (value.dtype() == t) return value; return ir::Cast::make(t, value); } // The public function with a quick checking path. void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) - if (lhs.type() == rhs.type()) return; - Type ltype = lhs.type(); - Type rtype = rhs.type(); + if (lhs.dtype() == rhs.dtype()) return; + DataType ltype = lhs.dtype(); + DataType rtype = rhs.dtype(); if (ltype.lanes() == 1 && rtype.lanes() != 1) { lhs = ir::Broadcast::make(lhs, rtype.lanes()); } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { @@ -48,37 +48,96 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } - if (lhs.type() == rhs.type()) return; + if (lhs.dtype() == rhs.dtype()) return; // Only do very simple type coversion - // int->float, int(32)->int(64) + // int->float, DataType::Int(32)->int(64) // require the types to be relatively consistent // This will the reduce amount code generated by operators // and also help user to find potential type conversion problems. - if (!lhs.type().is_float() && rhs.type().is_float()) { + if (!lhs.dtype().is_float() && rhs.dtype().is_float()) { // int->float - lhs = cast(rhs.type(), lhs); - } else if (lhs.type().is_float() && !rhs.type().is_float()) { + lhs = cast(rhs.dtype(), lhs); + } else if (lhs.dtype().is_float() && !rhs.dtype().is_float()) { // int->float - rhs = cast(lhs.type(), rhs); - } else if ((lhs.type().is_int() && rhs.type().is_int()) || - (lhs.type().is_uint() && rhs.type().is_uint())) { + rhs = cast(lhs.dtype(), rhs); + } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) || + (lhs.dtype().is_uint() && rhs.dtype().is_uint())) { // promote int to higher bits - if (lhs.type().bits() < rhs.type().bits()) { - lhs = cast(rhs.type(), lhs); + if (lhs.dtype().bits() < rhs.dtype().bits()) { + lhs = cast(rhs.dtype(), lhs); } else { - rhs = cast(lhs.type(), rhs); + rhs = cast(lhs.dtype(), rhs); } - } else if ((lhs.type().is_int() && rhs.type().is_uint()) || - (lhs.type().is_uint() && rhs.type().is_int())) { - int bits = std::max(lhs.type().bits(), rhs.type().bits()); - lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs); - rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs); + } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) || + (lhs.dtype().is_uint() && rhs.dtype().is_int())) { + int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); + lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs); + rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs); } else { LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; } } +// maximum and min limits +Expr max_value(const DataType& dtype) { + using namespace ir; + CHECK_EQ(dtype.lanes(), 1); + if (dtype.is_int()) { + if (dtype.bits() == 64) { + return IntImm::make(dtype, std::numeric_limits::max()); + } else if (dtype.bits() < 64) { + int64_t val = 1; + val = (val << (dtype.bits() - 1)) - 1; + return IntImm::make(dtype, val); + } + } else if (dtype.is_uint()) { + if (dtype.bits() == 64) { + return UIntImm::make(dtype, std::numeric_limits::max()); + } else if (dtype.bits() < 64) { + uint64_t val = 1; + val = (val << static_cast(dtype.bits())) - 1; + return UIntImm::make(dtype, val); + } + } else if (dtype.is_float()) { + if (dtype.bits() == 64) { + return FloatImm::make(dtype, std::numeric_limits::max()); + } else if (dtype.bits() == 32) { + return FloatImm::make(dtype, std::numeric_limits::max()); + } else if (dtype.bits() == 16) { + return FloatImm::make(dtype, 65504.0); + } + } + LOG(FATAL) << "Cannot decide max_value for type" << dtype; + return Expr(); +} + +Expr min_value(const DataType& dtype) { + using namespace ir; + CHECK_EQ(dtype.lanes(), 1); + if (dtype.is_int()) { + if (dtype.bits() == 64) { + return IntImm::make(dtype, std::numeric_limits::lowest()); + } else if (dtype.bits() < 64) { + int64_t val = 1; + val = -(val << (dtype.bits() - 1)); + return IntImm::make(dtype, val); + } + } else if (dtype.is_uint()) { + return UIntImm::make(dtype, 0); + } else if (dtype.is_float()) { + if (dtype.bits() == 64) { + return FloatImm::make(dtype, std::numeric_limits::lowest()); + } else if (dtype.bits() == 32) { + return FloatImm::make(dtype, std::numeric_limits::lowest()); + } else if (dtype.bits() == 16) { + return FloatImm::make(dtype, -65504.0); + } + } + LOG(FATAL) << "Cannot decide min_value for type" << dtype; + return Expr(); +} + template inline bool ConstPowerHelper(ValueType val, int *shift) { if (val <= 0) return false; @@ -103,11 +162,11 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) { } } -Expr cast(const Type& t, Expr value) { +Expr cast(const DataType& t, Expr value) { using ir::IntImm; using ir::UIntImm; using ir::FloatImm; - if (value.type() == t) return value; + if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImm* op = value.as()) { @@ -119,10 +178,10 @@ Expr cast(const Type& t, Expr value) { } return ir::Cast::make(t, value); } else { - if (value.type().lanes() == 1) { + if (value.dtype().lanes() == 1) { // manually unroll cast - Type vtype = t.element_of(); - if (value.type() != vtype) { + DataType vtype = t.element_of(); + if (value.dtype() != vtype) { if (const IntImm* op = value.as()) { value = make_const(vtype, op->value); } else if (const UIntImm* op = value.as()) { @@ -135,14 +194,14 @@ Expr cast(const Type& t, Expr value) { } return ir::Broadcast::make(value, t.lanes()); } else { - CHECK(value.type().lanes() == t.lanes()); + CHECK(value.dtype().lanes() == t.lanes()); return ir::Cast::make(t, value); } } } -Expr reinterpret(const Type& t, Expr value) { - if (value.type() == t) return value; +Expr reinterpret(const DataType& t, Expr value) { + if (value.dtype() == t) return value; return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic); } @@ -159,9 +218,9 @@ Expr operator-(Expr a) { using ir::FloatImm; const IntImm* pa = a.as(); const FloatImm* fa = a.as(); - if (pa) return ir::IntImm::make(a.type(), -pa->value); - if (fa) return ir::FloatImm::make(a.type(), -fa->value); - return make_zero(a.type()) - a; + if (pa) return ir::IntImm::make(a.dtype(), -pa->value); + if (fa) return ir::FloatImm::make(a.dtype(), -fa->value); + return make_zero(a.dtype()) - a; } Expr operator-(Expr a, Expr b) { @@ -186,8 +245,8 @@ Expr div(Expr a, Expr b) { } Expr truncdiv(Expr a, Expr b) { - CHECK(a.type().is_int() || a.type().is_uint()); - CHECK(b.type().is_int() || b.type().is_uint()); + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); return div(a, b); } @@ -216,8 +275,8 @@ Expr indexmod(Expr a, Expr b) { } Expr floordiv(Expr a, Expr b) { - CHECK(a.type().is_int() || a.type().is_uint()); - CHECK(b.type().is_int() || b.type().is_uint()); + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -225,8 +284,8 @@ Expr floordiv(Expr a, Expr b) { } Expr floormod(Expr a, Expr b) { - CHECK(a.type().is_int() || a.type().is_uint()); - CHECK(b.type().is_int() || b.type().is_uint()); + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -264,7 +323,7 @@ Expr max(Expr a, Expr b) { Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { using ir::IntImm; using ir::UIntImm; - CHECK(cond.type() == Bool(1)) + CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); if (const UIntImm* op = cond.as()) { @@ -281,7 +340,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { } } return ir::Call::make( - true_value.type(), + true_value.dtype(), ir::intrinsic::tvm_if_then_else, {cond, true_value, false_value}, ir::Call::PureIntrinsic); @@ -289,7 +348,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { Expr likely(Expr cond) { if (is_const(cond)) return cond; - return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic); + return ir::Call::make(cond.dtype(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic); } Expr operator>(Expr a, Expr b) { @@ -335,23 +394,23 @@ Expr operator!=(Expr a, Expr b) { } Expr operator&&(Expr a, Expr b) { - CHECK(a.type().is_bool()); - CHECK(b.type().is_bool()); + CHECK(a.dtype().is_bool()); + CHECK(b.dtype().is_bool()); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::And::make(a, b); } Expr operator||(Expr a, Expr b) { - CHECK(a.type().is_bool()); - CHECK(b.type().is_bool()); + CHECK(a.dtype().is_bool()); + CHECK(b.dtype().is_bool()); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::Or::make(a, b); } Expr operator!(Expr a) { - CHECK(a.type().is_bool()); + CHECK(a.dtype().is_bool()); Expr ret = arith::TryConstFold(a); if (ret.defined()) return ret; return ir::Not::make(a); @@ -360,211 +419,211 @@ Expr operator!(Expr a) { Expr operator>>(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; } }); - return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic); + return ir::Call::make(a.dtype(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic); } Expr operator<<(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; } }); - return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic); + return ir::Call::make(a.dtype(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic); } Expr operator&(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value)); }); - return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic); + return ir::Call::make(a.dtype(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic); } Expr operator|(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value)); }); - return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic); + return ir::Call::make(a.dtype(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic); } Expr operator^(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const Type& rtype = a.type(); + const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value)); }); - return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic); + return ir::Call::make(a.dtype(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic); } Expr operator~(Expr a) { - CHECK(a.type().is_int() || a.type().is_uint()); - return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic); + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + return ir::Call::make(a.dtype(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic); } Expr pow(Expr x, Expr y) { BinaryOpMatchTypes(x, y); - CHECK(x.type().is_float()) << "power only applies to float"; - return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); + CHECK(x.dtype().is_float()) << "power only applies to float"; + return ir::Call::make(x.dtype(), "pow", { x, y }, ir::Call::PureIntrinsic); } Expr abs(Expr x) { - if (x.type().is_int()) { + if (x.dtype().is_int()) { using ir::IntImm; const IntImm* px = x.as(); if (px) { - return ir::IntImm::make(x.type(), std::abs(px->value)); + return ir::IntImm::make(x.dtype(), std::abs(px->value)); } - return ir::Select::make(x >= make_zero(x.type()), x, -x); - } else if (x.type().is_float()) { + return ir::Select::make(x >= make_zero(x.dtype()), x, -x); + } else if (x.dtype().is_float()) { using ir::FloatImm; const FloatImm* fx = x.as(); if (fx) { - return ir::FloatImm::make(x.type(), std::fabs(fx->value)); + return ir::FloatImm::make(x.dtype(), std::fabs(fx->value)); } - return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic); - } else if (x.type().is_uint()) { + return ir::Call::make(x.dtype(), "fabs", {x}, ir::Call::PureIntrinsic); + } else if (x.dtype().is_uint()) { return x; } else { - LOG(FATAL) << "Data type " << x.type() + LOG(FATAL) << "Data type " << x.dtype() <<" not supported for absolute op. Skipping absolute op..."; return x; } } Expr isnan(Expr x) { - Type t = Bool(x.type().lanes()); - if (x.type().is_int() || x.type().is_uint()) { + DataType t = DataType::Bool(x.dtype().lanes()); + if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); - } else if (x.type().is_float()) { + } else if (x.dtype().is_float()) { using ir::FloatImm; const FloatImm* fx = x.as(); if (fx) { return make_const(t, std::isnan(fx->value)); } - if (x.type().bits() == 16) { + if (x.dtype().bits() == 16) { return ir::Call::make(t, ir::Call::isnan, - {cast(Float(32, t.lanes()), std::move(x))}, + {cast(DataType::Float(32, t.lanes()), std::move(x))}, ir::Call::PureIntrinsic); } else { return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic); } } else { - LOG(FATAL) << "Data type " << x.type() + LOG(FATAL) << "Data type " << x.dtype() <<" not supported for isnan op. Skipping isnan op..."; return x; } } Expr sum(Expr source, Array rdom) { - Var x("x", source.type()), y("y", source.type()); + Var x("x", source.dtype()), y("y", source.dtype()); Expr result = ir::Add::make(x, y); - Expr identity_element = make_zero(source.type()); + Expr identity_element = make_zero(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); + return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr all(Expr source, Array rdom) { - CHECK(source.type().is_bool()); - Var x("x", source.type()), y("y", source.type()); + CHECK(source.dtype().is_bool()); + Var x("x", source.dtype()), y("y", source.dtype()); Expr result = ir::And::make(x, y); - Expr identity_element = make_const(source.type(), true); + Expr identity_element = make_const(source.dtype(), true); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); + return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr any(Expr source, Array rdom) { - CHECK(source.type().is_bool()); - Var x("x", source.type()), y("y", source.type()); + CHECK(source.dtype().is_bool()); + Var x("x", source.dtype()), y("y", source.dtype()); Expr result = ir::Or::make(x, y); - Expr identity_element = make_const(source.type(), false); + Expr identity_element = make_const(source.dtype(), false); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); + return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr max(Expr source, Array rdom) { - Var x("x", source.type()), y("y", source.type()); + Var x("x", source.dtype()), y("y", source.dtype()); Expr result = ir::Max::make(x, y); - Expr identity_element = source.type().min(); + Expr identity_element = min_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); + return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr min(Expr source, Array rdom) { - Var x("x", source.type()), y("y", source.type()); + Var x("x", source.dtype()), y("y", source.dtype()); Expr result = ir::Min::make(x, y); - Expr identity_element = source.type().max(); + Expr identity_element = max_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); + return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr prod(Expr source, Array rdom) { - Var x("x", source.type()), y("y", source.type()); + Var x("x", source.dtype()), y("y", source.dtype()); Expr result = ir::Mul::make(x, y); - Expr identity_element = make_const(source.type(), 1); + Expr identity_element = make_const(source.dtype(), 1); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); + return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr fmod(Expr x, Expr y) { BinaryOpMatchTypes(x, y); - CHECK(x.type().is_float()) << "fmod only applies to float"; - return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic); + CHECK(x.dtype().is_float()) << "fmod only applies to float"; + return ir::Call::make(x.dtype(), "fmod", { x, y }, ir::Call::PureIntrinsic); } Expr floor(Expr x) { using ir::FloatImm; const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.type(), std::floor(fx->value)); - return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic); + if (fx) return FloatImm::make(x.dtype(), std::floor(fx->value)); + return ir::Call::make(x.dtype(), "floor", {x}, ir::Call::PureIntrinsic); } Expr ceil(Expr x) { using ir::FloatImm; const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.type(), std::ceil(fx->value)); - return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic); + if (fx) return FloatImm::make(x.dtype(), std::ceil(fx->value)); + return ir::Call::make(x.dtype(), "ceil", {x}, ir::Call::PureIntrinsic); } Expr round(Expr x) { using ir::FloatImm; const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value)); - return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic); + if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value)); + return ir::Call::make(x.dtype(), "round", {x}, ir::Call::PureIntrinsic); } Expr nearbyint(Expr x) { using ir::FloatImm; const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value)); - return ir::Call::make(x.type(), "nearbyint", {x}, ir::Call::PureIntrinsic); + if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value)); + return ir::Call::make(x.dtype(), "nearbyint", {x}, ir::Call::PureIntrinsic); } Expr trunc(Expr x) { using ir::FloatImm; const FloatImm* fx = x.as(); if (fx) { - return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) : + return FloatImm::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic); + return ir::Call::make(x.dtype(), "trunc", {x}, ir::Call::PureIntrinsic); } } // namespace tvm diff --git a/src/lang/ir.cc b/src/lang/ir.cc index bb8401dae843..427e026bc728 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -35,7 +35,7 @@ Expr UIntImm::make(DataType t, uint64_t value) { CHECK(t.is_uint() && t.lanes() == 1) << "ValueError: UIntImm can only take scalar"; NodePtr node = make_node(); - node->type = t; + node->dtype = t; node->value = value; return Expr(node); } @@ -44,23 +44,23 @@ Expr FloatImm::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) << "ValueError: FloatImm can only take scalar"; NodePtr node = make_node(); - node->type = t; + node->dtype = t; node->value = value; return Expr(node); } Expr StringImm::make(std::string value) { NodePtr node = make_node(); - node->type = Handle(); + node->dtype = DataType::Handle(); node->value = std::move(value); return Expr(node); } Expr Cast::make(DataType t, Expr value) { CHECK(value.defined()); - CHECK_EQ(t.lanes(), value.type().lanes()); + CHECK_EQ(t.lanes(), value.dtype().lanes()); NodePtr node = make_node(); - node->type = t; + node->dtype = t; node->value = std::move(value); return Expr(node); } @@ -68,12 +68,12 @@ Expr Cast::make(DataType t, Expr value) { Expr And::make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; - CHECK(a.type().is_bool()); - CHECK(b.type().is_bool()); - CHECK(a.type() == b.type()) << "TypeError: mismatched types"; + CHECK(a.dtype().is_bool()); + CHECK(b.dtype().is_bool()); + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; NodePtr node = make_node(); - node->type = Bool(a.type().lanes()); + node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); @@ -82,12 +82,12 @@ Expr And::make(Expr a, Expr b) { Expr Or::make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; - CHECK(a.type().is_bool()); - CHECK(b.type().is_bool()); - CHECK(a.type() == b.type()) << "TypeError: mismatched types"; + CHECK(a.dtype().is_bool()); + CHECK(b.dtype().is_bool()); + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; NodePtr node = make_node(); - node->type = Bool(a.type().lanes()); + node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); @@ -95,10 +95,10 @@ Expr Or::make(Expr a, Expr b) { Expr Not::make(Expr a) { CHECK(a.defined()) << "ValueError: a is undefined"; - CHECK(a.type().is_bool()); + CHECK(a.dtype().is_bool()); NodePtr node = make_node(); - node->type = Bool(a.type().lanes()); + node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); return Expr(node); } @@ -107,27 +107,27 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(false_value.defined()) << "ValueError: true_value is undefined"; - CHECK(condition.type().is_bool()); - CHECK_EQ(condition.type().lanes(), true_value.type().lanes()); - CHECK(false_value.type() == true_value.type()) << "TypeError: mismatched types"; + CHECK(condition.dtype().is_bool()); + CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes()); + CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; NodePtr(); - node->type = true_value.type(); + node->dtype = true_value.dtype(); node->condition = std::move(condition); node->true_value = std::move(true_value); node->false_value = std::move(false_value); return Expr(node); } -Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) { +Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) { CHECK(buffer_var.defined()); CHECK(predicate.defined()); CHECK(index.defined()); - CHECK_EQ(type.lanes(), index.type().lanes()); - CHECK_EQ(type.lanes(), predicate.type().lanes()); + CHECK_EQ(dtype.lanes(), index.dtype().lanes()); + CHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); NodePtr node = make_node(); - node->type = type; + node->dtype = dtype; node->buffer_var = std::move(buffer_var); node->index = std::move(index); node->predicate = std::move(predicate); @@ -138,13 +138,13 @@ Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) { Expr Ramp::make(Expr base, Expr stride, int lanes) { CHECK(base.defined()); CHECK(stride.defined()); - CHECK(base.type().is_scalar()); - CHECK(stride.type().is_scalar()); + CHECK(base.dtype().is_scalar()); + CHECK(stride.dtype().is_scalar()); CHECK_GT(lanes, 1); - CHECK_EQ(stride.type(), base.type()); + CHECK_EQ(stride.dtype(), base.dtype()); NodePtr node = make_node(); - node->type = base.type().with_lanes(lanes); + node->dtype = base.dtype().with_lanes(lanes); node->base = base; node->stride = stride; node->lanes = lanes; @@ -153,11 +153,11 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) { Expr Broadcast::make(Expr value, int lanes) { CHECK(value.defined()); - CHECK(value.type().is_scalar()); + CHECK(value.dtype().is_scalar()); CHECK_GT(lanes, 1); NodePtr node = make_node(); - node->type = value.type().with_lanes(lanes); + node->dtype = value.dtype().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; return Expr(node); @@ -166,10 +166,10 @@ Expr Broadcast::make(Expr value, int lanes) { Expr Let::make(Var var, Expr value, Expr body) { CHECK(value.defined()); CHECK(body.defined()); - CHECK_EQ(value.type(), var.type()); + CHECK_EQ(value.dtype(), var.dtype()); NodePtr node = make_node(); - node->type = body.type(); + node->dtype = body.dtype(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); @@ -192,7 +192,7 @@ bool Call::is_vectorizable() const { return false; } -Expr Call::make(DataType type, +Expr Call::make(DataType dtype, std::string name, Array args, CallType call_type, @@ -204,12 +204,12 @@ Expr Call::make(DataType type, if (call_type == Halide) { for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].type().is_int()); + CHECK(args[i].dtype().is_int()); } } NodePtr node = make_node(); - node->type = type; + node->dtype = dtype; node->name = std::move(name); node->args = std::move(args); node->call_type = call_type; @@ -223,17 +223,17 @@ Expr Shuffle::make(Array vectors, CHECK_NE(vectors.size(), 0U); CHECK_NE(indices.size(), 0U); - Type base_type = vectors[0].type().element_of(); + DataType base_type = vectors[0].dtype().element_of(); int total_lanes = 0; for (Expr val : vectors) { - CHECK(val.type().element_of() == base_type); - total_lanes += val.type().lanes(); + CHECK(val.dtype().element_of() == base_type); + total_lanes += val.dtype().lanes(); } CHECK_LE(indices.size(), static_cast(total_lanes)); NodePtr node = make_node(); - node->type = base_type.with_lanes(static_cast(indices.size())); + node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); return Expr(node); @@ -247,8 +247,8 @@ Expr Shuffle::make_concat(Array vectors) { Array indices; int index = 0; for (const Expr& e : vectors) { - for (int i = 0; i < e.type().lanes(); ++i) { - indices.push_back(IntImm::make(Int(32), index++)); + for (int i = 0; i < e.dtype().lanes(); ++i) { + indices.push_back(IntImm::make(DataType::Int(32), index++)); } } return make(vectors, indices); @@ -298,7 +298,7 @@ Expr Reduce::make(CommReducer combiner, Array source, for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); } - n->type = source[value_index].type(); + n->dtype = source[value_index].dtype(); n->combiner = std::move(combiner); n->source = std::move(source); n->axis = std::move(axis); @@ -315,7 +315,7 @@ Expr Any::make() { Stmt LetStmt::make(Var var, Expr value, Stmt body) { CHECK(value.defined()); CHECK(body.defined()); - CHECK_EQ(value.type(), var.type()); + CHECK_EQ(value.dtype(), var.dtype()); NodePtr node = make_node(); node->var = std::move(var); @@ -338,7 +338,7 @@ Stmt AttrStmt::make(NodeRef node, Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { CHECK(condition.defined()); - CHECK(message.type() == Int(32) || + CHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; @@ -368,9 +368,9 @@ Stmt For::make(Var loop_var, Stmt body) { CHECK(min.defined()); CHECK(extent.defined()); - CHECK(min.type().is_scalar()); - CHECK(extent.type().is_scalar()); - CHECK(loop_var.type().is_scalar()); + CHECK(min.dtype().is_scalar()); + CHECK(extent.dtype().is_scalar()); + CHECK(loop_var.dtype().is_scalar()); CHECK(body.defined()); NodePtr node = make_node(); @@ -387,8 +387,8 @@ Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) { CHECK(value.defined()); CHECK(index.defined()); CHECK(predicate.defined()); - CHECK_EQ(value.type().lanes(), index.type().lanes()); - CHECK_EQ(value.type().lanes(), predicate.type().lanes()); + CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); + CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); NodePtr node = make_node(); node->buffer_var = std::move(buffer_var); @@ -416,7 +416,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array ar } Stmt Allocate::make(Var buffer_var, - DataType type, + DataType dtype, Array extents, Expr condition, Stmt body, @@ -424,15 +424,15 @@ Stmt Allocate::make(Var buffer_var, std::string free_function) { for (size_t i = 0; i < extents.size(); ++i) { CHECK(extents[i].defined()); - CHECK(extents[i].type().is_scalar()); + CHECK(extents[i].dtype().is_scalar()); } CHECK(body.defined()); CHECK(condition.defined()); - CHECK(condition.type().is_bool()); + CHECK(condition.dtype().is_bool()); NodePtr node = make_node(); node->buffer_var = std::move(buffer_var); - node->type = type; + node->dtype = dtype; node->extents = std::move(extents); node->condition = std::move(condition); node->body = std::move(body); @@ -464,42 +464,42 @@ Stmt Free::make(Var buffer_var) { Stmt Realize::make(FunctionRef func, int value_index, - DataType type, + DataType dtype, Region bounds, Expr condition, Stmt body) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.type().is_scalar()); - CHECK(bounds[i]->extent.type().is_scalar()); + CHECK(bounds[i]->min.dtype().is_scalar()); + CHECK(bounds[i]->extent.dtype().is_scalar()); } CHECK(body.defined()); CHECK(condition.defined()); - CHECK(condition.type().is_bool()); + CHECK(condition.dtype().is_bool()); NodePtr node = make_node(); node->func = std::move(func); node->value_index = value_index; - node->type = type; + node->dtype = dtype; node->bounds = std::move(bounds); node->condition = std::move(condition); node->body = std::move(body); return Stmt(node); } -Stmt Prefetch::make(FunctionRef func, int value_index, DataType type, Region bounds) { +Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.type().is_scalar()); - CHECK(bounds[i]->extent.type().is_scalar()); + CHECK(bounds[i]->min.dtype().is_scalar()); + CHECK(bounds[i]->extent.dtype().is_scalar()); } NodePtr node = make_node(); node->func = std::move(func); node->value_index = value_index; - node->type = type; + node->dtype = dtype; node->bounds = std::move(bounds); return Stmt(node); } @@ -555,14 +555,14 @@ Stmt Evaluate::make(Expr value) { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "(" << op->type << ")" << op->value; + p->stream << "(" << op->dtype << ")" << op->value; }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter* p) { auto* op = static_cast(node.get()); auto& stream = p->stream; - switch (op->type.bits()) { + switch (op->dtype.bits()) { case 64: stream << op->value; break; @@ -573,7 +573,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) stream << op->value << 'h'; break; default: - LOG(FATAL) << "Unknown float type bits=" << op->type.bits(); + LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); } }); @@ -616,7 +616,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter* p) { auto* op = static_cast(node.get()); - p->stream << op->type << '('; + p->stream << op->dtype << '('; p->Print(op->value); p->stream << ')'; }) @@ -959,7 +959,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << "allocate " << op->buffer_var << "[" << op->type; + p->stream << "allocate " << op->buffer_var << "[" << op->dtype; for (size_t i = 0; i < op->extents.size(); ++i) { p->stream << " * "; p->Print(op->extents[i]); diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 05ba6f7a08bd..1c110936b3ef 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -56,7 +56,7 @@ Tensor Operation::output(size_t i) const { } Tensor TensorNode::make(Array shape, - Type dtype, + DataType dtype, Operation op, int value_index) { auto n = make_node(); diff --git a/src/node/reflection.cc b/src/node/reflection.cc index e92ca92834a2..f53583723f24 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -61,7 +61,7 @@ class AttrGetter : public AttrVisitor { void Visit(const char* key, void** value) final { if (skey == key) *ret = static_cast(value[0]); } - void Visit(const char* key, Type* value) final { + void Visit(const char* key, DataType* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, std::string* value) final { @@ -135,7 +135,7 @@ class AttrDir : public AttrVisitor { void Visit(const char* key, void** value) final { names->push_back(key); } - void Visit(const char* key, Type* value) final { + void Visit(const char* key, DataType* value) final { names->push_back(key); } void Visit(const char* key, std::string* value) final { diff --git a/src/node/serialization.cc b/src/node/serialization.cc index cb310eb2cda9..5a991aa3ad1b 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -39,11 +39,11 @@ namespace tvm { inline std::string Type2String(const DataType& t) { - return runtime::TVMType2String(Type2TVMType(t)); + return runtime::TVMType2String(t); } -inline Type String2Type(std::string s) { - return TVMType2Type(runtime::String2TVMType(s)); +inline DataType String2Type(std::string s) { + return DataType(runtime::String2TVMType(s)); } // indexer to index all the nodes diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 5f5d2d4f475b..bd129ac33058 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -70,9 +70,9 @@ Array BaseComputeOpNode::root_iter_vars() const { return ret; } -Type ComputeOpNode::output_dtype(size_t idx) const { +DataType ComputeOpNode::output_dtype(size_t idx) const { CHECK_LT(idx, num_outputs()); - return body[idx].type(); + return body[idx].dtype(); } Array BaseComputeOpNode::output_shape(size_t idx) const { @@ -100,7 +100,7 @@ Tensor compute(Array shape, std::ostringstream os; os << "ax" << i; axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar)); + Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -122,7 +122,7 @@ Array compute(Array shape, std::ostringstream os; os << "ax" << i; axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar)); + Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -190,7 +190,7 @@ Operation ComputeOpNode::ReplaceInputs( for (size_t k = 0; k < this->body.size(); ++k) { auto n = make_node(*r); n->value_index = static_cast(k); - n->type = r->source[k].type(); + n->dtype = r->source[k].dtype(); arr.push_back(Expr(n)); } } else { @@ -229,7 +229,7 @@ void ComputeOpNode::PropBoundToInputs( IntSet arg_intset = EvalSet(call->args[i], dom_map); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { - Expr shape_i_min_value = make_zero(t->shape[i].type()); + Expr shape_i_min_value = make_zero(t->shape[i].dtype()); Expr shape_i_max_value = t->shape[i] - 1; Expr min_value = arg_interval->min_value; Expr max_value = arg_interval->max_value; @@ -295,7 +295,7 @@ Stmt BaseComputeOpNode::BuildRealize( attr->dim_align_offset}; realize = ir::AttrStmt::make( t, ir::attr::buffer_dim_align, - Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), + Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), realize); } } diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc index 818acb912f9c..4a3aa54ccc6d 100644 --- a/src/op/cross_thread_reduction.cc +++ b/src/op/cross_thread_reduction.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -57,14 +57,14 @@ Stmt MakeCrossThreadReduction( cond = cond && v; } Array freduce_args; - freduce_args.push_back(make_const(UInt(32), static_cast(size))); + freduce_args.push_back(make_const(DataType::UInt(32), static_cast(size))); for (size_t i = 0; i < size; ++i) { freduce_args.push_back(reduces[0]->source[i]); } freduce_args.push_back(cond); std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle()); + res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle()); freduce_args.push_back(res_handles[idx]); } @@ -85,17 +85,17 @@ Stmt MakeCrossThreadReduction( } Stmt reduce_body = Evaluate::make(Call::make( - Handle(), + DataType::Handle(), ir::intrinsic::tvm_thread_allreduce, freduce_args, Call::Intrinsic)); reduce_body = AttrStmt::make( reduces[0]->combiner, attr::reduce_scope, - make_zero(Handle()), + make_zero(DataType::Handle()), reduce_body); std::vector assigns(size); for (size_t idx = 0; idx < size; ++idx) { - Type t = reduces[idx]->type; + DataType t = reduces[idx]->dtype; assigns[idx] = Provide::make( stage->op, idx, Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args); @@ -106,7 +106,7 @@ Stmt MakeCrossThreadReduction( Stmt body = Block::make(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate::make( - res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body); + res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); body = AttrStmt::make( res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body); } diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index 35fe469fbe16..883ebdc4a0f7 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -46,7 +46,7 @@ Array ExternOpNode::root_iter_vars() const { return {}; } -Type ExternOpNode::output_dtype(size_t i) const { +DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; } @@ -122,7 +122,7 @@ void ExternOpNode::PropBoundToInputs( for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( Range::make_by_min_extent( - make_const(t->shape[i].type(), 0), t->shape[i]))); + make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } @@ -145,7 +145,7 @@ Stmt ExternOpNode::BuildRealize( for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back( Range::make_by_min_extent( - make_const(t->shape[i].type(), 0), t->shape[i])); + make_const(t->shape[i].dtype(), 0), t->shape[i])); } realize_body = ir::Realize::make( t->op, t->value_index, t->dtype, @@ -159,19 +159,19 @@ Stmt ExternOpNode::BuildProvide( const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); + Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; Array tuple; bind_spec.push_back(buffer); bind_spec.push_back(tensor); for (size_t k = 0; k < buffer->shape.size(); ++k) { - tuple.push_back(make_const(buffer->shape[k].type(), 0)); + tuple.push_back(make_const(buffer->shape[k].dtype(), 0)); tuple.push_back(buffer->shape[k]); } ret = AttrStmt::make( bind_spec, attr::buffer_bind_scope, - Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret); + Call::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 7a99ea10b74d..1e1a81423b69 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -52,7 +52,7 @@ Array HybridOpNode::root_iter_vars() const { return this->axis; } -Type HybridOpNode::output_dtype(size_t i) const { +DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } @@ -138,7 +138,7 @@ void HybridOpNode::PropBoundToInputs( for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( Range::make_by_min_extent( - make_const(t->shape[i].type(), 0), t->shape[i]))); + make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } @@ -166,7 +166,7 @@ Stmt HybridOpNode::BuildRealize( for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back( Range::make_by_min_extent( - make_const(t->shape[i].type(), 0), t->shape[i])); + make_const(t->shape[i].dtype(), 0), t->shape[i])); } realize_body = ir::Realize::make( t->op, t->value_index, t->dtype, @@ -180,7 +180,7 @@ Stmt HybridOpNode::BuildProvide( const std::unordered_map &dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); + Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 691603157b1c..cd3b168d810b 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -74,7 +74,7 @@ MakeLoopNest(const Stage& stage, if (bind_iv->thread_tag.length() == 0) { // Only generate new loop if we're not bound to a thread. if (new_loop_var) { - var = Var(iv->var->name_hint + ".init", bind_iv->var.type()); + var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype()); } ForType for_type = ForType::Serial; @@ -98,7 +98,7 @@ MakeLoopNest(const Stage& stage, const std::string& pkey = it_attr->pragma_keys[k].as()->value; Expr pvalue = it_attr->pragma_values[k]; if (!pvalue.defined()) { - pvalue = make_const(Int(32), 1); + pvalue = make_const(DataType::Int(32), 1); } nest[i + 1].emplace_back( AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); @@ -114,7 +114,7 @@ MakeLoopNest(const Stage& stage, for_type, DeviceAPI::None, no_op)); value_map[iv] = var; } else { - Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type()); + Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype()); nest[i + 1].emplace_back( For::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); @@ -197,7 +197,7 @@ class TensorReplacer : public ir::IRMutator { auto it = vmap_.find(t); if (it != vmap_.end()) { Expr ret = ir::Call::make( - op->type, it->second->op->name, op->args, + op->dtype, it->second->op->name, op->args, op->call_type, it->second->op, it->second->value_index); found = true; return IRMutator::Mutate_(ret.as(), ret); diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index 91b0589e3dd0..6910f63b44d3 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -42,7 +42,7 @@ Array PlaceholderOpNode::root_iter_vars() const { return {}; } -Type PlaceholderOpNode::output_dtype(size_t i) const { +DataType PlaceholderOpNode::output_dtype(size_t i) const { CHECK_EQ(i, 0U); return dtype; } @@ -54,7 +54,7 @@ Array PlaceholderOpNode::output_shape(size_t i) const { Operation PlaceholderOpNode::make(std::string name, Array shape, - Type dtype) { + DataType dtype) { auto n = make_node(); n->name = name; n->shape = shape; @@ -62,7 +62,7 @@ Operation PlaceholderOpNode::make(std::string name, return Operation(n); } -Tensor placeholder(Array shape, Type dtype, std::string name) { +Tensor placeholder(Array shape, DataType dtype, std::string name) { return PlaceholderOpNode::make(name, shape, dtype).output(0); } diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index b02073b5357e..e83a23194cf8 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -53,7 +53,7 @@ Array ScanOpNode::root_iter_vars() const { return ret; } -Type ScanOpNode::output_dtype(size_t i) const { +DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index 83cdd76c2b2a..e59f90f4948e 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -46,7 +46,7 @@ int TensorComputeOpNode::num_outputs() const { return static_cast(this->intrin->buffers.size() - this->inputs.size()); } -Type TensorComputeOpNode::output_dtype(size_t i) const { +DataType TensorComputeOpNode::output_dtype(size_t i) const { return this->intrin->buffers[this->inputs.size() + i]->dtype; } @@ -155,7 +155,7 @@ Stmt TensorComputeOpNode::BuildProvide( } input_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, - Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); } // output binding @@ -179,7 +179,7 @@ Stmt TensorComputeOpNode::BuildProvide( output_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, - Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); } // Check variable remap diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index c4abf0b04141..b7f32de8b5ad 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -173,7 +173,7 @@ class TensorIntrinMatcher final : public IRMutator { args.push_back(op->args[i] - e.region[i]->min); } return Call::make( - op->type, e.tensor->op->name, args, + op->dtype, e.tensor->op->name, args, op->call_type, e.tensor->op, e.tensor->value_index); } } @@ -341,12 +341,12 @@ void VerifyTensorizeBody( lhs = CanonicalSimplify(lhs, compute_intrin_iter_space); Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); rhs = CanonicalSimplify(rhs, compute_intrin_iter_space); - if (lhs.type() != rhs.type()) { + if (lhs.dtype() != rhs.dtype()) { LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name << "'s declaration " - << " provided=" << lhs.type() - << ", intrin=" << rhs.type(); + << " provided=" << lhs.dtype() + << ", intrin=" << rhs.dtype(); } CHECK(Equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " @@ -390,7 +390,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, } input_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, - Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -410,7 +410,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, - Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; @@ -430,7 +430,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, IterVar target = intrin_compute->reduce_axis[i - start]; auto it = out_dom.find(iv); CHECK(it != out_dom.end()); - binder.Bind(target->dom->min, make_const(iv->dom->min.type(), 0), + binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0), "tensir_intrin.reduction.min"); binder.Bind(target->dom->extent, it->second->extent, "tensir_intrin.reduction.extent"); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index f892b6b957f8..e4ff9cb457a5 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -50,7 +50,7 @@ bool ArgBinder::Bind_(const Expr& arg, const Expr& value, const std::string& arg_name, bool with_lets) { - CHECK_EQ(arg.type(), value.type()); + CHECK_EQ(arg.dtype(), value.dtype()); if (const Variable* v = arg.as()) { auto it = def_map_->find(v); if (it == def_map_->end()) { @@ -118,8 +118,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) { if (arg->offset_factor > 1) { Expr offset = value->elem_offset; - Expr factor = make_const(offset.type(), arg->offset_factor); - Expr zero = make_zero(offset.type()); + Expr factor = make_const(offset.dtype(), arg->offset_factor); + Expr zero = make_zero(offset.dtype()); BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); } @@ -153,7 +153,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, } } -inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { +inline Expr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) { return TVMStructGet(t, arr, 0, kind); } @@ -162,8 +162,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const Expr& device_id, const Var& handle, const std::string& arg_name) { - const Type tvm_shape_type = TVMShapeIndexType(); - const Type tvm_ndim_type = Int(32); + const DataType tvm_shape_type = DataType::ShapeIndex(); + const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate::make(0); // dimension checks Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); @@ -175,52 +175,52 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, << buffer->shape.size(); asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str(), nop)); // type checks - Type dtype = buffer->dtype; + DataType dtype = buffer->dtype; std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; - Expr cond = (TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeCode) == - UIntImm::make(UInt(8), dtype.code()) && - TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeBits) == - UIntImm::make(UInt(8), dtype.bits()) && - TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) == - UIntImm::make(UInt(16), dtype.lanes())); + Expr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == + UIntImm::make(DataType::UInt(8), dtype.code()) && + TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == + UIntImm::make(DataType::UInt(8), dtype.bits()) && + TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == + UIntImm::make(DataType::UInt(16), dtype.lanes())); asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str(), nop)); // data field - if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData), + if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), arg_name + ".data", true)) { Var vptr(buffer->data); def_handle_dtype_.Set(vptr, ir::TypeAnnotation(buffer->dtype)); // mark alignment of external bufs init_nest_.emplace_back(AttrStmt::make( vptr, ir::attr::storage_alignment, - IntImm::make(Int(32), buffer->data_alignment), nop)); + IntImm::make(DataType::Int(32), buffer->data_alignment), nop)); } - Var v_shape(arg_name + ".shape", Handle()); + Var v_shape(arg_name + ".shape", DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back(LetStmt::make( - v_shape, TVMArrayGet(Handle(), handle, intrinsic::kArrShape), nop)); + v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { std::ostringstream field_name; field_name << v_shape->name_hint << '[' << k << ']'; Bind_(buffer->shape[k], - cast(buffer->shape[k].type(), + cast(buffer->shape[k].dtype(), Load::make(tvm_shape_type, v_shape, - IntImm::make(Int(32), k), const_true(1))), + IntImm::make(DataType::Int(32), k), const_true(1))), field_name.str(), true); } // strides field - Var v_strides(arg_name + ".strides", Handle()); + Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, ir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmt::make( - v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides), + v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); Expr is_null = Call::make( - Bool(1), intrinsic::tvm_handle_is_null, + DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, Call::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact - Type stype = buffer->DefaultIndexType(); + DataType stype = buffer->DefaultIndexType(); Expr expect_stride = make_const(stype, 1); Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { @@ -228,7 +228,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Expr svalue = cast( stype, Load::make(tvm_shape_type, v_strides, - IntImm::make(Int(32), k), const_true(1))); + IntImm::make(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -243,15 +243,15 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, asserts_.emplace_back(Block::make(check, Evaluate::make(0))); } } else if (buffer->buffer_type == kAutoBroadcast) { - Type stype = buffer->DefaultIndexType(); + DataType stype = buffer->DefaultIndexType(); Expr stride = make_const(stype, 1); for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; - Expr value = cast(buffer->shape[k].type(), + Expr value = cast(buffer->shape[k].dtype(), Load::make(tvm_shape_type, v_strides, - IntImm::make(Int(32), k), const_true(1))); + IntImm::make(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -266,9 +266,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; Bind_(buffer->strides[k], - cast(buffer->shape[k].type(), + cast(buffer->shape[k].dtype(), Load::make(tvm_shape_type, v_strides, - IntImm::make(Int(32), k), const_true(1))), + IntImm::make(DataType::Int(32), k), const_true(1))), field_name.str(), true); } } @@ -276,29 +276,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, int data_bytes = GetVectorBytes(buffer->dtype); int64_t const_offset; if (arith::GetConst(buffer->elem_offset, &const_offset)) { - Bind_(make_const(UInt(64), const_offset * data_bytes), - TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset), + Bind_(make_const(DataType::UInt(64), const_offset * data_bytes), + TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), arg_name + ".byte_offset", true); } else { if (Bind_(buffer->elem_offset, - cast(buffer->elem_offset.type(), - (TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) / - make_const(UInt(64), data_bytes))), + cast(buffer->elem_offset.dtype(), + (TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset) / + make_const(DataType::UInt(64), data_bytes))), arg_name + ".elem_offset", true)) { if (buffer->offset_factor > 1) { Expr offset = buffer->elem_offset; - Expr factor = make_const(offset.type(), buffer->offset_factor); - Expr zero = make_zero(offset.type()); + Expr factor = make_const(offset.dtype(), buffer->offset_factor); + Expr zero = make_zero(offset.dtype()); BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); } } } // device info. Bind_(device_type, - TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceType), + TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), arg_name + ".device_type", true); Bind_(device_id, - TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceId), + TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), arg_name + ".device_id", true); } diff --git a/src/pass/bound_checker.cc b/src/pass/bound_checker.cc index 55f98474994a..648302e9740a 100644 --- a/src/pass/bound_checker.cc +++ b/src/pass/bound_checker.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -58,7 +58,7 @@ class BoundChecker : public IRMutator { Stmt Mutate_(const Allocate *op, const Stmt &s) final { // If the shape was updated we should update the hashtable. if (UpdateIsNeeded(op->buffer_var)) { - Update(op->buffer_var, op->extents, op->type); + Update(op->buffer_var, op->extents, op->dtype); } return IRMutator::Mutate_(op, s); } @@ -108,26 +108,26 @@ class BoundChecker : public IRMutator { } void Update(const VarExpr &buffer_var, const Array &new_shape, - const Type &type) { + const DataType &type) { // Sanity check at first. if (!new_shape.size()) { return; } for (size_t i = 0; i < new_shape.size(); ++i) { - if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() || + if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() || is_negative_const(new_shape[i])) { return; } } // Scalarize the shape. - Expr shape = Mul::make(make_const(UInt(64), type.lanes()), - Cast::make(UInt(64), new_shape[0])); + Expr shape = Mul::make(make_const(DataType::UInt(64), type.lanes()), + Cast::make(DataType::UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. - shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()), - Cast::make(UInt(64), new_shape[i]))); + shape = Mul::make(shape, Mul::make(make_const(DataType::UInt(64), type.lanes()), + Cast::make(DataType::UInt(64), new_shape[i]))); } mem_to_shape_[buffer_var.get()] = shape; } @@ -139,9 +139,9 @@ class BoundChecker : public IRMutator { if (const Ramp *ramp_index = index.as()) { return ramp_index->base.defined() && - ramp_index->base.type().is_scalar() && + ramp_index->base.dtype().is_scalar() && ramp_index->stride.defined() && - ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0); + ramp_index->stride.dtype().is_scalar() && (ramp_index->lanes > 0); } return true; } @@ -168,7 +168,7 @@ class BoundChecker : public IRMutator { // Non inclusive range. index = Add::make( ramp_index->base, - Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(), + Mul::make(ramp_index->stride, make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); } @@ -177,11 +177,11 @@ class BoundChecker : public IRMutator { upper_bound = ir::Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. - index = Cast::make(Int(64), index); - upper_bound = Cast::make(Int(64), upper_bound); + index = Cast::make(DataType::Int(64), index); + upper_bound = Cast::make(DataType::Int(64), upper_bound); // Looks like a lower bound should always be zero after normalization. - Expr lower_bound = make_zero(Int(64)); + Expr lower_bound = make_zero(DataType::Int(64)); Expr current_condition = And::make(GE::make(index, lower_bound), LT::make(index, upper_bound)); diff --git a/src/pass/combine_context_call.cc b/src/pass/combine_context_call.cc index d7fb77961b4b..f1cb8fe10a4b 100644 --- a/src/pass/combine_context_call.cc +++ b/src/pass/combine_context_call.cc @@ -48,14 +48,14 @@ class ContextCallCombiner final : public IRMutator { if (it != ctx_map_.end()) { return it->second; } else { - CHECK(ctx.type().is_handle()); + CHECK(ctx.dtype().is_handle()); std::string name; if (const Call* call = ctx.as()) { name = call->name + "_cache"; } else { name = "ctx_cache_"; } - Var ctx_var(name, ctx.type()); + Var ctx_var(name, ctx.dtype()); ctx_map_[ctx] = ctx_var; return std::move(ctx_var); } diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index 3dacb6d5bff7..4aa8879f679b 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -198,7 +198,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { std::vector GetSync(std::string sync_name) { return {Evaluate::make(Call::make( - Int(32), + DataType::Int(32), sync_name, {}, Call::Intrinsic))}; } @@ -345,7 +345,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { Expr min = r->min; Expr extent = r->extent; return Evaluate::make(Call::make( - Int(32), func, + DataType::Int(32), func, {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, Call::Intrinsic)); } // Write barrier name @@ -588,14 +588,14 @@ class CoProcInstDepDetector : public IRVisitor { Stmt MakePush(int from, int to) { return Evaluate::make(Call::make( - Int(32), sync_push_name_, - {make_const(Int(32), from), make_const(Int(32), to)}, + DataType::Int(32), sync_push_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, Call::Intrinsic)); } Stmt MakePop(int from, int to) { return Evaluate::make(Call::make( - Int(32), sync_pop_name_, - {make_const(Int(32), from), make_const(Int(32), to)}, + DataType::Int(32), sync_pop_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, Call::Intrinsic)); } // sync states. diff --git a/src/pass/detect_device.cc b/src/pass/detect_device.cc index 92e368b62d20..cd7c979171a6 100644 --- a/src/pass/detect_device.cc +++ b/src/pass/detect_device.cc @@ -28,7 +28,7 @@ namespace tvm { namespace ir { Stmt DecorateDeviceScope(Stmt stmt) { - Stmt body = AttrStmt::make(make_zero(Int(32)), + Stmt body = AttrStmt::make(make_zero(DataType::Int(32)), ir::attr::device_scope, 0, stmt); diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 3b148361fbfc..7b7c5df48236 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -88,7 +88,7 @@ class CopyIntrinInjector : public IRMutator { load = cast->value.as(); } if (load == nullptr) return false; - if (load->type.lanes() != 1) return false; + if (load->dtype.lanes() != 1) return false; Array loop_vars; for (const For* op : loops) { loop_vars.push_back(op->loop_var); @@ -101,7 +101,7 @@ class CopyIntrinInjector : public IRMutator { Array dst_shape; const size_t loop_var_size = loop_vars.size(); if (loop_var_size == 0) { - dst_shape.push_back(make_const(Int(32), 1)); + dst_shape.push_back(make_const(DataType::Int(32), 1)); } else { for (const For* op : loops) { dst_shape.push_back(op->extent); @@ -121,7 +121,7 @@ class CopyIntrinInjector : public IRMutator { for (size_t i = 0; i < src_shape.size(); ++i) { Expr min_value = clip_bound[2 * i]; Expr max_value = clip_bound[2 * i + 1]; - Type t = loop_vars[i].type(); + DataType t = loop_vars[i].dtype(); Expr svalue = src_shape[i]; if (min_value.defined()) { Expr pbefore = Simplify(Max::make(min_value, make_zero(t))); @@ -148,12 +148,12 @@ class CopyIntrinInjector : public IRMutator { Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); if (loop_var_size == 0) { - src_strides.push_back(make_const(Int(32), 1)); - dst_strides.push_back(make_const(Int(32), 1)); + src_strides.push_back(make_const(DataType::Int(32), 1)); + dst_strides.push_back(make_const(DataType::Int(32), 1)); } Buffer dst = BufferNode::make( store->buffer_var, - store->value.type(), + store->value.dtype(), dst_shape, dst_strides, store_strides[loop_var_size], @@ -162,7 +162,7 @@ class CopyIntrinInjector : public IRMutator { 0, 0, kDefault); Buffer src = BufferNode::make( load->buffer_var, - load->type, + load->dtype, src_shape, src_strides, src_elem_offset, diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 065bbd4e4db3..78d3305d3e17 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -100,10 +100,10 @@ class DoubleBufferInjector : public IRMutator { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { it->second.stride = arith::ComputeReduce( - op->extents, Expr()) * op->type.lanes(); + op->extents, Expr()) * op->dtype.lanes(); Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as(); - Array new_extents{make_const(op->extents[0].type(), 2)}; + Array new_extents{make_const(op->extents[0].dtype(), 2)}; for (Expr e : op->extents) { new_extents.push_back(e); } @@ -114,7 +114,7 @@ class DoubleBufferInjector : public IRMutator { StringImm::make(it->second.scope), Evaluate::make(0))); alloc_nest.emplace_back(Allocate::make( - op->buffer_var, op->type, new_extents, op->condition, + op->buffer_var, op->dtype, new_extents, op->condition, Evaluate::make(0))); return op->body; } else { @@ -135,15 +135,15 @@ class DoubleBufferInjector : public IRMutator { CHECK(is_zero(old_loop->min)); Expr zero = old_loop->min; Expr new_ext = - old_loop->extent - make_const(old_loop->loop_var.type(), 1); - Expr factor = make_const(new_ext.type(), split_loop_); + old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); + Expr factor = make_const(new_ext.dtype(), split_loop_); Expr outer_ext = new_ext / factor; Expr tail_base = outer_ext * factor; - Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type()); + Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype()); std::unordered_map vmap; std::vector loop_seq; for (int32_t i = 0; i < split_loop_; ++i) { - vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i); + vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } Stmt loop = For::make( @@ -153,7 +153,7 @@ class DoubleBufferInjector : public IRMutator { std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body); for (int32_t i = 0; i < split_loop_; ++i) { - Expr idx = tail_base + make_const(tail_base.type(), i); + Expr idx = tail_base + make_const(tail_base.dtype(), i); vmap[old_loop->loop_var.get()] = idx; tail_seq.emplace_back( IfThenElse::make(idx < old_loop->extent, @@ -196,7 +196,7 @@ class DoubleBufferInjector : public IRMutator { const StorageEntry& e = it->second; CHECK(e.stride.defined()); CHECK(e.switch_read_var.defined()); - return Load::make(op->type, + return Load::make(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, op->predicate); @@ -222,12 +222,12 @@ class DoubleBufferInjector : public IRMutator { } StorageEntry& e = it->second; e.loop = loop_nest_.back(); - Expr zero = make_const(e.loop->loop_var.type(), 0); - Expr one = make_const(e.loop->loop_var.type(), 1); - Expr two = make_const(e.loop->loop_var.type(), 2); + Expr zero = make_const(e.loop->loop_var.dtype(), 0); + Expr one = make_const(e.loop->loop_var.dtype(), 1); + Expr two = make_const(e.loop->loop_var.dtype(), 2); Expr loop_shift = e.loop->loop_var + one; e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", - e.loop->loop_var.type()); + e.loop->loop_var.dtype()); e.switch_read_var = indexmod(e.loop->loop_var, two); in_double_buffer_scope_ = true; Stmt body = Mutate(op->body); diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index eafe5a928cd7..c80c7fcdaa8c 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -222,7 +222,7 @@ class VTInjector : public IRMutator { } auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return Load::make(op->type, op->buffer_var, + return Load::make(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), op->predicate); } else { @@ -233,7 +233,7 @@ class VTInjector : public IRMutator { Expr Mutate_(const Call* op, const Expr& e) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); - Type dtype = op->args[0].type(); + DataType dtype = op->args[0].dtype(); const Variable* buffer = op->args[1].as(); auto it = alloc_remap_.find(buffer); if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e); @@ -241,10 +241,10 @@ class VTInjector : public IRMutator { Expr offset = Mutate(op->args[2]); Expr extent = Mutate(op->args[3]); Expr stride = - it->second / make_const(offset.type(), dtype.lanes()); + it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; return Call::make( - op->type, op->name, + op->dtype, op->name, {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { @@ -395,9 +395,9 @@ class VTInjector : public IRMutator { if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. Expr stride = arith::ComputeReduce( - op->extents, Expr()) * op->type.lanes(); + op->extents, Expr()) * op->dtype.lanes(); Array other; - other.push_back(make_const(op->extents[0].type(), num_threads_)); + other.push_back(make_const(op->extents[0].dtype(), num_threads_)); for (Expr e : extents) { other.push_back(e); } @@ -417,7 +417,7 @@ class VTInjector : public IRMutator { return s; } else { return Allocate::make( - op->buffer_var, op->type, + op->buffer_var, op->dtype, extents, condition, body, op->new_expr, op->free_function); } @@ -439,19 +439,19 @@ class VTInjector : public IRMutator { // only unroll if number of vthreads are small if (max_loop_depth_ == 0 && num_threads_ < 16) { // do unrolling if it is inside innermost content. - Stmt blk = Substitute(stmt, {{var_, make_zero(var_.type())}}); + Stmt blk = Substitute(stmt, {{var_, make_zero(var_.dtype())}}); for (int i = 1; i < num_threads_; ++i) { blk = Block::make( - blk, Substitute(stmt, {{var_, make_const(var_.type(), i)}})); + blk, Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); } return blk; } else { // insert a for loop - Var idx(var_->name_hint + ".s", var_->type); + Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return For::make(idx, make_zero(idx.type()), - make_const(idx.type(), num_threads_), + return For::make(idx, make_zero(idx.dtype()), + make_const(idx.dtype(), num_threads_), ForType::Serial, DeviceAPI::None, stmt); } } diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index cb859d07f07b..e399e7f2c54f 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -63,7 +63,7 @@ class IRDeepCompare : if (order_ != 0) return; if (n.same_as(other)) return; if (CompareValue(n->type_index(), other->type_index()) != 0) return; - if (CompareType(n.type(), other.type()) != 0) return; + if (CompareType(n.dtype(), other.dtype()) != 0) return; ExprComparator::VisitExpr(n, other); } @@ -119,7 +119,7 @@ class IRDeepCompare : } else { if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; } - if (CompareType(op->type, rhs->type) != 0) return; + if (CompareType(op->dtype, rhs->dtype) != 0) return; if (CompareArray(op->extents, rhs->extents) != 0) return; if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return; @@ -166,7 +166,7 @@ class IRDeepCompare : const Realize* rhs = other.as(); if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareValue(op->value_index, rhs->value_index) != 0) return; - if (CompareType(op->type, rhs->type) != 0) return; + if (CompareType(op->dtype, rhs->dtype) != 0) return; if (CompareRegion(op->bounds, rhs->bounds) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return; } @@ -175,7 +175,7 @@ class IRDeepCompare : const Prefetch* rhs = other.as(); if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareValue(op->value_index, rhs->value_index) != 0) return; - if (CompareType(op->type, rhs->type) != 0) return; + if (CompareType(op->dtype, rhs->dtype) != 0) return; if (CompareRegion(op->bounds, rhs->bounds) != 0) return; } @@ -369,7 +369,7 @@ class IRDeepCompare : return order_; } - int CompareType(const Type& lhs, const Type& rhs) { + int CompareType(const DataType& lhs, const DataType& rhs) { if (order_ != 0) return order_; if (lhs == rhs) return order_; if (CompareValue(lhs.code(), rhs.code()) != 0) return order_; diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index f79a1ab8fe3b..139c467155ba 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -179,7 +179,7 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { return s; } else { return Allocate::make( - op->buffer_var, op->type, + op->buffer_var, op->dtype, new_extents, condition, body, new_expr, op->free_function); } @@ -247,7 +247,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { return s; } else { return Realize::make(op->func, op->value_index, - op->type, new_bounds, + op->dtype, new_bounds, condition, body); } } @@ -273,7 +273,7 @@ Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) { return s; } else { return Prefetch::make(op->func, op->value_index, - op->type, new_bounds); + op->dtype, new_bounds); } } @@ -358,7 +358,7 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) { if (index.same_as(op->index) && pred.same_as(op->predicate)) { return e; } else { - return Load::make(op->type, op->buffer_var, index, pred); + return Load::make(op->dtype, op->buffer_var, index, pred); } } @@ -378,7 +378,7 @@ Expr IRMutator::Mutate_(const Call* op, const Expr& e) { if (op->args.same_as(new_args)) { return e; } else { - return Call::make(op->type, op->name, new_args, op->call_type, + return Call::make(op->dtype, op->name, new_args, op->call_type, op->func, op->value_index); } } @@ -432,7 +432,7 @@ Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { if (value.same_as(op->value)) { return e; } else { - return Cast::make(op->type, value); + return Cast::make(op->dtype, value); } } diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 690feca135ef..0f8bb990c2d3 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -89,12 +89,12 @@ inline Array UpdateArray(Array arr, F fupdate) { * \return the get expression. */ inline Expr TVMStructGet( - Type dtype, Var handle, int index, + DataType dtype, Var handle, int index, intrinsic::TVMStructFieldKind kind) { Array args ={ handle, - make_const(Int(32), index), - make_const(Int(32), static_cast(kind))}; + make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind))}; return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic); } @@ -104,10 +104,10 @@ inline Expr TVMStructGet( * \param dtype The data type. * \param offset the offset index. */ -inline Expr AddressOffset(Var handle, Type dtype, int offset) { +inline Expr AddressOffset(Var handle, DataType dtype, int offset) { return Call::make( - Handle(), intrinsic::tvm_address_of, - {Load::make(dtype, handle, make_const(Int(32), offset * dtype.lanes()), + DataType::Handle(), intrinsic::tvm_address_of, + {Load::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), const_true(dtype.lanes()))}, Call::PureIntrinsic); } @@ -118,13 +118,13 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) { * \param dtype The data type. * \param offset the offset index. */ -inline Expr AddressOffset(Var handle, Type dtype, Expr offset) { +inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) { if (dtype.lanes() != 1) { - offset = offset * make_const(offset.type(), dtype.lanes()); - offset = Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes()); + offset = offset * make_const(offset.dtype(), dtype.lanes()); + offset = Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); } return Call::make( - Handle(), intrinsic::tvm_address_of, + DataType::Handle(), intrinsic::tvm_address_of, {Load::make(dtype, handle, offset, const_true(dtype.lanes()))}, Call::PureIntrinsic); @@ -143,11 +143,11 @@ inline Stmt TVMStructSet( intrinsic::TVMStructFieldKind kind, Expr value) { Array args ={ handle, - make_const(Int(32), index), - make_const(Int(32), static_cast(kind)), + make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind)), value}; return Evaluate::make( - Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic)); + Call::make(DataType::Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic)); } /*! @@ -155,13 +155,13 @@ inline Stmt TVMStructSet( * \param t The original type. * \return The corresponding API type. */ -inline Type APIType(Type t) { +inline DataType APIType(DataType t) { if (t.is_handle()) return t; CHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; - if (t.is_uint() || t.is_int()) return Int(64); + if (t.is_uint() || t.is_int()) return DataType::Int(64); CHECK(t.is_float()); - return Float(64); + return DataType::Float(64); } /*! @@ -170,7 +170,7 @@ inline Type APIType(Type t) { * \param const_size The constant size of the array. * \return the alignment */ -inline int GetTempAllocaAlignment(Type type, int32_t const_size) { +inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { int align = runtime::kTempAllocaAlignment; if (const_size > 0) { int64_t const_s = static_cast(const_size) * type.bits() * type.lanes() / 8; diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index adcaaebd6d6e..cfc6e5a7fc68 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -57,7 +57,7 @@ class AttrScopeLifter : public IRMutator { attr_node_ = NodeRef(); attr_value_ = Expr(); return Allocate::make( - op->buffer_var, op->type, + op->buffer_var, op->dtype, op->extents, op->condition, body, op->new_expr, op->free_function); } else { @@ -198,7 +198,7 @@ class AttrScopeLifter : public IRMutator { static bool ValueSame(const Expr& a, const Expr& b) { if (a.same_as(b)) return true; if (a->type_index() != b->type_index()) return false; - if (a.type() != b.type()) return false; + if (a.dtype() != b.dtype()) return false; if (const IntImm* op = a.as()) { return op->value == b.as()->value; } diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index ef5cc9c4fa9f..1ac386767ae3 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -181,7 +181,7 @@ class PartitionFinder : public IRVisitor { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); const Variable* var = thread_axis->var.get(); - IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value)); + IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); hint_map_.insert({var, dom}); relax_map_.insert({var, dom}); IRVisitor::Visit_(op); @@ -351,12 +351,12 @@ class LoopPartitioner : public IRMutator { if (scope.rank == 1) { // threadIdx should be put into relax map, in case of divergence. relax_map_.insert({var.get(), - IntSet::interval(make_zero(var.type()), op->value - 1)}); + IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = IRMutator::Mutate_(op, stmt); relax_map_.erase(var.get()); } else { hint_map_.insert({var.get(), - IntSet::interval(make_zero(var.type()), op->value - 1)}); + IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = IRMutator::Mutate_(op, stmt); hint_map_.erase(var.get()); } @@ -595,9 +595,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node, inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { const For *for_node = static_cast(node); CHECK(for_node); - if (analyzer_.CanProve(extent == make_const(Int(32), 1))) { + if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore - return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}}); + return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { return For::make(for_node->loop_var, 0, extent, for_node->for_type, for_node->device_api, body); diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index 3e71868ce3bc..e24cddd97f25 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -42,8 +42,8 @@ class CustomDatatypesLowerer : public IRMutator { explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} inline Expr Mutate_(const Cast* op, const Expr& e) final { - auto type_code = op->type.code(); - auto src_type_code = op->value.type().code(); + auto type_code = op->dtype.code(); + auto src_type_code = op->value.dtype().code(); // If either datatype is a registered custom datatype, we must lower. bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || datatype::Registry::Global()->GetTypeRegistered(src_type_code); @@ -60,7 +60,7 @@ class CustomDatatypesLowerer : public IRMutator { } inline Expr Mutate_(const FloatImm* imm, const Expr& e) final { - auto type_code = imm->type.code(); + auto type_code = imm->dtype.code(); if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); CHECK(lower) << "FloatImm lowering function for target " << target_ << " type " @@ -71,12 +71,12 @@ class CustomDatatypesLowerer : public IRMutator { } inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final { - bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->type.code()); + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code()); Stmt stmt = IRMutator::Mutate_(allocate, s); allocate = stmt.as(); if (toBeLowered) { - auto new_allocate_type = UInt(allocate->type.bits(), allocate->type.lanes()); + auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents, allocate->condition, allocate->body, allocate->new_expr, allocate->free_function); @@ -85,11 +85,11 @@ class CustomDatatypesLowerer : public IRMutator { } inline Expr Mutate_(const Load* load, const Expr& e) final { - bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->type.code()); + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); Expr expr = IRMutator::Mutate_(load, e); load = expr.as(); if (toBeLowered) { - auto new_load_type = UInt(load->type.bits()); + auto new_load_type = DataType::UInt(load->dtype.bits()); return Load::make(new_load_type, load->buffer_var, load->index, load->predicate); } return expr; @@ -97,7 +97,7 @@ class CustomDatatypesLowerer : public IRMutator { #define DEFINE_MUTATE__(OP) \ inline Expr Mutate_(const OP* op, const Expr& e) final { \ - auto type_code = op->type.code(); \ + auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ Expr expr = IRMutator::Mutate_(op, e); \ op = expr.as(); \ diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index c2a2fe6f5942..f0b0b3c36d42 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -76,7 +76,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { op = ret.as(); if (op == nullptr) return ret; int shift; - const DataType& dtype = op->type; + const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); if (support_bitwise_op_ && @@ -97,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // condition on b >= 0. // truncmod(a, b) < 0 will implies ceildiv, // So we need to correct these cases. - if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) { + if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { @@ -123,7 +123,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { if (op == nullptr) return ret; // Lower floordiv to native truncdiv. int shift; - const DataType& dtype = op->type; + const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); if (support_bitwise_op_ && @@ -144,7 +144,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // mod(a, b) < 0 will imply we are doing ceildiv, // So we need to correct these cases. Expr rmod = truncmod(op->a, op->b); - if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) { + if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { // (rmod >> shift) & b // -> (rmod >= 0 ? 0: -1) & b // -> rmod >= 0 ? 0 : b @@ -207,23 +207,23 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { if (const Cast* cast = bcast->value.as()) { auto should_swap = [&]() { // Maintain behaviour (int8 -> int16, fp16 -> fp32). - if (cast->type.bits() == cast->value.type().bits() * 2) { + if (cast->dtype.bits() == cast->value.dtype().bits() * 2) { return true; } // Check both operands are integer-like. - if (!cast->type.is_uint() && !cast->type.is_int()) { + if (!cast->dtype.is_uint() && !cast->dtype.is_int()) { return false; } - if (!cast->value.type().is_uint() && !cast->value.type().is_int()) { + if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) { return false; } // If both are integer-like, swap if we have a widening cast. - return cast->type.bits() > cast->value.type().bits(); + return cast->dtype.bits() > cast->value.dtype().bits(); }; if (should_swap()) { Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); - return Cast::make(bcast->type, new_bcast); + return Cast::make(bcast->dtype, new_bcast); } } } @@ -236,9 +236,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { Expr lhs = SwapBroadcastCast(a); Expr rhs = SwapBroadcastCast(b); - if (fma_ != nullptr && op->type.is_float()) { + if (fma_ != nullptr && op->dtype.is_float()) { Expr r = (*fma_)(Call::make( - op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic)); + op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic)); if (r.defined()) return this->Mutate(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index e8ea52e886cc..2a121180d695 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -83,7 +83,7 @@ class ThreadAllreduceBuilder final : public IRMutator { stmt = AttrStmt::make( repl->buffer_var, attr::volatile_scope, 1, op->body); stmt = Allocate::make( - repl->buffer_var, repl->type, + repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); stmt = AttrStmt::make( repl->buffer_var, attr::storage_scope, @@ -125,14 +125,14 @@ class ThreadAllreduceBuilder final : public IRMutator { CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; std::vector values(size); - std::vector types(size); + std::vector types(size); Expr cond = call->args[size+1]; for (size_t idx = 0; idx < size; ++idx) { values[idx] = call->args[1+idx]; if (!is_one(cond)) { values[idx] = Select::make(cond, values[idx], inits[idx]); } - types[idx] = values[idx].type(); + types[idx] = values[idx].dtype(); } std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { @@ -197,7 +197,7 @@ class ThreadAllreduceBuilder final : public IRMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle()); + shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); Expr pred = const_true(types[idx].lanes()); seq.emplace_back(Store::make( shared_bufs[idx], values[idx], @@ -212,7 +212,7 @@ class ThreadAllreduceBuilder final : public IRMutator { Expr pred = const_true(types[idx].lanes()); load_remap_[buffers[idx]] = Load::make( types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred); + BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); alloc_remap_[buffers[idx]] = Allocate::make( shared_bufs[idx], types[idx], {Expr(group_extent), Expr(reduce_extent)}, @@ -222,7 +222,7 @@ class ThreadAllreduceBuilder final : public IRMutator { } // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode *combiner, - const std::vector& types, + const std::vector& types, const Array& shared_bufs, Expr reduce_index, Expr group_index, @@ -293,7 +293,7 @@ class ThreadAllreduceBuilder final : public IRMutator { int& total_extent = *out_total_extent; total_extent = 1; if (tvec.size() == 0) { - return make_zero(Int(32)); + return make_zero(DataType::Int(32)); } Expr ret; @@ -311,7 +311,7 @@ class ThreadAllreduceBuilder final : public IRMutator { // sync thread op. static Stmt SyncThread(const std::string& sync) { return Evaluate::make( - Call::make(Int(32), intrinsic::tvm_storage_sync, + Call::make(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm::make(sync)}, Call::Intrinsic)); } diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index e73956cb3d62..c8c8fa9c62d0 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -33,12 +33,12 @@ namespace ir { inline Expr ConstInt32(size_t index) { CHECK_LE(index, std::numeric_limits::max()); - return make_const(Int(32), static_cast(index)); + return make_const(DataType::Int(32), static_cast(index)); } inline Expr StackAlloca(std::string type, size_t num) { Array args = {StringImm::make(type), ConstInt32(num)}; - return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic); + return Call::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic); } // Calculate the statistics of packed function. @@ -46,10 +46,10 @@ inline Expr StackAlloca(std::string type, size_t num) { class BuiltinLower : public IRMutator { public: Stmt Build(Stmt stmt) { - stack_shape_ = Var("stack_shape", Handle()); - stack_array_ = Var("stack_array", Handle()); - stack_value_ = Var("stack_value", Handle()); - stack_tcode_ = Var("stack_tcode", Handle()); + stack_shape_ = Var("stack_shape", DataType::Handle()); + stack_array_ = Var("stack_array", DataType::Handle()); + stack_value_ = Var("stack_value", DataType::Handle()); + stack_tcode_ = Var("stack_tcode", DataType::Handle()); stmt = this->Mutate(stmt); if (max_shape_stack_ != 0) { stmt = LetStmt::make( @@ -86,7 +86,7 @@ class BuiltinLower : public IRMutator { if (op->new_expr.defined()) return stmt; // Get constant allocation bound. int64_t dev_type; - int64_t nbytes = GetVectorBytes(op->type); + int64_t nbytes = GetVectorBytes(op->dtype); if (device_type_.defined()) { if (arith::GetConst(device_type_, &dev_type)) { if (dev_type == kDLCPU) { @@ -97,18 +97,18 @@ class BuiltinLower : public IRMutator { } } } - Expr total_bytes = make_const(op->extents[0].type(), nbytes); + Expr total_bytes = make_const(op->extents[0].dtype(), nbytes); for (size_t i = 0; i < op->extents.size(); ++i) { total_bytes = total_bytes * op->extents[i]; } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = Evaluate::make(Call::make(Int(32), + Stmt throw_last_error = Evaluate::make(Call::make(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, Call::Intrinsic)); Stmt body = Block::make( - IfThenElse::make(Call::make(Bool(1), + IfThenElse::make(Call::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, {op->buffer_var}, Call::PureIntrinsic), throw_last_error), @@ -116,27 +116,27 @@ class BuiltinLower : public IRMutator { Stmt alloca = LetStmt::make( op->buffer_var, - Call::make(op->buffer_var.type(), + Call::make(op->buffer_var.dtype(), "TVMBackendAllocWorkspace", - {cast(Int(32), device_type_), - cast(Int(32), device_id_), - cast(UInt(64), total_bytes), - IntImm::make(Int(32), op->type.code()), - IntImm::make(Int(32), op->type.bits())}, + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), + cast(DataType::UInt(64), total_bytes), + IntImm::make(DataType::Int(32), op->dtype.code()), + IntImm::make(DataType::Int(32), op->dtype.bits())}, Call::Extern), body); - Expr free_op = Call::make(Int(32), + Expr free_op = Call::make(DataType::Int(32), "TVMBackendFreeWorkspace", - {cast(Int(32), device_type_), - cast(Int(32), device_id_), + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), op->buffer_var}, Call::Extern); - Stmt free_stmt = IfThenElse::make(free_op != make_zero(Int(32)), throw_last_error); + Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error); body = Block::make(alloca, free_stmt); body = AttrStmt::make( op->buffer_var, attr::storage_alignment, - make_const(Int(32), runtime::kTempAllocaAlignment), + make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); return body; } @@ -164,7 +164,7 @@ class BuiltinLower : public IRMutator { } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) { return MakeArray(op, e); } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { - return make_zero(op->type); + return make_zero(op->dtype); } else { return IRMutator::Mutate_(op, e); } @@ -177,10 +177,10 @@ class BuiltinLower : public IRMutator { op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { prep_seq_.emplace_back( - Store::make(stack_shape_, cast(Int(64), op->args[i]), + Store::make(stack_shape_, cast(DataType::Int(64), op->args[i]), ConstInt32(stack_begin +i), const_true(1))); } - return AddressOffset(stack_shape_, Int(64), stack_begin); + return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); } // make array Expr MakeArray(const Call* op, const Expr& e) { @@ -194,40 +194,40 @@ class BuiltinLower : public IRMutator { TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); Expr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { - strides = make_zero(Handle()); + strides = make_zero(DataType::Handle()); } prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); - Type dtype = op->args[4].type(); + DataType dtype = op->args[4].dtype(); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode, - make_const(UInt(8), static_cast(dtype.code())))); + make_const(DataType::UInt(8), static_cast(dtype.code())))); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, - make_const(UInt(8), dtype.bits()))); + make_const(DataType::UInt(8), dtype.bits()))); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, - make_const(UInt(16), dtype.lanes()))); + make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); Expr byte_offset = op->args[5]; if (!is_zero(byte_offset)) { - byte_offset = byte_offset * make_const(byte_offset.type(), data_bytes); + byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, - cast(UInt(64), byte_offset))); + cast(DataType::UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, - cast(Int(32), device_id_))); + cast(DataType::Int(32), device_id_))); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, - cast(Int(32), device_type_))); - return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr); + cast(DataType::Int(32), device_type_))); + return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packed. Expr MakeCallPacked(const Call* op, const Expr& e) { @@ -241,8 +241,8 @@ class BuiltinLower : public IRMutator { for (size_t i = 1; i < op->args.size(); ++i) { Expr stack_index = ConstInt32(arg_stack_begin + i - 1); Expr arg = op->args[i]; - Type t = arg.type(); - Type api_type = APIType(t); + DataType t = arg.dtype(); + DataType api_type = APIType(t); if (t != api_type) { arg = Cast::make(api_type, arg); } @@ -274,7 +274,7 @@ class BuiltinLower : public IRMutator { ConstInt32(arg_stack_begin + op->args.size() - 1) }; return Call::make( - Int(32), intrinsic::tvm_call_packed_lowered, + DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, Call::Intrinsic); } @@ -290,8 +290,8 @@ class BuiltinLower : public IRMutator { for (size_t i = 1; i < op->args.size(); ++i) { Expr stack_index = ConstInt32(arg_stack_begin + i - 1); Expr arg = op->args[i]; - Type t = arg.type(); - Type api_type = APIType(t); + DataType t = arg.dtype(); + DataType api_type = APIType(t); if (t != api_type) { arg = Cast::make(api_type, arg); } @@ -324,7 +324,7 @@ class BuiltinLower : public IRMutator { op->args[args_size - 1] }; return Call::make( - op->type, intrinsic::tvm_call_trace_packed_lowered, + op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, Call::Intrinsic); } diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 393605e85b8a..0ed2b6232fc1 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -94,11 +94,11 @@ class WarpStoreCoeffFinder : private IRVisitor { /// Visitor implementation void Visit_(const Store *op) final { if (op->buffer_var.get() == buffer_) { - if (op->value.type().lanes() == 1) { + if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); } else { Expr base; - CHECK(GetRamp1Base(op->index, op->value.type().lanes(), &base)) + CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base)) << "LowerWarpMemory failed due to store index=" << op->index << ", can only handle continuous store"; UpdatePattern(base); @@ -196,7 +196,7 @@ class WarpAccessRewriter : protected IRMutator { int alloc_size = op->constant_allocation_size(); CHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; - alloc_size *= op->type.lanes(); + alloc_size *= op->dtype.lanes(); warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var; warp_coeff_ = WarpStoreCoeffFinder( buffer_, warp_index_, analyzer_).Find(op->body); @@ -205,8 +205,8 @@ class WarpAccessRewriter : protected IRMutator { warp_group_ = alloc_size / (warp_size_ * warp_coeff_); return Allocate::make( op->buffer_var, - op->type, - {make_const(Int(32), alloc_size / warp_size_)}, + op->dtype, + {make_const(DataType::Int(32), alloc_size / warp_size_)}, op->condition, this->Mutate(op->body)); } @@ -237,8 +237,8 @@ class WarpAccessRewriter : protected IRMutator { << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; Expr load_value = Load::make( - op->type, op->buffer_var, local_index, op->predicate); - return Call::make(load_value.type(), + op->dtype, op->buffer_var, local_index, op->predicate); + return Call::make(load_value.dtype(), intrinsic::tvm_warp_shuffle, {load_value, group}, Call::Intrinsic); @@ -252,15 +252,15 @@ class WarpAccessRewriter : protected IRMutator { // source index is the corresponding source index // in this access pattern. std::pair SplitIndexByGroup(const Expr& index) { - if (index.type().lanes() != 1) { + if (index.dtype().lanes() != 1) { Expr base, local_index, group; - CHECK(GetRamp1Base(index, index.type().lanes(), &base)); + CHECK(GetRamp1Base(index, index.dtype().lanes(), &base)); std::tie(local_index, group) = SplitIndexByGroup(base); local_index = - Ramp::make(local_index, make_const(local_index.type(), 1), index.type().lanes()); + Ramp::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); } - Expr m = make_const(index.type(), warp_coeff_); + Expr m = make_const(index.dtype(), warp_coeff_); // simple case, warp index is on the highest. if (warp_group_ == 1) { @@ -269,9 +269,9 @@ class WarpAccessRewriter : protected IRMutator { return std::make_pair(x, z); } else { Expr x = analyzer_->canonical_simplify(indexmod(index, m)); - Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_); + Expr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_); y = y * m + x; - Expr z = indexdiv(indexmod(index, make_const(index.type(), warp_coeff_ * warp_size_)), + Expr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)), m); return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 4d9c92bb428e..74b8f891299a 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -51,9 +51,9 @@ LoweredFunc MakeAPI(Stmt body, int num_packed_args = num_args - num_unpacked_args; // Data field definitions // The packed fields - Var v_packed_args("args", Handle()); - Var v_packed_arg_type_ids("arg_type_ids", Handle()); - Var v_num_packed_args("num_args", Int(32)); + Var v_packed_args("args", DataType::Handle()); + Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle()); + Var v_num_packed_args("num_args", DataType::Int(32)); // The arguments of the function. Array args; // The device context @@ -66,12 +66,12 @@ LoweredFunc MakeAPI(Stmt body, // --------------------------- // local function definitions // load i-th argument as type t - auto f_arg_value = [&](Type t, int i) { + auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, - IntImm::make(Int(32), i), - IntImm::make(Int(32), intrinsic::kTVMValueContent)}; + IntImm::make(DataType::Int(32), i), + IntImm::make(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version - Type api_type = APIType(t); + DataType api_type = APIType(t); Expr res = Call::make( api_type, intrinsic::tvm_struct_get, call_args, Call::PureIntrinsic); @@ -86,7 +86,7 @@ LoweredFunc MakeAPI(Stmt body, std::ostringstream os; os << "arg" << i; const Variable* v = api_args[i].as(); - return Var(os.str(), v ? v->type: Handle()); + return Var(os.str(), v ? v->dtype: DataType::Handle()); }; // --------------------------- // start of logics @@ -110,14 +110,15 @@ LoweredFunc MakeAPI(Stmt body, if (i < num_packed_args) { // Value loads seq_init.emplace_back(LetStmt::make( - v_arg, f_arg_value(v_arg.type(), i), nop)); + v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks - Var tcode(v_arg->name_hint + ".code", Int(32)); + Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); seq_init.emplace_back(LetStmt::make( tcode, Load::make( - Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)), + DataType::Int(32), v_packed_arg_type_ids, + IntImm::make(DataType::Int(32), i), const_true(1)), nop)); - Type t = v_arg.type(); + DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; msg << name << ": Expect arg[" << i << "] to be pointer"; @@ -174,7 +175,7 @@ LoweredFunc MakeAPI(Stmt body, n->is_packed_func = num_unpacked_args == 0; n->is_restricted = is_restricted; body = AttrStmt::make( - make_zero(Int(32)), attr::compute_scope, + make_zero(DataType::Int(32)), attr::compute_scope, StringImm::make(name + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { @@ -186,7 +187,7 @@ LoweredFunc MakeAPI(Stmt body, node, attr::device_context_type, device_type, nop)); Stmt set_device = IfThenElse::make( device_type != kDLCPU, Evaluate::make(Call::make( - Int(32), intrinsic::tvm_call_packed, + DataType::Int(32), intrinsic::tvm_call_packed, {StringImm::make(runtime::symbol::tvm_set_device), device_type, device_id}, Call::Intrinsic))); body = Block::make(set_device, body); @@ -215,7 +216,7 @@ class DeviceTypeBinder: public IRMutator { if (op->attr_key == attr::device_context_type) { if (const Variable* var = op->value.as()) { var_ = var; - Expr value = make_const(op->value.type(), device_type_); + Expr value = make_const(op->value.dtype(), device_type_); Stmt body = IRMutator::Mutate_(op, s); var_ = nullptr; std::ostringstream os; @@ -245,14 +246,14 @@ class DeviceTypeBinder: public IRMutator { Expr res = IRMutator::Mutate_(op, e); op = res.as(); if (ir::Equal(op->a, op->b)) { - return make_const(op->type, false); + return make_const(op->dtype, false); } return res; } Expr Mutate_(const Variable* op, const Expr& e) final { if (op == var_) { - return make_const(op->type, device_type_); + return make_const(op->dtype, device_type_); } else { return e; } diff --git a/src/pass/narrow_channel_access.cc b/src/pass/narrow_channel_access.cc index 13c4e5141e8d..6687512ec739 100644 --- a/src/pass/narrow_channel_access.cc +++ b/src/pass/narrow_channel_access.cc @@ -93,7 +93,7 @@ class ChannelAccessIndexRewriter : public IRMutator { op = expr.as(); if (read_access_ && buf_var_ == op->buffer_var.get()) { return Load::make( - op->type, op->buffer_var, ir::Simplify(op->index - min_), + op->dtype, op->buffer_var, ir::Simplify(op->index - min_), op->predicate); } else { return expr; diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 25ed03963524..43e3005aef64 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -115,12 +115,12 @@ class UnsafeSelectRewriter : public IRMutator { Expr expr = IRMutator::Mutate_(op, e); op = expr.as