diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index d135d30a8fbb..12acfc3d4d7d 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -105,7 +105,7 @@ class ConstIntBoundAnalyzer { * \param expr The expression of interest. * \return the result of the analysis. */ - ConstIntBound operator()(const Expr& expr); + ConstIntBound operator()(const PrimExpr& expr); /*! * \brief Update constant int bound information of var. @@ -136,7 +136,7 @@ class ConstIntBoundAnalyzer { * * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - std::function EnterConstraint(const Expr& constraint); + std::function EnterConstraint(const PrimExpr& constraint); struct Entry; class Impl; /*! \brief Internal impl */ @@ -192,7 +192,7 @@ class ModularSetAnalyzer { * \param expr The expression of interest. * \return the result of the analysis. */ - ModularSet operator()(const Expr& expr); + ModularSet operator()(const PrimExpr& expr); /*! * \brief Update constant int bound information of var. * @@ -215,7 +215,7 @@ class ModularSetAnalyzer { * * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - std::function EnterConstraint(const Expr& constraint); + std::function EnterConstraint(const PrimExpr& constraint); struct Entry; class Impl; /*! \brief Internal impl */ @@ -232,7 +232,7 @@ class RewriteSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - Expr operator()(const Expr& expr); + PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -242,10 +242,10 @@ class RewriteSimplifier { * \param override Whether do we allow override of existing information. */ void Update(const Var& var, - const Expr& new_expr, + const PrimExpr& new_expr, bool override = false); - std::function EnterConstraint(const Expr& constraint); + std::function EnterConstraint(const PrimExpr& constraint); private: friend class Analyzer; @@ -268,7 +268,7 @@ class CanonicalSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - Expr operator()(const Expr& expr); + PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -278,7 +278,7 @@ class CanonicalSimplifier { * \param override Whether do we allow override of existing information. */ void Update(const Var& var, - const Expr& new_expr, + const PrimExpr& new_expr, bool override = false); private: @@ -316,7 +316,7 @@ class ConstraintContext { * \param analyzer The analyzer. * \param constraint The constraint to be applied. */ - ConstraintContext(Analyzer* analyzer, Expr constraint) + ConstraintContext(Analyzer* analyzer, PrimExpr constraint) : analyzer_(analyzer), constraint_(constraint) {} // enter the scope. void EnterWithScope(); @@ -325,7 +325,7 @@ class ConstraintContext { /*! \brief The analyzer */ Analyzer* analyzer_; /*! \brief The constraint */ - Expr constraint_; + PrimExpr constraint_; /*! \brief function to be called in recovery */ std::function exit_; }; @@ -375,9 +375,9 @@ class IntSet : public ObjectRef { */ Range cover_range(Range max_range) const; /*! \return Lower bound of the set */ - Expr min() const; + PrimExpr min() const; /*! \return upper bound of the set */ - Expr max() const; + PrimExpr max() const; /*! \return Whether the set represent nothing */ bool is_nothing() const; /*! \return Whether the set represent everything */ @@ -398,7 +398,7 @@ class IntSet : public ObjectRef { * \brief The single point value, call only if is_single_point is true * \return The point value. */ - Expr point_value() const; + PrimExpr point_value() const; /*! * \brief Try to match IntSet with range r. * @@ -415,13 +415,13 @@ class IntSet : public ObjectRef { * \param point The point in the set. * \return construct a single point set */ - static IntSet single_point(Expr point); + static IntSet single_point(PrimExpr point); /*! * \brief construct a integer set from vector expression. * \param vec The vector expression, can also be single point. * \return The result set containing the indices in the vector. */ - static IntSet vector(Expr vec); + static IntSet vector(PrimExpr vec); /*! * \brief Construct a set representing a range. * \param r The range @@ -434,7 +434,7 @@ class IntSet : public ObjectRef { * \param max The maximum value of the interval. * \return constructed set. */ - static IntSet interval(Expr min, Expr max); + static IntSet interval(PrimExpr min, PrimExpr max); }; /*! @@ -450,7 +450,7 @@ class IntSetAnalyzer { * \param dom_map The domain map to indicate which variable to relax. * \return the result of the analysis. */ - IntSet operator()(const Expr& expr, const Map& dom_map); + IntSet operator()(const PrimExpr& expr, const Map& dom_map); private: friend class Analyzer; @@ -499,7 +499,7 @@ class Analyzer { * \param var The variable. * \param expr The expression we bind to. */ - void Bind(const VarExpr& var, const Expr& expr); + void Bind(const Var& var, const PrimExpr& expr); /*! * \brief Notify all the sub-analyzers that var * is created and binded to a range. @@ -509,7 +509,7 @@ class Analyzer { * \param var The variable. * \param range The range we bind to. */ - void Bind(const VarExpr& var, const Range& range); + void Bind(const Var& var, const Range& range); /*! * \brief Whether can we prove expr >= val. @@ -522,7 +522,7 @@ class Analyzer { * * \note Analyzer will call into sub-analyzers to get the result. */ - bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); + bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound); /*! * \brief Whether can we prove condition. * @@ -531,7 +531,7 @@ class Analyzer { * * \note Analyzer will call into sub-analyzers to get the result. */ - bool CanProve(const Expr& cond); + bool CanProve(const PrimExpr& cond); /*! * \brief Simplify expr. * @@ -540,7 +540,7 @@ class Analyzer { * * \note Analyzer will call into sub-analyzers to get the result. */ - Expr Simplify(const Expr& expr); + PrimExpr Simplify(const PrimExpr& expr); }; //----------------------------------------------- @@ -554,7 +554,7 @@ class Analyzer { * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(Expr e, +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map @@ -563,7 +563,7 @@ IntSet EvalSet(Expr e, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(Expr e, +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); /*! @@ -598,7 +598,7 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ -using ExprIntSetMap = std::unordered_map; +using ExprIntSetMap = std::unordered_map; /*! * \brief Find the integer set of every sub-expression, given the * domain of each iteration variables. @@ -608,7 +608,7 @@ using ExprIntSetMap = std::unordered_map; * \return the map from the expression to its possible value. */ ExprIntSetMap EvalSetForEachSubExpr( - Expr e, + PrimExpr e, const std::unordered_map& dom_map); /*! @@ -640,7 +640,7 @@ IntSet Intersect(const Array& sets); * The deduce bound must implies e for all value in relax_map * \return An integer set that always satisfies the condition. */ -IntSet DeduceBound(Expr v, Expr cond, +IntSet DeduceBound(PrimExpr v, PrimExpr cond, const Map& hint_map, const Map& relax_map); /*! @@ -653,7 +653,7 @@ IntSet DeduceBound(Expr v, Expr cond, * The deduce bound mush implies e for all value in relax_map * \return An integer set that always satisfies the condition. */ -IntSet DeduceBound(Expr v, Expr cond, +IntSet DeduceBound(PrimExpr v, PrimExpr cond, const std::unordered_map& hint_map, const std::unordered_map& relax_map); @@ -676,7 +676,7 @@ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool * \param vars List of variables to be used in detection. * \return [coeff[i]] if it is possible, empty array if it is not. */ -Array DetectLinearEquation(const Expr& e, +Array DetectLinearEquation(const PrimExpr& e, const Array& vars); /*! @@ -687,7 +687,7 @@ Array DetectLinearEquation(const Expr& e, * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value * return empty if the e does not match the pattern. */ -Array DetectClipBound(const Expr& e, +Array DetectClipBound(const PrimExpr& e, const Array& vars); // implementation diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 13c8b30c14df..ab9a711d28d8 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -486,7 +486,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { if (val.type_code() == kDLInt) { *ptr = static_cast(val.value().v_int64); } else { - Expr expr = val; + PrimExpr expr = val; CHECK(expr.defined()); if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); @@ -502,7 +502,7 @@ inline void SetValue(std::string* ptr, const TVMArgValue& val) { if (val.type_code() == kStr) { *ptr = val.operator std::string(); } else { - Expr expr = val; + PrimExpr expr = val; const ir::StringImmNode* op = expr.as(); CHECK(op != nullptr); *ptr = op->value; @@ -517,7 +517,7 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { *ptr = val.operator double(); } else { - Expr expr = val; + PrimExpr expr = val; CHECK(expr.defined()); if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 44c791863153..284e37063ab6 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -66,7 +66,7 @@ class Buffer : public ObjectRef { * If stride is not needed in the slice, it won't be presented * \return the result buffer. */ - TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; + TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; /*! * \brief Get access ptr to the entire buffer. * \param access_mask The access mask @@ -74,22 +74,22 @@ class Buffer : public ObjectRef { * \param content_lanes The number of lanes for the (data) type. * \param offset The offset of ptr. */ - TVM_DLL Expr access_ptr(int access_mask, + TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), int content_lanes = 1, - Expr offset = make_const(DataType::Int(32), 0)) const; + PrimExpr offset = make_const(DataType::Int(32), 0)) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index * \param dtype The data type to be loaded. */ - TVM_DLL Expr vload(Array begin, DataType dtype) const; + TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; /*! * \brief Create a Stmt that does a vector store at begin index. * \param begin The beginning index * \param value The value to be stored. */ - TVM_DLL Stmt vstore(Array begin, Expr value) const; + TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -112,14 +112,14 @@ class BufferNode : public Object { /*! \brief data type in the content of the tensor */ DataType dtype; /*! \brief The shape of the buffer */ - Array shape; + Array shape; /*! * \brief The strides of each dimension * This can be an empty array, indicating array is contiguous */ - Array strides; + Array strides; /*! \brief The offset in terms of number of dtype elements (including lanes) */ - Expr elem_offset; + PrimExpr elem_offset; // Meta data /*! \brief optional name of the buffer */ std::string name; @@ -159,9 +159,9 @@ class BufferNode : public Object { // A default value will be picked. TVM_DLL static Buffer make(Var ptr, DataType dtype, - Array shape, - Array strides, - Expr elem_offset, + Array shape, + Array strides, + PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, @@ -184,7 +184,7 @@ inline const BufferNode* Buffer::operator->() const { * \return The created buffer. * \sa BufferNode::make for complete constructor. */ -TVM_DLL Buffer decl_buffer(Array shape, +TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), std::string name = "buffer"); } // namespace tvm diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 5078621e4bda..8b49fb78d6c3 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -52,11 +52,11 @@ class TargetNode : public Object { /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ int thread_warp_size = 1; /*! \brief Keys for this target */ - Array keys_array; + Array keys_array; /*! \brief Options for this target */ - Array options_array; + Array options_array; /*! \brief Collection of imported libs */ - Array libs_array; + Array libs_array; /*! \return the full device string to pass to codegen::Build */ TVM_DLL const std::string& str() const; diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index 8c7247ff860b..d49320c0629e 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -316,9 +316,9 @@ class BijectiveLayoutNode : public Object { /*! \brief Describes how source axes can be mapped to the destination axes, * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n */ - Array forward_rule; + Array forward_rule; /*! \brief Describes how destination axes can be mapped to the source axes */ - Array backward_rule; + Array backward_rule; /*! \brief The source layout */ Layout src_layout; @@ -350,13 +350,13 @@ class BijectiveLayout : public ObjectRef { explicit BijectiveLayout(ObjectPtr n) : ObjectRef(n) {} // Given the source shape, infer the destination shape. - TVM_DLL Array ForwardShape(const Array& shape) const; + TVM_DLL Array ForwardShape(const Array& shape) const; // Given the destination shape, recover the source shape. - TVM_DLL Array BackwardShape(const Array& dst_shape) const; + TVM_DLL Array BackwardShape(const Array& dst_shape) const; // Given the destination indices, infer the destination indices. - TVM_DLL Array ForwardIndex(const Array& index) const; + TVM_DLL Array ForwardIndex(const Array& index) const; // Given the destination indices, recover the source indices. - TVM_DLL Array BackwardIndex(const Array& dst_index) const; + TVM_DLL Array BackwardIndex(const Array& dst_index) const; /*! * \brief access the internal node container diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 64d7547dbad5..976af619256a 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -37,44 +37,57 @@ namespace tvm { -/*! \brief Base node of all expressions. */ -class ExprNode : public Object { +/*! + * \brief Base node of all primitive expressions. + * + * A primitive expression deals with low-level + * POD data types and handles without + * doing life-cycle management for objects. + * + * PrimExpr is used in the low-level code + * optimizations and integer analysis. + * + * \sa PrimExpr + */ +class PrimExprNode : public Object { public: /*! \brief The data type of the expression. */ DataType dtype; - static constexpr const char* _type_key = "Expr"; - TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, Object); + static constexpr const char* _type_key = "PrimExpr"; + TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object); }; -/*! \brief Container of all expressions. */ -class Expr : public ObjectRef { +/*! + * \brief Container of all primitive expressions. + * \sa PrimExprNode + */ +class PrimExpr : public ObjectRef { public: - Expr() {} - explicit Expr(ObjectPtr ptr) : ObjectRef(ptr) {} + PrimExpr() {} + explicit PrimExpr(ObjectPtr ptr) : ObjectRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. */ - TVM_DLL Expr(int32_t value); // NOLINT(*) + TVM_DLL PrimExpr(int32_t value); // NOLINT(*) /*! * \brief construct from float. * \param value The value to be constructed. */ - TVM_DLL Expr(float value); // NOLINT(*) + TVM_DLL PrimExpr(float value); // NOLINT(*) /*! * \brief construct from string. * \param str The value to be constructed. */ - TVM_DLL Expr(std::string str); // NOLINT(*) + TVM_DLL PrimExpr(std::string str); // NOLINT(*) /*! \return the data type of this expression. */ DataType dtype() const { - return static_cast(get())->dtype; + return static_cast(get())->dtype; } - /*! \brief type indicate the container type */ - using ContainerType = ExprNode; + using ContainerType = PrimExprNode; }; /*! \brief Base node of all statements. */ @@ -102,7 +115,7 @@ class Var; * - Let * - LetStmt */ -class VarNode : public ExprNode { +class VarNode : public PrimExprNode { public: /*! * \brief The hint to the variable name. @@ -118,13 +131,13 @@ class VarNode : public ExprNode { } static constexpr const char* _type_key = "Variable"; - TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, PrimExprNode); }; /*! \brief a named variable in TVM */ -class Var : public Expr { +class Var : public PrimExpr { public: - explicit Var(ObjectPtr n) : Expr(n) {} + explicit Var(ObjectPtr n) : PrimExpr(n) {} TVM_DLL explicit Var(std::string name_hint = "v", DataType t = DataType::Int(32)); /*! @@ -153,15 +166,9 @@ class Var : public Expr { using ContainerType = VarNode; }; -// Backward compatibility, will be removed later. -using VarExpr = Var; -using BaseExprNode = ExprNode; -using ExprHash = ObjectHash; -using ExprEqual = ObjectEqual; - class Integer; /*! \brief ExprNode: constant integer. */ -class IntImmNode : public ExprNode { +class IntImmNode : public PrimExprNode { public: /*! \brief the Internal value. */ int64_t value; @@ -174,7 +181,7 @@ class IntImmNode : public ExprNode { TVM_DLL static Integer make(DataType t, int64_t value); static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); }; /*! @@ -183,17 +190,17 @@ class IntImmNode : public ExprNode { * This is used to store and automate type check * attributes that must be constant integer. */ -class Integer : public Expr { +class Integer : public PrimExpr { public: - Integer() : Expr() {} + Integer() : PrimExpr() {} /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : Expr(node) {} + explicit Integer(ObjectPtr node) : PrimExpr(node) {} /*! * \brief Construct integer from int value. */ - Integer(int value) : Expr(value) {} // NOLINT(*) + Integer(int value) : PrimExpr(value) {} // NOLINT(*) /*! * \brief Assign an expression to integer. * \param other another expression. @@ -225,12 +232,12 @@ class Integer : public Expr { class RangeNode : public Object { public: /*! \brief beginning of the node */ - Expr min; + PrimExpr min; /*! \brief the extend of range */ - Expr extent; + PrimExpr extent; /*! \brief constructor */ RangeNode() {} - RangeNode(Expr min, Expr extent) : min(min), extent(extent) {} + RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {} void VisitAttrs(AttrVisitor* v) { v->Visit("min", &min); @@ -249,7 +256,7 @@ class Range : public ObjectRef { * \param begin The begin of the range. * \param end The end of the range. */ - TVM_DLL Range(Expr begin, Expr end); + TVM_DLL Range(PrimExpr begin, PrimExpr end); /*! * \brief construct a new range with min and extent * The corresponding constructor is removed, @@ -259,7 +266,7 @@ class Range : public ObjectRef { * \param min The minimum range. * \param extent The extent of the range. */ - static Range make_by_min_extent(Expr min, Expr extent); + static Range make_by_min_extent(PrimExpr min, PrimExpr extent); // declare range. TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); }; @@ -357,7 +364,7 @@ class IterVar : public ObjectRef { /*! * \return the corresponding var in the IterVar. */ - inline operator Expr() const; + inline operator PrimExpr() const; /*! \brief specify container node */ using ContainerType = IterVarNode; }; @@ -428,7 +435,7 @@ inline const IterVarNode* IterVar::operator->() const { return static_cast(data_.get()); } -inline IterVar::operator Expr() const { +inline IterVar::operator PrimExpr() const { return (*this)->var; } diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index bf8b1a3f840f..2d8f37855856 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -44,19 +44,19 @@ namespace tvm { */ template::value>::type> -inline Expr make_const(DataType t, ValueType value); +inline PrimExpr make_const(DataType t, ValueType value); /*! * \brief Make a const zero expr. * \param t The target type. * \return the result expression. */ -inline Expr make_zero(DataType t); +inline PrimExpr make_zero(DataType t); /*! * \brief Make a constant true expression. * \param lanes The number of lanes in the bool * \return The result expression. */ -inline Expr const_true(int lanes = 1) { +inline PrimExpr const_true(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 1); } /*! @@ -64,7 +64,7 @@ inline Expr const_true(int lanes = 1) { * \param lanes The number of lanes in the bool * \return The result expression. */ -inline Expr const_false(int lanes = 1) { +inline PrimExpr const_false(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 0); } /*! @@ -73,7 +73,7 @@ inline Expr const_false(int lanes = 1) { * \return the address to the int expression, * return nullptr, if x is not IntImm. */ -inline const int64_t* as_const_int(const Expr& x) { +inline const int64_t* as_const_int(const PrimExpr& x) { if (!x.defined()) return nullptr; if (const ir::IntImmNode* op = x.as()) { return &(op->value); @@ -88,7 +88,7 @@ inline const int64_t* as_const_int(const Expr& x) { * \return the address to the int expression, * return nullptr, if x is not UIntImm. */ -inline const uint64_t* as_const_uint(const Expr& x) { +inline const uint64_t* as_const_uint(const PrimExpr& x) { if (!x.defined()) return nullptr; if (const ir::UIntImmNode* op = x.as()) { return &(op->value); @@ -103,7 +103,7 @@ inline const uint64_t* as_const_uint(const Expr& x) { * \param value the value to be compared against. * \return whether x is constant expression. */ -inline bool is_const_int(const Expr& x, int64_t value); +inline bool is_const_int(const PrimExpr& x, int64_t value); /*! * \brief Check whether stmt is nop. @@ -118,7 +118,7 @@ inline bool is_no_op(const Stmt& stmt); * \note This only return true for integer types. * \return whether x is constant 1 */ -inline bool is_one(const Expr& x) { +inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); } @@ -128,7 +128,7 @@ inline bool is_one(const Expr& x) { * \return whether x is constant 0 * \note This only return true for integer types. */ -inline bool is_zero(const Expr& x) { +inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } @@ -137,21 +137,21 @@ inline bool is_zero(const Expr& x) { * \note This only return true for integer types. * \return whether x is constant */ -inline bool is_const(const Expr& x); +inline bool is_const(const PrimExpr& x); /*! * Query the maximum possible value of dtype. * \param dtype The data type. * \return the maximum possible value in this format. */ -TVM_DLL Expr max_value(const DataType& dtype); +TVM_DLL PrimExpr max_value(const DataType& dtype); /*! * Query the minimum possible value of dtype. * \param dtype The data type. * \return the minimum possible value in this format. */ -TVM_DLL Expr min_value(const DataType& dtype); +TVM_DLL PrimExpr min_value(const DataType& dtype); /*! * \brief Check whether x is a constant power of two @@ -161,7 +161,7 @@ TVM_DLL Expr min_value(const DataType& dtype); * \param shift The output shift if x is power of two. * \return whether x is constant power of two */ -TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift); +TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift); /*! * \brief cast value to type. @@ -171,7 +171,7 @@ TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL Expr cast(const DataType& t, Expr value); +TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value); /*! * \brief perform reinterpret cast value to type. * @@ -180,7 +180,7 @@ TVM_DLL Expr cast(const DataType& t, Expr value); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL Expr reinterpret(const DataType& t, Expr value); +TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value); /*! * \brief add operator * @@ -190,7 +190,7 @@ TVM_DLL Expr reinterpret(const DataType& t, Expr value); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator+(Expr a, Expr b); +TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b); /*! * \brief subtraction operator * @@ -200,7 +200,7 @@ TVM_DLL Expr operator+(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator-(Expr a, Expr b); +TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b); /*! * \brief negation. * @@ -209,7 +209,7 @@ TVM_DLL Expr operator-(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator-(Expr a); +TVM_DLL PrimExpr operator-(PrimExpr a); /*! * \brief multiplication operator * @@ -219,7 +219,7 @@ TVM_DLL Expr operator-(Expr a); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator*(Expr a, Expr b); +TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b); /*! * \brief division operator * @@ -229,7 +229,7 @@ TVM_DLL Expr operator*(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator/(Expr a, Expr b); +TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b); /*! * \brief left shift operator * @@ -239,7 +239,7 @@ TVM_DLL Expr operator/(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator<<(Expr a, Expr b); +TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b); /*! * \brief right shift operator * @@ -249,7 +249,7 @@ TVM_DLL Expr operator<<(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator>>(Expr a, Expr b); +TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b); /*! * \brief greater * @@ -259,7 +259,7 @@ TVM_DLL Expr operator>>(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator>(Expr a, Expr b); +TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b); /*! * \brief greater_equal * @@ -269,7 +269,7 @@ TVM_DLL Expr operator>(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator>=(Expr a, Expr b); +TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b); /*! * \brief less * @@ -279,7 +279,7 @@ TVM_DLL Expr operator>=(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator<(Expr a, Expr b); +TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b); /*! * \brief less_equal * @@ -289,7 +289,7 @@ TVM_DLL Expr operator<(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator<=(Expr a, Expr b); +TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b); /*! * \brief equal * @@ -299,7 +299,7 @@ TVM_DLL Expr operator<=(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator==(Expr a, Expr b); +TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b); /*! * \brief not_equal * @@ -309,7 +309,7 @@ TVM_DLL Expr operator==(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator!=(Expr a, Expr b); +TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b); /*! * \brief and * @@ -318,7 +318,7 @@ TVM_DLL Expr operator!=(Expr a, Expr b); * \return The result expression. * \note This operator does eager constant folding. */ -TVM_DLL Expr operator&&(Expr a, Expr b); +TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b); /*! * \brief or * @@ -327,7 +327,7 @@ TVM_DLL Expr operator&&(Expr a, Expr b); * \return The result expression. * \note This operator does eager constant folding. */ -TVM_DLL Expr operator||(Expr a, Expr b); +TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b); /*! * \brief not * @@ -335,7 +335,7 @@ TVM_DLL Expr operator||(Expr a, Expr b); * \return The result expression. * \note This operator does eager constant folding. */ -TVM_DLL Expr operator!(Expr a); +TVM_DLL PrimExpr operator!(PrimExpr a); /*! * \brief compute division in C semantics. * @@ -349,7 +349,7 @@ TVM_DLL Expr operator!(Expr a); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr div(Expr a, Expr b); +TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b); /*! * \brief compute trunc(a / b) * @@ -361,7 +361,7 @@ TVM_DLL Expr div(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr truncdiv(Expr a, Expr b); +TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b); /*! * \brief compute the remainder of truncdiv * @@ -373,7 +373,7 @@ TVM_DLL Expr truncdiv(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr truncmod(Expr a, Expr b); +TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b); /*! * \brief compute floor(a / b) where a and b are non-negative. * @@ -388,7 +388,7 @@ TVM_DLL Expr truncmod(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr indexdiv(Expr a, Expr b); +TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b); /*! * \brief compute the remainder floor(a / b) where a and b are non-negative. * @@ -402,7 +402,7 @@ TVM_DLL Expr indexdiv(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr indexmod(Expr a, Expr b); +TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b); /*! * \brief compute floor(a / b) * @@ -412,7 +412,7 @@ TVM_DLL Expr indexmod(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr floordiv(Expr a, Expr b); +TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b); /*! * \brief compute the remainder of floordiv * @@ -422,7 +422,7 @@ TVM_DLL Expr floordiv(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr floormod(Expr a, Expr b); +TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b); /*! * \brief take maximum of two values * @@ -432,7 +432,7 @@ TVM_DLL Expr floormod(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr max(Expr a, Expr b); +TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b); /*! * \brief take minimum of two values * @@ -442,7 +442,7 @@ TVM_DLL Expr max(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr min(Expr a, Expr b); +TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b); /*! * \brief take bitwise and of two values * @@ -452,7 +452,7 @@ TVM_DLL Expr min(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator&(Expr a, Expr b); +TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b); /*! * \brief take bitwise or of two values * @@ -462,7 +462,7 @@ TVM_DLL Expr operator&(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator|(Expr a, Expr b); +TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b); /*! * \brief take bitwise xor of two values * @@ -472,7 +472,7 @@ TVM_DLL Expr operator|(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator^(Expr a, Expr b); +TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b); /*! * \brief take bitwise negation of two values * @@ -481,7 +481,7 @@ TVM_DLL Expr operator^(Expr a, Expr b); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr operator~(Expr a); +TVM_DLL PrimExpr operator~(PrimExpr a); /*! * \brief Conditional expression. * @@ -492,95 +492,95 @@ TVM_DLL Expr operator~(Expr a); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ -TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value); +TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value); /*! * \brief Mark condition as likely. * \param cond The condition * \return The marked expression. */ -TVM_DLL Expr likely(Expr cond); +TVM_DLL PrimExpr likely(PrimExpr cond); /*! * \brief Calculate power(x, y) * \param x The left operand. * \param y The right operand. */ -TVM_DLL Expr pow(Expr x, Expr y); +TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y); /*! * \brief Calculate absolute value of x. * \param x The input data * * \return The aboslute value of input data x */ -TVM_DLL Expr abs(Expr x); +TVM_DLL PrimExpr abs(PrimExpr x); /*! * \brief Check if x is NaN. * \param x The input data * \return The result expression. */ -TVM_DLL Expr isnan(Expr x); +TVM_DLL PrimExpr isnan(PrimExpr x); /*! * \brief sum of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL Expr sum(Expr source, Array axis); +TVM_DLL PrimExpr sum(PrimExpr source, Array axis); /*! * \brief logical And of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL Expr all(Expr source, Array axis); +TVM_DLL PrimExpr all(PrimExpr source, Array axis); /*! * \brief logical Or of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL Expr any(Expr source, Array axis); +TVM_DLL PrimExpr any(PrimExpr source, Array axis); /*! * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL Expr max(Expr source, Array axis); +TVM_DLL PrimExpr max(PrimExpr source, Array axis); /*! * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL Expr min(Expr source, Array axis); +TVM_DLL PrimExpr min(PrimExpr source, Array axis); /*! * \brief product of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. */ -TVM_DLL Expr prod(Expr source, Array axis); +TVM_DLL PrimExpr prod(PrimExpr source, Array axis); /*! * \brief Calculate floor(x) * \param x The input expression. * \return The result expression. */ -TVM_DLL Expr floor(Expr x); +TVM_DLL PrimExpr floor(PrimExpr x); /*! * \brief Calculate ceil(x) * \param x The input expression. * \return The result expression. */ -TVM_DLL Expr ceil(Expr x); +TVM_DLL PrimExpr ceil(PrimExpr x); /*! * \brief Calculate round(x) * \param x The input expression. * \return The result expression. */ -TVM_DLL Expr round(Expr x); +TVM_DLL PrimExpr round(PrimExpr x); /*! * \brief Calculates std::nearbyint(x) @@ -588,20 +588,20 @@ TVM_DLL Expr round(Expr x); * \return The result expression. * This is a faster alternate to round. */ -TVM_DLL Expr nearbyint(Expr x); +TVM_DLL PrimExpr nearbyint(PrimExpr x); /*! * \brief Calculate trunc(x) * \param x The input expression. * \return The result expression. */ -TVM_DLL Expr trunc(Expr x); +TVM_DLL PrimExpr trunc(PrimExpr x); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline Expr OpName(Expr x) { \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \ - } \ + } \ TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(erf); @@ -616,11 +616,11 @@ TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(atan); // Implementation details after this -inline bool is_const(const Expr& x) { +inline bool is_const(const PrimExpr& x) { if (x.as() || x.as()) { return true; } else if (const auto* op = x.as()) { - const Expr& val = op->value; + const PrimExpr& val = op->value; if (val.as() || val.as()) { return true; } @@ -628,7 +628,7 @@ inline bool is_const(const Expr& x) { return false; } -inline bool is_positive_const(const Expr& a) { +inline bool is_positive_const(const PrimExpr& a) { if (const ir::IntImmNode* op = a.as()) { return op->value > 0; } else if (const ir::UIntImmNode* op = a.as()) { @@ -638,7 +638,7 @@ inline bool is_positive_const(const Expr& a) { } } -inline bool is_negative_const(const Expr& a) { +inline bool is_negative_const(const PrimExpr& a) { if (const ir::IntImmNode* op = a.as()) { return op->value < 0; } else { @@ -646,13 +646,13 @@ inline bool is_negative_const(const Expr& a) { } } -inline bool is_const_int(const Expr& x, int64_t value) { +inline bool is_const_int(const PrimExpr& x, int64_t value) { if (const auto* op = x.as()) { return op->value == value; } else if (const auto* op = x.as()) { return op->value == static_cast(value); } else if (const auto* op = x.as()) { - const Expr& val = op->value; + const PrimExpr& val = op->value; if (const auto* opv = val.as()) { return opv->value == value; } else if (const auto* opv = val.as()) { @@ -674,7 +674,7 @@ inline bool is_no_op(const Stmt& stmt) { } template -inline Expr MakeConstScalar(DataType t, ValueType value) { +inline PrimExpr MakeConstScalar(DataType t, ValueType value) { if (t.is_int()) return ir::IntImmNode::make(t, static_cast(value)); if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast(value)); if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); @@ -685,11 +685,11 @@ inline Expr MakeConstScalar(DataType t, ValueType value) { if (static_cast(t.code()) >= static_cast(kCustomBegin)) return ir::FloatImmNode::make(t, static_cast(value)); LOG(FATAL) << "cannot make const for type " << t; - return Expr(); + return PrimExpr(); } template -inline Expr make_const(DataType t, ValueType value) { +inline PrimExpr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { @@ -698,7 +698,7 @@ inline Expr make_const(DataType t, ValueType value) { } } -inline Expr make_zero(DataType t) { +inline PrimExpr make_zero(DataType t) { if (t.is_handle()) { return reinterpret(t, make_const(DataType::UInt(64), 0)); } @@ -706,43 +706,43 @@ inline Expr make_zero(DataType t) { } // additional const expression overloading -#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ - inline Expr Name(Expr& a, Expr b) { \ - a = OpFunc(a, b); \ - return a; \ +#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ + inline PrimExpr Name(PrimExpr& a, PrimExpr b) {\ + a = OpFunc(a, b); \ + return a; \ } #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ - inline Expr Name(const Expr& a, float b) { \ - return Name(a, Expr(b)); \ + inline PrimExpr Name(const PrimExpr& a, float b) { \ + return Name(a, PrimExpr(b)); \ } \ - inline Expr Name(float a, const Expr& b) { \ - return Name(Expr(a), b); \ + inline PrimExpr Name(float a, const PrimExpr& b) { \ + return Name(PrimExpr(a), b); \ } \ - inline Expr Name(int a, const Expr& b) { \ + inline PrimExpr Name(int a, const PrimExpr& b) { \ return Name(make_const(b.dtype(), a), b); \ } \ - inline Expr Name(const Expr& a, int b) { \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ return Name(a, make_const(a.dtype(), b)); \ } \ - inline Expr Name(const Expr& a, double b) { \ - return Name(a, make_const(DataType::Float(64), b)); \ + inline PrimExpr Name(const PrimExpr& a, double b) {\ + return Name(a, make_const(DataType::Float(64), b)); \ } -#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ - inline Expr Name(const Expr& a, bool b) { \ - return Name(a, Expr(b)); \ - } \ - inline Expr Name(bool a, const Expr& b) { \ - return Name(Expr(a), b); \ +#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, bool b) { \ + return Name(a, PrimExpr(b)); \ + } \ + inline PrimExpr Name(bool a, const PrimExpr& b) { \ + return Name(PrimExpr(a), b); \ } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline Expr Name(const Expr& a, int b) { \ - return Name(a, make_const(a.dtype(), b)); \ - } \ - inline Expr Name(int a, const Expr& b) { \ - return Name(make_const(b.dtype(), a), b); \ +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(int a, const PrimExpr& b) { \ + return Name(make_const(b.dtype(), a), b); \ } @@ -798,19 +798,19 @@ inline void DivAmbiguityError(const TA& a) { // The second template argument is necessary to make sure the // code compiles lazily by the compiler during invocation. template -inline Expr operator/(const Expr& a, const TB& b) { +inline PrimExpr operator/(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } template -inline Expr operator/=(const Expr& a, const TB& b) { +inline PrimExpr operator/=(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } template -inline Expr operator%(const Expr& a, const TB& b) { +inline PrimExpr operator%(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 11ce09d2a697..84039485ae69 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -40,7 +40,7 @@ using IntImmNode = tvm::IntImmNode; using VarNode = tvm::VarNode; /*! \brief constant unsigned integer. */ -class UIntImmNode : public ExprNode { +class UIntImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ uint64_t value; @@ -50,14 +50,14 @@ class UIntImmNode : public ExprNode { v->Visit("value", &value); } - TVM_DLL static Expr make(DataType t, uint64_t value); + TVM_DLL static PrimExpr make(DataType t, uint64_t value); static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode); }; /*! \brief Floating point constants. */ -class FloatImmNode : public ExprNode { +class FloatImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ double value; @@ -67,14 +67,14 @@ class FloatImmNode : public ExprNode { v->Visit("value", &value); } - TVM_DLL static Expr make(DataType t, double value); + TVM_DLL static PrimExpr make(DataType t, double value); static constexpr const char* _type_key = "FloatImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); }; /*! \brief String constants, only used in asserts. */ -class StringImmNode : public ExprNode { +class StringImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ std::string value; @@ -84,30 +84,30 @@ class StringImmNode : public ExprNode { v->Visit("value", &value); } - TVM_DLL Expr static make(std::string value); + TVM_DLL PrimExpr static make(std::string value); static constexpr const char* _type_key = "StringImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); }; /*! * \brief Cast value from one data type to another. * \note The lanes of value should keep fixed. */ -class CastNode : public ExprNode { +class CastNode : public PrimExprNode { public: /*! \brief Original data type. */ - Expr value; + PrimExpr value; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("value", &value); } - TVM_DLL static Expr make(DataType t, Expr v); + TVM_DLL static PrimExpr make(DataType t, PrimExpr v); static constexpr const char* _type_key = "Cast"; - TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); }; /*! @@ -115,12 +115,12 @@ class CastNode : public ExprNode { * \tparam T The type of the child class. */ template -class BinaryOpNode : public ExprNode { +class BinaryOpNode : public PrimExprNode { public: /*! \brief The left operand. */ - Expr a; + PrimExpr a; /*! \brief The right operand. */ - Expr b; + PrimExpr b; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); @@ -128,7 +128,7 @@ class BinaryOpNode : public ExprNode { v->Visit("b", &b); } - static Expr make(Expr a, Expr b) { + static PrimExpr make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; @@ -136,10 +136,10 @@ class BinaryOpNode : public ExprNode { node->dtype = a.dtype(); node->a = std::move(a); node->b = std::move(b); - return Expr(node); + return PrimExpr(node); } - TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; /*! \brief a + b */ @@ -207,12 +207,12 @@ class MaxNode : public BinaryOpNode { * \tparam T The type of the child class. */ template -class CmpOpNode : public ExprNode { +class CmpOpNode : public PrimExprNode { public: /*! \brief The left operand. */ - Expr a; + PrimExpr a; /*! \brief The right operand. */ - Expr b; + PrimExpr b; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); @@ -220,7 +220,7 @@ class CmpOpNode : public ExprNode { v->Visit("b", &b); } - static Expr make(Expr a, Expr b) { + static PrimExpr make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; @@ -228,10 +228,10 @@ class CmpOpNode : public ExprNode { node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); - return Expr(node); + return PrimExpr(node); } - TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; /*! \brief a == b */ @@ -271,12 +271,12 @@ class GENode : public CmpOpNode { }; /*! \brief a && b */ -class AndNode : public ExprNode { +class AndNode : public PrimExprNode { public: /*! \brief The left operand. */ - Expr a; + PrimExpr a; /*! \brief The right operand. */ - Expr b; + PrimExpr b; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); @@ -284,19 +284,19 @@ class AndNode : public ExprNode { v->Visit("b", &b); } - TVM_DLL static Expr make(Expr a, Expr b); + TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); static constexpr const char* _type_key = "And"; - TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); }; /*! \brief a || b */ -class OrNode : public ExprNode { +class OrNode : public PrimExprNode { public: /*! \brief The left operand. */ - Expr a; + PrimExpr a; /*! \brief The right operand. */ - Expr b; + PrimExpr b; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -304,27 +304,27 @@ class OrNode : public ExprNode { v->Visit("b", &b); } - TVM_DLL static Expr make(Expr a, Expr b); + TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); static constexpr const char* _type_key = "Or"; - TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); }; /*! \brief !a */ -class NotNode : public ExprNode { +class NotNode : public PrimExprNode { public: /*! \brief The input operand. */ - Expr a; + PrimExpr a; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("a", &a); } - TVM_DLL static Expr make(Expr a); + TVM_DLL static PrimExpr make(PrimExpr a); static constexpr const char* _type_key = "Not"; - TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); }; /*! @@ -334,14 +334,14 @@ class NotNode : public ExprNode { * Do not use it to guard against out of bound access, * please use if_then_else instead. */ -class SelectNode : public ExprNode { +class SelectNode : public PrimExprNode { public: /*! \brief The condition */ - Expr condition; + PrimExpr condition; /*! \brief value to be returned when condition is true. */ - Expr true_value; + PrimExpr true_value; /*! \brief value to be returned when condition is false. */ - Expr false_value; + PrimExpr false_value; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -350,10 +350,10 @@ class SelectNode : public ExprNode { v->Visit("false_value", &false_value); } - TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value); + TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); static constexpr const char* _type_key = "Select"; - TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); }; /*! @@ -371,14 +371,14 @@ class SelectNode : public ExprNode { * * \endcode */ -class LoadNode : public ExprNode { +class LoadNode : public PrimExprNode { public: /*! \brief The buffer variable. */ Var buffer_var; /*! \brief The index locations to be loaded. */ - Expr index; + PrimExpr index; /*! \brief The predicate to mask which lanes would be loaded. */ - Expr predicate; + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -387,10 +387,10 @@ class LoadNode : public ExprNode { v->Visit("predicate", &predicate); } - TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate); + TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); static constexpr const char* _type_key = "Load"; - TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode); }; /*! @@ -402,12 +402,12 @@ class LoadNode : public ExprNode { * - ramp(0, 1, 3) = [0, 1, 2] * - ramp(1, 2, 4) = [1, 3, 5, 7] */ -class RampNode : public ExprNode { +class RampNode : public PrimExprNode { public: /*! \brief The base value. */ - Expr base; + PrimExpr base; /*! \brief The stride of each step. */ - Expr stride; + PrimExpr stride; /*! \brief Total number of lanes. */ int lanes; @@ -418,17 +418,17 @@ class RampNode : public ExprNode { v->Visit("lanes", &lanes); } - TVM_DLL static Expr make(Expr base, Expr stride, int lanes); + TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes); static constexpr const char* _type_key = "Ramp"; - TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); }; /*! \brief Create a vector where all the elements are value. */ -class BroadcastNode : public ExprNode { +class BroadcastNode : public PrimExprNode { public: /*! \brief The base value. */ - Expr value; + PrimExpr value; /*! \brief The number of lanes. */ int lanes; @@ -438,23 +438,23 @@ class BroadcastNode : public ExprNode { v->Visit("lanes", &lanes); } - TVM_DLL static Expr make(Expr value, int lanes); + TVM_DLL static PrimExpr make(PrimExpr value, int lanes); static constexpr const char* _type_key = "Broadcast"; - TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); }; /*! * \brief Let binding. Bind var to value then evaluate body. */ -class LetNode : public ExprNode { +class LetNode : public PrimExprNode { public: /*! \brief The variable. */ Var var; /*! \brief The value to be binded. */ - Expr value; + PrimExpr value; /*! \brief The result expression. */ - Expr body; + PrimExpr body; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -463,10 +463,10 @@ class LetNode : public ExprNode { v->Visit("body", &body); } - TVM_DLL static Expr make(Var var, Expr value, Expr body); + TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body); static constexpr const char* _type_key = "Let"; - TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); }; // Call node, represent a function call or a multi-dimensional array load. @@ -494,7 +494,7 @@ class FunctionRef : public ObjectRef { /*! * \brief Call node. */ -class CallNode : public ExprNode { +class CallNode : public PrimExprNode { public: /*! \brief Possible types of calls. */ enum CallType : int { @@ -514,7 +514,7 @@ class CallNode : public ExprNode { /*! \brief The name of the function/intrinsic. */ std::string name; /*! \brief The arguments. */ - Array args; + Array args; /*! \brief Type of calls. */ CallType call_type; /*! \brief The function to be called. */ @@ -531,9 +531,9 @@ class CallNode : public ExprNode { v->Visit("value_index", &value_index); } - TVM_DLL static Expr make(DataType dtype, + TVM_DLL static PrimExpr make(DataType dtype, std::string name, - Array args, + Array args, CallType call_type, FunctionRef func = FunctionRef(), int value_index = 0); @@ -560,7 +560,7 @@ class CallNode : public ExprNode { bool is_vectorizable() const; static constexpr const char* _type_key = "Call"; - TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); // Build-in intrinsics static constexpr const char* reinterpret = "reinterpret"; @@ -585,24 +585,24 @@ class CallNode : public ExprNode { * vec = concat(vectors) * result = (vec[indices[0]], vec[indices[1]] ...) */ -class ShuffleNode : public ExprNode { +class ShuffleNode : public PrimExprNode { public: /*! \brief the input vectors. */ - Array vectors; + Array vectors; /*! \brief The indices of each element. */ - Array indices; + Array indices; void VisitAttrs(AttrVisitor* v) { v->Visit("vectors", &vectors); v->Visit("indices", &indices); } - TVM_DLL static Expr make(Array vectors, Array indices); - TVM_DLL static Expr make_concat(Array vectors); - TVM_DLL static Expr make_extract_element(Expr vector, int index); + TVM_DLL static PrimExpr make(Array vectors, Array indices); + TVM_DLL static PrimExpr make_concat(Array vectors); + TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index); static constexpr const char* _type_key = "Shuffle"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); }; // Reduce operator @@ -637,20 +637,20 @@ class CommReducerNode : public Object { /*! \brief The right argument of reducer */ Array rhs; /*! \brief The result of reducer */ - Array result; + Array result; /*! * \brief The identity element of reducer, which leaves other * elements unchanged when combined with it, with respect to * the binary operation of this reducer uses. */ - Array identity_element; + Array identity_element; /*! \brief Function call operator to combine a and b */ - Array operator()(Array a, Array b) const; + Array operator()(Array a, Array b) const; /*! \brief construct CommReducer from args, result and identity_element */ TVM_DLL static CommReducer make(Array lhs, Array rhs, - Array result, - Array identity_element); + Array result, + Array identity_element); void VisitAttrs(AttrVisitor* v) { v->Visit("lhs", &lhs); @@ -671,27 +671,27 @@ inline const CommReducerNode* CommReducer::operator->() const { } /*! \brief Reduction operator operator */ -class ReduceNode : public ExprNode { +class ReduceNode : public PrimExprNode { public: /*! \brief The commutative combiner */ CommReducer combiner; /*! \brief The source operand */ - Array source; + Array source; /*! \brief The reduction axis */ Array axis; /*! * \brief Predicate on the reduction * Only add the body to reduction if condition is true. */ - Expr condition; + PrimExpr condition; /*! \brief the index of this reduce node */ int value_index; /*! \brief construct expr from op and rdom */ - TVM_DLL static Expr make(CommReducer combiner, - Array src, + TVM_DLL static PrimExpr make(CommReducer combiner, + Array src, Array rdom, - Expr condition, + PrimExpr condition, int value_index); void VisitAttrs(AttrVisitor* v) { @@ -704,11 +704,11 @@ class ReduceNode : public ExprNode { } static constexpr const char* _type_key = "Reduce"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); }; /*! \brief Any shape. */ -class AnyNode : public ExprNode { +class AnyNode : public PrimExprNode { public: void VisitAttrs(AttrVisitor* v) {} /*! \brief Convert to var. */ @@ -716,10 +716,10 @@ class AnyNode : public ExprNode { return VarNode::make(DataType::Int(32), "any_dim"); } - TVM_DLL static Expr make(); + TVM_DLL static PrimExpr make(); static constexpr const char* _type_key = "Any"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; // Statements @@ -731,7 +731,7 @@ class LetStmtNode : public StmtNode { /*! \brief The variable. */ Var var; /*! \brief The value to be binded. */ - Expr value; + PrimExpr value; /*! \brief The body block. */ Stmt body; @@ -741,7 +741,7 @@ class LetStmtNode : public StmtNode { v->Visit("body", &body); } - TVM_DLL static Stmt make(Var var, Expr value, Stmt body); + TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body); static constexpr const char* _type_key = "LetStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); @@ -764,7 +764,7 @@ class AttrStmtNode : public StmtNode { /*! \brief the type key of the attribute */ std::string attr_key; /*! \brief The attribute value, value is well defined at current scope. */ - Expr value; + PrimExpr value; /*! \brief The body statement to be executed */ Stmt body; @@ -777,7 +777,7 @@ class AttrStmtNode : public StmtNode { TVM_DLL static Stmt make(ObjectRef node, std::string type_key, - Expr value, + PrimExpr value, Stmt body); static constexpr const char* _type_key = "AttrStmt"; @@ -790,9 +790,9 @@ class AttrStmtNode : public StmtNode { class AssertStmtNode : public StmtNode { public: /*! \brief Condition to be checked. */ - Expr condition; + PrimExpr condition; /*! \brief Error message when assertion failed. */ - Expr message; + PrimExpr message; /*! * \brief Body which this assertion holds true. * Will be executed after the assertion. @@ -805,7 +805,7 @@ class AssertStmtNode : public StmtNode { v->Visit("body", &body); } - TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body); + TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body); static constexpr const char* _type_key = "AssertStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); @@ -857,11 +857,11 @@ class StoreNode : public StmtNode { /*! \brief The buffer variable. */ Var buffer_var; /*! \brief The value to be stored. */ - Expr value; + PrimExpr value; /*! \brief The index locations to be stored. */ - Expr index; + PrimExpr index; /*! \brief The predicate to mask which lanes would be stored. */ - Expr predicate; + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); @@ -871,9 +871,9 @@ class StoreNode : public StmtNode { } TVM_DLL static Stmt make(Var buffer_var, - Expr value, - Expr index, - Expr predicate); + PrimExpr value, + PrimExpr index, + PrimExpr predicate); static constexpr const char* _type_key = "Store"; TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); @@ -889,9 +889,9 @@ class ProvideNode : public StmtNode { /*! \brief The output value index if func's value is a tuple. */ int value_index{0}; /*! \brief The value to be stored. */ - Expr value; + PrimExpr value; /*! \brief The index arguments of the function. */ - Array args; + Array args; void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); @@ -902,8 +902,8 @@ class ProvideNode : public StmtNode { TVM_DLL static Stmt make(FunctionRef func, int value_index, - Expr value, - Array args); + PrimExpr value, + Array args); static constexpr const char* _type_key = "Provide"; TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); @@ -919,14 +919,14 @@ class AllocateNode : public StmtNode { /*! \brief The type of the buffer. */ DataType dtype; /*! \brief The extents of the buffer. */ - Array extents; + Array extents; /*! \brief Only allocate buffer when condition is satisfied. */ - Expr condition; + PrimExpr condition; /*! \brief The body to be executed. */ Stmt body; // The following two fields are deprecated // kept for backward compatibility and will be refactored later. - Expr new_expr; + PrimExpr new_expr; std::string free_function; void VisitAttrs(AttrVisitor* v) { @@ -939,10 +939,10 @@ class AllocateNode : public StmtNode { TVM_DLL static Stmt make(Var buffer_var, DataType dtype, - Array extents, - Expr condition, + Array extents, + PrimExpr condition, Stmt body, - Expr new_expr = Expr(), + PrimExpr new_expr = PrimExpr(), std::string free_function = std::string()); /*! @@ -960,7 +960,7 @@ class AllocateNode : public StmtNode { * \return The result. */ TVM_DLL static int32_t constant_allocation_size( - const Array& extents); + const Array& extents); static constexpr const char* _type_key = "Allocate"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); @@ -997,7 +997,7 @@ class RealizeNode : public StmtNode { /*! \brief Bounds to be realized. */ Region bounds; /*! \brief Only realize if condition holds. */ - Expr condition; + PrimExpr condition; /*! \brief The body of realization. */ Stmt body; @@ -1014,7 +1014,7 @@ class RealizeNode : public StmtNode { int value_index, DataType dtype, Region bounds, - Expr condition, + PrimExpr condition, Stmt body); static constexpr const char* _type_key = "Realize"; @@ -1136,7 +1136,7 @@ class SeqStmt : public Stmt { class IfThenElseNode : public StmtNode { public: /*! \brief The condition. */ - Expr condition; + PrimExpr condition; /*! \brief The branch to be executed when condition is true. */ Stmt then_case; /*! \brief The branch to be executed when condition is false, can be null. */ @@ -1148,7 +1148,7 @@ class IfThenElseNode : public StmtNode { v->Visit("else_case", &else_case); } - TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt()); + TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); static constexpr const char* _type_key = "IfThenElse"; TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); @@ -1163,13 +1163,13 @@ class IfThenElseNode : public StmtNode { class EvaluateNode : public StmtNode { public: /*! \brief The expression to be evaluated. */ - Expr value; + PrimExpr value; void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); } - TVM_DLL static Stmt make(Expr v); + TVM_DLL static Stmt make(PrimExpr v); static constexpr const char* _type_key = "Evaluate"; TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); @@ -1209,9 +1209,9 @@ class ForNode : public StmtNode { /*! \brief The loop variable. */ Var loop_var; /*! \brief The minimum value of iteration. */ - Expr min; + PrimExpr min; /*! \brief The extent of the iteration. */ - Expr extent; + PrimExpr extent; /*! \brief The type of the for loop. */ ForType for_type; /*! @@ -1223,8 +1223,8 @@ class ForNode : public StmtNode { Stmt body; TVM_DLL static Stmt make(Var loop_var, - Expr min, - Expr extent, + PrimExpr min, + PrimExpr extent, ForType for_type, DeviceAPI device_api, Stmt body); @@ -1707,7 +1707,7 @@ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; * \param dtype The data type * \return Expr a expression with dtype. */ -inline Expr TypeAnnotation(DataType dtype) { +inline PrimExpr TypeAnnotation(DataType dtype) { return ir::CallNode::make(dtype, "type_annotation", {}, ir::CallNode::PureIntrinsic); diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index d70c8dec7689..7d57564fd3df 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -102,9 +102,9 @@ class StmtFunctor; }); \ template -class ExprFunctor { +class ExprFunctor { private: - using TSelf = ExprFunctor; + using TSelf = ExprFunctor; using FType = NodeFunctor; public: @@ -118,7 +118,7 @@ class ExprFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Expr& n, Args... args) { + R operator()(const PrimExpr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } /*! @@ -127,7 +127,7 @@ class ExprFunctor { * \param args Additional arguments. * \return The result of the call */ - virtual R VisitExpr(const Expr& n, Args... args) { + virtual R VisitExpr(const PrimExpr& n, Args... args) { static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } @@ -291,7 +291,7 @@ class StmtFunctor { * \brief ExprVisitor */ class TVM_DLL ExprVisitor : - public ExprFunctor { + public ExprFunctor { public: using ExprFunctor::operator(); @@ -336,45 +336,45 @@ class TVM_DLL ExprVisitor : * \brief ExprMutator that mutates expressions. */ class TVM_DLL ExprMutator : - protected ExprFunctor { + protected ExprFunctor { public: using ExprFunctor::operator(); protected: using ExprFunctor::VisitExpr; // list of functions to override. - Expr VisitExpr_(const VarNode* op) override; - Expr VisitExpr_(const LoadNode* op) override; - Expr VisitExpr_(const LetNode* op) override; - Expr VisitExpr_(const CallNode* op) override; - Expr VisitExpr_(const AddNode* op) override; - Expr VisitExpr_(const SubNode* op) override; - Expr VisitExpr_(const MulNode* op) override; - Expr VisitExpr_(const DivNode* op) override; - Expr VisitExpr_(const ModNode* op) override; - Expr VisitExpr_(const FloorDivNode* op) override; - Expr VisitExpr_(const FloorModNode* op) override; - Expr VisitExpr_(const MinNode* op) override; - Expr VisitExpr_(const MaxNode* op) override; - Expr VisitExpr_(const EQNode* op) override; - Expr VisitExpr_(const NENode* op) override; - Expr VisitExpr_(const LTNode* op) override; - Expr VisitExpr_(const LENode* op) override; - Expr VisitExpr_(const GTNode* op) override; - Expr VisitExpr_(const GENode* op) override; - Expr VisitExpr_(const AndNode* op) override; - Expr VisitExpr_(const OrNode* op) override; - Expr VisitExpr_(const ReduceNode* op) override; - Expr VisitExpr_(const CastNode* op) override; - Expr VisitExpr_(const NotNode* op) override; - Expr VisitExpr_(const SelectNode* op) override; - Expr VisitExpr_(const RampNode* op) override; - Expr VisitExpr_(const BroadcastNode* op) override; - Expr VisitExpr_(const ShuffleNode* op) override; - Expr VisitExpr_(const IntImmNode* op) override; - Expr VisitExpr_(const UIntImmNode* op) override; - Expr VisitExpr_(const FloatImmNode* op) override; - Expr VisitExpr_(const StringImmNode* op) override; + PrimExpr VisitExpr_(const VarNode* op) override; + PrimExpr VisitExpr_(const LoadNode* op) override; + PrimExpr VisitExpr_(const LetNode* op) override; + PrimExpr VisitExpr_(const CallNode* op) override; + PrimExpr VisitExpr_(const AddNode* op) override; + PrimExpr VisitExpr_(const SubNode* op) override; + PrimExpr VisitExpr_(const MulNode* op) override; + PrimExpr VisitExpr_(const DivNode* op) override; + PrimExpr VisitExpr_(const ModNode* op) override; + PrimExpr VisitExpr_(const FloorDivNode* op) override; + PrimExpr VisitExpr_(const FloorModNode* op) override; + PrimExpr VisitExpr_(const MinNode* op) override; + PrimExpr VisitExpr_(const MaxNode* op) override; + PrimExpr VisitExpr_(const EQNode* op) override; + PrimExpr VisitExpr_(const NENode* op) override; + PrimExpr VisitExpr_(const LTNode* op) override; + PrimExpr VisitExpr_(const LENode* op) override; + PrimExpr VisitExpr_(const GTNode* op) override; + PrimExpr VisitExpr_(const GENode* op) override; + PrimExpr VisitExpr_(const AndNode* op) override; + PrimExpr VisitExpr_(const OrNode* op) override; + PrimExpr VisitExpr_(const ReduceNode* op) override; + PrimExpr VisitExpr_(const CastNode* op) override; + PrimExpr VisitExpr_(const NotNode* op) override; + PrimExpr VisitExpr_(const SelectNode* op) override; + PrimExpr VisitExpr_(const RampNode* op) override; + PrimExpr VisitExpr_(const BroadcastNode* op) override; + PrimExpr VisitExpr_(const ShuffleNode* op) override; + PrimExpr VisitExpr_(const IntImmNode* op) override; + PrimExpr VisitExpr_(const UIntImmNode* op) override; + PrimExpr VisitExpr_(const FloatImmNode* op) override; + PrimExpr VisitExpr_(const StringImmNode* op) override; }; /*! @@ -394,7 +394,7 @@ class TVM_DLL StmtVisitor : * or have a class sub-class both StmtVisitor and ExprVisitor * and redirect Visit to ExprMutator::VisitExpr(Expr) */ - virtual void VisitExpr(const Expr& e) {} + virtual void VisitExpr(const PrimExpr& e) {} // statement visitor void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; @@ -486,7 +486,7 @@ class TVM_DLL StmtMutator : * or have a class sub-class both StmtMutator and ExprMutator * and redirect Mutate to ExprMutator::Mutate(Expr) */ - virtual Expr VisitExpr(const Expr& e) { + virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; } // statement visitor @@ -537,7 +537,7 @@ class StmtExprVisitor : using StmtVisitor::VisitStmt; using ExprVisitor::VisitExpr; - void VisitExpr(const Expr& e) override { + void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); } }; @@ -556,7 +556,7 @@ class StmtExprMutator : using StmtMutator::VisitExpr; using ExprMutator::VisitExpr; - Expr VisitExpr(const Expr& e) override { + PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); } }; @@ -579,7 +579,7 @@ class StmtExprMutator : TVM_DLL Stmt IRTransform(Stmt node, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, - const Array& only_enable = {}); + const Array& only_enable = {}); /*! * \brief recursively visit the ir in post DFS order node, apply fvisit diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index aa1415ef206c..36ca03f5bed0 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -45,7 +45,7 @@ namespace ir { * \param vrange The range information about the variable. * \return Canonicalized statement. */ -TVM_DLL Expr Simplify(Expr expr, Map vrange = Map()); +TVM_DLL PrimExpr Simplify(PrimExpr expr, Map vrange = Map()); /*! * \brief Simplify the statement. @@ -70,7 +70,7 @@ Stmt CanonicalSimplify(Stmt stmt, * \param vrange The range information about the variable. * \return Canonicalized expression. */ -TVM_DLL Expr CanonicalSimplify(Expr expr, +TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, Map vrange = Map()); /*! @@ -79,7 +79,7 @@ TVM_DLL Expr CanonicalSimplify(Expr expr, * \param rhs The right operand * \return The comparison result. */ -TVM_DLL bool Equal(const Expr& lhs, const Expr& rhs); +TVM_DLL bool Equal(const PrimExpr& lhs, const PrimExpr& rhs); /*! * \brief Deep compare lhs and rhs @@ -100,7 +100,7 @@ bool Equal(const Stmt& lhs, const Stmt& rhs); * \param rhs The right operand * \return The comparison result. */ -int Compare(const Expr& lhs, const Expr& rhs); +int Compare(const PrimExpr& lhs, const PrimExpr& rhs); /*! * \brief verifies whether the IR stmt or Expr is in SSA form. @@ -116,7 +116,7 @@ TVM_DLL bool VerifySSA(const Stmt& ir); * \brief Whether the expression have side effect. * \return whether expression have side effect */ -TVM_DLL bool HasSideEffect(const Expr& e); +TVM_DLL bool HasSideEffect(const PrimExpr& e); /*! * \brief Whether e expression used var. @@ -124,7 +124,7 @@ TVM_DLL bool HasSideEffect(const Expr& e); * \param v The variable. * \return Whether e uses v. */ -bool ExprUseVar(const Expr& e, const Var& v); +bool ExprUseVar(const PrimExpr& e, const Var& v); /*! * \brief Whether e expression used any var in variable set.. @@ -132,7 +132,7 @@ bool ExprUseVar(const Expr& e, const Var& v); * \param vset The variable set. * \return Whether e uses vset. */ -bool ExprUseVar(const Expr& e, const std::unordered_set& vset); +bool ExprUseVar(const PrimExpr& e, const std::unordered_set& vset); /*! * \brief Convert a IR node to be SSA form. @@ -148,7 +148,7 @@ TVM_DLL Stmt ConvertSSA(Stmt stmt); * \return The converted form. */ Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); + const std::unordered_map& value_map); /*! * \brief Substitute the var specified in key->var to be value. @@ -156,8 +156,8 @@ Stmt Substitute(Stmt stmt, * \param value_map The map of new values. * \return The converted expression. */ -Expr Substitute(Expr expr, - const std::unordered_map& value_map); +PrimExpr Substitute(PrimExpr expr, + const std::unordered_map& value_map); /*! * \brief Substitute the var specified in key->var to be value. @@ -165,7 +165,7 @@ Expr Substitute(Expr expr, * \param value_map The map of new values. * \return The converted form. */ -Stmt Substitute(Stmt stmt, const Map& value_map); +Stmt Substitute(Stmt stmt, const Map& value_map); /*! * \brief Substitute the var specified in key->var to be value. @@ -173,7 +173,7 @@ Stmt Substitute(Stmt stmt, const Map& value_map); * \param value_map The map of new values. * \return The converted expression. */ -Expr Substitute(Expr expr, const Map& value_map); +PrimExpr Substitute(PrimExpr expr, const Map& value_map); /*! * \brief inline all calls of f in stmt. @@ -189,7 +189,7 @@ Expr Substitute(Expr expr, const Map& value_map); Stmt Inline(Stmt stmt, FunctionRef f, Array args, - Expr body); + PrimExpr body); /*! * \brief Flatten the multi-dimensional read/write @@ -485,7 +485,7 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); * \param axis_map The map from StringImm -> ItrVar * \return Transformed function. */ -LoweredFunc RemapThreadAxis(LoweredFunc f, Map axis_map); +LoweredFunc RemapThreadAxis(LoweredFunc f, Map axis_map); /*! * \brief Lower packed function call. @@ -590,7 +590,7 @@ bool VerifyMemory(LoweredFunc func, int device_type); * */ bool VerifyGPUCode(Stmt stmt, - Map constraints); + Map constraints); } // namespace ir diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index 3de6bfdbb087..310b454c15d3 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -98,7 +98,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode { * \note Expr is used instead Type, because Type cannot be hold by Map. * constant Expr of given type is used. */ - Map handle_data_type; + Map handle_data_type; /*! \brief The type of the function */ LoweredFuncType func_type{kMixedFunc}; /*! \brief Whether this function is packed function */ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index ad8f8259c016..039e26ecfe90 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -81,7 +81,7 @@ class OperationNode : public ir::FunctionBaseNode { * \param i The output index. * \return shape of i-th output. */ - virtual Array output_shape(size_t i) const = 0; + virtual Array output_shape(size_t i) const = 0; /*! * \brief List all the input Tensors. * \return List of input tensors. @@ -158,14 +158,14 @@ class OperationNode : public ir::FunctionBaseNode { class PlaceholderOpNode : public OperationNode { public: /*! \brief The shape of the input */ - Array shape; + Array shape; /*! \brief The data type of the input. */ DataType dtype; // override behavior. int num_outputs() const final; Array root_iter_vars() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; + Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( const Operation& self, @@ -196,7 +196,7 @@ class PlaceholderOpNode : public OperationNode { v->Visit("dtype", &dtype); } static Operation make(std::string name, - Array shape, + Array shape, DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; @@ -216,7 +216,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { Array reduce_axis; // override functions Array root_iter_vars() const final; - Array output_shape(size_t idx) const final; + Array output_shape(size_t idx) const final; void GatherBound( const Operation& self, const std::unordered_map& tensor_dom, @@ -238,7 +238,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { class TVM_DLL ComputeOpNode : public BaseComputeOpNode { public: /*! \brief the compute expression */ - Array body; + Array body; /*! \brief constructor */ ComputeOpNode() {} // override functions @@ -271,7 +271,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { std::string tag, Map attrs, Array axis, - Array body); + Array body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); @@ -291,7 +291,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { /*! \brief region of input tensors */ Array input_regions; /*! \brief scalar expression inputs */ - Array scalar_inputs; + Array scalar_inputs; /*! \brief constructor */ TensorComputeOpNode() {} // override functions @@ -331,7 +331,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { TensorIntrin intrin, Array tensors, Array regions, - Array scalar_inputs); + Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode); @@ -371,7 +371,7 @@ class ScanOpNode : public OperationNode { int num_outputs() const final; Array root_iter_vars() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; + Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( const Operation& self, @@ -438,7 +438,7 @@ class ExternOpNode : public OperationNode { int num_outputs() const final; Array root_iter_vars() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; + Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( const Operation& self, @@ -506,7 +506,7 @@ class HybridOpNode : public OperationNode { int num_outputs() const final; Array root_iter_vars() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; + Array output_shape(size_t i) const final; Array InputTensors() const final; Operation ReplaceInputs( const Operation& self, @@ -550,10 +550,10 @@ class HybridOpNode : public OperationNode { }; /*! \brief The compute function to specify the input source of a Tensor */ -using FCompute = std::function& i)>; +using FCompute = std::function& i)>; /*! \brief The compute function to specify the inputs source of Tensors */ -using FBatchCompute = std::function (const Array& i)>; +using FBatchCompute = std::function (const Array& i)>; /*! * \brief create a place holder tensor. @@ -561,7 +561,7 @@ using FBatchCompute = std::function (const Array& i)>; * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(Array shape, +TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Float(32), std::string name = "placeholder"); @@ -574,7 +574,7 @@ TVM_DLL Tensor placeholder(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Tensor compute(Array shape, +TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", std::string tag = "", @@ -589,7 +589,7 @@ TVM_DLL Tensor compute(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array compute(Array shape, +TVM_DLL Array compute(Array shape, FBatchCompute fcompute, std::string name = "tensor", std::string tag = "", @@ -616,32 +616,32 @@ TVM_DLL Array scan(Array init, Map attrs = {}); // same as compute, specialized for different fcompute function -inline Tensor compute(Array shape, - std::function f, +inline Tensor compute(Array shape, + std::function f, std::string name = "tensor", std::string tag = "", Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, +inline Tensor compute(Array shape, + std::function f, std::string name = "tensor", std::string tag = "", Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, +inline Tensor compute(Array shape, + std::function f, std::string name = "tensor", std::string tag = "", Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, +inline Tensor compute(Array shape, + std::function f, std::string name = "tensor", std::string tag = "", Map attrs = {}) { diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index b301a18ea313..fa532eae36fa 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -100,15 +100,15 @@ struct ObjectTypeChecker > { }; // extensions for tvm arg value -inline TVMPODValue_::operator tvm::Expr() const { - if (type_code_ == kNull) return Expr(); +inline TVMPODValue_::operator tvm::PrimExpr() const { + if (type_code_ == kNull) return PrimExpr(); if (type_code_ == kDLInt) { CHECK_LE(value_.v_int64, std::numeric_limits::max()); CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return Expr(static_cast(value_.v_int64)); + return PrimExpr(static_cast(value_.v_int64)); } if (type_code_ == kDLFloat) { - return Expr(static_cast(value_.v_float64)); + return PrimExpr(static_cast(value_.v_float64)); } TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); @@ -120,10 +120,10 @@ inline TVMPODValue_::operator tvm::Expr() const { if (ptr->IsInstance()) { return Tensor(ObjectPtr(ptr))(); } - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect type " << ObjectTypeChecker::TypeName() + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); - return Expr(ObjectPtr(ptr)); + return PrimExpr(ObjectPtr(ptr)); } inline TVMPODValue_::operator tvm::Integer() const { @@ -136,7 +136,7 @@ inline TVMPODValue_::operator tvm::Integer() const { TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); Object* ptr = static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect type " << ObjectTypeChecker::TypeName() + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return Integer(ObjectPtr(ptr)); } diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 1c7fc1c45480..26637d5fd0f2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -26,6 +26,7 @@ #include #include +#include #include namespace tvm { diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index b4164fbb578e..f2db652c555e 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -57,7 +57,7 @@ namespace relay { /*! * \brief Symbolic expression for tensor shape. */ -using IndexExpr = ::tvm::Expr; +using IndexExpr = ::tvm::PrimExpr; using SourceName = tvm::SourceName; using Span = tvm::Span; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 2d1e45f8ee0f..1b6155f4c54c 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -95,9 +95,9 @@ class PassContextNode : public RelayNode { int fallback_device{static_cast(kDLCPU)}; /*! \brief The list of required passes. */ - tvm::Array required_pass; + tvm::Array required_pass; /*! \brief The list of disabled passes. */ - tvm::Array disabled_pass; + tvm::Array disabled_pass; PassContextNode() = default; @@ -192,7 +192,7 @@ class PassInfoNode : public RelayNode { std::string name; /*! \brief The passes that are required to perform the current pass. */ - tvm::Array required; + tvm::Array required; PassInfoNode() = default; @@ -204,7 +204,7 @@ class PassInfoNode : public RelayNode { TVM_DLL static PassInfo make(int opt_level, std::string name, - tvm::Array required); + tvm::Array required); static constexpr const char* _type_key = "relay.PassInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode); @@ -332,7 +332,7 @@ Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required); + const tvm::Array& required); /* * \brief Create a function pass. @@ -348,7 +348,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, Module, PassContext)>& pass_func, int opt_level, const std::string& name, - const tvm::Array& required); + const tvm::Array& required); /*! \brief Remove expressions which does not effect the program result. * diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 19c9434ed64e..afaff9c7364a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -51,7 +51,7 @@ namespace tvm { // forward declarations class Integer; -class Expr; +class PrimExpr; namespace runtime { @@ -495,7 +495,7 @@ class TVMPODValue_ { template inline TObjectRef AsObjectRef() const; // ObjectRef Specializations - inline operator tvm::Expr() const; + inline operator tvm::PrimExpr() const; inline operator tvm::Integer() const; protected: @@ -542,7 +542,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; - using TVMPODValue_::operator tvm::Expr; + using TVMPODValue_::operator tvm::PrimExpr; using TVMPODValue_::operator tvm::Integer; // conversion operator. @@ -629,7 +629,7 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; - using TVMPODValue_::operator tvm::Expr; + using TVMPODValue_::operator tvm::PrimExpr; using TVMPODValue_::operator tvm::Integer; TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 01caf5a02c91..7e1347536b07 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -112,7 +112,7 @@ class Stage : public ObjectRef { * \param predicate The condition to be checked. * \return reference to self. */ - TVM_DLL Stage& set_store_predicate(Expr predicate); + TVM_DLL Stage& set_store_predicate(PrimExpr predicate); /*! * \brief Specify environment threads that launched around the group's scope. * This can only be used in group stage. @@ -130,7 +130,7 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) /*! * \brief Split the iteration with given number of parts. * @@ -140,7 +140,7 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param outer The outer domain to be fused. @@ -185,7 +185,7 @@ class Stage : public ObjectRef { * \return reference to self. */ TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) - Expr x_factor, Expr y_factor, + PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner); /*! @@ -225,7 +225,7 @@ class Stage : public ObjectRef { */ TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type, - const Expr& pragma_value = Expr()); // NOLINT(*) + const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) /*! * \brief Fetch data in advance. * \param domain the tensor to be prefetched @@ -233,7 +233,7 @@ class Stage : public ObjectRef { * \param offset the number of iterations be to fetched in advance * \return reference to self */ - TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*) + TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*) /*! * \brief Set alignment requirement for specific dimension. * @@ -468,7 +468,7 @@ class StageNode : public Object { * Use this when there can be duplicated threads doing the same store. * \note Experimental primitive: used by cross thread-reduction. */ - Expr store_predicate; + PrimExpr store_predicate; /*! \brief The relation bwteen of IterVars */ Array relations; /*! \brief additional attributes about iter var. */ @@ -598,7 +598,7 @@ class IterVarAttrNode : public Object { /*! \brief List of tensor to be prefetched in this loop */ Array prefetch_data; /*! \brief The offset used in each prefetch */ - Array prefetch_offset; + Array prefetch_offset; /*! * \brief Tensor intrinsic used in tensorization, * when the axis is marked as Tensorized @@ -611,11 +611,11 @@ class IterVarAttrNode : public Object { /*! * \brief Additional pragma keys, array of StringImm */ - Array pragma_keys; + Array pragma_keys; /*! * \brief Additional values of pragma, if any */ - Array pragma_values; + Array pragma_values; void VisitAttrs(AttrVisitor* v) { v->Visit("iter_type", &iter_type); @@ -653,9 +653,9 @@ class SplitNode : public IterVarRelationNode { /*! \brief The inner domain */ IterVar inner; /*! \brief The split factor */ - Expr factor; + PrimExpr factor; /*! \brief Number of parts, only factor or nparts can be given */ - Expr nparts; + PrimExpr nparts; void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); @@ -668,8 +668,8 @@ class SplitNode : public IterVarRelationNode { static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, - Expr factor, - Expr nparts); + PrimExpr factor, + PrimExpr nparts); static constexpr const char* _type_key = "Split"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode); diff --git a/include/tvm/target_info.h b/include/tvm/target_info.h index 25fb7243eaf2..f7b36a06b9d6 100644 --- a/include/tvm/target_info.h +++ b/include/tvm/target_info.h @@ -45,7 +45,7 @@ struct MemoryInfoNode : public Object { * \brief head address of the buffer, if visible to CPU * This address can be None. */ - Expr head_address; + PrimExpr head_address; void VisitAttrs(AttrVisitor* v) { v->Visit("unit_bits", &unit_bits); diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index d6e93f567e50..ecadc3552923 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -76,8 +76,8 @@ class Tensor : public ObjectRef { * \return the result expression representing tensor read. */ template - inline Expr operator()(Args&& ...args) const { - Array indices{std::forward(args)...}; + inline PrimExpr operator()(Args&& ...args) const { + Array indices{std::forward(args)...}; return operator()(indices); } /*! @@ -85,13 +85,13 @@ class Tensor : public ObjectRef { * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL Expr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(Array indices) const; /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL Expr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(Array indices) const; /*! * \brief data structure to represent a slice that fixes first k coordinates. * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. @@ -99,15 +99,15 @@ class Tensor : public ObjectRef { class Slice { public: // construct via tensor and indices - Slice(const Tensor& tensor, std::vector indices) + Slice(const Tensor& tensor, std::vector indices) : tensor_(tensor), indices_(indices) {} /*! * \brief get i-th slice from the current slice. * \param i the index of the coordinate * \return the subsequent slice. */ - inline Slice operator[](Expr i) { - std::vector other = indices_; + inline Slice operator[](PrimExpr i) { + std::vector other = indices_; other.emplace_back(i); return Slice(tensor_, other); } @@ -116,20 +116,20 @@ class Tensor : public ObjectRef { * This is only valid when all the coordinates are fully specified. * \return the corresponding expression of this slice. */ - inline operator Expr() const { + inline operator PrimExpr() const { return tensor_(indices_); } private: const Tensor& tensor_; - std::vector indices_; + std::vector indices_; }; /*! * \brief get i-th slice from the current Tensor. * \param i the index of the coordinate * \return the subsequent slice. */ - inline Slice operator[](Expr i) const { + inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } /*! \brief specify container node */ @@ -161,7 +161,7 @@ class Operation : public ir::FunctionRef { class TensorNode : public Object { public: /*! \brief The shape of the tensor */ - Array shape; + Array shape; /*! \brief data type in the content of the tensor */ DataType dtype; /*! \brief the source operation, can be None */ @@ -177,7 +177,7 @@ class TensorNode : public Object { v->Visit("op", &op); v->Visit("value_index", &value_index); } - TVM_DLL static Tensor make(Array shape, + TVM_DLL static Tensor make(Array shape, DataType dtype, Operation op, int value_index); @@ -213,21 +213,21 @@ inline bool Tensor::operator!=(const Tensor& other) const { // macro to turn every operation of slice to expression #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ - inline Expr operator Op (const Tensor::Slice& a) { \ - return Op a.operator Expr() ; \ + inline PrimExpr operator Op (const Tensor::Slice& a) { \ + return Op a.operator PrimExpr() ; \ } \ #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ template \ - inline Expr operator Op (const Tensor::Slice& a, const T& b) { \ - return a.operator Expr() Op b; \ + inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \ + return a.operator PrimExpr() Op b; \ } \ template \ - inline Expr operator Op (const T& a, const Tensor::Slice& b) { \ - return a Op b.operator Expr(); \ - } \ - inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ - return a.operator Expr() Op b.operator Expr(); \ + inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \ + return a Op b.operator PrimExpr(); \ + } \ + inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ + return a.operator PrimExpr() Op b.operator PrimExpr(); \ } DEFINE_OVERLOAD_SLICE_UNARY_OP(!); diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index f973909ae398..879e206c1365 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -150,7 +150,7 @@ class TensorIntrinCallNode : public Object { Array reduce_axis; /*! \brief scalar expression inputs */ - Array scalar_inputs; + Array scalar_inputs; void VisitAttrs(AttrVisitor* v) { v->Visit("intrin", &intrin); @@ -163,7 +163,7 @@ class TensorIntrinCallNode : public Object { Array tensors, Array regions, Array reduce_axis, - Array scalar_inputs); + Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); diff --git a/python/tvm/api.py b/python/tvm/api.py index ef121bc880b2..4d0e3472683c 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -256,7 +256,7 @@ def placeholder(shape, dtype=None, name="placeholder"): tensor: Tensor The created tensor """ - shape = (shape,) if isinstance(shape, _expr.Expr) else shape + shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape dtype = float32 if dtype is None else dtype return _api_internal._Placeholder( shape, dtype, name) @@ -293,7 +293,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag - shape = (shape,) if isinstance(shape, _expr.Expr) else shape + shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) ndim = len(shape) @@ -482,8 +482,8 @@ def extern(shape, if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag - shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape - if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)): + shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape + if shape == () or isinstance(shape[0], (_expr.PrimExpr, _Integral)): shape = [shape] if in_buffers is not None: in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers @@ -518,7 +518,7 @@ def extern(shape, for shp, dt in zip(shape, dtype): output_placeholders.append(decl_buffer(shp, dt, name)) body = fcompute(input_placeholders, output_placeholders) - if isinstance(body, _expr.Expr): + if isinstance(body, _expr.PrimExpr): body = _make.Evaluate(body) op = _api_internal._ExternOp(name, tag, attrs, @@ -626,7 +626,7 @@ def decl_buffer(shape, If user pass a fully generic symbolic array to the strides, then the resulting function becomes fully generic. """ - shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape + shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape dtype = float32 if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: @@ -827,7 +827,7 @@ def _make_reduce(expr, axis, where=None): result = fcombine(lhs, rhs) id_elem = fidentity(*dtypes) else: - assert isinstance(expr, _expr.Expr) + assert isinstance(expr, _expr.PrimExpr) size = 1 dtype = expr.dtype lvar = var(code.co_varnames[0], dtype) diff --git a/python/tvm/contrib/sparse.py b/python/tvm/contrib/sparse.py index b82da074576f..28c703da3262 100644 --- a/python/tvm/contrib/sparse.py +++ b/python/tvm/contrib/sparse.py @@ -167,7 +167,7 @@ def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None tensor: SparsePlaceholderOp The created sparse tensor placeholder """ - shape = (shape,) if isinstance(shape, _expr.Expr) else shape + shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape nonzeros = 0 if nonzeros is None else nonzeros dtype = float32 if dtype is None else dtype stype = 'csr' if stype is None else stype diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 733f57a68c56..c6b3d9b866e2 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -242,27 +242,27 @@ def asnode(self): return _make._OpNE(self.a, self.b) -class Expr(ExprOp, NodeBase): +class PrimExpr(ExprOp, NodeBase): """Base class of all tvm Expressions""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = NodeBase.__hash__ -class ConstExpr(Expr): +class ConstExpr(PrimExpr): pass -class BinaryOpExpr(Expr): +class BinaryOpExpr(PrimExpr): pass -class CmpExpr(Expr): +class CmpExpr(PrimExpr): pass -class LogicalExpr(Expr): +class LogicalExpr(PrimExpr): pass @register_node("Variable") -class Var(Expr): +class Var(PrimExpr): """Symbolic variable. Parameters @@ -279,7 +279,7 @@ def __init__(self, name, dtype): @register_node -class Reduce(Expr): +class Reduce(PrimExpr): """Reduce node. Parameters @@ -383,7 +383,7 @@ def __ne__(self, other): @register_node -class Cast(Expr): +class Cast(PrimExpr): """Cast expression. Parameters @@ -703,7 +703,7 @@ def __init__(self, a): @register_node -class Select(Expr): +class Select(PrimExpr): """Select node. Note @@ -731,7 +731,7 @@ def __init__(self, condition, true_value, false_value): @register_node -class Load(Expr): +class Load(PrimExpr): """Load node. Parameters @@ -754,7 +754,7 @@ def __init__(self, dtype, buffer_var, index, predicate): @register_node -class Ramp(Expr): +class Ramp(PrimExpr): """Ramp node. Parameters @@ -774,7 +774,7 @@ def __init__(self, base, stride, lanes): @register_node -class Broadcast(Expr): +class Broadcast(PrimExpr): """Broadcast node. Parameters @@ -791,7 +791,7 @@ def __init__(self, value, lanes): @register_node -class Shuffle(Expr): +class Shuffle(PrimExpr): """Shuffle node. Parameters @@ -808,7 +808,7 @@ def __init__(self, vectors, indices): @register_node -class Call(Expr): +class Call(PrimExpr): """Call node. Parameters @@ -843,7 +843,7 @@ def __init__(self, dtype, name, args, call_type, func, value_index): @register_node -class Let(Expr): +class Let(PrimExpr): """Let node. Parameters diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 414822068e07..1d5612e67e80 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -91,7 +91,7 @@ def _allocate_tensor(func_id, args): "allocate's first argument should be a tuple of shape!") shape = args[0] for i in shape: - _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression") + _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression") if n > 1: _internal_assert(isinstance(args[1], str), "The data type should be an str") @@ -125,7 +125,7 @@ def len(func_id, args): def _cast(func_id, args): - _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \ + _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \ "Only one expression can be cast") return _make.Cast(func_id, args[0]) @@ -137,8 +137,8 @@ def _cast(func_id, args): def ceil_div(func_id, args): _internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!") _internal_assert(args.__len__() == 2, "2 arguments expected for division!") - _internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div") - _internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div") + _internal_assert(isinstance(args[0], _expr.PrimExpr), "Only expressions can div") + _internal_assert(isinstance(args[1], _expr.PrimExpr), "Only expressions can div") a, b = args[0], args[1] return (a + b - 1) // b diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 7e5659a8e9bb..06bcbcabe0c3 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -325,7 +325,7 @@ def visit_Assign(self, node): _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] - if isinstance(rhs, _expr.Expr): + if isinstance(rhs, _expr.PrimExpr): rhs = _ir_pass.Simplify(rhs) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 91fd291fb155..6b577c456fac 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -132,7 +132,7 @@ def vload(self, begin, dtype=None): load : Expr The corresponding load expression. """ - begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin + begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin dtype = dtype if dtype else self.dtype return _api_internal._BufferVLoad(self, begin, dtype) @@ -152,7 +152,7 @@ def vstore(self, begin, value): store : Stmt The corresponding store stmt. """ - begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin + begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin return _api_internal._BufferVStore(self, begin, value) diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 1cadf0621823..e4a2f4f76e7b 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -65,7 +65,7 @@ def __call__(self, *indices): indices = convert_to_node(indices) args = [] for x in indices: - if isinstance(x, _expr.Expr): + if isinstance(x, _expr.PrimExpr): args.append(x) elif isinstance(x, iter_var_cls): args.append(x.var) diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 2ef7a4bbb23e..378cfe51a7b7 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -133,9 +133,9 @@ def decl_tensor_intrin(op, else: body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) scalar_params = [] - if isinstance(body, (_expr.Expr, _stmt.Stmt)): + if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)): body = [body] - body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body] + body = [_make.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body] if len(body) < 3: body += [None] * (3 - len(body)) return _api_internal._TensorIntrin( diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index a69fd5d436df..7150d2723ab7 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -49,7 +49,7 @@ TVM_REGISTER_GLOBAL("arith.DetectClipBound") TVM_REGISTER_GLOBAL("arith.DeduceBound") .set_body_typed([]( - Expr v, Expr cond, + PrimExpr v, PrimExpr cond, const Map hint_map, const Map relax_map ) { @@ -121,7 +121,7 @@ TVM_REGISTER_GLOBAL("arith._CreateAnalyzer") if (args[1].IsObjectRef()) { self->Bind(args[0], args[1].operator Range()); } else { - self->Bind(args[0], args[1].operator Expr()); + self->Bind(args[0], args[1].operator PrimExpr()); } }); } else if (name == "enter_constraint_context") { diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index ba04239dabb7..ca4823bc6b83 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -71,7 +71,7 @@ TVM_REGISTER_GLOBAL("make.SeqStmt") TVM_REGISTER_GLOBAL("make.For") .set_body_typed([]( - VarExpr loop_var, Expr min, Expr extent, + Var loop_var, PrimExpr min, PrimExpr extent, int for_type, int device_api, Stmt body) { return ForNode::make(loop_var, min, @@ -93,7 +93,7 @@ TVM_REGISTER_GLOBAL("make.Load") TVM_REGISTER_GLOBAL("make.Store") .set_body([](TVMArgs args, TVMRetValue *ret) { - Expr value = args[1]; + PrimExpr value = args[1]; if (args.size() == 3) { *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); } else { @@ -107,7 +107,7 @@ TVM_REGISTER_GLOBAL("make.Realize") TVM_REGISTER_GLOBAL("make.Call") .set_body_typed([]( DataType type, std::string name, - Array args, int call_type, + Array args, int call_type, FunctionRef func, int value_index ) { return CallNode::make(type, @@ -173,7 +173,7 @@ REGISTER_MAKE(Evaluate); // has default args TVM_REGISTER_GLOBAL("make.Allocate") .set_body_typed([]( - VarExpr buffer_var, DataType type, Array extents, Expr condition, Stmt body + Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body ){ return AllocateNode::make(buffer_var, type, extents, condition, body); }); @@ -181,7 +181,7 @@ TVM_REGISTER_GLOBAL("make.Allocate") // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("make."#Node) \ - .set_body_typed([](Expr a, Expr b) { \ + .set_body_typed([](PrimExpr a, PrimExpr b) { \ return (Func(a, b)); \ }) @@ -191,11 +191,11 @@ TVM_REGISTER_GLOBAL("make.Allocate") bool lhs_is_int = args[0].type_code() == kDLInt; \ bool rhs_is_int = args[1].type_code() == kDLInt; \ if (lhs_is_int) { \ - *ret = (Func(args[0].operator int(), args[1].operator Expr())); \ + *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ } else if (rhs_is_int) { \ - *ret = (Func(args[0].operator Expr(), args[1].operator int())); \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ } else { \ - *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ } \ }) @@ -228,7 +228,7 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); TVM_REGISTER_GLOBAL("make._OpIfThenElse") -.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) { +.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { return if_then_else(cond, true_value, false_value); }); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 4e635ad4f67e..6a8bc58ad7d0 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -289,7 +289,7 @@ TVM_REGISTER_GLOBAL("_TensorHash") }); TVM_REGISTER_GLOBAL("_Placeholder") -.set_body_typed([](Array shape, DataType dtype, std::string name) { +.set_body_typed([](Array shape, DataType dtype, std::string name) { return placeholder(shape, dtype, name); }); @@ -337,14 +337,14 @@ TVM_REGISTER_GLOBAL("_StageBind") .set_body_method(&Stage::bind); TVM_REGISTER_GLOBAL("_StageSplitByFactor") -.set_body_typed([](Stage stage, IterVar parent, Expr factor) { +.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { IterVar outer, inner; stage.split(parent, factor, &outer, &inner); return Array({outer, inner}); }); TVM_REGISTER_GLOBAL("_StageSplitByNParts") -.set_body_typed([](Stage stage, IterVar parent, Expr nparts) { +.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { IterVar outer, inner; stage.split_by_nparts(parent, nparts, &outer, &inner); return Array({outer, inner}); @@ -373,7 +373,7 @@ TVM_REGISTER_GLOBAL("_StageTile") .set_body_typed([]( Stage stage, IterVar x_parent, IterVar y_parent, - Expr x_factor, Expr y_factor + PrimExpr x_factor, PrimExpr y_factor ) { IterVar x_outer, y_outer, x_inner, y_inner; stage.tile(x_parent, y_parent, diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 16c2b1bf44d9..ff30f5ed7216 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -43,9 +43,9 @@ TVM_REGISTER_GLOBAL("ir_pass.Simplify") } } else { if (args.size() > 1) { - *ret = Simplify(args[0].operator Expr(), args[1]); + *ret = Simplify(args[0].operator PrimExpr(), args[1]); } else { - *ret = Simplify(args[0].operator Expr()); + *ret = Simplify(args[0].operator PrimExpr()); } } }); @@ -60,9 +60,9 @@ TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify") } } else { if (args.size() > 1) { - *ret = CanonicalSimplify(args[0].operator Expr(), args[1]); + *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]); } else { - *ret = CanonicalSimplify(args[0].operator Expr()); + *ret = CanonicalSimplify(args[0].operator PrimExpr()); } } }); @@ -70,9 +70,9 @@ TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify") TVM_REGISTER_GLOBAL("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { - *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); + *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); } else { - *ret = Substitute(args[0].operator Expr(), args[1].operator Map()); + *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map()); } }); @@ -81,7 +81,7 @@ TVM_REGISTER_GLOBAL("ir_pass.Equal") if (args[0].IsObjectRef()) { *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); } else { - *ret = Equal(args[0].operator Expr(), args[1].operator Expr()); + *ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr()); } }); @@ -114,7 +114,7 @@ TVM_REGISTER_GLOBAL("ir_pass.AttrsHash") TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var()); + *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); }); TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") diff --git a/src/api/api_test.cc b/src/api/api_test.cc index de37111c9f92..6f01c7a4fb2d 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -34,7 +34,7 @@ namespace tvm { struct TestAttrs : public AttrsNode { int axis; std::string name; - Array padding; + Array padding; TypedEnvFunc func; TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") { @@ -47,7 +47,7 @@ struct TestAttrs : public AttrsNode { .describe("name"); TVM_ATTR_FIELD(padding) .describe("padding of input") - .set_default(Array({0, 0})); + .set_default(Array({0, 0})); TVM_ATTR_FIELD(func) .describe("some random env function") .set_default(TypedEnvFunc(nullptr)); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 68e0b0548196..7a3baa678352 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -35,8 +35,8 @@ Analyzer::Analyzer() int_set(this) { } -void Analyzer::Bind(const VarExpr& var, const Expr& expr) { - Expr new_expr = expr; +void Analyzer::Bind(const Var& var, const PrimExpr& expr) { + PrimExpr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); @@ -46,7 +46,7 @@ void Analyzer::Bind(const VarExpr& var, const Expr& expr) { this->canonical_simplify.Update(var, new_expr); } -void Analyzer::Bind(const VarExpr& var, const Range& range) { +void Analyzer::Bind(const Var& var, const Range& range) { CHECK(range.defined()); if (is_one(range->extent)) { this->Bind(var, range->min); @@ -77,7 +77,7 @@ void ConstraintContext::ExitWithScope() { exit_(); } -bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { +bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { if (const auto* ptr = expr.as()) { return ptr->value >= lower_bound; } @@ -86,7 +86,7 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { return false; } -bool Analyzer::CanProve(const Expr& expr) { +bool Analyzer::CanProve(const PrimExpr& expr) { if (const auto* ptr = expr.as()) { return ptr->value != 0; } @@ -101,7 +101,7 @@ bool Analyzer::CanProve(const Expr& expr) { return false; } -Expr Analyzer::Simplify(const Expr& expr) { +PrimExpr Analyzer::Simplify(const PrimExpr& expr) { if (is_const(expr)) return expr; auto res = this->rewrite_simplify(expr); if (is_const(res)) return res; diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 40f86de7561a..1ba0293fca8a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -42,9 +42,9 @@ using namespace ir; // from a expression. class VariablePathFinder: public ExprVisitor { public: - explicit VariablePathFinder(Expr target) : target_(target) {} + explicit VariablePathFinder(PrimExpr target) : target_(target) {} - void VisitExpr(const Expr& node) final { + void VisitExpr(const PrimExpr& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); @@ -58,13 +58,13 @@ class VariablePathFinder: public ExprVisitor { private: bool found_{false}; - Expr target_; + PrimExpr target_; std::unordered_set visited_; }; // get the path to the variable, // return empty vector to represent failure -std::vector GetPath(Expr target, Expr expr) { +std::vector GetPath(PrimExpr target, PrimExpr expr) { VariablePathFinder v(target); v(expr); return v.path_; @@ -77,14 +77,14 @@ class BoundDeducer: public ExprVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; - BoundDeducer(Expr target, Expr expr, + BoundDeducer(PrimExpr target, PrimExpr expr, const std::unordered_map& hint_map, const std::unordered_map& relax_map) : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} void Deduce(); - void VisitExpr(const Expr& e) final { + void VisitExpr(const PrimExpr& e) final { if (!success_) return; if (e.get() == path_[iter_++]) { ExprVisitor::VisitExpr(e); @@ -130,8 +130,8 @@ class BoundDeducer: public ExprVisitor { void VisitExpr_(const MulNode* op) final { bool left = op->a.get() == path_[iter_]; - Expr operand = left ? op->b : op->a; - Expr target_var = left ? op->a : op->b; + PrimExpr operand = left ? op->b : op->a; + PrimExpr target_var = left ? op->a : op->b; SignType sign_operand; if (operand.dtype().is_uint()) { @@ -176,7 +176,7 @@ class BoundDeducer: public ExprVisitor { this->VisitExpr(left ? op->a : op->b); } - Expr result_; + PrimExpr result_; CompareOp comp_op{kGreater}; bool success_{true}; @@ -185,8 +185,8 @@ class BoundDeducer: public ExprVisitor { void Transform(); void Relax(); CompareOp ReverseOp(CompareOp comp_op); - Expr target_; - Expr expr_; + PrimExpr target_; + PrimExpr expr_; const std::unordered_map& hint_map_; const std::unordered_map& relax_map_; ExprIntSetMap expr_map_; @@ -204,7 +204,7 @@ class BoundDeduceInputChecker: public ExprVisitor { return target_count == 1; } - void VisitExpr(const Expr& e) final { + void VisitExpr(const PrimExpr& e) final { if (e.same_as(deducer_->target_)) ++target_count; ExprVisitor::VisitExpr(e); } @@ -329,13 +329,13 @@ void BoundDeducer::Relax() { result_ = (comp_op == kGreater) ? b.max() : b.min(); } -IntSet DeduceBound(Expr v, Expr e, +IntSet DeduceBound(PrimExpr v, PrimExpr e, const std::unordered_map& hint_map, const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success_) return IntSet::nothing(); - Expr min = neg_inf(), max = pos_inf(); + PrimExpr min = neg_inf(), max = pos_inf(); if (d.comp_op == kEqual) { min = d.result_; max = d.result_; @@ -349,7 +349,7 @@ IntSet DeduceBound(Expr v, Expr e, // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. -IntSet DeduceBound(Expr v, Expr e, +IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, const Map& relax_map) { std::unordered_map hmap; diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index e33b0c516b4a..5f721d7a1f94 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -40,7 +40,7 @@ class SplitExpr; * \brief Base class of all temporary expression introduced * for canonicalization. */ -class CanonicalExprNode : public BaseExprNode { +class CanonicalExprNode : public PrimExprNode { public: virtual ~CanonicalExprNode() {} /*! @@ -48,14 +48,14 @@ class CanonicalExprNode : public BaseExprNode { * \note Can mutate the internal data structure. * \return The normal expression. */ - virtual Expr Normalize() const = 0; + virtual PrimExpr Normalize() const = 0; // overrides void VisitAttrs(tvm::AttrVisitor* v) { } static constexpr const char* _type_key = "arith.CanonicalExpr"; - TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, BaseExprNode); + TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); }; enum DivMode { @@ -65,7 +65,7 @@ enum DivMode { kFloorDiv }; -inline Expr ModImpl(Expr a, Expr b, DivMode mode) { +inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncmod(a, b); } else { @@ -74,7 +74,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) { } } -inline Expr DivImpl(Expr a, Expr b, DivMode mode) { +inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncdiv(a, b); } else { @@ -94,7 +94,7 @@ inline Expr DivImpl(Expr a, Expr b, DivMode mode) { class SplitExprNode : public CanonicalExprNode { public: /*! \brief The base index expression. */ - Expr index; + PrimExpr index; /*! \brief The division factor ratio. */ int64_t lower_factor{1}; /*! @@ -112,8 +112,8 @@ class SplitExprNode : public CanonicalExprNode { CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); } - Expr NormalizeWithScale(int64_t sscale) const { - Expr res = this->index; + PrimExpr NormalizeWithScale(int64_t sscale) const { + PrimExpr res = this->index; DataType dtype = this->dtype; if (this->scale == 0) { return make_const(dtype, 0); @@ -132,7 +132,7 @@ class SplitExprNode : public CanonicalExprNode { return res; } - Expr Normalize() const final { + PrimExpr Normalize() const final { return NormalizeWithScale(1); } @@ -149,9 +149,9 @@ class SplitExprNode : public CanonicalExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode); }; -class SplitExpr : public Expr { +class SplitExpr : public PrimExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, Expr, SplitExprNode); + TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode); }; @@ -190,7 +190,7 @@ class SumExprNode : public CanonicalExprNode { * \brief Return the normal Expr that is equivalent to self. * \return The normal expression. */ - Expr Normalize() const final { + PrimExpr Normalize() const final { // quick path 1. if (this->args.size() == 0) { return make_const(this->dtype, this->base); @@ -382,11 +382,11 @@ class SumExprNode : public CanonicalExprNode { std::stable_sort(args.begin(), args.end(), fcompare); return args; } - static Expr Normalize_(DataType dtype, + static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { // Positive scales first - Expr res = make_const(dtype, 0); + PrimExpr res = make_const(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { if (args[i]->scale > 0) { res = res + args[i]->Normalize(); @@ -408,9 +408,9 @@ class SumExprNode : public CanonicalExprNode { } }; -class SumExpr : public Expr { +class SumExpr : public PrimExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, Expr, SumExprNode); + TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode); }; @@ -433,31 +433,31 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { : Rewriter(parent) {} - Expr CanonicalSimplify(Expr expr) { + PrimExpr CanonicalSimplify(PrimExpr expr) { expr = operator()(expr); return expr; } // override the original mutate function. - Expr VisitExpr(const Expr& input_expr) final { + PrimExpr VisitExpr(const PrimExpr& input_expr) final { auto expr = Rewriter::VisitExpr(input_expr); return Normalize(expr); } // Normal mutation without normalization. - Expr CanonicalMutate(Expr expr) { + PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); } using Rewriter::VisitExpr_; - Expr VisitExpr_(const AddNode* op) final; - Expr VisitExpr_(const SubNode* op) final; - Expr VisitExpr_(const MulNode* op) final; - Expr VisitExpr_(const DivNode* op) final; - Expr VisitExpr_(const ModNode* op) final; - Expr VisitExpr_(const FloorDivNode* op) final; - Expr VisitExpr_(const FloorModNode* op) final; - Expr VisitExpr_(const ReduceNode* op) final; + PrimExpr VisitExpr_(const AddNode* op) final; + PrimExpr VisitExpr_(const SubNode* op) final; + PrimExpr VisitExpr_(const MulNode* op) final; + PrimExpr VisitExpr_(const DivNode* op) final; + PrimExpr VisitExpr_(const ModNode* op) final; + PrimExpr VisitExpr_(const FloorDivNode* op) final; + PrimExpr VisitExpr_(const FloorModNode* op) final; + PrimExpr VisitExpr_(const ReduceNode* op) final; private: /*! @@ -492,7 +492,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \param expr The input expression. * \return Normalized expr. */ - Expr Normalize(Expr expr) { + PrimExpr Normalize(PrimExpr expr) { if (const auto* op = expr.as()) { return op->Normalize(); } else { @@ -504,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \param expr The input expr. * \return The transformed SplitExpr. */ - SplitExpr ToSplitExpr(Expr expr) { + SplitExpr ToSplitExpr(PrimExpr expr) { if (const auto* op = expr.as()) { return GetRef(op); } @@ -547,7 +547,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \param expr The input expr. * \return The transformed SumExpr. */ - SumExpr ToSumExpr(Expr expr) { + SumExpr ToSumExpr(PrimExpr expr) { if (const auto* op = expr.as()) { return GetRef(op); } @@ -562,20 +562,20 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } } // Simplify the combiner used in reduce. - Expr SimplifyReduceCombiner(const ReduceNode* op); + PrimExpr SimplifyReduceCombiner(const ReduceNode* op); }; -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } // normalize - Expr a = this->CanonicalMutate(op->a); - Expr b = this->CanonicalMutate(op->b); + PrimExpr a = this->CanonicalMutate(op->a); + PrimExpr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -591,17 +591,17 @@ VisitExpr_(const AddNode* op) { return std::move(ret); } -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const SubNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } // normalize - Expr a = this->CanonicalMutate(op->a); - Expr b = this->CanonicalMutate(op->b); + PrimExpr a = this->CanonicalMutate(op->a); + PrimExpr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -618,17 +618,17 @@ VisitExpr_(const SubNode* op) { } -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const MulNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } // normalize - Expr a = this->CanonicalMutate(op->a); - Expr b = this->CanonicalMutate(op->b); + PrimExpr a = this->CanonicalMutate(op->a); + PrimExpr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // x * c @@ -651,7 +651,7 @@ VisitExpr_(const MulNode* op) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return GetRef(op); } else { return MulNode::make(a, b); } @@ -725,17 +725,17 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } - Expr a = this->CanonicalMutate(op->a); - Expr b = this->CanonicalMutate(op->b); + PrimExpr a = this->CanonicalMutate(op->a); + PrimExpr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -755,7 +755,7 @@ VisitExpr_(const DivNode* op) { if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { lhs.CopyOnWrite()->DivideBy(cval); - Expr temp = Normalize(extra); + PrimExpr temp = Normalize(extra); if (const auto* pconst = temp.as()) { lhs.CopyOnWrite()->AddToSelf(pconst->value / cval); } else { @@ -780,22 +780,22 @@ VisitExpr_(const DivNode* op) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return GetRef(op); } else { return DivNode::make(a, b); } } -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const FloorDivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } - Expr a = this->CanonicalMutate(op->a); - Expr b = this->CanonicalMutate(op->b); + PrimExpr a = this->CanonicalMutate(op->a); + PrimExpr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -812,7 +812,7 @@ VisitExpr_(const FloorDivNode* op) { } // continue simplification. lhs.CopyOnWrite()->DivideBy(cval); - Expr temp = Normalize(extra); + PrimExpr temp = Normalize(extra); if (const auto* pconst = temp.as()) { lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval)); } else { @@ -836,7 +836,7 @@ VisitExpr_(const FloorDivNode* op) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return GetRef(op); } else { return FloorDivNode::make(a, b); } @@ -892,17 +892,17 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const ModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } // normalize - Expr a = this->CanonicalMutate(op->a); - Expr b = this->CanonicalMutate(op->b); + PrimExpr a = this->CanonicalMutate(op->a); + PrimExpr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -918,7 +918,7 @@ VisitExpr_(const ModNode* op) { // both lhs and extra are non-negative if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { - Expr temp = Normalize(extra); + PrimExpr temp = Normalize(extra); if (temp.as()) { return truncmod(temp, c1.Eval()); } else { @@ -956,23 +956,23 @@ VisitExpr_(const ModNode* op) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return GetRef(op); } else { return ModNode::make(a, b); } } -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const FloorModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } // normalize - Expr a = this->CanonicalMutate(op->a); - Expr b = this->CanonicalMutate(op->b); + PrimExpr a = this->CanonicalMutate(op->a); + PrimExpr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -982,7 +982,7 @@ VisitExpr_(const FloorModNode* op) { if (const auto* psum = a.as()) { SumExpr lhs, extra; SeparateDivisibleParts(psum, cval, &lhs, &extra); - Expr temp = Normalize(extra); + PrimExpr temp = Normalize(extra); if (temp.as()) { return floormod(temp, c1.Eval()); } else { @@ -1016,19 +1016,19 @@ VisitExpr_(const FloorModNode* op) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return GetRef(op); } else { return FloorModNode::make(a, b); } } // Simplify reduce expression. -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: SimplifyReduceCombiner(const ReduceNode* op) { // First simplify the results - Array simplified_result; + Array simplified_result; for (const auto& res : op->combiner->result) { - Expr new_res = this->VisitExpr(res); + PrimExpr new_res = this->VisitExpr(res); simplified_result.push_back(new_res); } @@ -1066,11 +1066,11 @@ SimplifyReduceCombiner(const ReduceNode* op) { } int new_value_index = op->value_index; - Array new_result; - Array new_identity; + Array new_result; + Array new_identity; Array new_lhs; Array new_rhs; - Array new_source; + Array new_source; // new stuff is old stuff which is used for (size_t i = 0; i < used.size(); ++i) { @@ -1093,10 +1093,10 @@ SimplifyReduceCombiner(const ReduceNode* op) { new_combiner, new_source, op->axis, op->condition, new_value_index); } -Expr CanonicalSimplifier::Impl:: +PrimExpr CanonicalSimplifier::Impl:: VisitExpr_(const ReduceNode* op) { // Recursively call simplification when necessary. - Expr ret = RewriteSimplifier::Impl::VisitExpr_(op); + PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op); op = ret.as(); // already been simplified by const reduction axis removal if (op == nullptr) return ret; @@ -1115,12 +1115,12 @@ VisitExpr_(const ReduceNode* op) { return ret; } -Expr CanonicalSimplifier::operator()(const Expr& expr) { +PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { return impl_->CanonicalSimplify(expr); } void CanonicalSimplifier::Update(const Var& var, - const Expr& info, + const PrimExpr& info, bool override) { impl_->Update(var, info, override); } diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index aca26e85375a..d78838f85ae5 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -39,7 +39,7 @@ namespace arith { * \return The result. */ template -inline Expr Compute(Expr lhs, Expr rhs) { +inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) { return OP::make(lhs, rhs); } @@ -52,10 +52,10 @@ inline Expr Compute(Expr lhs, Expr rhs) { * \return The result. */ template -inline Expr ComputeReduce( - const Array& values, Expr empty_value); +inline PrimExpr ComputeReduce( + const Array& values, PrimExpr empty_value); -inline bool GetConst(Expr e, int64_t* out) { +inline bool GetConst(PrimExpr e, int64_t* out) { if (e.dtype().is_vector()) return false; const int64_t* v = as_const_int(e); if (v) { @@ -66,7 +66,7 @@ inline bool GetConst(Expr e, int64_t* out) { } // get a small constant int -inline bool GetConstInt(Expr e, int* out) { +inline bool GetConstInt(PrimExpr e, int* out) { int64_t v1 = 0; if (GetConst(e, &v1)) { if (v1 > static_cast( @@ -77,47 +77,47 @@ inline bool GetConstInt(Expr e, int* out) { } template<> -inline Expr Compute(Expr a, Expr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a + b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a - b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a * b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncdiv(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncmod(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return max(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return min(a, b); } template -inline Expr ComputeReduce(const Array& values, Expr empty_value) { +inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value) { if (values.size() == 0U) { CHECK(empty_value.defined()); return empty_value; } - Expr res = values[0]; + PrimExpr res = values[0]; for (size_t i = 1; i < values.size(); ++i) { res = Compute(res, values[i]); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index db98a7eede54..55c156d898f9 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -44,8 +44,8 @@ namespace arith { * \return nullptr if constant fold fails, otherwise return folded result. */ template -inline Expr TryConstFold(Expr a, Expr b) { - return Expr(); +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { + return PrimExpr(); } /*! @@ -58,7 +58,7 @@ inline Expr TryConstFold(Expr a, Expr b) { * \return nullptr if constant fold fails, otherwise return folded result. */ template -inline Expr TryConstFold(Expr a); +inline PrimExpr TryConstFold(PrimExpr a); /*! * \brief Check whether type is used to represent index. @@ -100,7 +100,7 @@ inline bool IsIndexType(const DataType& type) { // specialization of constant folders. template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value); @@ -110,11 +110,11 @@ inline Expr TryConstFold(Expr a, Expr b) { if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value); @@ -122,11 +122,11 @@ inline Expr TryConstFold(Expr a, Expr b) { if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value); @@ -148,11 +148,11 @@ inline Expr TryConstFold(Expr a, Expr b) { if (fb->value == 0) return b; } }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -177,11 +177,11 @@ inline Expr TryConstFold(Expr a, Expr b) { CHECK_NE(fb->value, 0) << "Divide by zero"; } }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -195,11 +195,11 @@ inline Expr TryConstFold(Expr a, Expr b) { CHECK_NE(pb->value, 0) << "Divide by zero"; } }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -222,11 +222,11 @@ inline Expr TryConstFold(Expr a, Expr b) { CHECK_NE(fb->value, 0) << "Divide by zero"; } }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -240,87 +240,87 @@ inline Expr TryConstFold(Expr a, Expr b) { CHECK_NE(pb->value, 0) << "Divide by zero"; } }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value); if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value); }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value); if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value); }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value); if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value); }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value); if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value); }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value); if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value); }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value); if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value); }); - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { using ir::UIntImmNode; const UIntImmNode* pa = a.as(); const UIntImmNode* pb = b.as(); @@ -328,11 +328,11 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && !pa->value) return a; if (pb && pb->value) return a; if (pb && !pb->value) return b; - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { using ir::UIntImmNode; const UIntImmNode* pa = a.as(); const UIntImmNode* pb = b.as(); @@ -340,25 +340,25 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && !pa->value) return b; if (pb && pb->value) return b; if (pb && !pb->value) return a; - return Expr(); + return PrimExpr(); } template<> -inline Expr TryConstFold(Expr a) { +inline PrimExpr TryConstFold(PrimExpr a) { using ir::UIntImmNode; const UIntImmNode* pa = a.as(); if (pa) { return UIntImmNode::make(DataType::UInt(1), !(pa->value)); } - return Expr(); + return PrimExpr(); } /*! \brief Helper namespace for symbolic value limits */ struct SymbolicLimits { /*! \brief positive infinity */ - static Expr pos_inf_; + static PrimExpr pos_inf_; /*! \brief negative infinity */ - static Expr neg_inf_; + static PrimExpr neg_inf_; }; /*! @@ -369,7 +369,7 @@ struct SymbolicLimits { * * \return positive infinity. */ -inline Expr pos_inf() { +inline PrimExpr pos_inf() { return SymbolicLimits::pos_inf_; } @@ -379,7 +379,7 @@ inline Expr pos_inf() { * * \return The check result. */ -inline bool is_pos_inf(const Expr& value) { +inline bool is_pos_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::pos_inf_); } @@ -391,7 +391,7 @@ inline bool is_pos_inf(const Expr& value) { * * \return negative infinity. */ -inline Expr neg_inf() { +inline PrimExpr neg_inf() { return SymbolicLimits::neg_inf_; } @@ -401,7 +401,7 @@ inline Expr neg_inf() { * * \return The check result. */ -inline bool is_neg_inf(const Expr& value) { +inline bool is_neg_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::neg_inf_); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index d3f885ae8d8c..a041e40abf46 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -76,17 +76,17 @@ struct ConstIntBoundAnalyzer::Entry { }; class ConstIntBoundAnalyzer::Impl : - public ExprFunctor { + public ExprFunctor { public: /*! \brief additional bound info about expr \in bound */ struct BoundInfo { /*! \brief The expr */ - Expr expr; + PrimExpr expr; /*! \brief The additional bound */ Entry bound; BoundInfo() {} - BoundInfo(Expr expr, Entry bound) + BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) { } }; @@ -125,10 +125,10 @@ class ConstIntBoundAnalyzer::Impl : // Override visitor behaviors Entry VisitExprDefault_(const Object* op) final { return Everything( - static_cast(op)->dtype); + static_cast(op)->dtype); } - Entry VisitExpr(const Expr& expr) final { + Entry VisitExpr(const PrimExpr& expr) final { Entry res = ExprFunctor::VisitExpr(expr); // a linear search over additional info // assume we won't have a lot of conditions @@ -315,7 +315,7 @@ class ConstIntBoundAnalyzer::Impl : } } - std::function EnterConstraint(const Expr& constraint) { + std::function EnterConstraint(const PrimExpr& constraint) { std::vector info = DetectBoundInfo(constraint); if (info.size() == 0) return nullptr; size_t old_size = additional_info_.size(); @@ -330,7 +330,7 @@ class ConstIntBoundAnalyzer::Impl : private: // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; // additional bound info std::vector additional_info_; // constants: the limit value means umlimited @@ -494,8 +494,8 @@ class ConstIntBoundAnalyzer::Impl : * \param cond The constraint condition. * \return List of detected bounds. */ - static std::vector DetectBoundInfo(const Expr& cond) { - PVar x, y; + static std::vector DetectBoundInfo(const PrimExpr& cond) { + PVar x, y; PVar c; // NOTE: canonical form always use <= or < if ((c <= x).Match(cond)) { @@ -520,7 +520,7 @@ class ConstIntBoundAnalyzer::Impl : } }; -ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) { +ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { Entry ret = impl_->VisitExpr(expr); return ConstIntBound(ret.min_value, ret.max_value); } @@ -535,7 +535,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { impl_->Bind(var, range); } -std::function ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) { +std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { return impl_->EnterConstraint(constraint); } diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 7785801a5520..3de555765293 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -33,22 +33,22 @@ using namespace ir; // Linear equation, the components can be undefined. struct LinearEqEntry { - Expr base; - Expr coeff; + PrimExpr base; + PrimExpr coeff; }; struct IntervalEntry { - Expr min_value; - Expr max_value; + PrimExpr min_value; + PrimExpr max_value; }; class LinearEqDetector - : public ExprFunctor { + : public ExprFunctor { public: explicit LinearEqDetector(Var var) : var_(var) {} - bool Detect(const Expr& e, LinearEqEntry* ret) { + bool Detect(const PrimExpr& e, LinearEqEntry* ret) { *ret = VisitExpr(e, e); if (fail_) return false; if (!ret->base.defined()) { @@ -60,7 +60,7 @@ class LinearEqDetector return true; } - LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const AddNode* op, const PrimExpr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -70,7 +70,7 @@ class LinearEqDetector return ret; } - LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const SubNode* op, const PrimExpr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -80,7 +80,7 @@ class LinearEqDetector return ret; } - LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const MulNode* op, const PrimExpr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -96,7 +96,7 @@ class LinearEqDetector ret.coeff = MulCombine(a.base, b.coeff); return ret; } - LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final { LinearEqEntry ret; if (op == var_.get()) { ret.coeff = make_const(op->dtype, 1); @@ -105,7 +105,7 @@ class LinearEqDetector } return ret; } - LinearEqEntry VisitExprDefault_(const Object* op, const Expr& e) final { + LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final { if (fail_) return LinearEqEntry(); if (ExprUseVar(e, var_)) { fail_ = true; @@ -121,32 +121,33 @@ class LinearEqDetector Var var_; bool fail_{false}; // Combine by add - Expr AddCombine(Expr a, Expr b) { + PrimExpr AddCombine(PrimExpr a, PrimExpr b) { if (!a.defined()) return b; if (!b.defined()) return a; return a + b; } - Expr SubCombine(Expr a, Expr b) { + PrimExpr SubCombine(PrimExpr a, PrimExpr b) { // Check b first in case they are both undefined if (!b.defined()) return a; if (!a.defined()) return -b; return a - b; } - Expr MulCombine(Expr a, Expr b) { + PrimExpr MulCombine(PrimExpr a, PrimExpr b) { if (!a.defined()) return a; if (!b.defined()) return b; return a * b; } }; -Array DetectLinearEquation(const Expr& e, const Array& vars) { - Expr base = e; - Array coeff; +Array DetectLinearEquation(const PrimExpr& e, + const Array& vars) { + PrimExpr base = e; + Array coeff; for (Var v : vars) { LinearEqEntry ret; if (!LinearEqDetector(v).Detect(base, &ret)) { - return Array(); + return Array(); } coeff.push_back(ret.coeff); base = std::move(ret.base); @@ -157,7 +158,7 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable if (ExprUseVar(coeff[i - 2], vset)) { - return Array(); + return Array(); } } coeff.push_back(base); @@ -166,7 +167,7 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { // Detect clip condition as min max value bool DetectClipBound( - const Expr& cond, + const PrimExpr& cond, std::unordered_map* bmap) { int flag = 0; Var var; @@ -187,7 +188,7 @@ bool DetectClipBound( PostOrderVisit(cond, fvisit); if (flag != 1) return false; // canonical form: exp >= 0 - Expr canonical; + PrimExpr canonical; if (const LTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->b - op->a - make_const(op->a.dtype(), 1); @@ -230,7 +231,7 @@ bool DetectClipBound( template -void SplitCommExpr(const Expr& e, std::vector* ret) { +void SplitCommExpr(const PrimExpr& e, std::vector* ret) { if (const OP* op = e.as()) { SplitCommExpr(op->a, ret); SplitCommExpr(op->b, ret); @@ -241,17 +242,17 @@ void SplitCommExpr(const Expr& e, std::vector* ret) { // Detect the lower and upper bound from the expression. // e must be connected by and. -Array DetectClipBound(const Expr& e, const Array& vars) { - std::vector splits; +Array DetectClipBound(const PrimExpr& e, const Array& vars) { + std::vector splits; SplitCommExpr(e, &splits); std::unordered_map rmap; for (Var v : vars) { rmap[v.get()] = IntervalEntry(); } - for (Expr cond : splits) { - if (!DetectClipBound(cond, &rmap)) return Array(); + for (PrimExpr cond : splits) { + if (!DetectClipBound(cond, &rmap)) return Array(); } - Array ret; + Array ret; for (Var v : vars) { IntervalEntry e = rmap[v.get()]; if (e.min_value.defined()) { diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc index 1821c1651ab4..3889cd2cf918 100644 --- a/src/arithmetic/domain_touched.cc +++ b/src/arithmetic/domain_touched.cc @@ -99,7 +99,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { } private: - void Touch(const Array& args) { + void Touch(const Array& args) { if (args.size() > bounds_.size()) { bounds_.resize(args.size()); } diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index c60c8254c80c..ceaa976469e8 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -35,17 +35,17 @@ namespace tvm { namespace arith { -Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); -Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); +PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); +PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); -IntervalSet::IntervalSet(Expr min_value, Expr max_value) { +IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { auto node = make_object(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); data_ = std::move(node); } -IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { +IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } @@ -54,8 +54,8 @@ TVM_REGISTER_GLOBAL("arith._make_IntervalSet") IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { - Expr max_value = min(a->max_value, b->max_value); - Expr min_value = max(a->min_value, b->min_value); + PrimExpr max_value = min(a->max_value, b->max_value); + PrimExpr min_value = max(a->min_value, b->min_value); if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && (min_value.dtype().is_int() || min_value.dtype().is_uint()) && analyzer->CanProveGreaterEqual(min_value - max_value, 1)) { @@ -66,8 +66,8 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { - Expr max_value = max(a->max_value, b->max_value); - Expr min_value = min(a->min_value, b->min_value); + PrimExpr max_value = max(a->max_value, b->max_value); + PrimExpr min_value = min(a->min_value, b->min_value); return IntervalSet(min_value, max_value); } @@ -102,7 +102,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { - Expr res = TryConstFold(a->min_value, b->min_value); + PrimExpr res = TryConstFold(a->min_value, b->min_value); if (!res.defined()) res = Op::make(a->min_value, b->min_value); return IntervalSet::SinglePoint(res); } @@ -126,10 +126,10 @@ inline IntervalSet Combine(Analyzer* analyer, } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - Expr min_value = + PrimExpr min_value = a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf(); - Expr max_value = + PrimExpr max_value = a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf(); return IntervalSet(min_value, max_value); @@ -144,10 +144,10 @@ inline IntervalSet Combine(Analyzer* analyer, } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - Expr min_value = + PrimExpr min_value = a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf(); - Expr max_value = + PrimExpr max_value = a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf(); return IntervalSet(min_value, max_value); @@ -170,18 +170,18 @@ inline IntervalSet Combine(Analyzer* analyzer, if (is_zero(b->min_value)) return b; if (is_one(b->min_value)) return a; if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { - Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf(); - Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf(); + PrimExpr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf(); + PrimExpr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { - Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf(); - Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); + PrimExpr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf(); + PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::SelectNode; - Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); - Expr e1 = a->min_value * b->min_value; - Expr e2 = a->max_value * b->min_value; + PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); + PrimExpr e1 = a->min_value * b->min_value; + PrimExpr e2 = a->max_value * b->min_value; return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } @@ -205,18 +205,18 @@ inline IntervalSet Combine(Analyzer* analyzer, if (is_one(b->min_value)) return a; // no relaxation is needed in here due to set is inclusive if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { - Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf(); - Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf(); + PrimExpr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf(); + PrimExpr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { - Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf(); - Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); + PrimExpr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf(); + PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::SelectNode; - Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); - Expr e1 = a->min_value / b->min_value; - Expr e2 = a->max_value / b->min_value; + PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); + PrimExpr e1 = a->min_value / b->min_value; + PrimExpr e2 = a->max_value / b->min_value; return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } @@ -235,7 +235,7 @@ inline IntervalSet Combine(Analyzer* analyzer, if (b->IsEmpty()) return b; if (b->IsSinglePoint()) { - const Expr& divisor = b->min_value; + const PrimExpr& divisor = b->min_value; if (is_zero(divisor)) { LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } @@ -246,7 +246,7 @@ inline IntervalSet Combine(Analyzer* analyzer, if (analyzer->CanProveGreaterEqual(divisor, 0)) { return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { - Expr bound = abs(divisor) - 1; + PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); } } @@ -271,18 +271,18 @@ inline IntervalSet Combine(Analyzer* analyzer, if (is_one(b->min_value)) return a; // no relaxation is needed in here due to set is inclusive if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { - Expr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf(); - Expr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf(); + PrimExpr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf(); + PrimExpr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { - Expr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf(); - Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); + PrimExpr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf(); + PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::SelectNode; - Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); - Expr e1 = floordiv(a->min_value, b->min_value); - Expr e2 = floordiv(a->max_value, b->min_value); + PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); + PrimExpr e1 = floordiv(a->min_value, b->min_value); + PrimExpr e2 = floordiv(a->max_value, b->min_value); return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } @@ -301,14 +301,14 @@ inline IntervalSet Combine(Analyzer* analyzer, if (b->IsEmpty()) return b; if (b->IsSinglePoint()) { - const Expr& divisor = b->min_value; + const PrimExpr& divisor = b->min_value; if (is_zero(divisor)) { LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { - Expr bound = abs(divisor) - 1; + PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); } } @@ -356,7 +356,7 @@ using namespace ir; // Simplified version of int set evaluator that operates on IntervalSet // We might use better set analysis in the future to replace the intervalset. class IntervalSetEvaluator : - public ExprFunctor { + public ExprFunctor { public: IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, @@ -366,7 +366,7 @@ class IntervalSetEvaluator : eval_vec_(eval_vec) { } - IntervalSet Eval(const Expr& val) { + IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } // evaluate and relax the set @@ -381,11 +381,11 @@ class IntervalSetEvaluator : } IntervalSet VisitExpr_(const IntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(GetRef(op)); } IntervalSet VisitExpr_(const UIntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(GetRef(op)); } IntervalSet VisitExpr_(const VarNode* op) final { @@ -492,7 +492,7 @@ class IntervalSetEvaluator : IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } - DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); + DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); return IntervalSet::Everything(); } @@ -515,7 +515,7 @@ class IntervalSetEvaluator : private: // whether set is exactly single point that equals value. bool MatchPoint(const IntervalSet& set, - const Expr& value) const { + const PrimExpr& value) const { return set->min_value.same_as(value) && set->max_value.same_as(value); } @@ -524,7 +524,7 @@ class IntervalSetEvaluator : IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(GetRef(op)); } return Combine(analyzer_, a, b); } @@ -543,7 +543,7 @@ class IntSetAnalyzer::Impl { : analyzer_(analyzer) { } - IntSet Eval(const Expr& expr, const Map& dom_map) const { + IntSet Eval(const PrimExpr& expr, const Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); } @@ -559,7 +559,7 @@ IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } -IntSet IntSetAnalyzer::operator()(const Expr& expr, +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& dom_map) { return impl_->Eval(expr, dom_map); } @@ -577,13 +577,13 @@ Range IntSet::cover_range(Range max_range) const { return max_range; } -Expr IntSet::min() const { +PrimExpr IntSet::min() const { const IntervalSetNode* s_int = (*this).as(); CHECK(s_int); return s_int->min_value; } -Expr IntSet::max() const { +PrimExpr IntSet::max() const { const IntervalSetNode* s_int = (*this).as(); CHECK(s_int); return s_int->max_value; @@ -641,7 +641,7 @@ SignType IntSet::sign_type() const { return kUnknown; } } -Expr IntSet::point_value() const { +PrimExpr IntSet::point_value() const { const IntervalSetNode* s_int = (*this).as(); CHECK(s_int && s_int->IsSinglePoint()); return s_int->min_value; @@ -655,11 +655,11 @@ IntSet IntSet::everything() { return IntervalSet::Everything(); } -IntSet IntSet::single_point(Expr x) { +IntSet IntSet::single_point(PrimExpr x) { return IntervalSet::SinglePoint(x); } -IntSet IntSet::interval(Expr min, Expr max) { +IntSet IntSet::interval(PrimExpr min, PrimExpr max) { if (min.same_as(max)) { return IntSet::single_point(min); } @@ -667,7 +667,7 @@ IntSet IntSet::interval(Expr min, Expr max) { } // Range related code -inline bool ProveEqual(Expr lhs, Expr rhs) { +inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) { return is_zero(ir::Simplify(lhs - rhs)); } @@ -728,24 +728,24 @@ Map ConvertDomMap( return dmap; } -IntSet EvalSet(Expr e, +IntSet EvalSet(PrimExpr e, const Map& dom_map) { Analyzer ana; return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } -IntSet IntSet::vector(Expr x) { +IntSet IntSet::vector(PrimExpr x) { Analyzer ana; Map dmap; return IntervalSetEvaluator(&ana, dmap, true).Eval(x); } -IntSet EvalSet(Expr e, +IntSet EvalSet(PrimExpr e, const Map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Expr e, +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } @@ -755,7 +755,7 @@ IntSet EvalSet(Range r, Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables - Expr sum = r->min + r->extent - 1; + PrimExpr sum = r->min + r->extent - 1; auto res = m.Eval(IntervalSet(r->min, Simplify(sum))); return std::move(res); } @@ -771,9 +771,9 @@ IntSet EvalSet(IntSet s, auto dmap = ConvertDomMap(dom_map); IntervalSetEvaluator m(&ana, dmap); const IntervalSetNode* s_int = s.as(); - Expr vmax = s_int->HasUpperBound() ? + PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value; - Expr vmin = s_int->HasLowerBound() ? + PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value; return IntervalSet(vmin, vmax); } @@ -785,7 +785,7 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { const Map& dom_map) : IntervalSetEvaluator(analyzer, dom_map) {} - IntervalSet VisitExpr(const Expr& n) final { + IntervalSet VisitExpr(const PrimExpr& n) final { IntervalSet ret = IntervalSetEvaluator::VisitExpr(n); expr_map[n] = ret; return ret; @@ -795,7 +795,7 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { }; ExprIntSetMap EvalSetForEachSubExpr( - Expr e, + PrimExpr e, const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 2e072127b449..b28f1cb4d3a4 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -42,9 +42,9 @@ namespace arith { class IntervalSetNode : public IntSetNode { public: /*! \brief Minimum value in the interval. */ - Expr min_value; + PrimExpr min_value; /*! \brief Maximum value in the interval. */ - Expr max_value; + PrimExpr max_value; // visitor overload. void VisitAttrs(tvm::AttrVisitor* v) { @@ -90,14 +90,14 @@ class IntervalSet : public IntSet { * \param max_value The maximum value in the interval. * \return The created set. */ - TVM_DLL IntervalSet(Expr min_value, Expr max_value); + TVM_DLL IntervalSet(PrimExpr min_value, PrimExpr max_value); /*! * \brief Create an IntervalSet that represents a single point. * \param value The value to be represented. * \return The result set. */ - static IntervalSet SinglePoint(Expr value) { + static IntervalSet SinglePoint(PrimExpr value) { return IntervalSet(value, value); } /*! diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 961c476113c6..1345e7e7a137 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -38,7 +38,7 @@ VisitStmt_(const ForNode* op) { Stmt IRMutatorWithAnalyzer:: VisitStmt_(const LetStmtNode* op) { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } @@ -58,7 +58,7 @@ VisitStmt_(const LetStmtNode* op) { Stmt IRMutatorWithAnalyzer:: VisitStmt_(const IfThenElseNode* op) { - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case, else_case; { With ctx(analyzer_, condition); @@ -107,8 +107,8 @@ VisitStmt_(const AttrStmtNode* op) { Stmt IRMutatorWithAnalyzer:: VisitStmt_(const AssertStmtNode* op) { - Expr condition = this->VisitExpr(op->condition); - Expr message = this->VisitExpr(op->message); + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr message = this->VisitExpr(op->message); With ctx(analyzer_, condition); Stmt body = this->VisitStmt(op->body); @@ -125,12 +125,12 @@ VisitStmt_(const AssertStmtNode* op) { } } -Expr IRMutatorWithAnalyzer:: +PrimExpr IRMutatorWithAnalyzer:: VisitExpr_(const CallNode* op) { // add condition context to if_then_else if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { - Expr cond = this->VisitExpr(op->args[0]); - Expr true_value, false_value; + PrimExpr cond = this->VisitExpr(op->args[0]); + PrimExpr true_value, false_value; { With constraint(analyzer_, cond); true_value = this->VisitExpr(op->args[1]); @@ -149,7 +149,7 @@ VisitExpr_(const CallNode* op) { if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { - return GetRef(op); + return GetRef(op); } else { return CallNode::make(op->dtype, op->name, {cond, true_value, false_value}, @@ -159,27 +159,27 @@ VisitExpr_(const CallNode* op) { return StmtExprMutator::VisitExpr_(op); } -Expr IRMutatorWithAnalyzer:: +PrimExpr IRMutatorWithAnalyzer:: VisitExpr_(const LetNode* op) { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here // as sub-class may or maynot choose to replace it. - Expr body = this->VisitExpr(op->body); + PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return GetRef(op); } else { return LetNode::make(op->var, value, body); } } -Expr IRMutatorWithAnalyzer:: +PrimExpr IRMutatorWithAnalyzer:: VisitExpr_(const SelectNode* op) { - Expr cond = this->VisitExpr(op->condition); - Expr true_value, false_value; + PrimExpr cond = this->VisitExpr(op->condition); + PrimExpr true_value, false_value; { With constraint(analyzer_, cond); true_value = VisitExpr(op->true_value); @@ -199,13 +199,13 @@ VisitExpr_(const SelectNode* op) { if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return GetRef(op); } else { return SelectNode::make(cond, true_value, false_value); } } -Expr IRMutatorWithAnalyzer:: +PrimExpr IRMutatorWithAnalyzer:: VisitExpr_(const ReduceNode* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { diff --git a/src/arithmetic/ir_mutator_with_analyzer.h b/src/arithmetic/ir_mutator_with_analyzer.h index 1e96c0a290af..a2297cb3c04e 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.h +++ b/src/arithmetic/ir_mutator_with_analyzer.h @@ -54,10 +54,10 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator { Stmt VisitStmt_(const ir::IfThenElseNode* op) override; Stmt VisitStmt_(const ir::AttrStmtNode* op) override; Stmt VisitStmt_(const ir::AssertStmtNode* op) override; - Expr VisitExpr_(const ir::LetNode* op) override; - Expr VisitExpr_(const ir::SelectNode* op) override; - Expr VisitExpr_(const ir::CallNode* op) override; - Expr VisitExpr_(const ir::ReduceNode* op) override; + PrimExpr VisitExpr_(const ir::LetNode* op) override; + PrimExpr VisitExpr_(const ir::SelectNode* op) override; + PrimExpr VisitExpr_(const ir::CallNode* op) override; + PrimExpr VisitExpr_(const ir::ReduceNode* op) override; protected: /*! \brief internal analyzer field. */ diff --git a/src/arithmetic/ir_visitor_with_analyzer.h b/src/arithmetic/ir_visitor_with_analyzer.h index 07ec1866eea7..08be59b8c423 100644 --- a/src/arithmetic/ir_visitor_with_analyzer.h +++ b/src/arithmetic/ir_visitor_with_analyzer.h @@ -34,7 +34,7 @@ namespace ir { class IRVisitorWithAnalyzer final : public StmtExprVisitor { public: - Expr Simplify(const Expr& expr) { + PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 8e2e065f1e7b..01dd2e8e499e 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -85,7 +85,7 @@ struct ModularSetAnalyzer::Entry { }; class ModularSetAnalyzer::Impl : - public ExprFunctor { + public ExprFunctor { public: explicit Impl(Analyzer* parent) : parent_(parent) {} @@ -107,7 +107,7 @@ class ModularSetAnalyzer::Impl : } // Detect useful constraints and use them in the analysis scope. - std::function EnterConstraint(const Expr& constraint) { + std::function EnterConstraint(const PrimExpr& constraint) { PVar var; PVar coeff, base; // pattern match interesting constraints @@ -168,7 +168,7 @@ class ModularSetAnalyzer::Impl : return Entry(coeff, a.base * b.base); } - Entry DivByConst(const Expr& lhs, + Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) { Entry a = VisitExpr(lhs); @@ -255,7 +255,7 @@ class ModularSetAnalyzer::Impl : /*! \brief pointer to parent. */ Analyzer* parent_{nullptr}; // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; /*! * \brief Update var by intersecting entry with var's current set. * \param var The variable. @@ -398,7 +398,7 @@ class ModularSetAnalyzer::Impl : } }; -ModularSet ModularSetAnalyzer::operator()(const Expr& expr) { +ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { Entry ret = impl_->VisitExpr(expr); return ModularSet(ret.coeff, ret.base); } @@ -409,7 +409,7 @@ void ModularSetAnalyzer::Update(const Var& var, impl_->Update(var, info, override); } -std::function ModularSetAnalyzer::EnterConstraint(const Expr& constraint) { +std::function ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { return impl_->EnterConstraint(constraint); } diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index e964abb4d7be..733dcf41ce94 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -131,9 +131,9 @@ class PEqualChecker { }; template<> -class PEqualChecker { +class PEqualChecker { public: - bool operator()(const Expr& lhs, const Expr& rhs) const { + bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { if (lhs.same_as(rhs)) return true; return ir::Equal(lhs, rhs); } @@ -260,10 +260,10 @@ class PBinaryExpr : } } - Expr Eval() const { - Expr lhs = a_.Eval(); - Expr rhs = b_.Eval(); - Expr ret = TryConstFold(lhs, rhs); + PrimExpr Eval() const { + PrimExpr lhs = a_.Eval(); + PrimExpr rhs = b_.Eval(); + PrimExpr ret = TryConstFold(lhs, rhs); if (ret.defined()) return ret; return NodeType::make(lhs, rhs); } @@ -290,7 +290,7 @@ class PConstWithTypeLike : } } - Expr Eval() const { + PrimExpr Eval() const { return make_const(ref_.Eval().dtype(), value_); } @@ -373,7 +373,7 @@ class PNotExpr : public Pattern > { } } - Expr Eval() const { + PrimExpr Eval() const { return ir::NotNode::make(value_.Eval()); } @@ -421,7 +421,7 @@ class PSelectExpr : } } - Expr Eval() const { + PrimExpr Eval() const { return ir::SelectNode::make( condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } @@ -482,7 +482,7 @@ class PCastExpr : } } - Expr Eval() const { + PrimExpr Eval() const { return ir::CastNode::make(dtype_.Eval(), value_.Eval()); } @@ -541,7 +541,7 @@ class PRampExpr : } } - Expr Eval() const { + PrimExpr Eval() const { return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); } @@ -602,7 +602,7 @@ class PBroadcastExpr : } } - Expr Eval() const { + PrimExpr Eval() const { return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); } @@ -675,7 +675,7 @@ struct PCallExprMatchFunctor { }; struct PCallExprEvalArgsFunctor { - Array args_; + Array args_; template void operator()(size_t i, const T& pattern) { @@ -716,7 +716,7 @@ class PCallExpr : } } - Expr Eval() const { + PrimExpr Eval() const { detail::PCallExprEvalArgsFunctor feval_args; detail::tuple_for_each(feval_args, args_); return Op::Eval(feval_args.args_); @@ -729,7 +729,7 @@ class PCallExpr : // arithemetic intrinsics #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ struct OpName { \ - static Expr Eval(Array args) { \ + static PrimExpr Eval(Array args) { \ return ir::CallNode::make(args[0].dtype(), kName, args, \ ir::CallNode::PureIntrinsic); \ } \ @@ -750,7 +750,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); // unary intrinsics #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ struct OpName { \ - static Expr Eval(Array args) { \ + static PrimExpr Eval(Array args) { \ return ir::CallNode::make(args[0].dtype(), kName, args, \ ir::CallNode::PureIntrinsic); \ } \ @@ -766,7 +766,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { - static Expr Eval(Array args) { + static PrimExpr Eval(Array args) { return ir::CallNode::make( args[1].dtype(), kName, args, ir::CallNode::PureIntrinsic); diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 2421e10b6e37..94d951da51db 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -67,8 +67,8 @@ using namespace ir; // try to prove x equals val RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: -TryCompare(const Expr& x, int64_t val) { - Expr diff = this->VisitExpr(x); +TryCompare(const PrimExpr& x, int64_t val) { + PrimExpr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { return kEQ; @@ -101,7 +101,7 @@ TryCompare(const Expr& x, int64_t val) { } void RewriteSimplifier::Impl:: -Update(const Var& var, const Expr& info, bool override) { +Update(const Var& var, const PrimExpr& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { @@ -115,14 +115,14 @@ Update(const Var& var, const Expr& info, bool override) { var_map_[var] = info; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const AddNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1, b2, s1, s2; + PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp @@ -219,7 +219,7 @@ VisitExpr_(const AddNode* op) { return ret; } -std::function RewriteSimplifier::Impl::EnterConstraint(const Expr& constraint) { +std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { size_t old_literal_size = literal_constraints_.size(); literal_constraints_.push_back(constraint); size_t new_literal_size = literal_constraints_.size(); @@ -230,14 +230,14 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const Expr& const return frecover; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const SubNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1, b2, s1, s2; + PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp @@ -429,14 +429,14 @@ VisitExpr_(const SubNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const MulNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1, b2, s1, s2; + PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm PVar c1, c2; // Pattern var for lanes in broadcast and ramp @@ -468,14 +468,14 @@ VisitExpr_(const MulNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const DivNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp @@ -690,15 +690,15 @@ VisitExpr_(const DivNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const ModNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm PVar c1, c2; // Pattern var for lanes in broadcast and ramp @@ -763,7 +763,7 @@ VisitExpr_(const ModNode* op) { // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( truncmod(x, c1), - truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), + truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis @@ -780,14 +780,14 @@ VisitExpr_(const ModNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const FloorDivNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp @@ -924,15 +924,15 @@ VisitExpr_(const FloorDivNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const FloorModNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm PVar c1, c2; // Pattern var for lanes in broadcast and ramp @@ -994,15 +994,15 @@ VisitExpr_(const FloorModNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const MinNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, s1, s2; + PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; PVar lanes; @@ -1179,15 +1179,15 @@ VisitExpr_(const MinNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const MaxNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, s1, s2; + PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; PVar lanes; @@ -1352,15 +1352,15 @@ VisitExpr_(const MaxNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const EQNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y; + PVar x, y; // Pattern var match IntImm PVar c1; PVar lanes; @@ -1386,35 +1386,35 @@ VisitExpr_(const EQNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const NENode* op) { return this->VisitExpr(NotNode::make(op->a == op->b)); } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const LENode* op) { return this->VisitExpr(NotNode::make(op->b < op->a)); } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const GENode* op) { return this->VisitExpr(NotNode::make(op->a < op->b)); } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const LTNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, s1, s2; + PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; PVar lanes; @@ -1562,14 +1562,14 @@ VisitExpr_(const LTNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const NotNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a); + PrimExpr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y; + PVar x, y; PVar lanes; if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); @@ -1587,15 +1587,15 @@ VisitExpr_(const NotNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const AndNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y; + PVar x, y; // Pattern var match IntImm PVar c1, c2; PVar lanes; @@ -1605,7 +1605,7 @@ VisitExpr_(const AndNode* op) { broadcast(x && y, lanes)); } - auto cfalse = PConst(make_const(op->dtype, false)); + auto cfalse = PConst(make_const(op->dtype, false)); TVM_TRY_REWRITE(x == y && x != y, cfalse); TVM_TRY_REWRITE(x != y && x == y, cfalse); TVM_TRY_REWRITE(x && !x, cfalse); @@ -1636,15 +1636,15 @@ VisitExpr_(const AndNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const OrNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y; + PVar x, y; // Pattern var match IntImm PVar c1, c2; PVar lanes; @@ -1654,7 +1654,7 @@ VisitExpr_(const OrNode* op) { broadcast(x || y, lanes)); } - auto ctrue = PConst(make_const(op->dtype, true)); + auto ctrue = PConst(make_const(op->dtype, true)); TVM_TRY_REWRITE(x == y || x != y, ctrue); TVM_TRY_REWRITE(x != y || x == y, ctrue); @@ -1686,21 +1686,21 @@ VisitExpr_(const OrNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const SelectNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; // Pattern var to match any expression - PVar x, y; + PVar x, y; TVM_TRY_REWRITE(select(x, y, y), y); return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const CallNode* op) { // add condition context to if_then_else - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) { @@ -1727,47 +1727,47 @@ VisitExpr_(const CallNode* op) { return ret; } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const VarNode* op) { Var var = GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) { return it->second; } - return GetRef(op); + return GetRef(op); } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const CastNode* op) { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); return cast(op->dtype, op->value); } -Expr RewriteSimplifier::Impl:: +PrimExpr RewriteSimplifier::Impl:: VisitExpr_(const LetNode* op) { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { // it is fine to discard the let binding // because the value will always be inlined in the simplifier. analyzer_->Bind(op->var, value); return this->VisitExpr(op->body); } - Expr body = this->VisitExpr(op->body); + PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return GetRef(op); } else { return LetNode::make(op->var, value, body); } } -Expr RewriteSimplifier::operator()(const Expr& expr) { +PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) { // Run simplification in post order - Expr res = expr; + PrimExpr res = expr; int max_iter = 2; for (int i = 0; i < max_iter; ++i) { - Expr new_expr = impl_->operator()(res); + PrimExpr new_expr = impl_->operator()(res); if (new_expr.same_as(res)) return res; res = new_expr; } @@ -1775,12 +1775,12 @@ Expr RewriteSimplifier::operator()(const Expr& expr) { } void RewriteSimplifier::Update(const Var& var, - const Expr& info, + const PrimExpr& info, bool override) { impl_->Update(var, info, override); } -std::function RewriteSimplifier::EnterConstraint(const Expr& constraint) { +std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { return impl_->EnterConstraint(constraint); } diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index f2659a95ab0a..6b4193c3ff59 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -49,32 +49,32 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} - void Update(const Var& var, const Expr& info, bool override_info); - Expr VisitExpr_(const AddNode* op) override; - Expr VisitExpr_(const SubNode* op) override; - Expr VisitExpr_(const MulNode* op) override; - Expr VisitExpr_(const DivNode* op) override; - Expr VisitExpr_(const ModNode* op) override; - Expr VisitExpr_(const FloorDivNode* op) override; - Expr VisitExpr_(const FloorModNode* op) override; - Expr VisitExpr_(const MinNode* op) override; - Expr VisitExpr_(const MaxNode* op) override; - Expr VisitExpr_(const EQNode* op) override; - Expr VisitExpr_(const NENode* op) override; - Expr VisitExpr_(const LTNode* op) override; - Expr VisitExpr_(const LENode* op) override; - Expr VisitExpr_(const GTNode* op) override; - Expr VisitExpr_(const GENode* op) override; - Expr VisitExpr_(const AndNode* op) override; - Expr VisitExpr_(const OrNode* op) override; - Expr VisitExpr_(const NotNode* op) override; - Expr VisitExpr_(const SelectNode* op) override; - Expr VisitExpr_(const CallNode* op) override; - Expr VisitExpr_(const VarNode* op) override; - Expr VisitExpr_(const CastNode* op) override; - Expr VisitExpr_(const LetNode* op) override; - - std::function EnterConstraint(const Expr& constraint); + void Update(const Var& var, const PrimExpr& info, bool override_info); + PrimExpr VisitExpr_(const AddNode* op) override; + PrimExpr VisitExpr_(const SubNode* op) override; + PrimExpr VisitExpr_(const MulNode* op) override; + PrimExpr VisitExpr_(const DivNode* op) override; + PrimExpr VisitExpr_(const ModNode* op) override; + PrimExpr VisitExpr_(const FloorDivNode* op) override; + PrimExpr VisitExpr_(const FloorModNode* op) override; + PrimExpr VisitExpr_(const MinNode* op) override; + PrimExpr VisitExpr_(const MaxNode* op) override; + PrimExpr VisitExpr_(const EQNode* op) override; + PrimExpr VisitExpr_(const NENode* op) override; + PrimExpr VisitExpr_(const LTNode* op) override; + PrimExpr VisitExpr_(const LENode* op) override; + PrimExpr VisitExpr_(const GTNode* op) override; + PrimExpr VisitExpr_(const GENode* op) override; + PrimExpr VisitExpr_(const AndNode* op) override; + PrimExpr VisitExpr_(const OrNode* op) override; + PrimExpr VisitExpr_(const NotNode* op) override; + PrimExpr VisitExpr_(const SelectNode* op) override; + PrimExpr VisitExpr_(const CallNode* op) override; + PrimExpr VisitExpr_(const VarNode* op) override; + PrimExpr VisitExpr_(const CastNode* op) override; + PrimExpr VisitExpr_(const LetNode* op) override; + + std::function EnterConstraint(const PrimExpr& constraint); protected: /*! \brief internal structure for comparison. */ @@ -90,9 +90,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // counter to record recursive rewrite depth. int recur_depth_{0}; // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; - std::vector literal_constraints_; + std::vector literal_constraints_; // maximum number of recursion allowed during a single pass. static const constexpr int kMaxRecurDepth = 5; @@ -103,15 +103,15 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { * \param val The constant value. * \return comparison result. */ - CompareResult TryCompare(const Expr& x, int64_t val); + CompareResult TryCompare(const PrimExpr& x, int64_t val); private: // Whether x >= val - bool CanProveGreaterEqual(const Expr& x, int64_t val) { + bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { return analyzer_->CanProveGreaterEqual(x, val); } // Whether x == val - bool CanProveEqual(const Expr& x, int64_t val) { + bool CanProveEqual(const PrimExpr& x, int64_t val) { // TODO(tqchen) refer back to super-analyzer. return TryCompare(x, val) == kEQ; } @@ -119,10 +119,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // Recursive rewrite x // we limit maximum depth of recursive rewrite allowed to // avoid infinite loop - Expr RecursiveRewrite(const Expr& x) { + PrimExpr RecursiveRewrite(const PrimExpr& x) { if (recur_depth_ >= kMaxRecurDepth) return x; ++recur_depth_; - Expr res = this->VisitExpr(x); + PrimExpr res = this->VisitExpr(x); --recur_depth_; return res; } diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 73b5dce31c4e..dcc7e5dfbb53 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -42,7 +42,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { using Parent::VisitStmt; using Parent::VisitStmt_; - Expr VisitExpr(const Expr& expr) final { + PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); } @@ -58,7 +58,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } Stmt VisitStmt_(const LetStmtNode* op) { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { // it is fine to discard the let binding // because the call to simplify will always inline the var. @@ -103,7 +103,7 @@ Stmt CanonicalSimplify(Stmt stmt, Map vrange) { return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt)); } -Expr CanonicalSimplify(Expr expr, Map vrange) { +PrimExpr CanonicalSimplify(PrimExpr expr, Map vrange) { arith::Analyzer analyzer; for (auto kv : vrange) { analyzer.Bind(kv.first, kv.second); @@ -111,7 +111,7 @@ Expr CanonicalSimplify(Expr expr, Map vrange) { return analyzer.canonical_simplify(expr); } -Expr Simplify(Expr expr, Map vrange) { +PrimExpr Simplify(PrimExpr expr, Map vrange) { arith::Analyzer analyzer; for (auto kv : vrange) { analyzer.Bind(kv.first, kv.second); diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 11452d31e44b..a83d248dc0df 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -60,7 +60,7 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - VarExpr var = op->node.as()->var; + Var var = op->node.as()->var; const auto *extent = op->value.as(); CHECK(extent); diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 9f65fb48aa92..b2ea80f0c29f 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -69,7 +69,7 @@ class FeatureVisitor : public StmtExprVisitor { * \param ann_type The type for the for loop * \return skip Whether skip this node */ - virtual bool EnterItervar_(tvm::VarExpr var, int64_t length, AnnotationType ann_type) = 0; + virtual bool EnterItervar_(tvm::Var var, int64_t length, AnnotationType ann_type) = 0; /*! \brief Exit a for loop subtree */ virtual void ExitItervar_() = 0; /*! @@ -77,7 +77,7 @@ class FeatureVisitor : public StmtExprVisitor { * \param buffer_var The buffer to access. * \param index Index expression */ - virtual void EnterMem_(tvm::VarExpr buffer_var, tvm::Expr index) = 0; + virtual void EnterMem_(tvm::Var buffer_var, tvm::PrimExpr index) = 0; /*! \brief Exit a memory access node */ virtual void ExitMem_() = 0; }; diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index 0ee4b1196c10..cf138edd494e 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -46,7 +46,7 @@ int ParallelLevel(AnnotationType ann) { // get touch pattern from index expression class IndexParser: public ExprVisitor { public: - void Parse(Expr expr) { + void Parse(PrimExpr expr) { pattern_map.clear(); this->VisitExpr(expr); } @@ -76,7 +76,7 @@ class IndexParser: public ExprVisitor { }; // extract iter vars and their touch pattern from ir -bool TouchExtractor::EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type) { +bool TouchExtractor::EnterItervar_(Var var, int64_t length, AnnotationType ann_type) { // do not insert duplicated occurrences of virtual thread if (ann_type == kVirtualThread && itervar_map.count(var) != 0) { skip_stack_size_.push_back(itervar_stack_.size()); @@ -90,7 +90,7 @@ bool TouchExtractor::EnterItervar_(VarExpr var, int64_t length, AnnotationType a // these happens when we create tvm.thread_axis("threadIdx.x") once and // bind it twice. Here we treat them as two axes // so we create a snapshot for the old one and freeze it - VarExpr old = VarExpr(var.get()->name_hint); + Var old = Var(var.get()->name_hint); itervar_map.insert({old, itervar_map[var]}); itervar_map.erase(var); } @@ -110,7 +110,7 @@ void TouchExtractor::ExitItervar_() { skip_stack_size_.pop_back(); return; } - VarExpr var = itervar_stack_.back(); + Var var = itervar_stack_.back(); // update count and reuse ratio for upper iter vars (includes self) for (auto kv : itervar_map[var].touch_feature) { @@ -169,7 +169,7 @@ void TouchExtractor::ExitItervar_() { } } -void TouchExtractor::EnterMem_(VarExpr buffer_var, Expr index) { +void TouchExtractor::EnterMem_(Var buffer_var, PrimExpr index) { std::string name = buffer_var.get()->name_hint; TouchedBuffer buf = name + "_" + std::to_string(buffer_counter_[name]++); @@ -219,17 +219,17 @@ void TouchExtractor::ExitMem_() { * \note If you want to flatten these features as the input of your model, * You can use the faster one GetItervarFeatureFlatten below. */ -void GetItervarFeature(Stmt stmt, bool take_log, Array > > *ret_feature) { +void GetItervarFeature(Stmt stmt, bool take_log, Array > > *ret_feature) { // extract TouchExtractor touch_analyzer; touch_analyzer.Analyze(stmt); // sort according to order - std::vector vars; + std::vector vars; for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -250,11 +250,11 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > *re // serialize for front end for (auto var : vars) { - Array > feature_row; + Array > feature_row; ItervarFeature &fea = touch_analyzer.itervar_map[var]; - feature_row.push_back(Array{std::string("_itervar_"), var}); + feature_row.push_back(Array{std::string("_itervar_"), var}); - Array attr{std::string("_attr_"), + Array attr{std::string("_attr_"), FloatImmNode::make(DataType::Float(32), trans(fea.length)), IntImmNode::make(DataType::Int(32), fea.nest_level), FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)), @@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > *re feature_row.push_back(attr); // arithmetic - feature_row.push_back(Array{std::string("_arith_"), + feature_row.push_back(Array{std::string("_arith_"), FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)), FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)), FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)), @@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > *re for (auto k : bufs) { TouchPattern &v = fea.touch_feature[k]; feature_row.push_back( - Array{k, + Array{k, FloatImmNode::make(DataType::Float(32), trans(v.stride)), FloatImmNode::make(DataType::Float(32), trans(v.mod)), FloatImmNode::make(DataType::Float(32), trans(v.count)), @@ -311,11 +311,11 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ touch_analyzer.Analyze(stmt); // sort according to order - std::vector vars; + std::vector vars; for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -383,11 +383,11 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r touch_ext.Analyze(stmt); // sort according to order - std::vector vars; + std::vector vars; for (auto kv : touch_ext.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order; }); @@ -490,7 +490,7 @@ TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature") .set_body([](TVMArgs args, TVMRetValue *ret) { Stmt stmt = args[0]; bool take_log = args[1]; - Array > > ret_feature; + Array > > ret_feature; GetItervarFeature(stmt, take_log, &ret_feature); diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 5265aad9df06..3af368dca6f4 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -56,7 +56,7 @@ struct TouchPattern { // all the feature of an iter var struct ItervarFeature { - ItervarFeature(VarExpr var, + ItervarFeature(Var var, int64_t extent, int nest, AnnotationType ann_type, @@ -122,18 +122,18 @@ class TouchExtractor : public FeatureVisitor { FeatureVisitor::VisitExpr_(op); } - std::unordered_map itervar_map; + std::unordered_map itervar_map; private: - bool EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type); + bool EnterItervar_(Var var, int64_t length, AnnotationType ann_type); void ExitItervar_(); - void EnterMem_(VarExpr buffer_var, Expr index); + void EnterMem_(Var buffer_var, PrimExpr index); void ExitMem_(); int64_t topdown_product_{1}; std::map buffer_counter_; size_t itervar_counter_{0}; - std::deque itervar_stack_; // use deque instead of stack for indexing + std::deque itervar_stack_; // use deque instead of stack for indexing std::deque skip_stack_size_; using FeatureVisitor::VisitExpr_; diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 77b1c9d33faf..9f793424d233 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -338,7 +338,7 @@ Target DefaultTargetHost(Target target) { } } -Buffer BufferWithOffsetAlignment(Array shape, +Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, int data_alignment, @@ -356,14 +356,14 @@ Buffer BufferWithOffsetAlignment(Array shape, } BufferType buffer_type = has_any ? kAutoBroadcast : kDefault; - Expr elem_offset; + PrimExpr elem_offset; if (offset_factor != 0) { elem_offset = Var(name + "_elem_offset", shape[0].dtype()); } else { - elem_offset = Expr(); + elem_offset = PrimExpr(); } - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", + return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, offset_factor, buffer_type); } @@ -855,7 +855,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); - Array tags = args[2]; + Array tags = args[2]; bool allow_override = args[3]; std::vector tags_vector; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index ea9d7bab2f70..777ad6203008 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -121,7 +121,7 @@ std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } -void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { std::ostringstream temp; VisitExpr(n, temp); @@ -146,7 +146,7 @@ void CodeGenC::PrintSSAAssign( // Print a reference expression to a buffer. std::string CodeGenC::GetBufferRef( - DataType t, const VarNode* buffer, Expr index) { + DataType t, const VarNode* buffer, PrimExpr index) { std::ostringstream os; std::string vid = GetVarID(buffer); std::string scope; @@ -213,7 +213,7 @@ std::string CodeGenC::GetBufferRef( // Print a reference expression to a buffer. std::string CodeGenC::GetStructRef( - DataType t, const Expr& buffer, const Expr& index, int kind) { + DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { if (kind < intrinsic::kArrKindBound_) { std::ostringstream os; os << "(((TVMArray*)"; @@ -296,12 +296,12 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, } std::string CodeGenC::GetVecLoad( - DataType t, const VarNode* buffer, Expr base) { + DataType t, const VarNode* buffer, PrimExpr base) { return GetBufferRef(t, buffer, base); } void CodeGenC::PrintVecStore(const VarNode* buffer, - DataType t, Expr base, + DataType t, PrimExpr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); this->PrintIndent(); @@ -594,7 +594,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintVecBinaryOp( const std::string& op, DataType t, - Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*) + PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) if (isalpha(op[0])) { os << op << "("; this->PrintExpr(lhs, os); @@ -619,7 +619,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) } else { CHECK(is_one(op->predicate)) << "predicated load is not supported"; - Expr base; + PrimExpr base; if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) { std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base); os << ref; @@ -673,7 +673,7 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { } else { CHECK(is_one(op->predicate)) << "Predicated store is not supported"; - Expr base; + PrimExpr base; if (GetRamp1Base(op->index, t.lanes(), &base)) { std::string value = this->PrintExpr(op->value); this->PrintVecStore(op->buffer_var.get(), t, base, value); diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 593bbcd8d525..cb092c566322 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -49,7 +49,7 @@ using namespace ir; * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`. */ class CodeGenC : - public ExprFunctor, + public ExprFunctor, public StmtFunctor, public CodeGenSourceBase { public: @@ -80,12 +80,12 @@ class CodeGenC : * \param n The expression to be printed. * \param os The output stream */ - void PrintExpr(const Expr& n, std::ostream& os); + void PrintExpr(const PrimExpr& n, std::ostream& os); /*! * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. */ - std::string PrintExpr(const Expr& n) { + std::string PrintExpr(const PrimExpr& n) { std::ostringstream os; PrintExpr(n, os); return os.str(); @@ -158,12 +158,12 @@ class CodeGenC : // Binary vector op. virtual void PrintVecBinaryOp( const std::string&op, DataType op_type, - Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) + PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*) // print vector load - virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base); + virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base); // print vector store virtual void PrintVecStore(const VarNode* buffer, - DataType t, Expr base, + DataType t, PrimExpr base, const std::string& value); // NOLINT(*) // print load of single element virtual void PrintVecElemLoad( @@ -177,10 +177,10 @@ class CodeGenC : protected: // Print reference to struct location std::string GetStructRef( - DataType t, const Expr& buffer, const Expr& index, int kind); + DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); // print reference to a buffer as type t in index. virtual std::string GetBufferRef( - DataType t, const VarNode* buffer, Expr index); + DataType t, const VarNode* buffer, PrimExpr index); /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 53a008dcaf2f..d06e9aa2a0ad 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -200,7 +200,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) void CodeGenCUDA::PrintVecBinaryOp( const std::string&op, DataType t, - Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*) + PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // unpacking operations. int lanes = t.lanes(); diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index fc2e6aed6c7e..23fbf7febf4e 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -48,7 +48,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp( const std::string&op, DataType t, - Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*) + PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintVecElemLoad( const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 8914db84a89f..ef90cfc69bd5 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -145,7 +145,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, - Expr base, std::ostream& os) { // NOLINT(*) + PrimExpr base, std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { os << '('; auto it = alloc_storage_scope_.find(buffer); @@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrintExpr(base, os); } std::string CodeGenOpenCL::GetVecLoad( - DataType t, const VarNode* buffer, Expr base) { + DataType t, const VarNode* buffer, PrimExpr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -169,7 +169,7 @@ std::string CodeGenOpenCL::GetVecLoad( } void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, - DataType t, Expr base, + DataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index a606a3a76eb1..07b28fd00573 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -45,13 +45,13 @@ class CodeGenOpenCL final : public CodeGenC { void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) std::string GetVecLoad(DataType t, const VarNode* buffer, - Expr base) final; + PrimExpr base) final; void PrintVecStore(const VarNode* buffer, - DataType t, Expr base, + DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store void PrintVecAddr(const VarNode* buffer, DataType t, - Expr base, std::ostream& os); // NOLINT(*) + PrimExpr base, std::ostream& os); // NOLINT(*) std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) // overload visitor diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 5666de3908d4..7967c1847ac2 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -194,7 +194,7 @@ void CodeGenOpenGL::VisitStmt_(const StoreNode* op) { } // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r -std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) { +std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, PrimExpr index) { std::ostringstream os; os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int("; PrintExpr(index, os); @@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) { // Print a reference expression to a buffer. // Format: texelFetch(buffer, index, 0).r std::string CodeGenOpenGL::GetBufferRef( - DataType t, const VarNode* buffer, Expr index) { + DataType t, const VarNode* buffer, PrimExpr index) { CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index bb6936591054..cd1ec83360c6 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -44,8 +44,8 @@ class CodeGenOpenGL final : public CodeGenC { void InitFuncState(LoweredFunc f) final; void BindThreadIndex(const IterVar& iv) final; void VisitStmt_(const StoreNode* op) final; - std::string TexelFetch(const VarNode* buffer, Expr index); - std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final; + std::string TexelFetch(const VarNode* buffer, PrimExpr index); + std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) final; void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) // Codegen for immediate values diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index e7231a16c280..a1bd9f013cc9 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -140,7 +140,7 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { std::string whole_code = cg.Finish(); // Generate source code for compilation. - Array > kernel_info; + Array > kernel_info; for (LoweredFunc f : funcs) { CodeGenVivadoHLS cg; cg.Init(output_ssa); @@ -149,7 +149,7 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) { code = (*f)(code).operator std::string(); } - kernel_info.push_back(Array({f->name, code})); + kernel_info.push_back(Array({f->name, code})); } std::string xclbin; diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc index 571ec523a328..0609989a3362 100644 --- a/src/codegen/intrin_rule.cc +++ b/src/codegen/intrin_rule.cc @@ -53,7 +53,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") .set_body([](const TVMArgs& args, TVMRetValue* rv){ - Expr e = args[0]; + PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); @@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") .set_body([](const TVMArgs& args, TVMRetValue* rv){ - Expr e = args[0]; + PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h index a0665bf5c5b9..56ba225631df 100644 --- a/src/codegen/intrin_rule.h +++ b/src/codegen/intrin_rule.h @@ -60,7 +60,7 @@ struct Direct { // Call pure extern function. template inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { - Expr e = args[0]; + PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); std::string name = T()(call->dtype, call->name); diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 397f9d3593ae..fb7abc394bb8 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -231,7 +231,7 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { const auto *find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); - Array bitcode_files = (*find_rocm_bitcodes)(); + Array bitcode_files = (*find_rocm_bitcodes)(); for (auto &bitcode : bitcode_files) { std::string path = bitcode.as()->value; diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index fdc1b421fee0..6879fd5f8542 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -42,7 +42,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CreateIntrinsic(const CallNode* op) override; private: - Expr ARMPopcount(const CallNode* op); + PrimExpr ARMPopcount(const CallNode* op); }; llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { @@ -50,16 +50,16 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { llvm::Intrinsic::ID id = static_cast( op->args[0].as()->value); if (id == ::llvm::Intrinsic::ctpop) { - Expr e = ARMPopcount(op); + PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); } } return CodeGenCPU::CreateIntrinsic(op); } -Expr CodeGenARM::ARMPopcount(const CallNode *call) { +PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { using namespace ir; - const Expr& e = call->args[2]; + const PrimExpr& e = call->args[2]; ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop; ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu; @@ -67,7 +67,7 @@ Expr CodeGenARM::ARMPopcount(const CallNode *call) { int total_size = call->dtype.bits() * call->dtype.lanes(); if (!call->dtype.is_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { - Array vcnt_args; + Array vcnt_args; vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); vcnt_args.push_back(e); @@ -88,41 +88,41 @@ Expr CodeGenARM::ARMPopcount(const CallNode *call) { uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); // Interpret input as vector of 8bit values - Expr input8 = reinterpret(uint8_type, e); + PrimExpr input8 = reinterpret(uint8_type, e); // Popcount 8bit->8bit const CallNode* c0 = input8.as(); CHECK(c0 != nullptr); - Array vcnt8_args; + Array vcnt8_args; vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - Expr vcnt8 = ir::CallNode::make( + PrimExpr vcnt8 = ir::CallNode::make( uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit - Array vcnt16_args; + Array vcnt16_args; vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - Expr vcnt16 = ir::CallNode::make( + PrimExpr vcnt16 = ir::CallNode::make( uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } // Accumulation 16->32bit - Array vcnt32_args; + Array vcnt32_args; vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - Expr vcnt32 = ir::CallNode::make( + PrimExpr vcnt32 = ir::CallNode::make( uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } // Accumulation 32->64bit - Array vcnt64_args; + Array vcnt64_args; vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc index 0622269e2aad..39d8c7fb28e4 100644 --- a/src/codegen/llvm/codegen_cpu.cc +++ b/src/codegen/llvm/codegen_cpu.cc @@ -669,7 +669,7 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { } llvm::BasicBlock * -CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, +CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, llvm::Value **ret_tcode, const DataType &r_type, const int64_t begin, const int64_t end) { using llvm::BasicBlock; @@ -923,8 +923,8 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { CHECK(parallel_env_.num_task.defined()); CHECK(parallel_env_.penv != nullptr); DataType t = op->extent.dtype(); - Expr num_task = cast(t, parallel_env_.num_task); - Expr task_id = cast(t, parallel_env_.task_id); + PrimExpr num_task = cast(t, parallel_env_.num_task); + PrimExpr task_id = cast(t, parallel_env_.task_id); CHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; @@ -935,9 +935,9 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { op->loop_var, op->body); } else { - Expr step = (op->extent + num_task - make_const(t, 1)) / num_task; - Expr begin = MinNode::make(task_id * step, op->extent); - Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); + PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; + PrimExpr begin = MinNode::make(task_id * step, op->extent); + PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), ConstInt32(1), diff --git a/src/codegen/llvm/codegen_cpu.h b/src/codegen/llvm/codegen_cpu.h index 46f3f96f082b..6ee0085abe4e 100644 --- a/src/codegen/llvm/codegen_cpu.h +++ b/src/codegen/llvm/codegen_cpu.h @@ -77,8 +77,8 @@ class CodeGenCPU : public CodeGenLLVM { private: // the parallel group information struct ParallelEnv { - VarExpr task_id; - VarExpr num_task; + Var task_id; + Var num_task; bool stride_pattern{false}; bool in_parallel_loop{false}; int parallel_loop_count{0}; @@ -101,7 +101,7 @@ class CodeGenCPU : public CodeGenLLVM { const Array& fields, std::unordered_map* vmap); // Make packed call. - llvm::BasicBlock *MakeCallPacked(const Array &args, + llvm::BasicBlock *MakeCallPacked(const Array &args, llvm::Value **rvalue, llvm::Value **ret_tcode, const DataType &r_type, const int64_t begin, const int64_t end); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index e2ba19a3680f..c04a023aefad 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -334,7 +334,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const DataType& t) const { // void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, - Expr index, + PrimExpr index, DataType type) { if (alias_var_set_.count(buffer) != 0) { // Mark all possibly aliased pointer as same type. @@ -389,7 +389,7 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, - const Expr& index, + const PrimExpr& index, int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); @@ -526,7 +526,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const VarExpr& loop_var, + const Var& loop_var, const Stmt& body) { using llvm::BasicBlock; BasicBlock* pre_block = builder_->GetInsertBlock(); @@ -711,7 +711,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { addrspace = llvm::dyn_cast( ptr->getType())->getAddressSpace(); } else { - Expr index = r->base / make_const(DataType::Int(32), r->lanes); + PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); ptr = CreateBufferVecPtr( l->dtype, MakeValue(l->buffer_var), MakeValue(index)); addrspace = llvm::dyn_cast( @@ -776,11 +776,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } } -void CodeGenLLVM::Scalarize(const Expr& e, +void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { - Expr offset = ramp->base + (ramp->stride * i); + PrimExpr offset = ramp->base + (ramp->stride * i); f(i, MakeValue(offset)); } } else { @@ -988,7 +988,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { llvm::LoadInst* load = builder_->CreateAlignedLoad( ptr, basic_align, is_volatile); ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); - AddAliasInfo(load, op->buffer_var.get(), Expr(), t); + AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t); }; this->Scalarize(op->index, f); return ret; @@ -1084,7 +1084,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::StoreInst* store = builder_->CreateAlignedStore( builder_->CreateExtractElement(value, i), ptr, basic_align, is_volatile); - AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.dtype()); + AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype()); }; this->Scalarize(op->index, f); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 67ca7c110310..34c3ee723e18 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -47,7 +47,7 @@ using namespace ir; * \brief A base class to generate a LLVM. */ class CodeGenLLVM : - public ExprFunctor, + public ExprFunctor, public StmtFunctor { public: /*! @@ -95,7 +95,7 @@ class CodeGenLLVM : * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const Expr& e) { + llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } // Short hande code to get a constant int 32 @@ -184,7 +184,7 @@ class CodeGenLLVM : virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); // Scalarize by iterating elements of e. // f is a callback that takes index and v. - virtual void Scalarize(const Expr& e, + virtual void Scalarize(const PrimExpr& e, std::function f); // Initialize target virtual void InitTarget(llvm::TargetMachine* tm); @@ -211,7 +211,7 @@ class CodeGenLLVM : void InitFuncState(); // Get alignment given index. void GetAlignment( - DataType t, const VarNode* buf_var, const Expr& index, + DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits); // Get constant string llvm::Value* GetConstString(const std::string& str); @@ -243,9 +243,9 @@ class CodeGenLLVM : void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const VarExpr& loop_var, const Stmt& body); + const Var& loop_var, const Stmt& body); // add alias information. - void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, Expr index, DataType type); + void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type); // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index 10774ec1fc57..b05185bafd9c 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -63,21 +63,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { - Expr e = targs[0]; + PrimExpr e = targs[0]; const ir::CallNode* call = e.as(); CHECK(call != nullptr); - const Expr& x = call->args[0]; - Expr one = make_const(x.dtype(), 1); - Expr two = make_const(x.dtype(), 2); - Expr neg_two = make_const(x.dtype(), -2); + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1); + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_two = make_const(x.dtype(), -2); - Expr exp_neg2x = ir::CallNode::make( + PrimExpr exp_neg2x = ir::CallNode::make( x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic); - Expr exp_pos2x = ir::CallNode::make( + PrimExpr exp_pos2x = ir::CallNode::make( x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic); - Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); - Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); *rv = ir::SelectNode::make( x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index a870385359b6..b3ab557ee215 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -38,15 +38,15 @@ namespace codegen { // num_signature means number of arguments used to query signature template inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { - Expr e = targs[0]; + PrimExpr e = targs[0]; const ir::CallNode* call = e.as(); CHECK(call != nullptr); - Array cargs; + Array cargs; // intrin id. cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); - for (Expr arg : call->args) { + for (PrimExpr arg : call->args) { cargs.push_back(arg); } *rv = ir::CallNode::make( @@ -55,14 +55,14 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { template inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { - Expr e = targs[0]; + PrimExpr e = targs[0]; const ir::CallNode* call = e.as(); CHECK(call != nullptr); - Array cargs; + Array cargs; // intrin id. cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); - for (Expr arg : call->args) { + for (PrimExpr arg : call->args) { cargs.push_back(arg); } *rv = ir::CallNode::make( diff --git a/src/codegen/llvm/intrin_rule_nvptx.cc b/src/codegen/llvm/intrin_rule_nvptx.cc index 00824bbc9046..fcd8a1a2b664 100644 --- a/src/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/codegen/llvm/intrin_rule_nvptx.cc @@ -33,7 +33,7 @@ namespace tvm { namespace codegen { inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { - Expr e = args[0]; + PrimExpr e = args[0]; using namespace ir; const CallNode* call = e.as(); CHECK(call != nullptr); diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index 09de88f1d887..41035af71e62 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -33,7 +33,7 @@ namespace tvm { namespace codegen { inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { - Expr e = args[0]; + PrimExpr e = args[0]; using namespace ir; const CallNode* call = e.as(); CHECK(call != nullptr); diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 254e436c8e05..a749424892e2 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -90,7 +90,7 @@ void CodeGenSPIRV::InitFuncState() { } spirv::Value CodeGenSPIRV::GetThreadIndex( - const IterVar& iv, const Expr& extent) { + const IterVar& iv, const PrimExpr& extent) { runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); spirv::Value v; if (ts.rank == 1) { @@ -403,7 +403,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - Expr vec_index = ir::Simplify( + PrimExpr vec_index = ir::Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); @@ -417,11 +417,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { return spirv::Value(); } -void CodeGenSPIRV::Scalarize(const Expr& e, +void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { - Expr offset = ramp->base + ramp->stride * i; + PrimExpr offset = ramp->base + ramp->stride * i; f(i, MakeValue(offset)); } } else { @@ -481,7 +481,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - Expr vec_index = ir::Simplify( + PrimExpr vec_index = ir::Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index b72cd5bc97fb..3804bda0f2e0 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -44,7 +44,7 @@ using namespace ir; * \brief Code generator into SPIRV */ class CodeGenSPIRV: - public ExprFunctor, + public ExprFunctor, public StmtFunctor { public: /*! @@ -58,7 +58,7 @@ class CodeGenSPIRV: * \param e The expression to be created value for. * \return created value. */ - spirv::Value MakeValue(const Expr& e) { + spirv::Value MakeValue(const PrimExpr& e) { return VisitExpr(e); } // override codegen @@ -128,9 +128,9 @@ class CodeGenSPIRV: // Reset the state so it works for a new function. void InitFuncState(); // Get the thread index - spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent); + spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); spirv::Value CreateStorageSync(const CallNode* op); - void Scalarize(const Expr& e, + void Scalarize(const PrimExpr& e, std::function f); // The builder std::unique_ptr builder_; diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index 69d20142b98c..d41d96db5165 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -34,14 +34,14 @@ using namespace runtime; // num_signature means number of arguments used to query signature template inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { - Expr e = targs[0]; + PrimExpr e = targs[0]; const ir::CallNode* call = e.as(); CHECK(call != nullptr); - Array cargs; + Array cargs; // intrin id. cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - for (Expr arg : call->args) { + for (PrimExpr arg : call->args) { cargs.push_back(arg); } *rv = ir::CallNode::make( diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 3da083be87a6..eccff6c74c2e 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -244,8 +244,8 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { } void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, - const Expr& a, - const Expr& b) { + const PrimExpr& a, + const PrimExpr& b) { this->Push(a); this->Push(b); DataType t = a.dtype(); diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 1f00eccb5596..07989b2062e1 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -46,7 +46,7 @@ using runtime::StackVM; * into device function when only device JIT is available. */ class CodeGenStackVM - : public ExprFunctor, + : public ExprFunctor, public StmtFunctor { public: /*! @@ -60,7 +60,7 @@ class CodeGenStackVM /*! \brief Push stmt to generate new code */ void Push(const Stmt& n); /*! \brief Push expr to generate new code */ - void Push(const Expr& n) { + void Push(const PrimExpr& n) { VisitExpr(n); } /*! @@ -105,8 +105,8 @@ class CodeGenStackVM int GetVarID(const VarNode* v) const; // Push binary operator void PushBinary(StackVM::OpCode op_int64, - const Expr& a, - const Expr& b); + const PrimExpr& a, + const PrimExpr& b); // push cast; void PushCast(DataType dst, DataType src); // overloadable functions diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 01696b2a8fcb..89a1ece577f9 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -46,7 +46,7 @@ using namespace ir; * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``. */ class CodeGenHybrid : - public ExprFunctor, + public ExprFunctor, public StmtFunctor { public: /*! @@ -77,14 +77,14 @@ class CodeGenHybrid : * \param n The expression to be printed. * \param os The output stream */ - void PrintExpr(const Expr &n, std::ostream &os) { + void PrintExpr(const PrimExpr &n, std::ostream &os) { this->VisitExpr(n, os); } /*! * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. */ - std::string PrintExpr(const Expr &n) { + std::string PrintExpr(const PrimExpr &n) { std::ostringstream os; PrintExpr(n, os); return os.str(); diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 6264e0f95fc5..1d3e767a5b71 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -44,9 +44,9 @@ void DictAttrsNode::InitByPackedArgs( if (val.IsObjectRef()) { dict.Set(key, val.operator ObjectRef()); } else if (val.type_code() == kStr) { - dict.Set(key, Expr(val.operator std::string())); + dict.Set(key, PrimExpr(val.operator std::string())); } else { - dict.Set(key, val.operator Expr()); + dict.Set(key, val.operator PrimExpr()); } } } diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index d96033d3c02d..925e3dbaffc6 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -34,22 +34,22 @@ namespace tvm { using IndexMod = ir::FloorModNode; using IndexDiv = ir::FloorDivNode; -Array SimplifyArray(Array array) { +Array SimplifyArray(Array array) { for (size_t i = 0; i < array.size(); ++i) { array.Set(i, ir::Simplify(array[i])); } return array; } -Buffer decl_buffer(Array shape, +Buffer decl_buffer(Array shape, DataType dtype, std::string name) { return BufferNode::make( Var(name, DataType::Handle()), dtype, shape, - Array(), - Expr(), + Array(), + PrimExpr(), name, "", 0, 0, @@ -57,13 +57,13 @@ Buffer decl_buffer(Array shape, } // Split the given expression w.r.t the add operator -inline std::vector ExprSplitAddition(const Expr &expr) { +inline std::vector ExprSplitAddition(const PrimExpr &expr) { using namespace ir; - std::vector ret; - std::stack split_buffer; + std::vector ret; + std::stack split_buffer; split_buffer.push(&expr); while (!split_buffer.empty()) { - const Expr* top_ele = split_buffer.top(); + const PrimExpr* top_ele = split_buffer.top(); split_buffer.pop(); auto expr_add_match = top_ele->as(); if (expr_add_match) { @@ -84,14 +84,14 @@ inline std::vector ExprSplitAddition(const Expr &expr) { // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(const Expr &mult_expr, - const Expr &mod_l_expr, - const Expr &mod_r_expr) { +inline std::pair MergeMulModInner(const PrimExpr &mult_expr, + const PrimExpr &mod_l_expr, + const PrimExpr &mod_r_expr) { using namespace ir; const MulNode* mult_ptr = mult_expr.as(); - if (!mult_ptr) return std::make_pair(false, Expr()); - Expr mult_outer = mult_ptr->b; - const Expr* inner = &(mult_ptr->a); + if (!mult_ptr) return std::make_pair(false, PrimExpr()); + PrimExpr mult_outer = mult_ptr->b; + const PrimExpr* inner = &(mult_ptr->a); // 1. Calculate the outer multiplier while (true) { mult_ptr = inner->as(); @@ -108,32 +108,32 @@ inline std::pair MergeMulModInner(const Expr &mult_expr, // If Mult is found, we will expand the inner multiplication factor // If Div is found, we will go on testing whether lhs matches the lhs of mod expr // and returns the optimization result. - const Expr* search_ptr = inner; - Expr mult_inner; // The inner multiplication factor - Expr no_opt_sum; // Sum of the exprs that cannot be optimized + const PrimExpr* search_ptr = inner; + PrimExpr mult_inner; // The inner multiplication factor + PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized while (true) { auto inner_div_ptr = search_ptr->as(); auto inner_mult_ptr = search_ptr->as(); auto inner_add_ptr = search_ptr->as(); if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { - return std::make_pair(false, Expr()); + return std::make_pair(false, PrimExpr()); } else if (inner_div_ptr) { - Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; + PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; if (Equal(overall_mult, inner_div_ptr->b) && Equal(overall_mult, mod_r_expr) && Equal(inner_div_ptr->a, mod_l_expr)) { // Found! - Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; + PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; return std::make_pair(true, ret); } else { - return std::make_pair(false, Expr()); + return std::make_pair(false, PrimExpr()); } } else if (inner_mult_ptr) { mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b; search_ptr = &(inner_mult_ptr->a); } else if (inner_add_ptr) { if (mult_inner.get()) { - return std::make_pair(false, Expr()); + return std::make_pair(false, PrimExpr()); } no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a; search_ptr = &(inner_add_ptr->b); @@ -142,23 +142,23 @@ inline std::pair MergeMulModInner(const Expr &mult_expr, break; } } - return std::make_pair(false, Expr()); + return std::make_pair(false, PrimExpr()); } // Insert the elements into the corresponding mult_exprs and mod_exprs. // If the element is found to match Mul, it will be pushed to the mult_exprs. // If the element it found to match Mod, it will be pused to the mod_exprs. // Otherwise, the elements will be added to the no_opt_sum variable -inline void MergeMulModInsertElements(const std::vector& eles, - std::list* mult_exprs, - std::list >* mod_exprs, - Expr* no_opt_sum, +inline void MergeMulModInsertElements(const std::vector& eles, + std::list* mult_exprs, + std::list >* mod_exprs, + PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { using namespace ir; *has_mult = false; *has_mod = false; - for (const Expr* ele : eles) { + for (const PrimExpr* ele : eles) { auto mod_ptr = ele->as(); auto mult_ptr = ele->as(); if (mod_ptr) { @@ -180,30 +180,30 @@ inline void MergeMulModInsertElements(const std::vector& eles, // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized -inline Expr MergeMulMod(const Expr &base) { +inline PrimExpr MergeMulMod(const PrimExpr &base) { using namespace ir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and // a list that contain all the elements that match Mod. // The elements in the Mod will be used to match against the elements in Mul. // The result will then be split and pushed back to these two lists. - Expr simplified_base = Simplify(base); - std::vector eles = ExprSplitAddition(simplified_base); - std::list mult_exprs; - std::list > mod_exprs; - Expr no_opt_sum; + PrimExpr simplified_base = Simplify(base); + std::vector eles = ExprSplitAddition(simplified_base); + std::list mult_exprs; + std::list > mod_exprs; + PrimExpr no_opt_sum; bool has_mult; bool has_mod; MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); bool find_opt = false; - std::list >::iterator search_mod_it = mod_exprs.begin(); + std::list >::iterator search_mod_it = mod_exprs.begin(); // 2. Exhaustive Search while (search_mod_it != mod_exprs.end()) { - std::list::iterator mult_it = mult_exprs.begin(); + std::list::iterator mult_it = mult_exprs.begin(); bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { - std::pair ret = MergeMulModInner(*mult_it, + std::pair ret = MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { @@ -212,7 +212,7 @@ inline Expr MergeMulMod(const Expr &base) { ++search_mod_it; mod_exprs.erase(temp_mod_it); mult_exprs.erase(mult_it); - std::vector ret_eles = ExprSplitAddition(ret.second); + std::vector ret_eles = ExprSplitAddition(ret.second); MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); if (has_mult) { @@ -233,10 +233,10 @@ inline Expr MergeMulMod(const Expr &base) { if (!find_opt) { return simplified_base; } - for (std::list::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) { + for (std::list::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) { no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it; } - for (std::list >::iterator it = mod_exprs.begin(); + for (std::list >::iterator it = mod_exprs.begin(); it != mod_exprs.end(); ++it) { no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second); @@ -247,8 +247,8 @@ inline Expr MergeMulMod(const Expr &base) { // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. -inline Expr ElemOffset(const BufferNode* n, Array index) { - Expr base = n->elem_offset; +inline PrimExpr ElemOffset(const BufferNode* n, Array index) { + PrimExpr base = n->elem_offset; if (n->strides.size() == 0) { // Scalar case if (n->shape.size() == 0 && index.size() == 1) { @@ -258,7 +258,7 @@ inline Expr ElemOffset(const BufferNode* n, Array index) { } else { CHECK_EQ(n->shape.size(), index.size()); if (index.size() > 0) { - Expr offset = index[0]; + PrimExpr offset = index[0]; for (size_t i = 1; i < index.size(); ++i) { offset = MergeMulMod(offset * n->shape[i] + index[i]); } @@ -279,8 +279,8 @@ inline Expr ElemOffset(const BufferNode* n, Array index) { return base; } -inline Expr BufferOffset(const BufferNode* n, Array index, DataType dtype) { - Expr offset = ElemOffset(n, index); +inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataType dtype) { + PrimExpr offset = ElemOffset(n, index); if (n->dtype.lanes() != 1) { offset = offset * make_const(offset.dtype(), dtype.lanes()); } @@ -291,7 +291,7 @@ inline Expr BufferOffset(const BufferNode* n, Array index, DataType dtype) } } -Expr Buffer::vload(Array begin, DataType dtype) const { +PrimExpr Buffer::vload(Array begin, DataType dtype) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); CHECK(dtype.element_of() == n->dtype.element_of() && @@ -311,7 +311,7 @@ Expr Buffer::vload(Array begin, DataType dtype) const { } } -Stmt Buffer::vstore(Array begin, Expr value) const { +Stmt Buffer::vstore(Array begin, PrimExpr value) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); DataType dtype = value.dtype(); @@ -333,9 +333,9 @@ Stmt Buffer::vstore(Array begin, Expr value) const { Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; - std::vector temp; + std::vector temp; auto n = make_object(*operator->()); - Expr acc = make_const(n->DefaultIndexType(), 1); + PrimExpr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0 ; --i) { temp.push_back(acc); acc = acc * n->shape[i - 1]; @@ -346,11 +346,11 @@ Buffer Buffer::MakeStrideView() const { return Buffer(n); } -Buffer Buffer::MakeSlice(Array begins, Array extents) const { +Buffer Buffer::MakeSlice(Array begins, Array extents) const { const BufferNode* n = operator->(); begins = SimplifyArray(begins); - Expr elem_offset = ir::Simplify(ElemOffset(n, begins)); - Array strides = n->strides; + PrimExpr elem_offset = ir::Simplify(ElemOffset(n, begins)); + Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; bool need_stride = false; @@ -381,19 +381,22 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const { n->buffer_type); } -Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, Expr offset) const { +PrimExpr Buffer::access_ptr(int access_mask, + DataType ptr_type, + int content_lanes, + PrimExpr offset) const { const BufferNode* self = operator->(); - Expr e_dtype; - Expr extent; + PrimExpr e_dtype; + PrimExpr extent; if (self->shape.size() == 0) { extent = make_const(self->DefaultIndexType(), 1); } else if (self->strides.size() == self->shape.size()) { int highest_dim = 0; extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; } else { - extent = arith::ComputeReduce(self->shape, Expr()) - offset; + extent = arith::ComputeReduce(self->shape, PrimExpr()) - offset; } - Expr elem_offset = self->elem_offset + offset; + PrimExpr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.dtype(), content_lanes); @@ -402,7 +405,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E } else { e_dtype = ir::TypeAnnotation(self->dtype); } - Array acc_args{ + Array acc_args{ e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; return ir::CallNode::make( @@ -411,9 +414,9 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E Buffer BufferNode::make(Var data, DataType dtype, - Array shape, - Array strides, - Expr elem_offset, + Array shape, + Array strides, + PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc index c30f34476709..ba5e4adeb66f 100644 --- a/src/lang/data_layout.cc +++ b/src/lang/data_layout.cc @@ -103,13 +103,13 @@ Layout::Layout(const std::string& name) { // NOLINT(*) << " before dimension " << c; std::string shape_name("_shape"); shape_name.insert(0, 1, c); - IterVar axis = IterVarNode::make(Range(Expr(0), Var(shape_name)), + IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " for dimension " << c; - IterVar axis = IterVarNode::make(Range(Expr(0), Expr(factor)), + IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), kDataPar); node->axes.push_back(axis); factor = 0; @@ -171,7 +171,7 @@ Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { if (i == target_pos) { - new_layout.push_back(IterVarNode::make(Range(Expr(0), Expr(factor)), + new_layout.push_back(IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), Var(axis.ToSubordinate().name()), kDataPar)); } if (i == this->ndim()) break; @@ -200,7 +200,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "Layout(" << l->name << ")"; }); -inline bool GetStoreRule(Array* rule, +inline bool GetStoreRule(Array* rule, const Layout& src_layout, const Layout& dst_layout) { if (!src_layout.defined() || src_layout.name().empty() || @@ -210,17 +210,17 @@ inline bool GetStoreRule(Array* rule, for (size_t i = 0; i < dst_layout.ndim(); ++i) { const auto& store_axis = dst_layout[i]; const IterVar& store_axis_impl = dst_layout->axes[i]; - Expr store(0); + PrimExpr store(0); for (size_t j = 0; j < src_layout.ndim(); ++j) { const auto& orig_axis = src_layout[j]; const IterVar& orig_axis_impl = src_layout->axes[j]; if (store_axis.ToPrimal() == orig_axis.ToPrimal()) { if (orig_axis.IsPrimal()) { - Expr orig_var = orig_axis_impl->var; + PrimExpr orig_var = orig_axis_impl->var; const int32_t factor = src_layout.FactorOf(orig_axis); if (factor > 0) { - orig_var = orig_var * Expr(factor); + orig_var = orig_var * PrimExpr(factor); } store = store + orig_var; } else { @@ -236,7 +236,7 @@ inline bool GetStoreRule(Array* rule, if (store_axis.IsPrimal()) { const int32_t factor = dst_layout.FactorOf(store_axis); if (factor > 0) { - store = indexdiv(store, Expr(factor)); + store = indexdiv(store, PrimExpr(factor)); } } else { store = indexmod(store, store_axis_impl->dom->extent); @@ -247,21 +247,21 @@ inline bool GetStoreRule(Array* rule, return true; } -inline Array TransformIndex(const Array& src_index, +inline Array TransformIndex(const Array& src_index, const Array& src_axis, - const Array& transform_rule) { - Array result; - std::unordered_map bind_map; + const Array& transform_rule) { + Array result; + std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } - for (Expr rule : transform_rule) { + for (PrimExpr rule : transform_rule) { result.push_back(ir::Simplify(ir::Substitute(rule, bind_map))); } return result; } -Array BijectiveLayout::ForwardIndex(const Array& src_index) const { +Array BijectiveLayout::ForwardIndex(const Array& src_index) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(src_index.size(), self->src_layout->axes.size()) @@ -270,7 +270,7 @@ Array BijectiveLayout::ForwardIndex(const Array& src_index) const { } -Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { +Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) @@ -278,19 +278,19 @@ Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule); } -inline Array TransformShape(const Array& src_shape, +inline Array TransformShape(const Array& src_shape, const Array& src_axis, const Array& target_axis, - const Array& transform_rule) { + const Array& transform_rule) { CHECK_EQ(src_shape.size(), src_axis.size()); // bind variables for original axes // for major-axis, bind the corresponding size // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, // e.g., (C * 16 + c) / 32 - std::unordered_map bind_map; + std::unordered_map bind_map; std::unordered_set symbolic_var_set; for (size_t i = 0; i < src_shape.size(); ++i) { - Expr orig_shape = src_shape[i]; + PrimExpr orig_shape = src_shape[i]; IterVar orig_axis = src_axis[i]; if (orig_shape.as()) { symbolic_var_set.insert(i); @@ -305,7 +305,7 @@ inline Array TransformShape(const Array& src_shape, << orig_axis->dom->extent << ", get " << orig_shape; } } - bind_map[orig_axis->var.get()] = Expr(0); + bind_map[orig_axis->var.get()] = PrimExpr(0); } else { bind_map[orig_axis->var.get()] = orig_shape; } @@ -313,10 +313,10 @@ inline Array TransformShape(const Array& src_shape, // infer the target shape, // for major-axis, use the forward/backward_rule directly, // for minor-axis, simply use the extent. - Array result; + Array result; CHECK_EQ(transform_rule.size(), target_axis.size()); for (size_t i = 0; i < transform_rule.size(); ++i) { - Expr rule = transform_rule[i]; + PrimExpr rule = transform_rule[i]; IterVar axis = target_axis[i]; if (!LayoutAxis::Get(axis).IsPrimal()) { result.push_back(axis->dom->extent); @@ -331,14 +331,14 @@ inline Array TransformShape(const Array& src_shape, return result; } -Array BijectiveLayout::ForwardShape(const Array& shape) const { +Array BijectiveLayout::ForwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule); } -Array BijectiveLayout::BackwardShape(const Array& shape) const { +Array BijectiveLayout::BackwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->dst_layout->axes, diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 58a97ed91742..a7289369bcd4 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -29,14 +29,14 @@ namespace tvm { -Expr::Expr(int32_t value) - : Expr(IntImmNode::make(DataType::Int(32), value)) {} +PrimExpr::PrimExpr(int32_t value) + : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {} -Expr::Expr(float value) - : Expr(ir::FloatImmNode::make(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(float value) + : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {} -Expr::Expr(std::string str) - : Expr(ir::StringImmNode::make(str)) {} +PrimExpr::PrimExpr(std::string str) + : PrimExpr(ir::StringImmNode::make(str)) {} Var::Var(std::string name_hint, DataType t) : Var(VarNode::make(t, name_hint)) {} @@ -48,7 +48,7 @@ Var VarNode::make(DataType t, std::string name_hint) { return Var(node); } -Range::Range(Expr begin, Expr end) +Range::Range(PrimExpr begin, PrimExpr end) : Range(make_object( begin, is_zero(begin) ? end : (end - begin))) { @@ -63,7 +63,7 @@ Integer IntImmNode::make(DataType t, int64_t value) { return Integer(node); } -Range Range::make_by_min_extent(Expr min, Expr extent) { +Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 34fac72a0ab5..078ca628ad24 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -30,13 +30,13 @@ namespace tvm { // simple cast that only checks if type matches and cast -inline Expr SimpleCast(const DataType& t, Expr value) { +inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; return ir::CastNode::make(t, value); } // The public function with a quick checking path. -void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) +void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); @@ -80,7 +80,7 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) // maximum and min limits -Expr max_value(const DataType& dtype) { +PrimExpr max_value(const DataType& dtype) { using namespace ir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { @@ -109,10 +109,10 @@ Expr max_value(const DataType& dtype) { } } LOG(FATAL) << "Cannot decide max_value for type" << dtype; - return Expr(); + return PrimExpr(); } -Expr min_value(const DataType& dtype) { +PrimExpr min_value(const DataType& dtype) { using namespace ir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { @@ -135,7 +135,7 @@ Expr min_value(const DataType& dtype) { } } LOG(FATAL) << "Cannot decide min_value for type" << dtype; - return Expr(); + return PrimExpr(); } template @@ -152,7 +152,7 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { return true; } -bool is_const_power_of_two_integer(const Expr& x, int* shift) { +bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); } else if (const auto* op = x.as()) { @@ -162,7 +162,7 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) { } } -Expr cast(const DataType& t, Expr value) { +PrimExpr cast(const DataType& t, PrimExpr value) { using ir::IntImmNode; using ir::UIntImmNode; using ir::FloatImmNode; @@ -200,21 +200,21 @@ Expr cast(const DataType& t, Expr value) { } } -Expr reinterpret(const DataType& t, Expr value) { +PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; return ir::CallNode::make( t, ir::CallNode::reinterpret, { value }, ir::CallNode::PureIntrinsic); } -Expr operator+(Expr a, Expr b) { +PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::AddNode::make(a, b); } // negation -Expr operator-(Expr a) { +PrimExpr operator-(PrimExpr a) { using ir::IntImmNode; using ir::FloatImmNode; const IntImmNode* pa = a.as(); @@ -224,76 +224,76 @@ Expr operator-(Expr a) { return make_zero(a.dtype()) - a; } -Expr operator-(Expr a, Expr b) { +PrimExpr operator-(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::SubNode::make(a, b); } -Expr operator*(Expr a, Expr b) { +PrimExpr operator*(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::MulNode::make(a, b); } -Expr div(Expr a, Expr b) { +PrimExpr div(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::DivNode::make(a, b); } -Expr truncdiv(Expr a, Expr b) { +PrimExpr truncdiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); return div(a, b); } -Expr truncmod(Expr a, Expr b) { +PrimExpr truncmod(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::ModNode::make(a, b); } -Expr operator/(Expr a, Expr b) { +PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } -Expr operator%(Expr a, Expr b) { +PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); } // TODO(tqchen): switch to floordiv -Expr indexdiv(Expr a, Expr b) { +PrimExpr indexdiv(PrimExpr a, PrimExpr b) { return floordiv(a, b); } -Expr indexmod(Expr a, Expr b) { +PrimExpr indexmod(PrimExpr a, PrimExpr b) { return floormod(a, b); } -Expr floordiv(Expr a, Expr b) { +PrimExpr floordiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::FloorDivNode::make(a, b); } -Expr floormod(Expr a, Expr b) { +PrimExpr floormod(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::FloorModNode::make(a, b); } -Expr min(Expr a, Expr b) { +PrimExpr min(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton using arith::is_pos_inf; using arith::is_neg_inf; @@ -302,12 +302,12 @@ Expr min(Expr a, Expr b) { if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::MinNode::make(a, b); } -Expr max(Expr a, Expr b) { +PrimExpr max(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton using arith::is_pos_inf; using arith::is_neg_inf; @@ -316,12 +316,12 @@ Expr max(Expr a, Expr b) { if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::MaxNode::make(a, b); } -Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { +PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { using ir::IntImmNode; using ir::UIntImmNode; CHECK(cond.dtype() == DataType::Bool(1)) @@ -347,7 +347,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { ir::CallNode::PureIntrinsic); } -Expr likely(Expr cond) { +PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; return ir::CallNode::make(cond.dtype(), ir::CallNode::likely, @@ -355,72 +355,72 @@ Expr likely(Expr cond) { ir::CallNode::PureIntrinsic); } -Expr operator>(Expr a, Expr b) { +PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::GTNode::make(a, b); } -Expr operator>=(Expr a, Expr b) { +PrimExpr operator>=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::GENode::make(a, b); } -Expr operator<(Expr a, Expr b) { +PrimExpr operator<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::LTNode::make(a, b); } -Expr operator<=(Expr a, Expr b) { +PrimExpr operator<=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::LENode::make(a, b); } -Expr operator==(Expr a, Expr b) { +PrimExpr operator==(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::EQNode::make(a, b); } -Expr operator!=(Expr a, Expr b) { +PrimExpr operator!=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::NENode::make(a, b); } -Expr operator&&(Expr a, Expr b) { +PrimExpr operator&&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::AndNode::make(a, b); } -Expr operator||(Expr a, Expr b) { +PrimExpr operator||(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - Expr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return ir::OrNode::make(a, b); } -Expr operator!(Expr a) { +PrimExpr operator!(PrimExpr a) { CHECK(a.dtype().is_bool()); - Expr ret = arith::TryConstFold(a); + PrimExpr ret = arith::TryConstFold(a); if (ret.defined()) return ret; return ir::NotNode::make(a); } -Expr operator>>(Expr a, Expr b) { +PrimExpr operator>>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -433,7 +433,7 @@ Expr operator>>(Expr a, Expr b) { a.dtype(), ir::CallNode::shift_right, { a, b }, ir::CallNode::PureIntrinsic); } -Expr operator<<(Expr a, Expr b) { +PrimExpr operator<<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -446,7 +446,7 @@ Expr operator<<(Expr a, Expr b) { a.dtype(), ir::CallNode::shift_left, { a, b }, ir::CallNode::PureIntrinsic); } -Expr operator&(Expr a, Expr b) { +PrimExpr operator&(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -456,7 +456,7 @@ Expr operator&(Expr a, Expr b) { a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic); } -Expr operator|(Expr a, Expr b) { +PrimExpr operator|(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -466,7 +466,7 @@ Expr operator|(Expr a, Expr b) { a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic); } -Expr operator^(Expr a, Expr b) { +PrimExpr operator^(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -476,20 +476,20 @@ Expr operator^(Expr a, Expr b) { a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic); } -Expr operator~(Expr a) { +PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_not, { a }, ir::CallNode::PureIntrinsic); } -Expr pow(Expr x, Expr y) { +PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; return ir::CallNode::make( x.dtype(), "pow", { x, y }, ir::CallNode::PureIntrinsic); } -Expr abs(Expr x) { +PrimExpr abs(PrimExpr x) { if (x.dtype().is_int()) { using ir::IntImmNode; const IntImmNode* px = x.as(); @@ -513,7 +513,7 @@ Expr abs(Expr x) { } } -Expr isnan(Expr x) { +PrimExpr isnan(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); @@ -537,97 +537,97 @@ Expr isnan(Expr x) { } } -Expr sum(Expr source, Array rdom) { +PrimExpr sum(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::AddNode::make(x, y); - Expr identity_element = make_zero(source.dtype()); + PrimExpr result = ir::AddNode::make(x, y); + PrimExpr identity_element = make_zero(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } -Expr all(Expr source, Array rdom) { +PrimExpr all(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::AndNode::make(x, y); - Expr identity_element = make_const(source.dtype(), true); + PrimExpr result = ir::AndNode::make(x, y); + PrimExpr identity_element = make_const(source.dtype(), true); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } -Expr any(Expr source, Array rdom) { +PrimExpr any(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::OrNode::make(x, y); - Expr identity_element = make_const(source.dtype(), false); + PrimExpr result = ir::OrNode::make(x, y); + PrimExpr identity_element = make_const(source.dtype(), false); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } -Expr max(Expr source, Array rdom) { +PrimExpr max(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::MaxNode::make(x, y); - Expr identity_element = min_value(source.dtype()); + PrimExpr result = ir::MaxNode::make(x, y); + PrimExpr identity_element = min_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } -Expr min(Expr source, Array rdom) { +PrimExpr min(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::MinNode::make(x, y); - Expr identity_element = max_value(source.dtype()); + PrimExpr result = ir::MinNode::make(x, y); + PrimExpr identity_element = max_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } -Expr prod(Expr source, Array rdom) { +PrimExpr prod(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::MulNode::make(x, y); - Expr identity_element = make_const(source.dtype(), 1); + PrimExpr result = ir::MulNode::make(x, y); + PrimExpr identity_element = make_const(source.dtype(), 1); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } -Expr fmod(Expr x, Expr y) { +PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; return ir::CallNode::make(x.dtype(), "fmod", { x, y }, ir::CallNode::PureIntrinsic); } -Expr floor(Expr x) { +PrimExpr floor(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value)); return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic); } -Expr ceil(Expr x) { +PrimExpr ceil(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value)); return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic); } -Expr round(Expr x) { +PrimExpr round(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic); } -Expr nearbyint(Expr x) { +PrimExpr nearbyint(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic); } -Expr trunc(Expr x) { +PrimExpr trunc(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 6b777cc5e887..ad7f260226bd 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -31,41 +31,41 @@ namespace tvm { namespace ir { // constructors -Expr UIntImmNode::make(DataType t, uint64_t value) { +PrimExpr UIntImmNode::make(DataType t, uint64_t value) { CHECK(t.is_uint() && t.lanes() == 1) << "ValueError: UIntImm can only take scalar"; ObjectPtr node = make_object(); node->dtype = t; node->value = value; - return Expr(node); + return PrimExpr(node); } -Expr FloatImmNode::make(DataType t, double value) { +PrimExpr FloatImmNode::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) << "ValueError: FloatImm can only take scalar"; ObjectPtr node = make_object(); node->dtype = t; node->value = value; - return Expr(node); + return PrimExpr(node); } -Expr StringImmNode::make(std::string value) { +PrimExpr StringImmNode::make(std::string value) { ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); - return Expr(node); + return PrimExpr(node); } -Expr CastNode::make(DataType t, Expr value) { +PrimExpr CastNode::make(DataType t, PrimExpr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); - return Expr(node); + return PrimExpr(node); } -Expr AndNode::make(Expr a, Expr b) { +PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); @@ -76,10 +76,10 @@ Expr AndNode::make(Expr a, Expr b) { node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); - return Expr(node); + return PrimExpr(node); } -Expr OrNode::make(Expr a, Expr b) { +PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); @@ -90,20 +90,20 @@ Expr OrNode::make(Expr a, Expr b) { node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); - return Expr(node); + return PrimExpr(node); } -Expr NotNode::make(Expr a) { +PrimExpr NotNode::make(PrimExpr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); - return Expr(node); + return PrimExpr(node); } -Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) { +PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(false_value.defined()) << "ValueError: true_value is undefined"; @@ -116,10 +116,10 @@ Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) { node->condition = std::move(condition); node->true_value = std::move(true_value); node->false_value = std::move(false_value); - return Expr(node); + return PrimExpr(node); } -Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) { +PrimExpr LoadNode::make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { CHECK(buffer_var.defined()); CHECK(predicate.defined()); CHECK(index.defined()); @@ -132,10 +132,10 @@ Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) node->index = std::move(index); node->predicate = std::move(predicate); - return Expr(node); + return PrimExpr(node); } -Expr RampNode::make(Expr base, Expr stride, int lanes) { +PrimExpr RampNode::make(PrimExpr base, PrimExpr stride, int lanes) { CHECK(base.defined()); CHECK(stride.defined()); CHECK(base.dtype().is_scalar()); @@ -148,10 +148,10 @@ Expr RampNode::make(Expr base, Expr stride, int lanes) { node->base = base; node->stride = stride; node->lanes = lanes; - return Expr(node); + return PrimExpr(node); } -Expr BroadcastNode::make(Expr value, int lanes) { +PrimExpr BroadcastNode::make(PrimExpr value, int lanes) { CHECK(value.defined()); CHECK(value.dtype().is_scalar()); CHECK_GT(lanes, 1); @@ -160,10 +160,10 @@ Expr BroadcastNode::make(Expr value, int lanes) { node->dtype = value.dtype().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; - return Expr(node); + return PrimExpr(node); } -Expr LetNode::make(Var var, Expr value, Expr body) { +PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { CHECK(value.defined()); CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); @@ -173,7 +173,7 @@ Expr LetNode::make(Var var, Expr value, Expr body) { node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); - return Expr(node); + return PrimExpr(node); } const char* CallNode::vectorizable_intrinsics[] = { @@ -192,9 +192,9 @@ bool CallNode::is_vectorizable() const { return false; } -Expr CallNode::make(DataType dtype, +PrimExpr CallNode::make(DataType dtype, std::string name, - Array args, + Array args, CallType call_type, FunctionRef func, int value_index) { @@ -215,18 +215,18 @@ Expr CallNode::make(DataType dtype, node->call_type = call_type; node->func = std::move(func); node->value_index = value_index; - return Expr(node); + return PrimExpr(node); } -Expr ShuffleNode::make(Array vectors, - Array indices) { +PrimExpr ShuffleNode::make(Array vectors, + Array indices) { CHECK_NE(vectors.size(), 0U); CHECK_NE(indices.size(), 0U); DataType base_type = vectors[0].dtype().element_of(); int total_lanes = 0; - for (Expr val : vectors) { + for (PrimExpr val : vectors) { CHECK(val.dtype().element_of() == base_type); total_lanes += val.dtype().lanes(); } @@ -236,17 +236,17 @@ Expr ShuffleNode::make(Array vectors, node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); - return Expr(node); + return PrimExpr(node); } -Expr ShuffleNode::make_concat(Array vectors) { +PrimExpr ShuffleNode::make_concat(Array vectors) { CHECK_NE(vectors.size(), 0); if (vectors.size() == 1) { return vectors[0]; } - Array indices; + Array indices; int index = 0; - for (const Expr& e : vectors) { + for (const PrimExpr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { indices.push_back(IntImmNode::make(DataType::Int(32), index++)); } @@ -254,14 +254,14 @@ Expr ShuffleNode::make_concat(Array vectors) { return make(vectors, indices); } -Expr ShuffleNode::make_extract_element(Expr vector, int index) { +PrimExpr ShuffleNode::make_extract_element(PrimExpr vector, int index) { return make({vector}, {Integer(index)}); } CommReducer CommReducerNode::make(Array lhs, Array rhs, - Array result, - Array identity_element) { + Array result, + Array identity_element) { auto node = make_object(); node->lhs = lhs; node->rhs = rhs; @@ -270,22 +270,22 @@ CommReducer CommReducerNode::make(Array lhs, return CommReducer(node); } -Array CommReducerNode::operator()(Array a, Array b) const { +Array CommReducerNode::operator()(Array a, Array b) const { CHECK_EQ(a.size(), b.size()); CHECK_EQ(lhs.size(), a.size()); CHECK_EQ(rhs.size(), b.size()); - Map value_map; + Map value_map; for (size_t i = 0; i < a.size(); ++i) { value_map.Set(lhs[i], a[i]); value_map.Set(rhs[i], b[i]); } - return UpdateArray(result, [&value_map] (const Expr& e) { + return UpdateArray(result, [&value_map] (const PrimExpr& e) { return Substitute(e, value_map); }); } -Expr ReduceNode::make(CommReducer combiner, Array source, - Array axis, Expr condition, int value_index) { +PrimExpr ReduceNode::make(CommReducer combiner, Array source, + Array axis, PrimExpr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; @@ -304,15 +304,15 @@ Expr ReduceNode::make(CommReducer combiner, Array source, n->axis = std::move(axis); n->condition = condition; n->value_index = value_index; - return Expr(n); + return PrimExpr(n); } -Expr AnyNode::make() { +PrimExpr AnyNode::make() { auto n = make_object(); - return Expr(n); + return PrimExpr(n); } -Stmt LetStmtNode::make(Var var, Expr value, Stmt body) { +Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { CHECK(value.defined()); CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); @@ -326,7 +326,7 @@ Stmt LetStmtNode::make(Var var, Expr value, Stmt body) { Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, - Expr value, + PrimExpr value, Stmt body) { auto n = make_object(); n->node = node; @@ -336,7 +336,7 @@ Stmt AttrStmtNode::make(ObjectRef node, return Stmt(n); } -Stmt AssertStmtNode::make(Expr condition, Expr message, Stmt body) { +Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { CHECK(condition.defined()); CHECK(message.dtype() == DataType::Int(32) || message.as()) @@ -361,8 +361,8 @@ Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { } Stmt ForNode::make(Var loop_var, - Expr min, - Expr extent, + PrimExpr min, + PrimExpr extent, ForType for_type, DeviceAPI device_api, Stmt body) { @@ -383,7 +383,7 @@ Stmt ForNode::make(Var loop_var, return Stmt(node); } -Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) { +Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { CHECK(value.defined()); CHECK(index.defined()); CHECK(predicate.defined()); @@ -398,7 +398,7 @@ Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) { return Stmt(node); } -Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array args) { +Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { CHECK(value_index >=0 && value_index < func->num_outputs()) << "value index output function return value bound"; CHECK(value.defined()) << "Provide of undefined value\n"; @@ -417,10 +417,10 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array extents, - Expr condition, + Array extents, + PrimExpr condition, Stmt body, - Expr new_expr, + PrimExpr new_expr, std::string free_function) { for (size_t i = 0; i < extents.size(); ++i) { CHECK(extents[i].defined()); @@ -441,7 +441,7 @@ Stmt AllocateNode::make(Var buffer_var, return Stmt(node); } -int32_t AllocateNode::constant_allocation_size(const Array& extents) { +int32_t AllocateNode::constant_allocation_size(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { if (const IntImmNode *int_size = extents[i].as()) { @@ -466,7 +466,7 @@ Stmt RealizeNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds, - Expr condition, + PrimExpr condition, Stmt body) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); @@ -510,7 +510,7 @@ SeqStmt::SeqStmt(Array seq) { data_ = std::move(node); } -Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) { +Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { CHECK(condition.defined()); CHECK(then_case.defined()); // else_case may be null. @@ -522,7 +522,7 @@ Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) { return Stmt(node); } -Stmt EvaluateNode::make(Expr value) { +Stmt EvaluateNode::make(PrimExpr value) { CHECK(value.defined()); ObjectPtr node = make_object(); diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index f797700394e8..35b40295e7a9 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -28,12 +28,12 @@ namespace tvm { // Tensor -Expr Tensor::operator()(Array indices) const { - Array arr(indices.begin(), indices.end()); +PrimExpr Tensor::operator()(Array indices) const { + Array arr(indices.begin(), indices.end()); return operator()(arr); } -Expr Tensor::operator()(Array indices) const { +PrimExpr Tensor::operator()(Array indices) const { using ir::CallNode; if (ndim() != 0) { CHECK_EQ(ndim(), indices.size()) @@ -55,7 +55,7 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor TensorNode::make(Array shape, +Tensor TensorNode::make(Array shape, DataType dtype, Operation op, int value_index) { @@ -114,7 +114,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, Array regions, Array reduce_axis, - Array scalar_inputs) { + Array scalar_inputs) { auto n = make_object(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 0ad68b1cba40..7c8427a1f42f 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -75,10 +75,10 @@ DataType ComputeOpNode::output_dtype(size_t idx) const { return body[idx].dtype(); } -Array BaseComputeOpNode::output_shape(size_t idx) const { +Array BaseComputeOpNode::output_shape(size_t idx) const { CHECK_LT(idx, num_outputs()); // for now, all outputs of a BaseComputeOp have the same shape - Array shape; + Array shape; for (const auto& ivar : this->axis) { const Range& r = ivar->dom; shape.push_back(r->extent); @@ -86,7 +86,7 @@ Array BaseComputeOpNode::output_shape(size_t idx) const { return shape; } -Tensor compute(Array shape, +Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, @@ -108,7 +108,7 @@ Tensor compute(Array shape, name, tag, attrs, axis, {fcompute(args)}).output(0); } -Array compute(Array shape, +Array compute(Array shape, FBatchCompute fcompute, std::string name, std::string tag, @@ -138,7 +138,7 @@ Operation ComputeOpNode::make(std::string name, std::string tag, Map attrs, Array axis, - Array body) { + Array body) { if (!attrs.defined()) { attrs = Map(); } @@ -180,24 +180,24 @@ Operation ComputeOpNode::ReplaceInputs( const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); VerifyComputeOp(this); - Array arr; + Array arr; if (this->body[0]->IsInstance()) { // Specially handle reduce so the replaced op // still share all the components - Expr new_reduce = op::ReplaceTensor(this->body[0], rmap); + PrimExpr new_reduce = op::ReplaceTensor(this->body[0], rmap); if (!new_reduce.same_as(this->body[0])) { const ir::ReduceNode* r = new_reduce.as(); for (size_t k = 0; k < this->body.size(); ++k) { auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); - arr.push_back(Expr(n)); + arr.push_back(PrimExpr(n)); } } else { arr = this->body; } } else { - arr = UpdateArray(this->body, [&rmap] (const Expr& e) { + arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) { return op::ReplaceTensor(e, rmap); }); } @@ -229,10 +229,10 @@ void ComputeOpNode::PropBoundToInputs( IntSet arg_intset = EvalSet(call->args[i], dom_map); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { - Expr shape_i_min_value = make_zero(t->shape[i].dtype()); - Expr shape_i_max_value = t->shape[i] - 1; - Expr min_value = arg_interval->min_value; - Expr max_value = arg_interval->max_value; + PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); + PrimExpr shape_i_max_value = t->shape[i] - 1; + PrimExpr min_value = arg_interval->min_value; + PrimExpr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. if (arith::is_neg_inf(min_value) || analyzer->CanProve(shape_i_min_value >= min_value)) { @@ -290,7 +290,7 @@ Stmt BaseComputeOpNode::BuildRealize( if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; if (attr->dim_align_factor != 0) { - Array tuple = {static_cast(i), + Array tuple = {static_cast(i), attr->dim_align_factor, attr->dim_align_offset}; realize = ir::AttrStmtNode::make( @@ -315,7 +315,7 @@ void MakeReduction(const ComputeOpNode* op, const Array& tensors, Stmt* init, Stmt* provide) { - Array args; + Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } @@ -326,12 +326,12 @@ void MakeReduction(const ComputeOpNode* op, CHECK(reduce); const CommReducerNode* combiner = reduce->combiner.as(); CHECK(combiner); - Array lhs; + Array lhs; for (size_t i = 0; i < size; ++i) { lhs.push_back(tensors[i](args)); } - Array init_value = combiner->identity_element; - Array update_value = (*combiner)(lhs, reduce->source); + Array init_value = combiner->identity_element; + Array update_value = (*combiner)(lhs, reduce->source); for (size_t i = 0; i < size; ++i) { Tensor t = tensors[i]; inits.emplace_back(ProvideNode::make( @@ -349,7 +349,7 @@ void MakeReduction(const ComputeOpNode* op, // Normal computation. Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { - Array args; + Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } @@ -555,7 +555,7 @@ class ComputeVerifier final : protected ir::ExprVisitor { /// Interface to perform compute verification void Run() { - for (const Expr e : compute_->body) { + for (const PrimExpr e : compute_->body) { // Check for consistency of top level reductions const ir::ReduceNode* reduce = e.as(); CHECK((reduce && reduce_) || (!reduce && !reduce_)) @@ -576,7 +576,7 @@ class ComputeVerifier final : protected ir::ExprVisitor { protected: /// Visitor implementation //@{ - void VisitExpr(const Expr& n) final { + void VisitExpr(const PrimExpr& n) final { ++level_; ExprVisitor::VisitExpr(n); --level_; @@ -608,7 +608,7 @@ Stmt TransformUpdate(const Stage& stage, const ComputeLoopNest& n, Stmt body, Stmt update) { - Array conds; + Array conds; std::unordered_set banned; for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { IterVar iv = stage->leaf_iter_vars[i]; @@ -627,7 +627,7 @@ Stmt TransformUpdate(const Stage& stage, banned.insert(iv->var.get()); } } - for (const Expr& pred : n.main_predicates) { + for (const PrimExpr& pred : n.main_predicates) { if (ir::ExprUseVar(pred, banned)) { LOG(FATAL) << "Tensorize update transform failed, the condition " << pred << " has a conflict with the reset condition"; diff --git a/src/op/compute_op.h b/src/op/compute_op.h index f5735d887333..3fe98e869541 100644 --- a/src/op/compute_op.h +++ b/src/op/compute_op.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -38,17 +38,17 @@ struct ComputeLoopNest { // The common number of loops between init and main size_t num_common_loop; // predicates for the initialize loop - std::vector init_predicates; + std::vector init_predicates; // Initialization nest involved. std::vector > init_nest; // Value map for the init code - std::unordered_map init_vmap; + std::unordered_map init_vmap; // Predicates for the main update loop - std::vector main_predicates; + std::vector main_predicates; // The general loop nest std::vector > main_nest; // Value map for the IterVar. - std::unordered_map main_vmap; + std::unordered_map main_vmap; /*! * \brief constructor to build ComputeOpNest diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc index 89d0ca756b71..9de4bde19b17 100644 --- a/src/op/cross_thread_reduction.cc +++ b/src/op/cross_thread_reduction.cc @@ -33,11 +33,11 @@ Stmt MakeCrossThreadReduction( const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { - Array args; + Array args; for (IterVar iv : self->axis) { args.push_back(iv->var); } - std::unordered_map value_map; + std::unordered_map value_map; auto nest = op::MakeLoopNest( stage, dom_map, 0, false, std::unordered_set(), &value_map, debug_keep_trivial_loop); auto conds = schedule::MakeBoundCheck( @@ -52,11 +52,11 @@ Stmt MakeCrossThreadReduction( CHECK(reduce); reduces[i] = reduce; } - Expr cond = reduces[0]->condition; - for (Expr v : conds) { + PrimExpr cond = reduces[0]->condition; + for (PrimExpr v : conds) { cond = cond && v; } - Array freduce_args; + Array freduce_args; freduce_args.push_back(make_const(DataType::UInt(32), static_cast(size))); for (size_t i = 0; i < size; ++i) { freduce_args.push_back(reduces[0]->source[i]); @@ -79,7 +79,7 @@ Stmt MakeCrossThreadReduction( } } // Checks for the thread. - std::vector thread_head_check; + std::vector thread_head_check; if (stage->store_predicate.defined()) { thread_head_check.emplace_back(stage->store_predicate); } diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index ee958da25de6..6fc54a8bd506 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -50,7 +50,7 @@ DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; } -Array ExternOpNode::output_shape(size_t i) const { +Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } @@ -162,7 +162,7 @@ Stmt ExternOpNode::BuildProvide( Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; - Array tuple; + Array tuple; bind_spec.push_back(buffer); bind_spec.push_back(tensor); for (size_t k = 0; k < buffer->shape.size(); ++k) { diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 5364c38d0d81..c3be2346b55b 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -56,7 +56,7 @@ DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } -Array HybridOpNode::output_shape(size_t i) const { +Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } @@ -222,7 +222,7 @@ namespace op { Stmt ApplyLoopShapes(const Stage &stage, const std::unordered_map &dom_map, Stmt stmt) { class LoopSpliter : public StmtExprMutator { - Expr factor; + PrimExpr factor; const VarNode *parent; IterVar inner, outer; @@ -249,14 +249,14 @@ Stmt ApplyLoopShapes(const Stage &stage, Stmt VisitStmt_(const ForNode *op) final { if (op->loop_var.get() == parent) { - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = inner + outer * factor; Stmt ret = ir::Substitute(op->body, rmap); - Expr cond = likely(outer * factor < (op->extent - inner)); + PrimExpr cond = likely(outer * factor < (op->extent - inner)); ret = IfThenElseNode::make(cond, ret); - ret = ForNode::make(inner->var, Expr(0), inner->dom->extent, + ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent, IterVarTypeToForType(inner->iter_type), op->device_api, ret); - ret = ForNode::make(outer->var, Expr(0), outer->dom->extent, + ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent, IterVarTypeToForType(outer->iter_type), op->device_api, ret); splitted = true; return ret; @@ -270,7 +270,7 @@ Stmt ApplyLoopShapes(const Stage &stage, const VarNode *inner; const VarNode *outer; bool under_outer; - Expr extent; + PrimExpr extent; public: bool fused; @@ -283,7 +283,7 @@ Stmt ApplyLoopShapes(const Stage &stage, Stmt VisitStmt_(const ForNode* op) final { if (op->loop_var.get() == inner) { CHECK(under_outer); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(parent, op->extent); extent = op->extent; fused = true; @@ -291,15 +291,15 @@ Stmt ApplyLoopShapes(const Stage &stage, } else if (op->loop_var.get() == outer) { under_outer = true; Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexdiv(parent, extent); body = ir::Substitute(body, rmap); under_outer = false; - return ForNode::make(parent->var, Expr(0), extent * op->extent, + return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api, body); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); body = ir::Substitute(body, rmap); extent = extent * op->extent; @@ -342,7 +342,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, CHECK(Equal(iter_var->dom->extent, op->extent)) << "Thread extent and loop extent mismatch!\n"; } - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = iter_var; Stmt body = ir::Substitute(op->body, rmap); return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body); diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 31d736d98306..e108ad312b0a 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -42,14 +42,14 @@ MakeLoopNest(const Stage& stage, size_t begin_iter_pos, bool new_loop_var, const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, + std::unordered_map* p_value_map, bool debug_keep_trivial_loop) { auto leaf_iter_vars = stage->leaf_iter_vars; Stmt no_op = EvaluateNode::make(0); // create the loop nest std::vector > nest; nest.resize(leaf_iter_vars.size() + 1); - std::unordered_map& value_map = *p_value_map; + std::unordered_map& value_map = *p_value_map; for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) { auto iv = leaf_iter_vars[i]; @@ -96,7 +96,7 @@ MakeLoopNest(const Stage& stage, CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size()); for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) { const std::string& pkey = it_attr->pragma_keys[k].as()->value; - Expr pvalue = it_attr->pragma_values[k]; + PrimExpr pvalue = it_attr->pragma_values[k]; if (!pvalue.defined()) { pvalue = make_const(DataType::Int(32), 1); } @@ -118,7 +118,7 @@ MakeLoopNest(const Stage& stage, nest[i + 1].emplace_back( ForNode::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); - Expr new_value = dom->min + idx; + PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; nest[i + 1].emplace_back( LetStmtNode::make(var, new_value, no_op)); @@ -176,10 +176,10 @@ MakeLoopNest(const Stage& stage, return nest; } -std::vector MakeIfNest(const std::vector& predicates) { +std::vector MakeIfNest(const std::vector& predicates) { Stmt no_op = EvaluateNode::make(0); std::vector nest; - for (const Expr& cond : predicates) { + for (const PrimExpr& cond : predicates) { nest.emplace_back(IfThenElseNode::make(cond, no_op)); } return nest; @@ -191,12 +191,12 @@ class TensorReplacer : public ir::StmtExprMutator { explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - Expr VisitExpr_(const ir::CallNode* op) final { + PrimExpr VisitExpr_(const ir::CallNode* op) final { if (op->call_type == ir::CallNode::Halide) { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { - Expr ret = ir::CallNode::make( + PrimExpr ret = ir::CallNode::make( op->dtype, it->second->op->name, op->args, op->call_type, it->second->op, it->second->value_index); found = true; @@ -219,17 +219,17 @@ Stmt ReplaceTensor(Stmt stmt, Stmt ret = repl(stmt); return repl.found ? ret : stmt; } -Expr ReplaceTensor(Expr expr, +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace) { TensorReplacer repl(replace); - Expr ret = repl(expr); + PrimExpr ret = repl(expr); return repl.found ? ret : expr; } Stmt Substitute(Stmt s, - const std::unordered_map& value_map) { - std::unordered_map init; + const std::unordered_map& value_map) { + std::unordered_map init; for (const auto& kv : value_map) { init[kv.first->var.get()] = kv.second; } diff --git a/src/op/op_util.h b/src/op/op_util.h index b57000f66e6b..cea050b2695d 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -55,7 +55,7 @@ MakeLoopNest(const Stage& stage, size_t begin_iter_pos, bool new_loop_var, const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, + std::unordered_map* p_value_map, bool debug_keep_trivial_loop); /*! @@ -64,7 +64,7 @@ MakeLoopNest(const Stage& stage, * \param predicates The predicates to be checked. * \return List of If nest that checks the predicates. */ -std::vector MakeIfNest(const std::vector& predicates); +std::vector MakeIfNest(const std::vector& predicates); /*! * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. @@ -78,7 +78,7 @@ Stmt ReplaceTensor(Stmt stmt, * \param expr The expression to be processed. * \param replace The replacement rule. */ -Expr ReplaceTensor(Expr expr, +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace); /*! @@ -88,7 +88,7 @@ Expr ReplaceTensor(Expr expr, * \return Substituted result. */ Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); + const std::unordered_map& value_map); /*! * \brief Converts Halide ForType to its corresponding IterVarType diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index 2ec10caf07a9..22e0ad4de34e 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -47,13 +47,13 @@ DataType PlaceholderOpNode::output_dtype(size_t i) const { return dtype; } -Array PlaceholderOpNode::output_shape(size_t i) const { +Array PlaceholderOpNode::output_shape(size_t i) const { CHECK_EQ(i, 0U); return shape; } Operation PlaceholderOpNode::make(std::string name, - Array shape, + Array shape, DataType dtype) { auto n = make_object(); n->name = name; @@ -62,7 +62,7 @@ Operation PlaceholderOpNode::make(std::string name, return Operation(n); } -Tensor placeholder(Array shape, DataType dtype, std::string name) { +Tensor placeholder(Array shape, DataType dtype, std::string name) { return PlaceholderOpNode::make(name, shape, dtype).output(0); } diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index c4c0960c5fae..f7b16f28acbf 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -38,7 +38,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_REGISTER_NODE_TYPE(ScanOpNode); -inline bool prove_equal(Expr lhs, Expr rhs) { +inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) { return is_zero(ir::Simplify(lhs - rhs)); } @@ -57,7 +57,7 @@ DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } -Array ScanOpNode::output_shape(size_t i) const { +Array ScanOpNode::output_shape(size_t i) const { CHECK_LT(i, state_placeholder.size()); return state_placeholder[i]->shape; } @@ -232,7 +232,7 @@ void ScanOpNode::GatherBound( Range r = arith::Union(time_dom).cover_range(sdom); (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent( sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); - Map fix_pt = ScanFixPointAnalysis(self); + Map fix_pt = ScanFixPointAnalysis(self); // Update for spatial axis. size_t sp_idx = 0; for (size_t i = 0; i < output.size(); ++i) { @@ -295,7 +295,7 @@ Stmt ScanOpNode::BuildProvide( begin_scan = i + 1; } } - std::unordered_map vmap; + std::unordered_map vmap; std::unordered_set empty; auto nest = op::MakeLoopNest( stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index e0656ea0462d..08df8b7f64ed 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -57,7 +57,7 @@ Operation TensorComputeOpNode::make(std::string name, TensorIntrin intrin, Array tensors, Array regions, - Array scalar_inputs) { + Array scalar_inputs) { auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); @@ -147,7 +147,7 @@ Stmt TensorComputeOpNode::BuildProvide( Buffer buffer = this->intrin->buffers[i]; Array bind_spec{buffer, tensor}; - Array tuple; + Array tuple; for (size_t i = 0; i < region.size(); ++i) { tuple.push_back(region[i]->min); tuple.push_back(region[i]->extent); @@ -165,7 +165,7 @@ Stmt TensorComputeOpNode::BuildProvide( Buffer buffer = this->intrin->buffers[num_inputs + i]; Array bind_spec{buffer, tensor}; - Array tuple; + Array tuple; for (size_t i = 0; i < this->axis.size(); ++i) { auto ivar = this->axis[i]; if (i < static_cast(this->schedulable_ndim)) { @@ -186,16 +186,16 @@ Stmt TensorComputeOpNode::BuildProvide( } // Check variable remap - std::unordered_map vmap; + std::unordered_map vmap; ir::ArgBinder binder(&vmap); // Map the expressions passed in the call to the TensorIntrin, to the placeholder // variables - Array user_expr = this->scalar_inputs; + Array user_expr = this->scalar_inputs; Array scalar_params = this->intrin->scalar_params; - Array sp_expr; + Array sp_expr; for (auto sp : scalar_params) { - Expr esp = sp; + PrimExpr esp = sp; sp_expr.push_back(esp); } CHECK_EQ(sp_expr.size(), user_expr.size()); diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 601c4446ac87..4460d905a54d 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -144,13 +144,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, } } } - for (const Expr& pred : n.main_predicates) { + for (const PrimExpr& pred : n.main_predicates) { if (ir::ExprUseVar(pred, banned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } } - for (const Expr& pred : n.init_predicates) { + for (const PrimExpr& pred : n.init_predicates) { if (ir::ExprUseVar(pred, banned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; @@ -161,8 +161,8 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, // Remap the tensor placeholder, index and inline things. class TensorIntrinMatcher final : public StmtExprMutator { public: - Expr VisitExpr_(const CallNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->call_type == CallNode::Halide) { Tensor t = Downcast(op->func).output(op->value_index); @@ -170,7 +170,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { if (it != in_remap_.end()) { const InputEntry& e = it->second; CHECK_EQ(op->args.size(), e.region.size()); - Array args; + Array args; for (size_t i = e.start; i < e.region.size(); ++i) { args.push_back(op->args[i] - e.region[i]->min); } @@ -182,17 +182,17 @@ class TensorIntrinMatcher final : public StmtExprMutator { return expr; } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { return it->second; } else { - return GetRef(op); + return GetRef(op); } } - Expr VisitExpr_(const ReduceNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const ReduceNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); Array axis; for (size_t i = 0; i < op->axis.size(); ++i) { @@ -301,13 +301,13 @@ class TensorIntrinMatcher final : public StmtExprMutator { // input data remap std::unordered_map in_remap_; // variable remap. - std::unordered_map var_remap_; + std::unordered_map var_remap_; // IterVar remap. std::unordered_map axis_remap_; }; // Try to match tensor dataflow of the stage with the intrinsic -Array MatchTensorizeBody( +Array MatchTensorizeBody( const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, @@ -317,8 +317,8 @@ Array MatchTensorizeBody( Map* compute_intrin_iter_space) { TensorIntrinMatcher matcher; matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space); - Array ret; - for (Expr expr : self->body) { + Array ret; + for (PrimExpr expr : self->body) { ret.push_back(matcher(expr)); } return ret; @@ -332,16 +332,16 @@ void VerifyTensorizeBody( const std::unordered_map >& in_region, const TensorIntrin& intrin) { Map compute_intrin_iter_space; - Array body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin, + Array body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin, &compute_intrin_iter_space); const ComputeOpNode* intrin_compute = intrin->op.as(); CHECK(intrin_compute) << "Only support compute intrinsic for now"; CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; for (size_t i = 0; i < body.size(); ++i) { - Expr lhs = Simplify(body[i], compute_intrin_iter_space); + PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space); lhs = CanonicalSimplify(lhs, compute_intrin_iter_space); - Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); + PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); rhs = CanonicalSimplify(rhs, compute_intrin_iter_space); if (lhs.dtype() != rhs.dtype()) { LOG(FATAL) @@ -385,7 +385,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, auto it = in_region.find(tensor); CHECK(it != in_region.end()); const Array& region = it->second; - Array tuple; + Array tuple; for (const Range r : region) { tuple.push_back(r->min); tuple.push_back(r->extent); @@ -401,7 +401,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, CHECK(intrin_compute) << "Only support compute intrinsic for now"; CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size()); CHECK_EQ(intrin_compute->body.size(), self->body.size()); - Array tuple; + Array tuple; for (IterVar iv : self->axis) { auto it = out_dom.find(iv); CHECK(it != out_dom.end()); @@ -419,7 +419,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, tuple, CallNode::Intrinsic), nop)); } // Check variable remap - std::unordered_map vmap; + std::unordered_map vmap; ir::ArgBinder binder(&vmap); CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) << "Tensorization fail: reduction axis size do not match"; diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 340f3a83608e..2c04de3710fa 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -31,10 +31,10 @@ namespace tvm { namespace ir { -void BinderAddAssert(Expr cond, +void BinderAddAssert(PrimExpr cond, const std::string& arg_name, std::vector* asserts) { - Expr scond = Simplify(cond); + PrimExpr scond = Simplify(cond); if (is_zero(scond)) { LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " << " on argument " << arg_name; @@ -46,8 +46,8 @@ void BinderAddAssert(Expr cond, } } -bool ArgBinder::Bind_(const Expr& arg, - const Expr& value, +bool ArgBinder::Bind_(const PrimExpr& arg, + const PrimExpr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.dtype(), value.dtype()); @@ -72,15 +72,15 @@ bool ArgBinder::Bind_(const Expr& arg, return false; } -void ArgBinder::Bind(const Expr& arg, - const Expr& value, +void ArgBinder::Bind(const PrimExpr& arg, + const PrimExpr& value, const std::string& arg_name, bool with_let) { Bind_(arg, value, arg_name, with_let); } -void ArgBinder::BindArray(const Array& arg, - const Array& value, +void ArgBinder::BindArray(const Array& arg, + const Array& value, const std::string& arg_name) { CHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; @@ -117,9 +117,9 @@ void ArgBinder::BindBuffer(const Buffer& arg, this->Bind(arg->data, value->data, arg_name + ".data"); if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) { if (arg->offset_factor > 1) { - Expr offset = value->elem_offset; - Expr factor = make_const(offset.dtype(), arg->offset_factor); - Expr zero = make_zero(offset.dtype()); + PrimExpr offset = value->elem_offset; + PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); } @@ -153,21 +153,21 @@ void ArgBinder::BindBuffer(const Buffer& arg, } } -inline Expr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) { +inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) { return TVMStructGet(t, arr, 0, kind); } void ArgBinder::BindDLTensor(const Buffer& buffer, - const Expr& device_type, - const Expr& device_id, + const PrimExpr& device_type, + const PrimExpr& device_id, const Var& handle, const std::string& arg_name) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = EvaluateNode::make(0); // dimension checks - Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); - Expr a_ndim = make_const(tvm_ndim_type, + PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); + PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; ndim_err_msg << arg_name @@ -178,7 +178,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, DataType dtype = buffer->dtype; std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; - Expr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == + PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == UIntImmNode::make(DataType::UInt(8), dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == UIntImmNode::make(DataType::UInt(8), dtype.bits()) && @@ -215,17 +215,17 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, init_nest_.emplace_back(LetStmtNode::make( v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); - Expr is_null = CallNode::make( + PrimExpr is_null = CallNode::make( DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); - Expr expect_stride = make_const(stype, 1); - Array conds; + PrimExpr expect_stride = make_const(stype, 1); + Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - Expr svalue = cast( + PrimExpr svalue = cast( stype, LoadNode::make(tvm_shape_type, v_strides, IntImmNode::make(DataType::Int(32), k), const_true(1))); @@ -237,19 +237,19 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, << " expected to be compact array"; if (conds.size() != 0) { Stmt check = - AssertStmtNode::make(arith::ComputeReduce(conds, Expr()), + AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), stride_err_msg.str(), EvaluateNode::make(0)); check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { DataType stype = buffer->DefaultIndexType(); - Expr stride = make_const(stype, 1); + PrimExpr stride = make_const(stype, 1); for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; - Expr value = cast(buffer->shape[k].dtype(), + PrimExpr value = cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, IntImmNode::make(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); @@ -288,9 +288,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, make_const(DataType::UInt(64), data_bytes))), arg_name + ".elem_offset", true)) { if (buffer->offset_factor > 1) { - Expr offset = buffer->elem_offset; - Expr factor = make_const(offset.dtype(), buffer->offset_factor); - Expr zero = make_zero(offset.dtype()); + PrimExpr offset = buffer->elem_offset; + PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); } } diff --git a/src/pass/arg_binder.h b/src/pass/arg_binder.h index 55d8c22edfe7..75006d68dfba 100644 --- a/src/pass/arg_binder.h +++ b/src/pass/arg_binder.h @@ -62,7 +62,7 @@ class ArgBinder { * ArgBinder will update this def_map when adding new definitions. */ explicit ArgBinder( - std::unordered_map* def_map) + std::unordered_map* def_map) : def_map_(def_map) { } /*! @@ -72,8 +72,8 @@ class ArgBinder { * \param arg_name argument name. * \param with_let Whether add lets during bind */ - void Bind(const Expr& arg, - const Expr& value, + void Bind(const PrimExpr& arg, + const PrimExpr& value, const std::string& arg_name, bool with_let = false); /*! @@ -82,8 +82,8 @@ class ArgBinder { * \param value The target expression value * \param arg_name argument name. */ - void BindArray(const Array& arg, - const Array& value, + void BindArray(const Array& arg, + const Array& value, const std::string& arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer @@ -105,8 +105,8 @@ class ArgBinder { * \param arg_name argument name. */ void BindDLTensor(const Buffer& buffer, - const Expr& device_type, - const Expr& device_id, + const PrimExpr& device_type, + const PrimExpr& device_id, const Var& handle, const std::string& arg_name); @@ -133,24 +133,24 @@ class ArgBinder { return init_nest_; } /*! \return Handle data type of the data */ - const Map& def_handle_dtype() const { + const Map& def_handle_dtype() const { return def_handle_dtype_; } private: // Internal bind function - bool Bind_(const Expr& arg, - const Expr& value, + bool Bind_(const PrimExpr& arg, + const PrimExpr& value, const std::string& arg_name, bool with_lets); /*! \brief The definition map, can be uses to substitute */ - std::unordered_map* def_map_; + std::unordered_map* def_map_; /*! \brief defs generated in the current binder */ std::vector defs_; /*! \brief Initialize nest */ std::vector init_nest_; /*! \brief handle data type in the defintiions */ - Map def_handle_dtype_; + Map def_handle_dtype_; /*! \brief asserts generated */ std::vector asserts_; }; diff --git a/src/pass/bound_checker.cc b/src/pass/bound_checker.cc index 84939fc12426..439c8862c9c6 100644 --- a/src/pass/bound_checker.cc +++ b/src/pass/bound_checker.cc @@ -45,13 +45,13 @@ class BoundCollector : public StmtVisitor { StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape; + std::unordered_map mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: explicit BoundChecker( - const std::unordered_map &mem_to_shape) + const std::unordered_map &mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -62,7 +62,7 @@ class BoundChecker : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { unsafe_rewritten_ = true; } @@ -80,7 +80,7 @@ class BoundChecker : public StmtExprMutator { } // The collector should has at least one item. if (store_scope_bound_collector_.size()) { - Expr condition = MakeCondition(); + PrimExpr condition = MakeCondition(); if (!condition.as()) { Stmt nop = EvaluateNode::make(1); Stmt then_case = @@ -94,7 +94,7 @@ class BoundChecker : public StmtExprMutator { return GetRef(op); } - Expr VisitExpr_(const LoadNode* op) final { + PrimExpr VisitExpr_(const LoadNode* op) final { if (CanInstrument(op->index, op->buffer_var)) { Collect(op->index, op->buffer_var); } @@ -102,12 +102,12 @@ class BoundChecker : public StmtExprMutator { } private: - bool UpdateIsNeeded(const VarExpr& buffer_var) const { + bool UpdateIsNeeded(const Var& buffer_var) const { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const VarExpr& buffer_var, - const Array& new_shape, + void Update(const Var& buffer_var, + const Array& new_shape, const DataType& type) { // Sanity check at first. if (!new_shape.size()) { @@ -122,7 +122,7 @@ class BoundChecker : public StmtExprMutator { } // Scalarize the shape. - Expr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()), + PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()), CastNode::make(DataType::UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. @@ -132,7 +132,7 @@ class BoundChecker : public StmtExprMutator { mem_to_shape_[buffer_var.get()] = shape; } - bool IndexIsValid(const Expr& index) const { + bool IndexIsValid(const PrimExpr& index) const { if (!index.defined()) { return false; } @@ -146,22 +146,22 @@ class BoundChecker : public StmtExprMutator { return true; } - bool CanInstrument(const Expr& index, const VarExpr& buffer_var) const { + bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const { return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && !unsafe_rewritten_; } - void Collect(Expr index, VarExpr buffer_var) { + void Collect(PrimExpr index, Var buffer_var) { store_scope_bound_collector_.push_back( std::make_pair(index, mem_to_shape_[buffer_var.get()])); } - Expr MakeCondition() { - Expr condition; + PrimExpr MakeCondition() { + PrimExpr condition; for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) { - std::pair buffer_to_mem = store_scope_bound_collector_[i]; - Expr index = buffer_to_mem.first; - Expr upper_bound = buffer_to_mem.second; + std::pair buffer_to_mem = store_scope_bound_collector_[i]; + PrimExpr index = buffer_to_mem.first; + PrimExpr upper_bound = buffer_to_mem.second; if (const RampNode *ramp_index = index.as()) { // In case index is base + stride * i. @@ -181,9 +181,9 @@ class BoundChecker : public StmtExprMutator { upper_bound = CastNode::make(DataType::Int(64), upper_bound); // Looks like a lower bound should always be zero after normalization. - Expr lower_bound = make_zero(DataType::Int(64)); + PrimExpr lower_bound = make_zero(DataType::Int(64)); - Expr current_condition = + PrimExpr current_condition = AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound)); condition = !i ? current_condition : AndNode::make(condition, current_condition); @@ -196,11 +196,11 @@ class BoundChecker : public StmtExprMutator { // Whether we face tvm_if_then_else intrinsic. bool unsafe_rewritten_{false}; // Pool which collects the pair of index and shape for specific store/load. - std::vector> store_scope_bound_collector_; + std::vector> store_scope_bound_collector_; // Error message. const char *const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape_; + std::unordered_map mem_to_shape_; }; Stmt InstrumentBoundCheckers(Stmt stmt) { diff --git a/src/pass/combine_context_call.cc b/src/pass/combine_context_call.cc index 62ceedeca7d8..4561dba5469c 100644 --- a/src/pass/combine_context_call.cc +++ b/src/pass/combine_context_call.cc @@ -35,15 +35,15 @@ namespace ir { class ContextCallCombiner final : public StmtExprMutator { public: struct CompareExpr { - bool operator()(const Expr& lhs, const Expr& rhs) const { + bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { return Compare(lhs, rhs) < 0; } }; - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); - Expr ctx = op->args[0]; + PrimExpr ctx = op->args[0]; auto it = ctx_map_.find(ctx); if (it != ctx_map_.end()) { return it->second; @@ -68,7 +68,7 @@ class ContextCallCombiner final : public StmtExprMutator { if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) { // Map of comparison expression to variable - std::map temp; + std::map temp; std::swap(temp, ctx_map_); Stmt stmt = StmtExprMutator::VisitStmt_(op); std::swap(temp, ctx_map_); @@ -81,7 +81,7 @@ class ContextCallCombiner final : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->for_type == ForType::Parallel) { // Map of comparison expression to variable - std::map temp; + std::map temp; std::swap(temp, ctx_map_); Stmt stmt = StmtExprMutator::VisitStmt_(op); std::swap(temp, ctx_map_); @@ -96,7 +96,7 @@ class ContextCallCombiner final : public StmtExprMutator { } private: - static Stmt BuildContext(const std::map& cmap, + static Stmt BuildContext(const std::map& cmap, Stmt body) { for (const auto& kv : cmap) { body = LetStmtNode::make(kv.second, kv.first, body); @@ -104,7 +104,7 @@ class ContextCallCombiner final : public StmtExprMutator { return body; } // Map of comparison expression to variable - std::map ctx_map_; + std::map ctx_map_; }; LoweredFunc CombineContextCall(LoweredFunc f) { diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index a7afd46ffd0d..4e68793cc875 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -341,8 +341,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { Range r = arith::Union(wset).cover_range(none); CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; - Expr min = r->min; - Expr extent = r->extent; + PrimExpr min = r->min; + PrimExpr extent = r->extent; return EvaluateNode::make(CallNode::make( DataType::Int(32), func, {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic)); diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index 789877f850fb..aff7d3de9eb2 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -161,7 +161,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { }); return IRTransform(parent_for_stmt, nullptr, replace_target_for, - {Expr("For")}); + {PrimExpr("For")}); } // Remove IfThenElse node from a For node. @@ -188,10 +188,10 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { }); then_for = IRTransform(for_stmt, nullptr, replace_then_case, - {Expr("IfThenElse")}); + {PrimExpr("IfThenElse")}); if (if_stmt.as()->else_case) { else_for = IRTransform(for_stmt, nullptr, replace_else_case, - {Expr("IfThenElse")}); + {PrimExpr("IfThenElse")}); } return std::make_pair(then_for, else_for); @@ -412,7 +412,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { *ret = new_for; } }); - return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); + return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")}); } Stmt HoistIfThenElse(Stmt stmt) { diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc index 8f6c06d8abc9..6dfa509345be 100644 --- a/src/pass/infer_fragment.cc +++ b/src/pass/infer_fragment.cc @@ -185,7 +185,7 @@ class InferFragmenter : public StmtMutator { std::string shape = std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k); - Expr shape_expr = StringImmNode::make(shape); + PrimExpr shape_expr = StringImmNode::make(shape); Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); if (info.layout != "") { // Add shape attribute to matrix_a and matrix_b diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 0a19c692db96..e41a868a6372 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -69,7 +69,7 @@ class CopyIntrinInjector : public StmtMutator { if (store == nullptr) return false; // Expr sel_cond, sel_true_value, sel_false_value; // match select or if - PVar sel_cond, sel_true_value, sel_false_value; + PVar sel_cond, sel_true_value, sel_false_value; bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || select(sel_cond, sel_true_value, sel_false_value).Match(store->value); @@ -93,12 +93,12 @@ class CopyIntrinInjector : public StmtMutator { for (const ForNode* op : loops) { loop_vars.push_back(op->loop_var); } - Array store_strides = + Array store_strides = arith::DetectLinearEquation(store->index, loop_vars); - Array load_strides = + Array load_strides = arith::DetectLinearEquation(load->index, loop_vars); if (load_strides.size() == 0 || store_strides.size() == 0) return false; - Array dst_shape; + Array dst_shape; const size_t loop_var_size = loop_vars.size(); if (loop_var_size == 0) { dst_shape.push_back(make_const(DataType::Int(32), 1)); @@ -107,24 +107,24 @@ class CopyIntrinInjector : public StmtMutator { dst_shape.push_back(op->extent); } } - Array src_shape = dst_shape; - Array pad_before, pad_after; - Expr pad_value; - Expr src_elem_offset = load_strides[loop_var_size]; + Array src_shape = dst_shape; + Array pad_before, pad_after; + PrimExpr pad_value; + PrimExpr src_elem_offset = load_strides[loop_var_size]; if (has_cond) { - Array clip_bound = + Array clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars); pad_value = sel_false_value.Eval(); if (clip_bound.size() == 0) return false; CHECK_EQ(src_shape.size(), loop_vars.size()); CHECK_EQ(clip_bound.size(), loop_vars.size() * 2); for (size_t i = 0; i < src_shape.size(); ++i) { - Expr min_value = clip_bound[2 * i]; - Expr max_value = clip_bound[2 * i + 1]; + PrimExpr min_value = clip_bound[2 * i]; + PrimExpr max_value = clip_bound[2 * i + 1]; DataType t = loop_vars[i].dtype(); - Expr svalue = src_shape[i]; + PrimExpr svalue = src_shape[i]; if (min_value.defined()) { - Expr pbefore = Simplify(MaxNode::make(min_value, make_zero(t))); + PrimExpr pbefore = Simplify(MaxNode::make(min_value, make_zero(t))); src_elem_offset = src_elem_offset + pbefore * load_strides[i]; svalue = svalue - pbefore; pad_before.push_back(pbefore); @@ -132,7 +132,7 @@ class CopyIntrinInjector : public StmtMutator { pad_before.push_back(make_zero(t)); } if (max_value.defined()) { - Expr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1), + PrimExpr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1), make_zero(t))); svalue = svalue - pafter; pad_after.push_back(pafter); @@ -145,8 +145,8 @@ class CopyIntrinInjector : public StmtMutator { } CHECK_EQ(load_strides.size(), store_strides.size()); CHECK_EQ(load_strides.size(), loop_var_size + 1); - Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); - Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); + Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); + Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); if (loop_var_size == 0) { src_strides.push_back(make_const(DataType::Int(32), 1)); dst_strides.push_back(make_const(DataType::Int(32), 1)); diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 4bd431ebff9c..9ed56062dacc 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -99,11 +99,11 @@ class DoubleBufferInjector : public StmtExprMutator { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { it->second.stride = arith::ComputeReduce( - op->extents, Expr()) * op->dtype.lanes(); + op->extents, PrimExpr()) * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - Array new_extents{make_const(op->extents[0].dtype(), 2)}; - for (Expr e : op->extents) { + Array new_extents{make_const(op->extents[0].dtype(), 2)}; + for (PrimExpr e : op->extents) { new_extents.push_back(e); } CHECK(it->second.loop != nullptr); @@ -132,14 +132,14 @@ class DoubleBufferInjector : public StmtExprMutator { CHECK(split_loop_ % 2 == 0 || split_loop_ == 1) << "It is better to split with multiple of 2"; CHECK(is_zero(old_loop->min)); - Expr zero = old_loop->min; - Expr new_ext = + PrimExpr zero = old_loop->min; + PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); - Expr factor = make_const(new_ext.dtype(), split_loop_); - Expr outer_ext = new_ext / factor; - Expr tail_base = outer_ext * factor; + PrimExpr factor = make_const(new_ext.dtype(), split_loop_); + PrimExpr outer_ext = new_ext / factor; + PrimExpr tail_base = outer_ext * factor; Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype()); - std::unordered_map vmap; + std::unordered_map vmap; std::vector loop_seq; for (int32_t i = 0; i < split_loop_; ++i) { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); @@ -152,7 +152,7 @@ class DoubleBufferInjector : public StmtExprMutator { std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); for (int32_t i = 0; i < split_loop_; ++i) { - Expr idx = tail_base + make_const(tail_base.dtype(), i); + PrimExpr idx = tail_base + make_const(tail_base.dtype(), i); vmap[old_loop->loop_var.get()] = idx; tail_seq.emplace_back( IfThenElseNode::make(idx < old_loop->extent, @@ -187,8 +187,8 @@ class DoubleBufferInjector : public StmtExprMutator { } } - Expr VisitExpr_(const LoadNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { @@ -204,14 +204,14 @@ class DoubleBufferInjector : public StmtExprMutator { } } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { CHECK(!dbuffer_info_.count(op)); - return GetRef(op); + return GetRef(op); } private: Stmt MakeProducer(const AttrStmtNode* op) { - const VarExpr buffer = Downcast(op->node); + const Var buffer = Downcast(op->node); CHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; auto it = dbuffer_info_.find(buffer.get()); @@ -221,17 +221,17 @@ class DoubleBufferInjector : public StmtExprMutator { } StorageEntry& e = it->second; e.loop = loop_nest_.back(); - Expr zero = make_const(e.loop->loop_var.dtype(), 0); - Expr one = make_const(e.loop->loop_var.dtype(), 1); - Expr two = make_const(e.loop->loop_var.dtype(), 2); - Expr loop_shift = e.loop->loop_var + one; + PrimExpr zero = make_const(e.loop->loop_var.dtype(), 0); + PrimExpr one = make_const(e.loop->loop_var.dtype(), 1); + PrimExpr two = make_const(e.loop->loop_var.dtype(), 2); + PrimExpr loop_shift = e.loop->loop_var + one; e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.loop->loop_var.dtype()); e.switch_read_var = indexmod(e.loop->loop_var, two); in_double_buffer_scope_ = true; Stmt body = this->VisitStmt(op->body); in_double_buffer_scope_ = false; - std::unordered_map vmap; + std::unordered_map vmap; vmap[e.switch_write_var.get()] = zero; vmap[e.loop->loop_var.get()] = zero; loop_pre_[e.loop].emplace_back(Substitute(body, vmap)); @@ -245,13 +245,13 @@ class DoubleBufferInjector : public StmtExprMutator { // Storage entry for those who need double buffering. struct StorageEntry { // The size of the buffer - Expr stride; + PrimExpr stride; // The loop we need const ForNode* loop{nullptr}; // The switch variable. - VarExpr switch_write_var; + Var switch_write_var; // The switch variable for reading. - Expr switch_read_var; + PrimExpr switch_read_var; // The storage scope. std::string scope; }; diff --git a/src/pass/inject_prefetch.cc b/src/pass/inject_prefetch.cc index c58a91dbb0de..d7abed8c6062 100644 --- a/src/pass/inject_prefetch.cc +++ b/src/pass/inject_prefetch.cc @@ -79,7 +79,7 @@ class PrefetchInjector : public StmtMutator { } private: - std::vector loop_nest_; + std::vector loop_nest_; std::unordered_map vectorized_; static const Range none; }; diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 8eeee9d09742..83fc582cd4f1 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -36,7 +36,7 @@ class ExprTouched final : public StmtExprVisitor { bool check_write) : touched_var_(touched), check_write_(check_write) {} - void VisitExpr(const Expr& n) final { + void VisitExpr(const PrimExpr& n) final { // early stopping if (expr_touched_ && !check_write_) return; StmtExprVisitor::VisitExpr(n); @@ -205,20 +205,20 @@ class VTInjector : public StmtExprMutator { return stmt; } // Variable - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { CHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread"; if (touched_var_.count(op)) { visit_touched_var_ = true; } - return GetRef(op); + return GetRef(op); } - Expr RewriteIndex(Expr index, Expr alloc_extent) const { + PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const { return index + var_ * alloc_extent; } // Load - Expr VisitExpr_(const LoadNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (touched_var_.count(op->buffer_var.get())) { visit_touched_var_ = true; @@ -233,7 +233,7 @@ class VTInjector : public StmtExprMutator { } } // Expression. - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -241,9 +241,9 @@ class VTInjector : public StmtExprMutator { auto it = alloc_remap_.find(buffer); if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op); visit_touched_var_ = true; - Expr offset = this->VisitExpr(op->args[2]); - Expr extent = this->VisitExpr(op->args[3]); - Expr stride = + PrimExpr offset = this->VisitExpr(op->args[2]); + PrimExpr extent = this->VisitExpr(op->args[3]); + PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; return CallNode::make( @@ -251,7 +251,7 @@ class VTInjector : public StmtExprMutator { {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { - return allow_share_ ? GetRef(op) : var_; + return allow_share_ ? GetRef(op) : var_; } else { return StmtExprMutator::VisitExpr_(op); } @@ -280,7 +280,7 @@ class VTInjector : public StmtExprMutator { } // Attribute Stmt VisitStmt_(const AttrStmtNode* op) final { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } else if (!allow_share_ && !vt_loop_injected_ && @@ -299,7 +299,7 @@ class VTInjector : public StmtExprMutator { } // LetStmt Stmt VisitStmt_(const LetStmtNode* op) final { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } @@ -315,7 +315,7 @@ class VTInjector : public StmtExprMutator { // For Stmt VisitStmt_(const ForNode* op) final { CHECK(is_zero(op->min)); - Expr extent = this->VisitExpr(op->extent); + PrimExpr extent = this->VisitExpr(op->extent); if (visit_touched_var_ && !vt_loop_injected_) { Stmt stmt = InjectVTLoop(GetRef(op), true); ++max_loop_depth_; @@ -334,7 +334,7 @@ class VTInjector : public StmtExprMutator { } // IfThenElse Stmt VisitStmt_(const IfThenElseNode* op) final { - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } @@ -374,15 +374,15 @@ class VTInjector : public StmtExprMutator { if (op->new_expr.defined() && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } bool changed = false; - Array extents; + Array extents; for (size_t i = 0; i < op->extents.size(); i++) { - Expr new_ext = this->VisitExpr(op->extents[i]); + PrimExpr new_ext = this->VisitExpr(op->extents[i]); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } @@ -395,11 +395,11 @@ class VTInjector : public StmtExprMutator { // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - Expr stride = arith::ComputeReduce( - op->extents, Expr()) * op->dtype.lanes(); - Array other; + PrimExpr stride = arith::ComputeReduce( + op->extents, PrimExpr()) * op->dtype.lanes(); + Array other; other.push_back(make_const(op->extents[0].dtype(), num_threads_)); - for (Expr e : extents) { + for (PrimExpr e : extents) { other.push_back(e); } extents = other; @@ -448,7 +448,7 @@ class VTInjector : public StmtExprMutator { } else { // insert a for loop Var idx(var_->name_hint + ".s", var_->dtype); - Map values{{var_, idx}}; + Map values{{var_, idx}}; stmt = Substitute(stmt, values); return ForNode::make(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), @@ -474,7 +474,7 @@ class VTInjector : public StmtExprMutator { // Whether allow shareding. bool allow_share_; // The allocations that get touched -> extent - std::unordered_map alloc_remap_; + std::unordered_map alloc_remap_; }; diff --git a/src/pass/inline.cc b/src/pass/inline.cc index 4a087dd8ebda..fad3f1766872 100644 --- a/src/pass/inline.cc +++ b/src/pass/inline.cc @@ -32,11 +32,11 @@ namespace ir { // ConvertSSA need to be applied after this pass class IRInline final : public StmtExprMutator { public: - IRInline(FunctionRef f, Array args, Expr body) + IRInline(FunctionRef f, Array args, PrimExpr body) : f_(f), args_(args), body_(body) {} - Expr VisitExpr_(const CallNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->func == f_) { @@ -53,7 +53,7 @@ class IRInline final : public StmtExprMutator { expr = LetNode::make(args_[i], op->args[i], expr); } } else { - Map vmap; + Map vmap; for (size_t i = 0; i < args_.size(); ++i) { vmap.Set(args_[i], op->args[i]); } @@ -69,13 +69,13 @@ class IRInline final : public StmtExprMutator { private: FunctionRef f_; Array args_; - Expr body_; + PrimExpr body_; }; Stmt Inline(Stmt stmt, FunctionRef f, Array args, - Expr body) { + PrimExpr body) { CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation"; Stmt ret = IRInline(f, args, body)(std::move(stmt)); diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index a1218f26868a..6eacb145b29b 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -26,14 +26,14 @@ namespace tvm { namespace ir { -using ExprComparator = ExprFunctor; +using ExprComparator = ExprFunctor; using StmtComparator = StmtFunctor; #define DEFINE_BIOP_EXPR_CMP_(OP) \ - void VisitExpr_(const OP* op, const Expr& other) final { \ + void VisitExpr_(const OP* op, const PrimExpr& other) final { \ const OP* rhs = other.as(); \ - if (CompareExpr(op->a, rhs->a) != 0) return; \ - if (CompareExpr(op->b, rhs->b) != 0) return; \ + if (CompareExpr(op->a, rhs->a) != 0) return; \ + if (CompareExpr(op->b, rhs->b) != 0) return; \ } // Deep comparison to check if two IR graph are equivalent @@ -47,19 +47,19 @@ class IRDeepCompare : return order_ == 0; } - bool Equal(const Expr& lhs, const Expr& rhs) { + bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) { tie_def_ = true; VisitExpr(lhs, rhs); return order_ == 0; } - int Compare(const Expr& lhs, const Expr& rhs) { + int Compare(const PrimExpr& lhs, const PrimExpr& rhs) { tie_def_ = false; VisitExpr(lhs, rhs); return order_; } - void VisitExpr(const Expr& n, const Expr& other) override { + void VisitExpr(const PrimExpr& n, const PrimExpr& other) override { if (order_ != 0) return; if (n.same_as(other)) return; if (CompareValue(n->type_index(), other->type_index()) != 0) return; @@ -193,7 +193,7 @@ class IRDeepCompare : } // Exprs - void VisitExpr_(const VarNode* op, const Expr& other) final { + void VisitExpr_(const VarNode* op, const PrimExpr& other) final { const VarNode* rhs = other.as(); auto it = vmap_.find(op); if (it != vmap_.end()) op = it->second; @@ -203,14 +203,14 @@ class IRDeepCompare : order_ = +1; } } - void VisitExpr_(const LoadNode* op, const Expr& other) final { + void VisitExpr_(const LoadNode* op, const PrimExpr& other) final { const LoadNode* rhs = other.as(); if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; if (CompareExpr(op->index, rhs->index) != 0) return; if (CompareExpr(op->predicate, rhs->predicate) != 0) return; } - void VisitExpr_(const LetNode* op, const Expr& other) final { + void VisitExpr_(const LetNode* op, const PrimExpr& other) final { const LetNode* rhs = other.as(); if (tie_def_) { vmap_[op->var.get()] = rhs->var.get(); @@ -221,7 +221,7 @@ class IRDeepCompare : if (CompareExpr(op->body, rhs->body) != 0) return; } - void VisitExpr_(const CallNode* op, const Expr& other) final { + void VisitExpr_(const CallNode* op, const PrimExpr& other) final { const CallNode* rhs = other.as(); if (CompareString(op->name, rhs->name)) return; if (CompareArray(op->args, rhs->args)) return; @@ -230,7 +230,7 @@ class IRDeepCompare : if (CompareValue(op->value_index, rhs->value_index) != 0) return; } - void VisitExpr_(const ReduceNode *op, const Expr& other) final { + void VisitExpr_(const ReduceNode *op, const PrimExpr& other) final { const ReduceNode* rhs = other.as(); if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return; if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return; @@ -248,51 +248,51 @@ class IRDeepCompare : if (CompareArray(op->source, rhs->source) != 0) return; } - void VisitExpr_(const IntImmNode *op, const Expr& other) final { + void VisitExpr_(const IntImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } - void VisitExpr_(const UIntImmNode *op, const Expr& other) final { + void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } - void VisitExpr_(const FloatImmNode *op, const Expr& other) final { + void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } - void VisitExpr_(const StringImmNode *op, const Expr& other) final { + void VisitExpr_(const StringImmNode *op, const PrimExpr& other) final { CompareString(op->value, other.as()->value); } - void VisitExpr_(const CastNode *op, const Expr& other) final { + void VisitExpr_(const CastNode *op, const PrimExpr& other) final { CompareExpr(op->value, other.as()->value); } - void VisitExpr_(const NotNode *op, const Expr& other) final { + void VisitExpr_(const NotNode *op, const PrimExpr& other) final { CompareExpr(op->a, other.as()->a); } - void VisitExpr_(const SelectNode *op, const Expr& other) final { + void VisitExpr_(const SelectNode *op, const PrimExpr& other) final { const SelectNode* rhs = other.as(); if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareExpr(op->true_value, rhs->true_value) != 0) return; if (CompareExpr(op->false_value, rhs->false_value) != 0) return; } - void VisitExpr_(const RampNode *op, const Expr& other) final { + void VisitExpr_(const RampNode *op, const PrimExpr& other) final { const RampNode* rhs = other.as(); if (CompareExpr(op->base, rhs->base) != 0) return; if (CompareExpr(op->stride, rhs->stride) != 0) return; if (CompareValue(op->lanes, rhs->lanes) != 0) return; } - void VisitExpr_(const BroadcastNode *op, const Expr& other) final { + void VisitExpr_(const BroadcastNode *op, const PrimExpr& other) final { const BroadcastNode* rhs = other.as(); if (CompareExpr(op->value, rhs->value) != 0) return; if (CompareValue(op->lanes, rhs->lanes) != 0) return; } - void VisitExpr_(const ShuffleNode *op, const Expr& other) final { + void VisitExpr_(const ShuffleNode *op, const PrimExpr& other) final { const ShuffleNode* rhs = other.as(); if (CompareArray(op->vectors, rhs->vectors) != 0) return; if (CompareArray(op->indices, rhs->indices) != 0) return; @@ -317,7 +317,7 @@ class IRDeepCompare : DEFINE_BIOP_EXPR_CMP_(OrNode) private: - int CompareExpr(const Expr& lhs, const Expr& rhs) { + int CompareExpr(const PrimExpr& lhs, const PrimExpr& rhs) { if (order_ != 0) return order_; if (!lhs.defined() && rhs.defined()) { order_ = -1; return order_; @@ -341,7 +341,7 @@ class IRDeepCompare : return order_; } - int CompareArray(const Array& lhs, const Array& rhs) { + int CompareArray(const Array& lhs, const Array& rhs) { if (order_ != 0) return order_; if (CompareValue(lhs.size(), rhs.size()) != 0) return order_; for (size_t i = 0; i < lhs.size(); ++i) { @@ -438,7 +438,7 @@ bool Equal(const Stmt& lhs, const Stmt& rhs) { return IRDeepCompare().Equal(lhs, rhs); } -bool Equal(const Expr& lhs, const Expr& rhs) { +bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) { // quick pass for constant expressions. if (const int64_t *a = as_const_int(lhs)) { if (const int64_t *b = as_const_int(rhs)) { @@ -455,7 +455,7 @@ bool Equal(const Expr& lhs, const Expr& rhs) { return IRDeepCompare().Equal(lhs, rhs); } -int Compare(const Expr& lhs, const Expr& rhs) { +int Compare(const PrimExpr& lhs, const PrimExpr& rhs) { return IRDeepCompare().Compare(lhs, rhs); } diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index b7a73623ab76..67acec674630 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -31,7 +31,7 @@ class IRApplyVisit : public: explicit IRApplyVisit(std::function f) : f_(f) {} - void VisitExpr(const Expr& node) final { + void VisitExpr(const PrimExpr& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); ExprVisitor::VisitExpr(node); @@ -57,7 +57,7 @@ void PostOrderVisit(const ObjectRef& node, visitor(Downcast(node)); } else { IRApplyVisit visitor(fvisit); - visitor(Downcast(node)); + visitor(Downcast(node)); } } @@ -77,8 +77,8 @@ class IRTransformer final : return this->BaseVisitStmt(s); }); } - Expr VisitExpr(const Expr& expr) final { - return MutateInternal(expr, [this](const Expr& e) { + PrimExpr VisitExpr(const PrimExpr& expr) final { + return MutateInternal(expr, [this](const PrimExpr& e) { return this->BaseVisitExpr(e); }); } @@ -89,7 +89,7 @@ class IRTransformer final : Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); } - Expr BaseVisitExpr(const Expr& e) { + PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); } @@ -120,9 +120,9 @@ class IRTransformer final : Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, - const Array& only_enable) { + const Array& only_enable) { std::unordered_set only_type_index; - for (Expr s : only_enable) { + for (PrimExpr s : only_enable) { only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); } IRTransformer transform(f_preorder, f_postorder, only_type_index); @@ -154,7 +154,7 @@ void StmtVisitor::VisitStmt_(const ForNode* op) { } void StmtVisitor::VisitStmt_(const AllocateNode* op) { - VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); }); + VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitStmt(op->body); this->VisitExpr(op->condition); if (op->new_expr.defined()) { @@ -189,7 +189,7 @@ void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) { } void StmtVisitor::VisitStmt_(const ProvideNode* op) { - VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); }); + VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->value); } @@ -232,7 +232,7 @@ void ExprVisitor::VisitExpr_(const LetNode* op) { } void ExprVisitor::VisitExpr_(const CallNode* op) { - VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); }); + VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); } #define DEFINE_BINOP_VISIT_(OP) \ @@ -269,7 +269,7 @@ void ExprVisitor::VisitExpr_(const ReduceNode* op) { this->VisitExpr(r->dom->min); this->VisitExpr(r->dom->extent); }); - VisitArray(op->source, [this](const Expr& e) { this->VisitExpr(e); }); + VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->condition); } @@ -293,8 +293,8 @@ void ExprVisitor::VisitExpr_(const RampNode* op) { } void ExprVisitor::VisitExpr_(const ShuffleNode* op) { - VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); }); - VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); }); + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); + VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); } void ExprVisitor::VisitExpr_(const BroadcastNode* op) { @@ -320,8 +320,8 @@ inline Array MutateArray(const Array& arr, class StmtMutator::Internal { public: - static Array Mutate(StmtMutator* self, const Array& arr) { - auto fmutate = [self](const Expr& e) { return self->VisitExpr(e); }; + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); }; return MutateArray(arr, fmutate, self->allow_copy_on_write_); } @@ -332,8 +332,8 @@ class StmtMutator::Internal { static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const Range& r) { - Expr min = self->VisitExpr(r->min); - Expr extent = self->VisitExpr(r->extent); + PrimExpr min = self->VisitExpr(r->min); + PrimExpr extent = self->VisitExpr(r->extent); if (min.same_as(r->min) && extent.same_as(r->extent)) { return r; } else { @@ -345,7 +345,7 @@ class StmtMutator::Internal { }; Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { @@ -359,7 +359,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { } Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { @@ -373,8 +373,8 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { } Stmt StmtMutator::VisitStmt_(const ForNode* op) { - Expr min = this->VisitExpr(op->min); - Expr extent = this->VisitExpr(op->extent); + PrimExpr min = this->VisitExpr(op->min); + PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); if (min.same_as(op->min) && extent.same_as(op->extent) && @@ -390,10 +390,10 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { } Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { - Array extents = Internal::Mutate(this, op->extents); + Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); - Expr condition = this->VisitExpr(op->condition); - Expr new_expr; + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr new_expr; if (op->new_expr.defined()) { new_expr = this->VisitExpr(op->new_expr); } @@ -413,7 +413,7 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { } Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); Stmt else_case; if (op->else_case.defined()) { @@ -433,9 +433,9 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { } Stmt StmtMutator::VisitStmt_(const StoreNode* op) { - Expr value = this->VisitExpr(op->value); - Expr index = this->VisitExpr(op->index); - Expr predicate = this->VisitExpr(op->predicate); + PrimExpr value = this->VisitExpr(op->value); + PrimExpr index = this->VisitExpr(op->index); + PrimExpr predicate = this->VisitExpr(op->predicate); if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) { @@ -450,8 +450,8 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { } Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { - Array args = Internal::Mutate(this, op->args); - Expr value = this->VisitExpr(op->value); + Array args = Internal::Mutate(this, op->args); + PrimExpr value = this->VisitExpr(op->value); if (args.same_as(op->args) && value.same_as(op->value)) { return GetRef(op); @@ -466,7 +466,7 @@ Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Stmt StmtMutator::VisitStmt_(const RealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { @@ -549,8 +549,8 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, } Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { - Expr condition = this->VisitExpr(op->condition); - Expr message = this->VisitExpr(op->message); + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr message = this->VisitExpr(op->message); Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && @@ -578,7 +578,7 @@ Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) { } Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { @@ -593,37 +593,37 @@ Stmt StmtMutator::VisitStmt_(const FreeNode* op) { } -Expr ExprMutator::VisitExpr_(const VarNode* op) { - return GetRef(op); +PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { + return GetRef(op); } -Expr ExprMutator::VisitExpr_(const LoadNode* op) { - Expr index = this->VisitExpr(op->index); - Expr predicate = this->VisitExpr(op->predicate); +PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { + PrimExpr index = this->VisitExpr(op->index); + PrimExpr predicate = this->VisitExpr(op->predicate); if (index.same_as(op->index) && predicate.same_as(op->predicate)) { - return GetRef(op); + return GetRef(op); } else { return LoadNode::make(op->dtype, op->buffer_var, index, predicate); } } -Expr ExprMutator::VisitExpr_(const LetNode* op) { - Expr value = this->VisitExpr(op->value); - Expr body = this->VisitExpr(op->body); +PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { + PrimExpr value = this->VisitExpr(op->value); + PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return GetRef(op); } else { return LetNode::make(op->var, value, body); } } -Expr ExprMutator::VisitExpr_(const CallNode* op) { - auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); }; - Array args = MutateArray(op->args, fmutate); +PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array args = MutateArray(op->args, fmutate); if (args.same_as(op->args)) { - return GetRef(op); + return GetRef(op); } else { return CallNode::make(op->dtype, op->name, @@ -635,8 +635,8 @@ Expr ExprMutator::VisitExpr_(const CallNode* op) { } #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - Expr ExprMutator::VisitExpr_(const OP *op) { \ - return GetRef(op); \ + PrimExpr ExprMutator::VisitExpr_(const OP *op) { \ + return GetRef(op); \ } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) @@ -645,12 +645,12 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) #define DEFINE_BIOP_EXPR_MUTATE_(OP) \ - Expr ExprMutator::VisitExpr_(const OP* op) { \ - Expr a = this->VisitExpr(op->a); \ - Expr b = this->VisitExpr(op->b); \ + PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ if (a.same_as(op->a) && \ b.same_as(op->b)) { \ - return GetRef(op); \ + return GetRef(op); \ } else { \ return OP::make(a, b); \ } \ @@ -674,11 +674,11 @@ DEFINE_BIOP_EXPR_MUTATE_(GENode); DEFINE_BIOP_EXPR_MUTATE_(AndNode); DEFINE_BIOP_EXPR_MUTATE_(OrNode); -Expr ExprMutator::VisitExpr_(const ReduceNode* op) { +PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { auto fitervar = [this](const IterVar& v) { Range r = v->dom; - Expr min = this->VisitExpr(r->min); - Expr extent = this->VisitExpr(r->extent); + PrimExpr min = this->VisitExpr(r->min); + PrimExpr extent = this->VisitExpr(r->extent); if (min.same_as(r->min) && extent.same_as(r->extent)) { return v; @@ -690,77 +690,77 @@ Expr ExprMutator::VisitExpr_(const ReduceNode* op) { }; Array axis = MutateArray(op->axis, fitervar); - auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); }; - Array source = MutateArray(op->source, fexpr); + auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array source = MutateArray(op->source, fexpr); - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) { - return GetRef(op); + return GetRef(op); } else { return ReduceNode::make( op->combiner, source, axis, condition, op->value_index); } } -Expr ExprMutator::VisitExpr_(const CastNode* op) { - Expr value = this->VisitExpr(op->value); +PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { + PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return GetRef(op); } else { return CastNode::make(op->dtype, value); } } -Expr ExprMutator::VisitExpr_(const NotNode* op) { - Expr a = this->VisitExpr(op->a); +PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { + PrimExpr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { - return GetRef(op); + return GetRef(op); } else { return NotNode::make(a); } } -Expr ExprMutator::VisitExpr_(const SelectNode* op) { - Expr condition = this->VisitExpr(op->condition); - Expr true_value = this->VisitExpr(op->true_value); - Expr false_value = this->VisitExpr(op->false_value); +PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr true_value = this->VisitExpr(op->true_value); + PrimExpr false_value = this->VisitExpr(op->false_value); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return GetRef(op); } else { return SelectNode::make(condition, true_value, false_value); } } -Expr ExprMutator::VisitExpr_(const RampNode* op) { - Expr base = this->VisitExpr(op->base); - Expr stride = this->VisitExpr(op->stride); +PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { + PrimExpr base = this->VisitExpr(op->base); + PrimExpr stride = this->VisitExpr(op->stride); if (base.same_as(op->base) && stride.same_as(op->stride)) { - return GetRef(op); + return GetRef(op); } else { return RampNode::make(base, stride, op->lanes); } } -Expr ExprMutator::VisitExpr_(const BroadcastNode* op) { - Expr value = this->VisitExpr(op->value); +PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { + PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return GetRef(op); } else { return BroadcastNode::make(value, op->lanes); } } -Expr ExprMutator::VisitExpr_(const ShuffleNode* op) { - auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); }; +PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { + auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; auto vectors = MutateArray(op->vectors, fexpr); if (vectors.same_as(op->vectors)) { - return GetRef(op); + return GetRef(op); } else { return ShuffleNode::make(vectors, op->indices); } diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 74d57813896f..f1a01953d5bd 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -81,10 +81,10 @@ inline Array UpdateArray(Array arr, F fupdate) { * \param kind The data kind. * \return the get expression. */ -inline Expr TVMStructGet( +inline PrimExpr TVMStructGet( DataType dtype, Var handle, int index, intrinsic::TVMStructFieldKind kind) { - Array args ={ + Array args ={ handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind))}; @@ -97,7 +97,7 @@ inline Expr TVMStructGet( * \param dtype The data type. * \param offset the offset index. */ -inline Expr AddressOffset(Var handle, DataType dtype, int offset) { +inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { return CallNode::make( DataType::Handle(), intrinsic::tvm_address_of, {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), @@ -111,7 +111,7 @@ inline Expr AddressOffset(Var handle, DataType dtype, int offset) { * \param dtype The data type. * \param offset the offset index. */ -inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) { +inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { if (dtype.lanes() != 1) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); @@ -133,8 +133,8 @@ inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) { */ inline Stmt TVMStructSet( Var handle, int index, - intrinsic::TVMStructFieldKind kind, Expr value) { - Array args ={ + intrinsic::TVMStructFieldKind kind, PrimExpr value) { + Array args ={ handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), @@ -182,7 +182,7 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { * \param base The result base. * \return true if pattern match success and store the base to base. */ -inline bool GetRamp1Base(Expr index, int lanes, Expr *base) { +inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) { const RampNode* r = index.as(); if (!r) return false; if (!is_one(r->stride)) return false; diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 9a97031f330f..7b760fa4a672 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -55,7 +55,7 @@ class AttrScopeLifter : public StmtMutator { attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = ObjectRef(); - attr_value_ = Expr(); + attr_value_ = PrimExpr(); return AllocateNode::make( op->buffer_var, op->dtype, op->extents, op->condition, body, @@ -78,11 +78,11 @@ class AttrScopeLifter : public StmtMutator { Stmt VisitStmt_(const SeqStmtNode* op) final { // remember the decorations. std::vector attr_node; - std::vector attr_value; + std::vector attr_value; auto fmutate = [&](const Stmt& s) { attr_node_ = ObjectRef(); - attr_value_ = Expr(); + attr_value_ = PrimExpr(); Stmt ret = this->VisitStmt(s); attr_node.push_back(attr_node_); attr_value.push_back(attr_value_); @@ -123,7 +123,7 @@ class AttrScopeLifter : public StmtMutator { begin = end; } attr_node_ = ObjectRef(); - attr_value_ = Expr(); + attr_value_ = PrimExpr(); return SeqStmt::Flatten(reorg); } @@ -133,7 +133,7 @@ class AttrScopeLifter : public StmtMutator { } Stmt then_case = this->VisitStmt(op->then_case); ObjectRef first_node; - Expr first_value; + PrimExpr first_value; std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); Stmt else_case = this->VisitStmt(op->else_case); @@ -159,7 +159,7 @@ class AttrScopeLifter : public StmtMutator { attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = ObjectRef(); - attr_value_ = Expr(); + attr_value_ = PrimExpr(); } if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { @@ -172,7 +172,7 @@ class AttrScopeLifter : public StmtMutator { private: // value comparison that also compares content of int constant - static bool ValueSame(const Expr& a, const Expr& b) { + static bool ValueSame(const PrimExpr& a, const PrimExpr& b) { if (a.same_as(b)) return true; if (!a.defined() || !b.defined()) return false; if (a->type_index() != b->type_index()) return false; @@ -188,7 +188,7 @@ class AttrScopeLifter : public StmtMutator { std::string attr_key_; ObjectRef attr_node_; - Expr attr_value_; + PrimExpr attr_value_; }; Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 7d9ce62c5b90..adcd5ecd0287 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -49,7 +49,7 @@ struct PartitionKeyHash { // condition cond is proven to have value cond_value (true or false) in interval. using Partition = std::unordered_map; -bool ExprUseVars(Expr expr, const std::unordered_set& vars) { +bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) { bool success = false; PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { if (const VarNode* v = node.as()) { @@ -152,7 +152,7 @@ class CandidateSelector final : public StmtExprVisitor { // (currently, "likely" conditions) has fixed true or false value class PartitionFinder : public StmtExprVisitor { public: - explicit PartitionFinder(VarExpr current_var, + explicit PartitionFinder(Var current_var, const std::unordered_map& hint_map, const std::unordered_map& relax_map) : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { @@ -194,7 +194,7 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(CallNode::likely)) { - Expr cond = op->args[0]; + PrimExpr cond = op->args[0]; if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is @@ -206,7 +206,7 @@ class PartitionFinder : public StmtExprVisitor { // cond is true within interval partitions[{cond.get(), true}] = interval; } - Expr inverse_cond = InverseCond(cond); + PrimExpr inverse_cond = InverseCond(cond); if (inverse_cond.defined()) { IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); @@ -224,8 +224,8 @@ class PartitionFinder : public StmtExprVisitor { Partition partitions; private: - Expr InverseCond(const Expr& cond) { - Expr inverse_cond; + PrimExpr InverseCond(const PrimExpr& cond) { + PrimExpr inverse_cond; if (const LTNode* op = cond.as()) { // a < b -> a >= b inverse_cond = GENode::make(op->a, op->b); @@ -248,7 +248,7 @@ class PartitionFinder : public StmtExprVisitor { return inverse_cond; } - VarExpr current_var_; + Var current_var_; std::unordered_set out_vars_; std::unordered_map hint_map_; std::unordered_map relax_map_; @@ -260,7 +260,7 @@ class ConditionEliminator : public StmtExprMutator { explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) : ps_(ps), cond_value_(cond_value) {} - Expr VisitExpr(const Expr& e) final { + PrimExpr VisitExpr(const PrimExpr& e) final { if (ps_.find(e.get()) != ps_.end()) { return VisitExpr(cond_value_ ? const_true() : const_false()); } @@ -277,7 +277,7 @@ class ConditionEliminator : public StmtExprMutator { class ThreadPartitionInserter : public StmtMutator { public: explicit ThreadPartitionInserter(const std::unordered_set& ps, - Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} + PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -287,7 +287,7 @@ class ThreadPartitionInserter : public StmtMutator { if (innermost_thread_scope_) { Stmt simplified_body = ConditionEliminator(ps_)(op->body); Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body); - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); stmt = AttrStmtNode::make(op->node, op->attr_key, value, body); } innermost_thread_scope_ = false; @@ -299,7 +299,7 @@ class ThreadPartitionInserter : public StmtMutator { private: const std::unordered_set& ps_; - Expr cond_; + PrimExpr cond_; bool innermost_thread_scope_; }; @@ -363,15 +363,15 @@ class LoopPartitioner : public StmtMutator { } private: - Stmt TryPartition(const Object* op, const Stmt& stmt, VarExpr var, - Expr min, Expr max, Stmt body, bool partition_thread_scope); + Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, + PrimExpr min, PrimExpr max, Stmt body, bool partition_thread_scope); std::pair> GetIntervalAndCondset(const Partition &partitions, const arith::IntervalSet &for_interval, bool cond_value); - inline Stmt MakeFor(const Object* op, Expr extent, Stmt body); + inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body); /* Candidate IRs that may be partitioned potentially */ std::unordered_map hint_map_; @@ -452,9 +452,9 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, */ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, - VarExpr var, - Expr min, - Expr max, + Var var, + PrimExpr min, + PrimExpr max, Stmt body, bool partition_thread_scope) { using namespace arith; @@ -496,13 +496,13 @@ Stmt LoopPartitioner::TryPartition(const Object* node, // Calculating pre-subrange and generating code for it. // pre-subrange = [min, body_begin) - Expr body_begin; + PrimExpr body_begin; Stmt pre_stmt; bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); if (!analyzer_.CanProve(body_begin == min)) { - Expr cond = (body_begin - min >= 0); + PrimExpr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; @@ -521,14 +521,14 @@ Stmt LoopPartitioner::TryPartition(const Object* node, // Calculating post-subrange and generating code for it. // post-subrange = [post_doubt_begin, max+1) - Expr post_doubt_begin; + PrimExpr post_doubt_begin; Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative - Expr cond = (max - post_doubt_begin + 1 >= 0); + PrimExpr cond = (max - post_doubt_begin + 1 >= 0); if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; @@ -571,7 +571,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, } s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt); } else { - Expr cond = const_true(); + PrimExpr cond = const_true(); if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); s = ThreadPartitionInserter(cond_set, cond)(stmt); @@ -580,7 +580,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, return s; } -inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) { +inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt body) { const ForNode *for_node = static_cast(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { @@ -594,7 +594,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) class RemoveLikelyTags : public StmtExprMutator { public: - Expr VisitExpr_(const CallNode *op) final { + PrimExpr VisitExpr_(const CallNode *op) final { if (op->is_intrinsic(CallNode::likely)) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index ded17d4ca9b2..98eaf8c58530 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -41,13 +41,13 @@ class CustomDatatypesLowerer : public StmtExprMutator { public: explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} - inline Expr VisitExpr_(const CastNode* op) final { + inline PrimExpr VisitExpr_(const CastNode* op) final { auto type_code = op->dtype.code(); auto src_type_code = op->value.dtype().code(); // If either datatype is a registered custom datatype, we must lower. bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || datatype::Registry::Global()->GetTypeRegistered(src_type_code); - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (toBeLowered) { auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); @@ -59,9 +59,9 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } - inline Expr VisitExpr_(const FloatImmNode* imm) final { + inline PrimExpr VisitExpr_(const FloatImmNode* imm) final { auto type_code = imm->dtype.code(); - auto e = GetRef(imm); + auto e = GetRef(imm); if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); CHECK(lower) << "FloatImm lowering function for target " << target_ << " type " @@ -85,9 +85,9 @@ class CustomDatatypesLowerer : public StmtExprMutator { return stmt; } - inline Expr VisitExpr_(const LoadNode* load) final { + inline PrimExpr VisitExpr_(const LoadNode* load) final { bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); - Expr expr = StmtExprMutator::VisitExpr_(load); + PrimExpr expr = StmtExprMutator::VisitExpr_(load); load = expr.as(); if (toBeLowered) { auto new_load_type = DataType::UInt(load->dtype.bits()); @@ -97,10 +97,10 @@ class CustomDatatypesLowerer : public StmtExprMutator { } #define DEFINE_MUTATE__(OP, NodeName) \ - inline Expr VisitExpr_(const NodeName* op) final { \ + inline PrimExpr VisitExpr_(const NodeName* op) final { \ auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - Expr expr = StmtExprMutator::VisitExpr_(op); \ + PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ op = expr.as(); \ if (toBeLowered) { \ auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index b46bf1808865..ed8be8bb39fc 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -53,16 +53,16 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - Expr r = ApplyPattern(op->name, GetRef(op)); + PrimExpr r = ApplyPattern(op->name, GetRef(op)); if (r.defined()) return r; } return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const AddNode* op) final { + PrimExpr VisitExpr_(const AddNode* op) final { if (const MulNode* mb = op->b.as()) { return MakeFMA(mb->a, mb->b, op->a, op); } else if (const MulNode* ma = op->a.as()) { @@ -73,9 +73,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // We use floordiv for integer analysis, // but will need to lower them to native truncdiv instructions - Expr VisitExpr_(const FloorDivNode* op) final { - auto e = GetRef(op); - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr VisitExpr_(const FloorDivNode* op) final { + auto e = GetRef(op); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; int shift; @@ -95,8 +95,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return truncdiv(op->a, op->b); } else { DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; - Expr rdiv = truncdiv(op->a, op->b); - Expr rmod = truncmod(op->a, op->b); + PrimExpr rdiv = truncdiv(op->a, op->b); + PrimExpr rmod = truncmod(op->a, op->b); // condition on b >= 0. // truncmod(a, b) < 0 will implies ceildiv, // So we need to correct these cases. @@ -112,16 +112,16 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) - Expr rdiv = truncdiv(op->a, op->b); - Expr rmod = truncmod(op->a, op->b); + PrimExpr rdiv = truncdiv(op->a, op->b); + PrimExpr rmod = truncmod(op->a, op->b); return ir::SelectNode::make( (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, rdiv - make_const(dtype, 1)); } } - Expr VisitExpr_(const FloorModNode* op) final { - Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + PrimExpr VisitExpr_(const FloorModNode* op) final { + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; // Lower floordiv to native truncdiv. @@ -146,7 +146,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // NOTE:condition on b >= 0. // mod(a, b) < 0 will imply we are doing ceildiv, // So we need to correct these cases. - Expr rmod = truncmod(op->a, op->b); + PrimExpr rmod = truncmod(op->a, op->b); if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) { // (rmod >> shift) & b // -> (rmod >= 0 ? 0: -1) & b @@ -159,7 +159,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } else { // uncommon case DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident"; - Expr rmod = truncmod(op->a, op->b); + PrimExpr rmod = truncmod(op->a, op->b); // b > 0 && rmod >= 0 -> rmod // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod @@ -170,11 +170,11 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } - Expr VisitExpr_(const MaxNode* op) final { + PrimExpr VisitExpr_(const MaxNode* op) final { using namespace arith; - PVar x, y; + PVar x, y; PVar c; - auto e = GetRef(op); + auto e = GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { @@ -183,20 +183,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const EQNode* op) final { + PrimExpr VisitExpr_(const EQNode* op) final { using namespace arith; - PVar x, y; - auto e = GetRef(op); + PVar x, y; + auto e = GetRef(op); if ((floormod(x, y) == 0).Match(e)) { return VisitExpr((truncmod(x, y) == 0).Eval()); } return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const NENode* op) final { + PrimExpr VisitExpr_(const NENode* op) final { using namespace arith; - PVar x, y; - auto e = GetRef(op); + PVar x, y; + auto e = GetRef(op); if ((floormod(x, y) != 0).Match(e)) { return VisitExpr((truncmod(x, y) != 0).Eval()); } @@ -204,7 +204,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } private: - Expr SwapBroadcastCast(const Expr& e) { + PrimExpr SwapBroadcastCast(const PrimExpr& e) { // Try to change broadcast(cast(x)) to cast(broadcast(x)) // For some targets, LLVM will generate more efficient FMA // instruction with the latter. For example, vmla vs. vmlal @@ -228,7 +228,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { }; if (should_swap()) { - Expr new_bcast = BroadcastNode::make(cast->value, bcast->lanes); + PrimExpr new_bcast = BroadcastNode::make(cast->value, bcast->lanes); return CastNode::make(bcast->dtype, new_bcast); } } @@ -236,26 +236,26 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return e; } - Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c, + PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) { // emit fma instruction: a * b + c - Expr lhs = SwapBroadcastCast(a); - Expr rhs = SwapBroadcastCast(b); + PrimExpr lhs = SwapBroadcastCast(a); + PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - Expr r = (*fma_)(CallNode::make( + PrimExpr r = (*fma_)(CallNode::make( op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { - Expr mul = this->VisitExpr(MulNode::make(lhs, rhs)); + PrimExpr mul = this->VisitExpr(MulNode::make(lhs, rhs)); return AddNode::make(mul, this->VisitExpr(c)); } } return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr ApplyPattern(const std::string& name, const Expr& e) { + PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) { for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i]; size_t psize = p.length(); @@ -265,14 +265,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { p.resize(psize); // if pattern exists. if (f != nullptr) { - Expr r = (*f)(e); + PrimExpr r = (*f)(e); CHECK(r.defined()) << "intrinsic rule must always return valid Expr"; if (!r.same_as(e)) { return this->VisitExpr(r); } } } - return Expr(); + return PrimExpr(); } // patterns diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index d38d1da76aab..a0b07c293b05 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -93,7 +93,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return stmt; } } - Expr VisitExpr_(const LoadNode* op) final { + PrimExpr VisitExpr_(const LoadNode* op) final { auto it = load_remap_.find(op->buffer_var.get()); if (it != load_remap_.end()) { CHECK(is_zero(op->index)); @@ -123,10 +123,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const UIntImmNode *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); - Array inits = combiner->identity_element; - std::vector values(size); + Array inits = combiner->identity_element; + std::vector values(size); std::vector types(size); - Expr cond = call->args[size+1]; + PrimExpr cond = call->args[size+1]; for (size_t idx = 0; idx < size; ++idx) { values[idx] = call->args[1+idx]; if (!is_one(cond)) { @@ -175,13 +175,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // the size of each index. int reduce_extent, group_extent; int threadx_extent = 1; - Expr reduce_index = FlattenThread(vred, &reduce_extent); - Expr group_index = FlattenThread(vpar, &group_extent); + PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); + PrimExpr group_index = FlattenThread(vpar, &group_extent); if (reduce_extent == 1) { // special case, no reduction is needed. std::vector stores(size); for (size_t i = 0; i < size; ++i) { - Expr pred = const_true(types[i].lanes()); + PrimExpr pred = const_true(types[i].lanes()); Var buffer_var = Downcast(call->args[2+size+i]); stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); } @@ -198,7 +198,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); - Expr pred = const_true(types[idx].lanes()); + PrimExpr pred = const_true(types[idx].lanes()); seq.emplace_back(StoreNode::make( shared_bufs[idx], values[idx], BufIndex(reduce_index, group_index, reduce_extent), pred)); @@ -209,13 +209,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { reduce_index, group_index, reduce_extent, threadx_extent)); for (size_t idx = 0; idx < size; ++idx) { CHECK(!load_remap_.count(buffers[idx])); - Expr pred = const_true(types[idx].lanes()); + PrimExpr pred = const_true(types[idx].lanes()); load_remap_[buffers[idx]] = LoadNode::make( types[idx], shared_bufs[idx], BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); alloc_remap_[buffers[idx]] = AllocateNode::make( shared_bufs[idx], types[idx], - {Expr(group_extent), Expr(reduce_extent)}, + {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, EvaluateNode::make(0)); } return SeqStmt::Flatten(seq); @@ -224,8 +224,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt MakeBufAllreduce(const CommReducerNode *combiner, const std::vector& types, const Array& shared_bufs, - Expr reduce_index, - Expr group_index, + PrimExpr reduce_index, + PrimExpr group_index, int reduce_extent, int threadx_extent) { // Get next power of two @@ -237,17 +237,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector seq; size_t size = shared_bufs.size(); - Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent); + PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent); // make reduction auto freduce = [&](int offset) { - Array a, b; + Array a, b; for (size_t i = 0; i < size; ++i) { b.push_back(LoadNode::make(types[i], shared_bufs[i], BufIndex(reduce_index + offset, group_index, reduce_extent), const_true())); a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true())); } - Array ret = (*combiner)(a, b); + Array ret = (*combiner)(a, b); std::vector stores(size); for (size_t i = 0; i < size; ++i) { stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true()); @@ -258,7 +258,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (reduce_align > reduce_extent) { // reduction with the boundary condition reduce_align = reduce_align >> 1; - Expr cond = reduce_index < (reduce_extent - reduce_align); + PrimExpr cond = reduce_index < (reduce_extent - reduce_align); seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } @@ -267,13 +267,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { while (reduce_align > threadx_extent || reduce_align > warp_size_) { reduce_align = reduce_align >> 1; - Expr cond = reduce_index < reduce_align; + PrimExpr cond = reduce_index < reduce_align; seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } // in warp synchronization. std::vector in_warp_seq; - Expr in_warp_cond = reduce_index < (reduce_align >> 1); + PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1); while (reduce_align > 1) { reduce_align = reduce_align >> 1; in_warp_seq.emplace_back(freduce(reduce_align)); @@ -288,7 +288,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Flatten the thread index. // Also return a warp number, - Expr FlattenThread(const std::vector& tvec, + PrimExpr FlattenThread(const std::vector& tvec, int* out_total_extent) { int& total_extent = *out_total_extent; total_extent = 1; @@ -296,7 +296,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return make_zero(DataType::Int(32)); } - Expr ret; + PrimExpr ret; for (const ThreadEntry& e : tvec) { if (ret.defined()) { ret = ret + e.iv->var * total_extent; @@ -316,7 +316,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { CallNode::Intrinsic)); } // The local buffer index. - static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) { + static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { if (!is_zero(group_index)) { return ir::Simplify(group_index * reduce_extent + reduce_index); } else { @@ -330,7 +330,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector thread_extents_; std::vector reduce_combiner_; // The load remap - std::unordered_map load_remap_; + std::unordered_map load_remap_; // Allocate remap std::unordered_map alloc_remap_; }; diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index a9b401f82b8c..8e7f1d86da74 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -31,13 +31,13 @@ namespace tvm { namespace ir { -inline Expr ConstInt32(size_t index) { +inline PrimExpr ConstInt32(size_t index) { CHECK_LE(index, std::numeric_limits::max()); return make_const(DataType::Int(32), static_cast(index)); } -inline Expr StackAlloca(std::string type, size_t num) { - Array args = {StringImmNode::make(type), ConstInt32(num)}; +inline PrimExpr StackAlloca(std::string type, size_t num) { + Array args = {StringImmNode::make(type), ConstInt32(num)}; return CallNode::make( DataType::Handle(), intrinsic::tvm_stack_alloca, @@ -103,7 +103,7 @@ class BuiltinLower : public StmtExprMutator { } } } - Expr total_bytes = make_const(op->extents[0].dtype(), nbytes); + PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); for (size_t i = 0; i < op->extents.size(); ++i) { total_bytes = total_bytes * op->extents[i]; } @@ -134,7 +134,7 @@ class BuiltinLower : public StmtExprMutator { CallNode::Extern), body); - Expr free_op = CallNode::make(DataType::Int(32), + PrimExpr free_op = CallNode::make(DataType::Int(32), "TVMBackendFreeWorkspace", {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), @@ -163,7 +163,7 @@ class BuiltinLower : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_call_packed)) { return MakeCallPacked(op); } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) { @@ -179,10 +179,10 @@ class BuiltinLower : public StmtExprMutator { } } // call shape - Expr MakeShape(const CallNode* op) { + PrimExpr MakeShape(const CallNode* op) { size_t stack_begin = run_shape_stack_; run_shape_stack_ += op->args.size(); - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { prep_seq_.emplace_back( @@ -192,16 +192,16 @@ class BuiltinLower : public StmtExprMutator { return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); } // make array - Expr MakeArray(const CallNode* op) { + PrimExpr MakeArray(const CallNode* op) { size_t idx = run_array_stack_; run_array_stack_ += 1; - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); - Expr strides = op->args[2]; + PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } @@ -221,7 +221,7 @@ class BuiltinLower : public StmtExprMutator { make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); - Expr byte_offset = op->args[5]; + PrimExpr byte_offset = op->args[5]; if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } @@ -239,17 +239,17 @@ class BuiltinLower : public StmtExprMutator { return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packed. - Expr MakeCallPacked(const CallNode* op) { + PrimExpr MakeCallPacked(const CallNode* op) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; run_arg_stack_ += op->args.size(); // Specially handle the buffer packed intrinsic - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); for (size_t i = 1; i < op->args.size(); ++i) { - Expr stack_index = ConstInt32(arg_stack_begin + i - 1); - Expr arg = op->args[i]; + PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); + PrimExpr arg = op->args[i]; DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { @@ -275,7 +275,7 @@ class BuiltinLower : public StmtExprMutator { run_shape_stack_ = restore_shape_stack; run_array_stack_ = restore_array_stack; run_arg_stack_ = arg_stack_begin; - Array packed_args = { + Array packed_args = { op->args[0], stack_value_, stack_tcode_, @@ -287,18 +287,18 @@ class BuiltinLower : public StmtExprMutator { packed_args, CallNode::Intrinsic); } - Expr MakeCallTracePacked(const CallNode *op) { + PrimExpr MakeCallTracePacked(const CallNode *op) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; run_arg_stack_ += op->args.size(); size_t args_size = op->args.size(); CHECK_GT(args_size, 0); - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); for (size_t i = 1; i < op->args.size(); ++i) { - Expr stack_index = ConstInt32(arg_stack_begin + i - 1); - Expr arg = op->args[i]; + PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); + PrimExpr arg = op->args[i]; DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { @@ -323,7 +323,7 @@ class BuiltinLower : public StmtExprMutator { // Update the top of the stack, so we can use more than one // packed function's arguments with the one stack. run_arg_stack_ = arg_stack_begin + args_size - 1; - Array packed_args = { + Array packed_args = { op->args[0], stack_value_, stack_tcode_, @@ -338,7 +338,7 @@ class BuiltinLower : public StmtExprMutator { } private: - bool IsArrayHandle(const Expr& arg) { + bool IsArrayHandle(const PrimExpr& arg) { // specially set array handle. if (const CallNode* buf = arg.as()) { if (buf->is_intrinsic(intrinsic::tvm_struct_get) && @@ -351,8 +351,8 @@ class BuiltinLower : public StmtExprMutator { // The prepration sequence to be emitted. std::vector prep_seq_; - Expr device_type_; - Expr device_id_; + PrimExpr device_type_; + PrimExpr device_id_; // Var handle for each stack. Var stack_shape_; Var stack_array_; diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 75f128e85323..6a1c3c499b83 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -96,7 +96,7 @@ class WarpStoreCoeffFinder : private StmtVisitor { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); } else { - Expr base; + PrimExpr base; CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base)) << "LowerWarpMemory failed due to store index=" << op->index << ", can only handle continuous store"; @@ -107,13 +107,13 @@ class WarpStoreCoeffFinder : private StmtVisitor { } } - void UpdatePattern(const Expr& index) { - Array m = + void UpdatePattern(const PrimExpr& index) { + Array m = arith::DetectLinearEquation(index, {warp_index_}); CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; int coeff = 0; - Expr mcoeff = analyzer_->canonical_simplify(m[0]); + PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) << "LowerWarpMemory failed due to store index=" << index @@ -211,7 +211,7 @@ class WarpAccessRewriter : protected StmtExprMutator { } protected: - Expr Mutate_(const VarNode* op) { + PrimExpr Mutate_(const VarNode* op) { CHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); @@ -219,7 +219,7 @@ class WarpAccessRewriter : protected StmtExprMutator { Stmt VisitStmt_(const StoreNode* op) { if (op->buffer_var.get() == buffer_) { - Expr local_index, group; + PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate); } else { @@ -227,15 +227,15 @@ class WarpAccessRewriter : protected StmtExprMutator { } } - Expr Mutate_(const LoadNode* op) { + PrimExpr Mutate_(const LoadNode* op) { if (op->buffer_var.get() == buffer_) { - Expr local_index, group; + PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id CHECK(!ExprUseVar(local_index, {warp_index_.get()})) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; - Expr load_value = LoadNode::make( + PrimExpr load_value = LoadNode::make( op->dtype, op->buffer_var, local_index, op->predicate); return CallNode::make(load_value.dtype(), intrinsic::tvm_warp_shuffle, @@ -250,27 +250,27 @@ class WarpAccessRewriter : protected StmtExprMutator { // local index is the index in the local // source index is the corresponding source index // in this access pattern. - std::pair SplitIndexByGroup(const Expr& index) { + std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { - Expr base, local_index, group; + PrimExpr base, local_index, group; CHECK(GetRamp1Base(index, index.dtype().lanes(), &base)); std::tie(local_index, group) = SplitIndexByGroup(base); local_index = RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); } - Expr m = make_const(index.dtype(), warp_coeff_); + PrimExpr m = make_const(index.dtype(), warp_coeff_); // simple case, warp index is on the highest. if (warp_group_ == 1) { - Expr x = analyzer_->canonical_simplify(indexmod(index, m)); - Expr z = analyzer_->canonical_simplify(indexdiv(index, m)); + PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); + PrimExpr z = analyzer_->canonical_simplify(indexdiv(index, m)); return std::make_pair(x, z); } else { - Expr x = analyzer_->canonical_simplify(indexmod(index, m)); - Expr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_); + PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); + PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_); y = y * m + x; - Expr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)), + PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)), m); return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 56609bbfc712..d5c73a2e8a75 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -35,7 +35,7 @@ namespace tvm { namespace ir { -inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { +inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0)); } @@ -62,18 +62,18 @@ LoweredFunc MakeAPI(Stmt body, // seq_init gives sequence of initialization // seq_check gives sequence of later checks after init std::vector seq_init, seq_check; - std::unordered_map vmap; + std::unordered_map vmap; ArgBinder binder(&vmap); // --------------------------- // local function definitions // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { - Array call_args{v_packed_args, + Array call_args{v_packed_args, IntImmNode::make(DataType::Int(32), i), IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - Expr res = CallNode::make( + PrimExpr res = CallNode::make( api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); // cast to the target version. @@ -189,7 +189,7 @@ LoweredFunc MakeAPI(Stmt body, StringImmNode::make(name + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { - Expr node = StringImmNode::make("default"); + PrimExpr node = StringImmNode::make("default"); CHECK(vmap.count(device_type.get())); seq_check.push_back(AttrStmtNode::make( node, attr::device_context_id, device_id, nop)); @@ -226,7 +226,7 @@ class DeviceTypeBinder: public StmtExprMutator { if (op->attr_key == attr::device_context_type) { if (const VarNode* var = op->value.as()) { var_ = var; - Expr value = make_const(op->value.dtype(), device_type_); + PrimExpr value = make_const(op->value.dtype(), device_type_); Stmt body = StmtExprMutator::VisitStmt_(op); var_ = nullptr; std::ostringstream os; @@ -251,9 +251,9 @@ class DeviceTypeBinder: public StmtExprMutator { return res; } - Expr VisitExpr_(const NENode* op) final { + PrimExpr VisitExpr_(const NENode* op) final { // eager check NE for device check - Expr res = StmtExprMutator::VisitExpr_(op); + PrimExpr res = StmtExprMutator::VisitExpr_(op); op = res.as(); if (ir::Equal(op->a, op->b)) { return make_const(op->dtype, false); @@ -261,11 +261,11 @@ class DeviceTypeBinder: public StmtExprMutator { return res; } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { if (op == var_) { return make_const(op->dtype, device_type_); } else { - return GetRef(op); + return GetRef(op); } } diff --git a/src/pass/remap_thread_axis.cc b/src/pass/remap_thread_axis.cc index 2a486b5b8bd8..4201e785064a 100644 --- a/src/pass/remap_thread_axis.cc +++ b/src/pass/remap_thread_axis.cc @@ -63,7 +63,7 @@ class ThreadAxisRewriter : private StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { auto it = vmap_.find(op); if (it != vmap_.end()) return it->second; return StmtExprMutator::VisitExpr_(op); @@ -75,7 +75,7 @@ class ThreadAxisRewriter : private StmtExprMutator { }; LoweredFunc -RemapThreadAxis(LoweredFunc f, Map thread_map) { +RemapThreadAxis(LoweredFunc f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { const StringImmNode* str = kv.first.as(); diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc index 3c9114dd0901..eecbe30828d1 100644 --- a/src/pass/remove_no_op.cc +++ b/src/pass/remove_no_op.cc @@ -126,16 +126,16 @@ class NoOpRemover : public StmtMutator { } private: - Stmt MakeEvaluate(Expr value) { + Stmt MakeEvaluate(PrimExpr value) { if (HasSideEffect(value)) { return EvaluateNode::make(value); } else { return EvaluateNode::make(0); } } - Stmt MakeEvaluate(const Array& values) { + Stmt MakeEvaluate(const Array& values) { Stmt stmt; - for (Expr e : values) { + for (PrimExpr e : values) { if (HasSideEffect(e)) { if (stmt.defined()) { stmt = SeqStmt({stmt, EvaluateNode::make(e)}); diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index c38fac14351e..224a81c12396 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -31,7 +31,7 @@ namespace ir { // For now, rewrite unsafe select expression to if_then_else // TODO(tqchen) pattern matching to support masked load -class UnsafeExprDetector : public ExprFunctor { +class UnsafeExprDetector : public ExprFunctor { public: // select itself is always considered safe if condition is safe // Because we will issue guard to make sure it is. @@ -45,7 +45,7 @@ class UnsafeExprDetector : public ExprFunctor { const LoadNode* l = op->args[0].as(); return this->VisitExpr(l->index); } else if (op->is_pure()) { - for (Expr e : op->args) { + for (PrimExpr e : op->args) { if (VisitExpr(e)) return true; } return false; @@ -90,7 +90,7 @@ class UnsafeExprDetector : public ExprFunctor { return VisitExpr(op->base) && VisitExpr(op->stride); } bool VisitExpr_(const ShuffleNode* op) final { - for (Expr e : op->vectors) { + for (PrimExpr e : op->vectors) { if (VisitExpr(e)) return true; } return false; @@ -110,8 +110,8 @@ class UnsafeExprDetector : public ExprFunctor { class UnsafeSelectRewriter : public StmtExprMutator { public: - Expr VisitExpr_(const SelectNode* op) { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const SelectNode* op) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); UnsafeExprDetector unsafe; bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 3233e503760f..9737f7047ab6 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -30,7 +30,7 @@ namespace ir { class IRSideEffect : public ExprVisitor { public: - void VisitExpr(const Expr& e) final { + void VisitExpr(const PrimExpr& e) final { if (has_side_effect_) return; ExprVisitor::VisitExpr(e); } @@ -46,7 +46,7 @@ class IRSideEffect : public ExprVisitor { bool has_side_effect_{false}; }; -bool HasSideEffect(const Expr& e) { +bool HasSideEffect(const PrimExpr& e) { IRSideEffect v; v(e); return v.has_side_effect_; @@ -55,45 +55,45 @@ bool HasSideEffect(const Expr& e) { class IRSubstitue : public StmtExprMutator { public: explicit IRSubstitue( - const std::unordered_map& smap) + const std::unordered_map& smap) : smap_(smap) { } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { auto it = smap_.find(op); if (it != smap_.end()) { return it->second; } else { - return GetRef(op); + return GetRef(op); } } private: - const std::unordered_map& smap_; + const std::unordered_map& smap_; }; Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map) { + const std::unordered_map& value_map) { if (value_map.size() == 0) return stmt; return IRSubstitue(value_map)(std::move(stmt)); } -Expr Substitute(Expr expr, - const std::unordered_map& value_map) { +PrimExpr Substitute(PrimExpr expr, + const std::unordered_map& value_map) { if (value_map.size() == 0) return expr; return IRSubstitue(value_map)(std::move(expr)); } -Stmt Substitute(Stmt stmt, const Map& value_map) { - std::unordered_map vmap; +Stmt Substitute(Stmt stmt, const Map& value_map) { + std::unordered_map vmap; for (const auto& kv : value_map) { vmap[kv.first.get()] = kv.second; } return Substitute(stmt, vmap); } -Expr Substitute(Expr expr, const Map& value_map) { - std::unordered_map vmap; +PrimExpr Substitute(PrimExpr expr, const Map& value_map) { + std::unordered_map vmap; for (const auto& kv : value_map) { vmap[kv.first.get()] = kv.second; } @@ -102,7 +102,7 @@ Expr Substitute(Expr expr, const Map& value_map) { class VarTouchVisitor : public ExprVisitor { public: - void VisitExpr(const Expr& e) final { + void VisitExpr(const PrimExpr& e) final { if (use_var_) return; ExprVisitor::VisitExpr(e); } @@ -146,13 +146,13 @@ class ExprUseVSetVisitor : public VarTouchVisitor { const std::unordered_set& vset_; }; -bool ExprUseVar(const Expr& e, const Var& v) { +bool ExprUseVar(const PrimExpr& e, const Var& v) { ExprUseVarVisitor visitor(v.get()); visitor(e); return visitor.use_var_; } -bool ExprUseVar(const Expr& e, +bool ExprUseVar(const PrimExpr& e, const std::unordered_set& vset) { ExprUseVSetVisitor visitor(vset); visitor(e); diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index f71f13b01d51..7309c724099b 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -46,7 +46,7 @@ class IRUseDefAnalysis : public StmtExprMutator { thread_extent_.push_back(op->value); } - Expr value = op->value; + PrimExpr value = op->value; if (visit_thread_extent_) { value = this->VisitExpr(value); } @@ -68,7 +68,7 @@ class IRUseDefAnalysis : public StmtExprMutator { !HasSideEffect(op->value)) { return body; } else { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); @@ -93,30 +93,30 @@ class IRUseDefAnalysis : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Expr VisitExpr_(const LetNode* op) final { + PrimExpr VisitExpr_(const LetNode* op) final { this->HandleDef(op->var.get()); - Expr body = this->VisitExpr(op->body); + PrimExpr body = this->VisitExpr(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { - Expr value = this->VisitExpr(op->value); + PrimExpr value = this->VisitExpr(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); + return GetRef(op); } else { return LetNode::make(op->var, value, body); } } } - Expr VisitExpr_(const VarNode* op) final { - this->HandleUse(GetRef(op)); + PrimExpr VisitExpr_(const VarNode* op) final { + this->HandleUse(GetRef(op)); return StmtExprMutator::VisitExpr_(op); } - Expr VisitExpr_(const LoadNode* op) final { + PrimExpr VisitExpr_(const LoadNode* op) final { this->HandleUse(op->buffer_var); return StmtExprMutator::VisitExpr_(op); } @@ -132,7 +132,7 @@ class IRUseDefAnalysis : public StmtExprMutator { def_count_[v] = 1; } - void HandleUse(const Expr& v) { + void HandleUse(const PrimExpr& v) { CHECK(v.as()); Var var = Downcast(v); auto it = use_count_.find(var.get()); @@ -151,7 +151,7 @@ class IRUseDefAnalysis : public StmtExprMutator { bool visit_thread_extent_{true}; Array undefined_; Array thread_axis_; - Array thread_extent_; + Array thread_extent_; std::unordered_map use_count_; std::unordered_map def_count_; }; @@ -218,12 +218,12 @@ class HostDeviceSplitter : public StmtMutator { } } LoweredFunc f_device(n); - Array call_args; + Array call_args; call_args.push_back(StringImmNode::make(f_device->name)); for (Var arg : n->args) { call_args.push_back(arg); } - for (Expr ext : m.thread_extent_) { + for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } device_funcs_.emplace_back(f_device); @@ -236,7 +236,7 @@ class HostDeviceSplitter : public StmtMutator { std::string name_; // the device functions std::vector device_funcs_; - std::unordered_map handle_data_type_; + std::unordered_map handle_data_type_; }; diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 3dafb40e28c7..8375e806a006 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -37,7 +37,7 @@ class IRVerifySSA final : public StmtExprVisitor { public: bool is_ssa{true}; - void VisitExpr(const Expr& n) final { + void VisitExpr(const PrimExpr& n) final { if (!is_ssa) return; StmtExprVisitor::VisitExpr(n); } @@ -76,20 +76,20 @@ class IRVerifySSA final : public StmtExprVisitor { class IRConvertSSA final : public StmtExprMutator { public: - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { if (scope_.count(op)) { return scope_[op].back(); } else { - return GetRef(op); + return GetRef(op); } } - Expr VisitExpr_(const LetNode* op) final { - const VarExpr& v = op->var; + PrimExpr VisitExpr_(const LetNode* op) final { + const Var& v = op->var; if (defined_.count(v.get())) { - Expr value = this->VisitExpr(op->value); - VarExpr new_var = VarNode::make(v.dtype(), v->name_hint); + PrimExpr value = this->VisitExpr(op->value); + Var new_var = VarNode::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Expr body = this->VisitExpr(op->body); + PrimExpr body = this->VisitExpr(op->body); scope_[v.get()].pop_back(); return LetNode::make(new_var, value, body); } else { @@ -97,8 +97,8 @@ class IRConvertSSA final : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } } - Expr VisitExpr_(const LoadNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (scope_.count(op->buffer_var.get())) { return LoadNode::make( @@ -120,10 +120,10 @@ class IRConvertSSA final : public StmtExprMutator { } } Stmt VisitStmt_(const LetStmtNode* op) final { - const VarExpr& v = op->var; + const Var& v = op->var; if (defined_.count(v.get())) { - Expr value = this->VisitExpr(op->value); - VarExpr new_var = VarNode::make(v.dtype(), v->name_hint); + PrimExpr value = this->VisitExpr(op->value); + Var new_var = VarNode::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); Stmt body = this->VisitStmt(op->body); scope_[v.get()].pop_back(); @@ -134,9 +134,9 @@ class IRConvertSSA final : public StmtExprMutator { } } Stmt VisitStmt_(const ForNode* op) final { - const VarExpr& v = op->loop_var; + const Var& v = op->loop_var; if (defined_.count(v.get())) { - VarExpr new_var = VarNode::make(v.dtype(), v->name_hint); + Var new_var = VarNode::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); @@ -149,9 +149,9 @@ class IRConvertSSA final : public StmtExprMutator { } } Stmt VisitStmt_(const AllocateNode* op) final { - const VarExpr& v = op->buffer_var; + const Var& v = op->buffer_var; if (defined_.count(v.get())) { - VarExpr new_var = VarNode::make(v.dtype(), v->name_hint); + Var new_var = VarNode::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); @@ -191,7 +191,7 @@ class IRConvertSSA final : public StmtExprMutator { } private: - std::unordered_map > scope_; + std::unordered_map > scope_; std::unordered_set defined_; }; diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index cb779f983ff2..d98299f24160 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -187,8 +187,8 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); - Expr offset = op->args[2]; - Expr extent = op->args[3]; + PrimExpr offset = op->args[2]; + PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); StorageScope scope = GetScope(buffer); // The buffer scope. @@ -197,7 +197,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { AccessEntry e; e.threads = env_threads(); e.dtype = dtype; - e.buffer = Downcast(op->args[1]); + e.buffer = Downcast(op->args[1]); e.touched = arith::IntSet::range( Range::make_by_min_extent(offset, extent)); e.scope = scope; @@ -277,7 +277,7 @@ class StorageAccessInfoLower : public StmtExprMutator { } } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { return MakeAccessPtr(op); } else { @@ -287,15 +287,15 @@ class StorageAccessInfoLower : public StmtExprMutator { private: // tvm_access_ptr - Expr MakeAccessPtr(const CallNode* op) { + PrimExpr MakeAccessPtr(const CallNode* op) { // Specially handle the buffer packed intrinsic - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); Var buffer_var = Downcast(op->args[1]); - Expr offset = op->args[2]; + PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); if (it != storage_info_.end() && it->second.info.defined()) { return MakeTaggedAccessPtr( @@ -307,10 +307,10 @@ class StorageAccessInfoLower : public StmtExprMutator { return AddressOffset(buffer_var, dtype, offset); } - Expr MakeTaggedAccessPtr(DataType ptr_type, + PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, - Expr offset, + PrimExpr offset, const MemoryInfo& info) { if (ptr_type.is_handle()) { CHECK(info->head_address.defined()) diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index ea828ff490c8..08c61aafbc0c 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -70,7 +70,7 @@ class StorageFlattener : public StmtExprMutator { if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); - VarExpr buf_var = Downcast(it->second); + Var buf_var = Downcast(it->second); return StoreNode::make(buf_var, op->value, op->index, op->predicate); } else { return stmt; @@ -167,7 +167,7 @@ class StorageFlattener : public StmtExprMutator { // create a buffer entry BufferEntry e; e.bounds = op->bounds; - Array shape; + Array shape; for (auto r : e.bounds) { shape.push_back(r->extent); } @@ -198,29 +198,29 @@ class StorageFlattener : public StmtExprMutator { << "Allocation exceed bound of memory tag " << skey.to_string(); } } - Array strides; + Array strides; if (dim_align_.count(key) != 0 && shape.size() != 0) { - std::vector rstrides; + std::vector rstrides; const std::vector& avec = dim_align_[key]; int first_dim = 0; - Expr stride = make_const(shape[first_dim].dtype(), 1); + PrimExpr stride = make_const(shape[first_dim].dtype(), 1); for (size_t i = shape.size(); i != 0; --i) { size_t dim = i - 1; if (dim < avec.size() && avec[dim].align_factor != 0) { - Expr factor = make_const(stride.dtype(), avec[dim].align_factor); - Expr offset = make_const(stride.dtype(), avec[dim].align_offset); + PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); stride = ir::Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; } - strides = Array(rstrides.rbegin(), rstrides.rend()); + strides = Array(rstrides.rbegin(), rstrides.rend()); } e.buffer = BufferNode::make( Var(key.GetName(), DataType::Handle()), - op->dtype, shape, strides, Expr(), + op->dtype, shape, strides, PrimExpr(), key.GetName(), skey.to_string(), align, 0, kDefault); @@ -262,31 +262,31 @@ class StorageFlattener : public StmtExprMutator { } } - Expr VisitExpr_(const LoadNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = var_remap_.find(op->buffer_var.get()); if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); - VarExpr buf_var = Downcast(it->second); + Var buf_var = Downcast(it->second); return LoadNode::make(op->dtype, buf_var, op->index, op->predicate); } else { return expr; } } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { return it->second; } else { - return GetRef(op); + return GetRef(op); } } - Expr VisitExpr_(const CallNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op != nullptr && op->call_type == CallNode::Halide) { TensorKey key{op->func, op->value_index}; @@ -332,20 +332,20 @@ class StorageFlattener : public StmtExprMutator { block_size *= shape; starts--; } - Expr stride(elem_cnt / block_size); + PrimExpr stride(elem_cnt / block_size); - Array args; - std::vector vars; + Array args; + std::vector vars; for (int i = op->bounds.size() - 1; i > starts; --i) { args.push_back(op->bounds[i]->min); } auto &func_name = op->func->func_name(); - vars.push_back(VarExpr( + vars.push_back(Var( "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); args.push_back(op->bounds[starts]->min + stride * vars.back()); for (int i = starts - 1; i >= 0; --i) { - vars.push_back(VarExpr( + vars.push_back(Var( "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); args.push_back(vars.back() + op->bounds[i]->min); } @@ -354,13 +354,13 @@ class StorageFlattener : public StmtExprMutator { stmt = ForNode::make( vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); } else { - Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); - Expr address = CallNode::make( + PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); + PrimExpr address = CallNode::make( DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); - Expr prefetch = CallNode::make( + PrimExpr prefetch = CallNode::make( op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); stmt = EvaluateNode::make(prefetch); - Expr extent = (op->bounds[i]->extent - 1) / stride + 1; + PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); } } @@ -416,7 +416,7 @@ class StorageFlattener : public StmtExprMutator { const BufferEntry& be = buf_map_.at(key); CHECK(!be.released); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); - Array begins, extents; + Array begins, extents; if (be.bounds.size() != 0) { CHECK_EQ(tuple->args.size(), be.bounds.size() * 2); for (size_t i = 0; i < be.buffer->shape.size(); ++i) { @@ -467,9 +467,9 @@ class StorageFlattener : public StmtExprMutator { // Whether we are out of allocation bounds and buffer get released. bool released{false}; // relative index - inline Array RelIndex(Array args) const { + inline Array RelIndex(Array args) const { if (bounds.size() != 0) { - Array index; + Array index; CHECK_EQ(bounds.size(), args.size()); for (size_t i = 0; i < bounds.size(); ++i) { index.push_back(args[i] - bounds[i]->min); @@ -481,7 +481,7 @@ class StorageFlattener : public StmtExprMutator { } }; - bool ShapeIsValid(const Array &shape) { + bool ShapeIsValid(const Array &shape) { // Zero-dimensional tensor does not need boundary check. if (!shape.size()) return false; @@ -495,9 +495,9 @@ class StorageFlattener : public StmtExprMutator { return true; } - Expr MakeBound(const DataType &type, const Array &shape) { + PrimExpr MakeBound(const DataType &type, const Array &shape) { // We have already checked the shape size to be greater then 0. - Expr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); + PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); for (size_t i = 1; i < shape.size(); ++i) { bound = MulNode::make( bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); @@ -507,7 +507,7 @@ class StorageFlattener : public StmtExprMutator { // The buffer assignment map // Variable remap - std::unordered_map var_remap_; + std::unordered_map var_remap_; // Buffer map std::unordered_map buf_map_; // Dimension alignment @@ -517,7 +517,7 @@ class StorageFlattener : public StmtExprMutator { // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. - std::vector>> shape_collector_; + std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. // However IRVisitorWithAnalyzer* bounded_analyzer_; diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 928be4b29561..7a4b13cb2cf5 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -261,7 +261,7 @@ class InplaceOpVerifier : public StmtExprVisitor { if (!result_) return; StmtExprVisitor::VisitStmt(n); } - void VisitExpr(const Expr& n) final { + void VisitExpr(const PrimExpr& n) final { if (!result_) return; StmtExprVisitor::VisitExpr(n); } @@ -376,8 +376,8 @@ class StoragePlanRewriter : public StmtExprMutator { RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } - Expr VisitExpr_(const LoadNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; @@ -386,7 +386,7 @@ class StoragePlanRewriter : public StmtExprMutator { RemapIndex(op->dtype, op->index, it->second), op->predicate); } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); if (it != alloc_map_.end()) { if (it->second->bits_offset != 0) { @@ -394,10 +394,10 @@ class StoragePlanRewriter : public StmtExprMutator { } return it->second->alloc_var; } else { - return GetRef(op); + return GetRef(op); } } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -407,8 +407,8 @@ class StoragePlanRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } const StorageEntry* se = it->second; - Expr offset = this->VisitExpr(op->args[2]); - Expr extent = this->VisitExpr(op->args[3]); + PrimExpr offset = this->VisitExpr(op->args[2]); + PrimExpr extent = this->VisitExpr(op->args[3]); uint64_t elem_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(se->bits_offset % elem_bits, 0U); if (se->bits_offset != 0) { @@ -488,7 +488,7 @@ class StoragePlanRewriter : public StmtExprMutator { // The replacement allocation, if any. Stmt new_alloc; // The var expr of new allocation. - VarExpr alloc_var; + Var alloc_var; // The allocation element type. DataType elem_type; // This is non-zero if this allocate is folded into another one @@ -529,7 +529,7 @@ class StoragePlanRewriter : public StmtExprMutator { return MergeNest(nest, body); } // Remap the index - Expr RemapIndex(DataType dtype, Expr index, StorageEntry* e) { + PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) { if (e->bits_offset == 0) return index; uint64_t elem_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(e->bits_offset % elem_bits, 0U); @@ -577,7 +577,7 @@ class StoragePlanRewriter : public StmtExprMutator { } if (e->allocs.size() == 1) { // simply use the original allocation. - Expr sz = arith::ComputeReduce(e->allocs[0]->extents, + PrimExpr sz = arith::ComputeReduce(e->allocs[0]->extents, make_const(DataType::Int(32), 1)); e->new_alloc = AllocateNode::make( e->alloc_var, alloc_type, {sz}, @@ -590,9 +590,10 @@ class StoragePlanRewriter : public StmtExprMutator { } } else { // Build a merged allocation - Expr combo_size; + PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - Expr sz = arith::ComputeReduce(op->extents, make_const(DataType::Int(32), 1)); + PrimExpr sz = arith::ComputeReduce( + op->extents, make_const(DataType::Int(32), 1)); auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { @@ -663,7 +664,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); - Expr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), + PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); e->new_alloc = AllocateNode::make( e->alloc_var, e->elem_type, {alloc_size}, const_true(), @@ -936,7 +937,7 @@ class StoragePlanRewriter : public StmtExprMutator { // if all its access is the same vector type. class VectorAllocRewriter : public StmtExprMutator { public: - Expr VisitExpr_(const LoadNode* op) final { + PrimExpr VisitExpr_(const LoadNode* op) final { UpdateTypeMap(op->buffer_var.get(), op->dtype); return StmtExprMutator::VisitExpr_(op); } @@ -945,7 +946,7 @@ class VectorAllocRewriter : public StmtExprMutator { UpdateTypeMap(op->buffer_var.get(), op->value.dtype()); return StmtExprMutator::VisitStmt_(op); } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -964,7 +965,7 @@ class VectorAllocRewriter : public StmtExprMutator { tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { int factor = tvec[0].lanes() / op->dtype.lanes(); - Array extents = op->extents; + Array extents = op->extents; arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, @@ -999,13 +1000,13 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { if (arg.dtype().is_handle()) { const auto& tvec = rewriter.acc_map_[arg.get()]; if (tvec.size() == 1) { - Expr dtype = make_const(tvec[0], 0); + PrimExpr dtype = make_const(tvec[0], 0); n->handle_data_type.Set(arg, dtype); } else { // always set data type to be non vectorized so // load/store can still work via scalarization if (tvec.size() != 0 && !n->handle_data_type.count(arg)) { - Expr dtype = make_const(tvec[0].with_lanes(1), 0); + PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0); n->handle_data_type.Set(arg, dtype); } } diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 7edf98b8007a..2358ce999231 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -222,7 +222,7 @@ class ThreadSyncInserter : public StmtExprMutator { return StmtExprMutator::VisitStmt(stmt); } } - Expr VisitExpr_(const LoadNode* op) final { + PrimExpr VisitExpr_(const LoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; @@ -247,8 +247,8 @@ class ThreadSyncInserter : public StmtExprMutator { // first thread scope. if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) { ret = InitGlobalBarrier(ret.as()); - num_blocks_ = Expr(); - is_lead_ = Expr(); + num_blocks_ = PrimExpr(); + is_lead_ = PrimExpr(); } return ret; } else if (op->attr_key == attr::storage_scope) { @@ -261,9 +261,9 @@ class ThreadSyncInserter : public StmtExprMutator { } } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); const VarNode* buffer_var = op->args[1].as(); @@ -300,7 +300,7 @@ class ThreadSyncInserter : public StmtExprMutator { // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { CHECK(op != nullptr); - Array pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)}; + Array pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = EvaluateNode::make( CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); Stmt body = op->body; @@ -332,7 +332,7 @@ class ThreadSyncInserter : public StmtExprMutator { num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); } else if (s.rank == 1) { - Expr cond = iv->var == make_zero(iv->var.dtype()); + PrimExpr cond = iv->var == make_zero(iv->var.dtype()); is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; } } @@ -351,14 +351,14 @@ class ThreadSyncInserter : public StmtExprMutator { // The storage scope of each buffer std::unordered_map storage_scope_; // The read write statistics of storage - std::unordered_map rw_stats_; + std::unordered_map rw_stats_; // The statistics for global barrier bool in_thread_env_{false}; // memorized results std::vector thread_extents_; size_t num_work_dim_{0}; - Expr num_blocks_; - Expr is_lead_; + PrimExpr num_blocks_; + PrimExpr is_lead_; }; Stmt ThreadSync(Stmt stmt, std::string storage_scope) { diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index b2658d9d7a99..bb57fe8c37d3 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -59,14 +59,14 @@ std::string simplify_name(std::string input) { } } -Expr unpack_type_cast(const Expr &input, const DataType &target_type) { +PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) { auto cast = input.as(); if (cast == nullptr) { return input; } else if (cast->dtype == target_type) { return cast->value; } - return Expr(); + return PrimExpr(); } // MMAMatcher matches C = Cast(A)*Cast(B)+C, @@ -217,14 +217,14 @@ class MMAMatcher: public StmtVisitor { buf_name_.insert(std::make_pair(load_a, buffer_a.name)); buf_name_.insert(std::make_pair(load_b, buffer_b.name)); mma_sync_.insert(std::make_pair(op, - Array{load_a_expr, load_b_expr, add->a})); + Array{load_a_expr, load_b_expr, add->a})); return true; } std::unordered_map buf_map_; std::unordered_map storage_scope_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map buf_name_; std::unordered_set frag_reg_; bool matched_{false}; @@ -243,7 +243,7 @@ class BodyVisitor : public StmtExprVisitor { if (comm_add == nullptr || op->combiner->result.size() > 1) { return; } - for (Expr source : op->source) { + for (PrimExpr source : op->source) { auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as(); auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as(); if (mul_0 == nullptr && mul_1 == nullptr) { @@ -263,7 +263,7 @@ class BodyVisitor : public StmtExprVisitor { friend class ScheduleAnalyser; private: - std::unordered_map> args_; + std::unordered_map> args_; bool tensorcore_candidate_{false}; }; @@ -294,7 +294,7 @@ class ScheduleAnalyser { reduce_axis_var = reduce_axis[0]->var.as(); BodyVisitor body_visitor; - for (Expr expr : compute->body) { + for (PrimExpr expr : compute->body) { body_visitor(expr); } if (!body_visitor.tensorcore_candidate_) { @@ -347,7 +347,7 @@ class ScheduleAnalyser { if (it0->second == "matrix_a" && it1->second == "matrix_b") { return true; } else if (it0->second == "matrix_b" && it1->second == "matrix_a") { - mma_sync.second = Array{operands[1], operands[0], operands[2]}; + mma_sync.second = Array{operands[1], operands[0], operands[2]}; } else { return false; } @@ -361,7 +361,7 @@ class ScheduleAnalyser { private: std::unordered_map matrix_abc_; std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map buf_name_; }; @@ -457,12 +457,12 @@ class BufferAnalyser : public StmtExprVisitor { } } - Array strides; + Array strides; if (bi.strides.size() > 0) { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - Expr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -473,7 +473,7 @@ class BufferAnalyser : public StmtExprVisitor { strides_.insert(std::make_pair(key.GetName(), strides)); if (frag_reg_.count(bi.name)) { - Expr dst = CallNode::make(bi.dtype, + PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, @@ -535,7 +535,7 @@ class BufferAnalyser : public StmtExprVisitor { const CallNode* value = op->value.as(); if (value != nullptr && frag_reg_.count(value->name)) { - Expr dst = CallNode::make(bi.dtype, + PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, @@ -570,12 +570,12 @@ class BufferAnalyser : public StmtExprVisitor { } } - Array strides; + Array strides; if (bi.strides.size() > 0) { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - Expr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -616,22 +616,22 @@ class BufferAnalyser : public StmtExprVisitor { BufferInfo bi; bi.bounds = op->bounds; - Array shape; + Array shape; for (auto r : bi.bounds) { shape.push_back(r->extent); } - Array strides; + Array strides; if (dim_align_.count(key) != 0 && shape.size() != 0) { - std::vector rstrides; + std::vector rstrides; const std::vector& avec = dim_align_[key]; int first_dim = 0; - Expr stride = make_const(shape[first_dim].dtype(), 1); + PrimExpr stride = make_const(shape[first_dim].dtype(), 1); for (size_t i = shape.size(); i != 0; --i) { size_t dim = i - 1; if (dim < avec.size() && avec[dim].align_factor != 0) { - Expr factor = make_const(stride.dtype(), avec[dim].align_factor); - Expr offset = make_const(stride.dtype(), avec[dim].align_offset); + PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + \ indexmod(factor + offset - indexmod(stride, factor), factor); stride = ir::Simplify(stride); @@ -639,7 +639,7 @@ class BufferAnalyser : public StmtExprVisitor { rstrides.push_back(stride); stride = stride * shape[dim]; } - strides = Array(rstrides.rbegin(), rstrides.rend()); + strides = Array(rstrides.rbegin(), rstrides.rend()); } bi.name = key.GetName(); @@ -689,14 +689,14 @@ class BufferAnalyser : public StmtExprVisitor { struct BufferInfo { std::string name; DataType dtype; - Array strides; - Array shape; + Array strides; + Array shape; Region bounds; bool external{false}; bool released{false}; - inline Array RelIndex(Array args) const { + inline Array RelIndex(Array args) const { if (bounds.size() != 0) { - Array index; + Array index; CHECK_EQ(bounds.size(), args.size()); for (size_t i = 0; i < bounds.size(); ++i) { index.push_back(args[i] - bounds[i]->min); @@ -744,9 +744,9 @@ class BufferAnalyser : public StmtExprVisitor { std::unordered_map matrix_abc_; std::unordered_map matrix_major_; std::unordered_set frag_reg_; - std::unordered_map> strides_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; + std::unordered_map> strides_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; std::unordered_map thread_extent_; IndexVisitor index_visitor; Tile warp_tile_; @@ -758,19 +758,19 @@ class BufferAnalyser : public StmtExprVisitor { // ThreadIdxMutator does the thread index unification inside a warp class ThreadIdxMutator : public StmtExprMutator { public: - explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {} + explicit ThreadIdxMutator(PrimExpr warp_y): warp_y_(warp_y) {} - Expr VisitExpr_(const VarNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const VarNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op != nullptr) { if (op->name_hint == "threadIdx.x") { - Expr zero = IntImmNode::make(DataType::Int(32), 0); + PrimExpr zero = IntImmNode::make(DataType::Int(32), 0); return zero; } if (op->name_hint == "threadIdx.y") { - Expr div = DivNode::make(expr, warp_y_); - Expr mul = MulNode::make(div, warp_y_); + PrimExpr div = DivNode::make(expr, warp_y_); + PrimExpr mul = MulNode::make(div, warp_y_); return mul; } } @@ -778,7 +778,7 @@ class ThreadIdxMutator : public StmtExprMutator { } private: - Expr warp_y_; + PrimExpr warp_y_; }; // TensorCoreIRMutator mutates the AST for TensorCore CodeGen @@ -856,11 +856,11 @@ class TensorCoreIRMutator : public StmtExprMutator { auto it = mma_sync_.find(op); if (it != mma_sync_.end()) { const auto &operands = it->second; - Expr a = operands[0]; + PrimExpr a = operands[0]; auto ca = a.as(); - Expr b = operands[1]; + PrimExpr b = operands[1]; auto cb = b.as(); - Expr c = operands[2]; + PrimExpr c = operands[2]; auto cc = c.as(); ObjectPtr buffer_node_a = make_object(); @@ -900,7 +900,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto it2 = frag_load_.find(op); if (it2 != frag_load_.end()) { - Expr dst = it2->second; + PrimExpr dst = it2->second; if (op->value.as() != nullptr || op->value.as() != nullptr) { auto call = dst.as(); @@ -931,19 +931,19 @@ class TensorCoreIRMutator : public StmtExprMutator { << "Cannot find stride for " << value->name; auto strides = it->second; CHECK_GE(strides.size(), 2); - Expr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size()-2]; // thread index unification inside a warp - Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); - Expr mutated_value = thread_idx_mutator(op->value); - Expr src = CallNode::make(value->dtype, + PrimExpr mutated_value = thread_idx_mutator(op->value); + PrimExpr src = CallNode::make(value->dtype, "&", {mutated_value}, CallNode::Extern); auto call = dst.as(); - Expr matrix_major; + PrimExpr matrix_major; auto iter2 = matrix_major_.find(simplify_name(call->name)); CHECK(iter2 != matrix_major_.end()) << "Can not determine matrix major for " << call->name; @@ -980,11 +980,11 @@ class TensorCoreIRMutator : public StmtExprMutator { << "Cannot find stride for " << key.GetName(); auto strides = it->second; CHECK_GE(strides.size(), 2); - Expr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size()-2]; - Expr dst = it3->second; + PrimExpr dst = it3->second; // thread index unification inside a warp - Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); dst = CallNode::make(DataType::Handle(), @@ -1027,7 +1027,7 @@ class TensorCoreIRMutator : public StmtExprMutator { int ori_extent_value = ori_extent->value; scaled_extent_value = ori_extent_value / scale_factor; } - Expr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); + PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body); } @@ -1036,13 +1036,13 @@ class TensorCoreIRMutator : public StmtExprMutator { } private: - Array get_tile_size_(const std::string &name) { + Array get_tile_size_(const std::string &name) { auto it = matrix_abc_.find(name); auto it2 = matrix_major_.find(name); CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) << "Cannot find matrix info for " << name; - Expr size0 = make_const(DataType::Int(32), 16); - Expr size1 = make_const(DataType::Int(32), 16); + PrimExpr size0 = make_const(DataType::Int(32), 16); + PrimExpr size1 = make_const(DataType::Int(32), 16); if (it->second == "matrix_a" && it2->second == "col_major") { size0 = make_const(DataType::Int(32), warp_tile_.k); size1 = make_const(DataType::Int(32), warp_tile_.m); @@ -1063,7 +1063,7 @@ class TensorCoreIRMutator : public StmtExprMutator { size0 = make_const(DataType::Int(32), warp_tile_.n); size1 = make_const(DataType::Int(32), warp_tile_.m); } - Array tile_size = {size0, size1}; + Array tile_size = {size0, size1}; return tile_size; } @@ -1073,13 +1073,13 @@ class TensorCoreIRMutator : public StmtExprMutator { DataType datatype) { auto it = bounds_.find(key); CHECK(it != bounds_.end()); - Array min_bound; + Array min_bound; for (auto i : it->second) { min_bound.push_back(i->min); } CHECK_GE(it->second.size(), 2); - Array shape; + Array shape; for (size_t i = 0; i < it->second.size() - 2; ++i) { shape.push_back(it->second[i]->extent); } @@ -1087,9 +1087,9 @@ class TensorCoreIRMutator : public StmtExprMutator { shape.push_back(tile_size[0]); shape.push_back(tile_size[1]); - Array strides; + Array strides; for (size_t i = 1; i < shape.size(); ++i) { - Expr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, shape[j]); } @@ -1097,7 +1097,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } strides.push_back(make_const(DataType::Int(32), 1)); - Expr elem_offset = IntImmNode::make(DataType::Int(32), 0); + PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0); CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( @@ -1126,7 +1126,7 @@ class TensorCoreIRMutator : public StmtExprMutator { tensor_node->dtype = datatype; Tensor tensor(tensor_node); - Array args; + Array args; for (size_t i = 0; i < call->args.size(); ++i) { args.push_back(call->args[i]); args.push_back(shape[i]); @@ -1144,12 +1144,12 @@ class TensorCoreIRMutator : public StmtExprMutator { std::unordered_map matrix_abc_; std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; - std::unordered_map> strides_; + std::unordered_map> mma_sync_; + std::unordered_map> strides_; std::unordered_set frag_reg_; std::unordered_map loop_scaling_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; std::unordered_map bounds_; Tile warp_tile_; int warp_threads_y_{-1}; diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index e2e7ad08c475..b2c50f7a8bd2 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -143,7 +143,7 @@ class LoopUnroller : public StmtExprMutator { CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return EvaluateNode::make(0); Stmt body = op->body; - Map vmap; + Map vmap; Array unrolled; for (int i = 0; i < value; ++i) { vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); @@ -157,7 +157,7 @@ class LoopUnroller : public StmtExprMutator { // returns the extent of the loop if it's a constant integer, otherwise return -1 int GetExtent(const ForNode* op) { // constant folding. - Expr extent = ir::Simplify(op->extent); + PrimExpr extent = ir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); const UIntImmNode *v2 = extent.as(); int value = -1; diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 450c6bab117c..c9f3441e6c47 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -33,7 +33,7 @@ namespace tvm { namespace ir { -inline Expr BroadcastTo(Expr e, int lanes) { +inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { if (e.dtype().lanes() == lanes) return e; if (const BroadcastNode* op = e.as()) { if (lanes % op->lanes == 0) { @@ -59,8 +59,8 @@ class VecAllocAccess : public StmtExprMutator { VecAllocAccess(const VarNode* buf, Var var, int var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} // Load - Expr VisitExpr_(const LoadNode* op) final { - Expr expr = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->buffer_var.get() == buf_) { return LoadNode::make(op->dtype, op->buffer_var, @@ -111,18 +111,18 @@ class Vectorizer : public StmtExprMutator { } } - Expr VisitExpr_(const AddNode* op) final { + PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op); } - Expr VisitExpr_(const SubNode* op) final { + PrimExpr VisitExpr_(const SubNode* op) final { return AddSubVec(op); } - Expr VisitExpr_(const MulNode* op) final { - Expr a = this->VisitExpr(op->a); - Expr b = this->VisitExpr(op->b); + PrimExpr VisitExpr_(const MulNode* op) final { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); if (lanes != 1) { @@ -141,51 +141,51 @@ class Vectorizer : public StmtExprMutator { } return BinaryVec(op); } - Expr VisitExpr_(const DivNode* op) final { + PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const ModNode* op) final { + PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const FloorDivNode* op) final { + PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const FloorModNode* op) final { + PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const MinNode* op) final { + PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const MaxNode* op) final { + PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const EQNode* op) final { + PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const NENode* op) final { + PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const LTNode* op) final { + PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const LENode* op) final { + PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const GTNode* op) final { + PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const GENode* op) final { + PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const AndNode* op) final { + PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const OrNode* op) final { + PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const RampNode* op) final { - Expr base = this->VisitExpr(op->base); - Expr stride = this->VisitExpr(op->stride); + PrimExpr VisitExpr_(const RampNode* op) final { + PrimExpr base = this->VisitExpr(op->base); + PrimExpr stride = this->VisitExpr(op->stride); if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { const RampNode* base_ramp = base.as(); if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { @@ -195,7 +195,7 @@ class Vectorizer : public StmtExprMutator { int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); base = BroadcastTo(base, lanes); stride = BroadcastTo(stride, lanes); - Array elems; + Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( RampNode::make(ShuffleNode::make_extract_element(base, i), @@ -204,14 +204,14 @@ class Vectorizer : public StmtExprMutator { } return ShuffleNode::make_concat(elems); } - Expr VisitExpr_(const SelectNode *op) final { - Expr cond = this->VisitExpr(op->condition); - Expr t = this->VisitExpr(op->true_value); - Expr f = this->VisitExpr(op->false_value); + PrimExpr VisitExpr_(const SelectNode *op) final { + PrimExpr cond = this->VisitExpr(op->condition); + PrimExpr t = this->VisitExpr(op->true_value); + PrimExpr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { - return GetRef(op); + return GetRef(op); } else { int lanes = std::max(std::max( cond.dtype().lanes(), @@ -219,37 +219,37 @@ class Vectorizer : public StmtExprMutator { return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } - Expr VisitExpr_(const CastNode *op) final { - Expr value = this->VisitExpr(op->value); + PrimExpr VisitExpr_(const CastNode *op) final { + PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return GetRef(op); } else { return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value); } } // Variable - Expr VisitExpr_(const VarNode* v) final { + PrimExpr VisitExpr_(const VarNode* v) final { if (v == var_.get()) { return ramp_; } else if (lets_.count(v)) { return lets_[v]; } else { - return GetRef(v); + return GetRef(v); } } // IfThenElse expr - Expr MutateIfThenElseExpr_(const CallNode *op) { - Expr cond = this->VisitExpr(op->args[0]); + PrimExpr MutateIfThenElseExpr_(const CallNode *op) { + PrimExpr cond = this->VisitExpr(op->args[0]); if (cond.dtype().is_vector()) { need_scalarize_ = true; - return GetRef(op); + return GetRef(op); } - Expr t = this->VisitExpr(op->args[1]); - Expr f = this->VisitExpr(op->args[2]); + PrimExpr t = this->VisitExpr(op->args[1]); + PrimExpr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { - return GetRef(op); + return GetRef(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); @@ -260,33 +260,33 @@ class Vectorizer : public StmtExprMutator { } } // Call - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->name == intrinsic::tvm_if_then_else) { return MutateIfThenElseExpr_(op); } if (!op->is_vectorizable()) { // Cannot vectorize this op - Array new_args; + Array new_args; for (auto arg : op->args) { auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_vector()) { need_scalarize_ = true; - return GetRef(op); + return GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { - return GetRef(op); + return GetRef(op); } else { return CallNode::make( op->dtype, op->name, new_args, op->call_type, op->func, op->value_index); } } else { int lane = 0; - Array new_args = MutateArray(op->args, &lane); + Array new_args = MutateArray(op->args, &lane); // normal code path. if (op->args.same_as(new_args)) { - return GetRef(op); + return GetRef(op); } else { return CallNode::make( op->dtype.with_lanes(lane), op->name, new_args, @@ -295,11 +295,11 @@ class Vectorizer : public StmtExprMutator { } } // Load - Expr VisitExpr_(const LoadNode* op) final { - Expr index = this->VisitExpr(op->index); - Expr pred = this->VisitExpr(op->predicate); + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr index = this->VisitExpr(op->index); + PrimExpr pred = this->VisitExpr(op->predicate); if (index.same_as(op->index) && pred.same_as(op->predicate)) { - return GetRef(op); + return GetRef(op); } else { int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); return LoadNode::make( @@ -310,18 +310,18 @@ class Vectorizer : public StmtExprMutator { } } // Let - Expr VisitExpr_(const LetNode* op) final { - Expr value = this->VisitExpr(op->value); + PrimExpr VisitExpr_(const LetNode* op) final { + PrimExpr value = this->VisitExpr(op->value); CHECK(!lets_.count(op->var.get())) << "not SSA"; if (value.dtype().lanes() != op->value.dtype().lanes()) { Var v(op->var->name_hint, value.dtype()); lets_[op->var.get()] = v; return LetNode::make(v, value, this->VisitExpr(op->body)); } else { - Expr body = this->VisitExpr(op->body); + PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return GetRef(op); } else { return LetNode::make(op->var, value, body); } @@ -329,9 +329,9 @@ class Vectorizer : public StmtExprMutator { } // Provide Stmt VisitStmt_(const ProvideNode* op) final { - Expr new_value = this->VisitExpr(op->value); + PrimExpr new_value = this->VisitExpr(op->value); int lane = new_value.dtype().lanes(); - Array new_args = MutateArray(op->args, &lane); + Array new_args = MutateArray(op->args, &lane); if (op->args.same_as(new_args) && op->value.same_as(new_value)) { return GetRef(op); } else { @@ -341,9 +341,9 @@ class Vectorizer : public StmtExprMutator { } // Store Stmt VisitStmt_(const StoreNode* op) final { - Expr value = this->VisitExpr(op->value); - Expr index = this->VisitExpr(op->index); - Expr pred = this->VisitExpr(op->predicate); + PrimExpr value = this->VisitExpr(op->value); + PrimExpr index = this->VisitExpr(op->index); + PrimExpr pred = this->VisitExpr(op->predicate); if (value.same_as(op->value) && index.same_as(op->index)) { return GetRef(op); } else { @@ -362,7 +362,7 @@ class Vectorizer : public StmtExprMutator { } CHECK(is_zero(op->min)); CHECK(!op->extent.dtype().is_vector()); - Expr extent = this->VisitExpr(op->extent); + PrimExpr extent = this->VisitExpr(op->extent); if (extent.dtype().is_vector()) { return Scalarize(GetRef(op)); } @@ -379,7 +379,7 @@ class Vectorizer : public StmtExprMutator { // IfThenElse Stmt VisitStmt_(const IfThenElseNode* op) final { CHECK(!op->condition.dtype().is_vector()); - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); if (condition.dtype().is_vector()) { return Scalarize(GetRef(op)); } @@ -407,14 +407,14 @@ class Vectorizer : public StmtExprMutator { LOG(WARNING) << "Cannot vectorize with new expr"; return Scalarize(GetRef(op)); } - Expr condition = this->VisitExpr(op->condition); + PrimExpr condition = this->VisitExpr(op->condition); if (condition.dtype().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; return Scalarize(GetRef(op)); } - Array extents; + Array extents; for (size_t i = 0; i < op->extents.size(); i++) { - Expr new_ext = this->VisitExpr(op->extents[i]); + PrimExpr new_ext = this->VisitExpr(op->extents[i]); if (new_ext.dtype().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; return Scalarize(GetRef(op)); @@ -435,7 +435,7 @@ class Vectorizer : public StmtExprMutator { // scalarize the statment Stmt Scalarize(Stmt stmt) { Var idx(var_->name_hint + ".s", var_->dtype); - Map values{{var_, idx}}; + Map values{{var_, idx}}; stmt = Substitute(stmt, values); return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); } @@ -448,21 +448,21 @@ class Vectorizer : public StmtExprMutator { // the lanes. int var_lanes_; // ramp representing the var. - Expr ramp_; + PrimExpr ramp_; // flag to mark requirment of scalarization. bool need_scalarize_{false}; // The lets - std::unordered_map lets_; + std::unordered_map lets_; // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. - Array MutateArray(Array arr, int* p_lanes) { + Array MutateArray(Array arr, int* p_lanes) { if (arr.size() == 0) return arr; int& lanes = *p_lanes; bool changed = false; - std::vector new_arr(arr.size()); + std::vector new_arr(arr.size()); for (size_t i = 0; i < arr.size(); i++) { - Expr old_elem = arr[i]; - Expr new_elem = this->VisitExpr(old_elem); + PrimExpr old_elem = arr[i]; + PrimExpr new_elem = this->VisitExpr(old_elem); if (!new_elem.same_as(old_elem)) changed = true; new_arr[i] = new_elem; lanes = std::max(lanes, new_elem.dtype().lanes()); @@ -475,27 +475,27 @@ class Vectorizer : public StmtExprMutator { } } if (!changed) return arr; - return Array(new_arr); + return Array(new_arr); } template - Expr BinaryVec(const T* op) { - Expr a = this->VisitExpr(op->a); - Expr b = this->VisitExpr(op->b); + PrimExpr BinaryVec(const T* op) { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } template - Expr AddSubVec(const T* op) { - Expr a = this->VisitExpr(op->a); - Expr b = this->VisitExpr(op->b); + PrimExpr AddSubVec(const T* op) { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); if (lanes != 1) { diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 96f231e31cf4..24f3e19356d4 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -100,7 +100,7 @@ class GPUCodeVerifier : public StmtVisitor { visited_shared_buffers_.insert(op->node.as()); } } else if (op->attr_key == attr::thread_extent) { - VarExpr var = op->node.as()->var; + Var var = op->node.as()->var; const auto *extent = op->value.as(); CHECK(extent); @@ -169,7 +169,7 @@ class GPUCodeVerifier : public StmtVisitor { }; bool VerifyGPUCode(Stmt stmt, - Map constraints) { + Map constraints) { GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; diff --git a/src/pass/verify_memory.cc b/src/pass/verify_memory.cc index 25e7258813f8..899e9bc3c435 100644 --- a/src/pass/verify_memory.cc +++ b/src/pass/verify_memory.cc @@ -65,7 +65,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { protected: /// Visitor implementation //@{ - void VisitExpr(const Expr &n) final { + void VisitExpr(const PrimExpr &n) final { if (Failed()) return; StmtExprVisitor::VisitExpr(n); } @@ -130,7 +130,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } /// Handle memory access to a Variable - void HandleLoadStoreToVariable(const VarExpr &var) { + void HandleLoadStoreToVariable(const Var &var) { // We skip the access within thread env. if (InThreadEnv()) return; @@ -181,7 +181,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { //@} LoweredFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type - std::unordered_map defs_; ///< Variable definitions + std::unordered_map defs_; ///< Variable definitions }; } // namespace diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 69731ea9a02a..1883558f50f5 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -83,7 +83,7 @@ struct GraphCodegen { std::unordered_map GetParams() { std::unordered_map ret; - auto names = CallFunc >("list_params_name", nullptr); + auto names = CallFunc >("list_params_name", nullptr); for (auto expr : names) { auto key = expr.as()->value; ret[key] = CallFunc("get_param_by_name", key); @@ -190,8 +190,8 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return Array names of params */ - Array ListParamNames() { - Array ret; + Array ListParamNames() { + Array ret; for (const auto& kv : params_) { ret.push_back(ir::StringImmNode::make(kv.first)); } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 6c511aeda73c..62de1c36fc45 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -203,7 +203,7 @@ class ScheduleGetter : return make_const(dtype, static_cast(data)[0]); } else { LOG(FATAL) << "not handled"; - return tvm::Expr(); + return tvm::PrimExpr(); } }, "compile_engine_const", topi::kBroadcast); scalars_.push_back(value->op); @@ -479,7 +479,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { return make_const(dtype, static_cast(data)[0]); } else { LOG(FATAL) << "not handled"; - return tvm::Expr(); + return tvm::PrimExpr(); } }, "data_const", topi::kBroadcast); scalars_.push_back(value); diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 618e13522902..3ff72b3cc086 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -641,9 +641,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Array ret; + Array ret; for (const auto &kv : this->output_.params) { - tvm::Expr name = ir::StringImmNode::make(kv.first); + tvm::PrimExpr name = ir::StringImmNode::make(kv.first); ret.push_back(name); } *rv = ret; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7946ee66007c..5d262a09a84d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -50,7 +50,7 @@ namespace transform { Pass LambdaLift(); Pass InlinePrimitives(); -Pass RemoveUnusedFunctions(Array entry_functions); +Pass RemoveUnusedFunctions(Array entry_functions); Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); @@ -875,7 +875,7 @@ void VMCompiler::Lower(Module mod, Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { Array pass_seqs; - Array entry_functions{tvm::Expr{"main"}}; + Array entry_functions{tvm::PrimExpr{"main"}}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index cea611579f4d..5de2e9283f4a 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -100,7 +100,7 @@ struct CallTracer : ExprVisitor { * \return The module with dead functions removed. */ Module RemoveUnusedFunctions(const Module& module, - Array entry_funcs) { + Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { auto* str_name = entry.as(); @@ -121,7 +121,7 @@ Module RemoveUnusedFunctions(const Module& module, namespace transform { -Pass RemoveUnusedFunctions(Array entry_functions) { +Pass RemoveUnusedFunctions(Array entry_functions) { runtime::TypedPackedFunc pass_func = [=](Module m, PassContext pc) { return relay::vm::RemoveUnusedFunctions(m, entry_functions); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index a6f44cee6b16..f6ebadf477eb 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) TensorType ConstantNode::tensor_type() const { auto dtype = DataType(data->dtype); - Array shape; + Array shape; for (int i = 0; i < data->ndim; i++) { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index d179d7ebb849..3bc72fd22d50 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -127,7 +127,7 @@ class RelayHashHandler: using AttrsHashHandler::VisitAttr_; size_t VisitAttr_(const tvm::VarNode* var) final { size_t hash = std::hash()(VarNode::_type_key); - auto it = hash_map_.find(GetRef(var)); + auto it = hash_map_.find(GetRef(var)); if (it != hash_map_.end()) { return it->second; } diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index b83750f124d4..b888ecbd9241 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -137,10 +137,10 @@ void OpRegistry::UpdateAttr(const std::string& key, // Frontend APIs TVM_REGISTER_GLOBAL("relay.op._ListOpNames") .set_body_typed([]() { - Array ret; + Array ret; for (const std::string& name : dmlc::Registry::ListAllNames()) { - ret.push_back(tvm::Expr(name)); + ret.push_back(tvm::PrimExpr(name)); } return ret; }); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index a8f2e8618a75..68aa77bd18bb 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -211,7 +211,7 @@ bool BinaryDenseRel(const Array& types, int num_inputs, const Attrs& attrs CHECK(static_cast(data->shape.size()) != 0); CHECK(param->units.defined()); - Array oshape = data->shape; + Array oshape = data->shape; oshape.Set((oshape.size() - 1), param->units); DataType out_dtype = param->out_dtype; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index aeb40fdbe6da..35d6cba2747c 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -136,7 +136,7 @@ bool FIFOBufferRel(const Array& types, } reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]); - Array oshape = buffer->shape; + Array oshape = buffer->shape; reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype)); return true; @@ -877,7 +877,7 @@ bool BatchMatmulRel(const Array& types, << " x shape=" << x->shape << ", y shape=" << y->shape; - Array oshape = x->shape; + Array oshape = x->shape; oshape.Set(2, y->shape[1]); // assign output type diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 110e4353982e..1b27dea28825 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -42,9 +42,9 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(static_cast(data->shape.size()) != 0); - Array oshape = data->shape; + Array oshape = data->shape; if (param->units.defined()) { - Array dshape = data->shape; + Array dshape = data->shape; // validate the weight shape is proper if defined // Assign weight type Array wshape({param->units, dshape[dshape.size() - 1]}); @@ -56,7 +56,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.Set((oshape.size() - 1), param->units); } else { if (weight == nullptr) return false; - Array wshape = weight->shape; + Array wshape = weight->shape; oshape.Set((oshape.size() - 1), wshape[0]); } diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index f9d753fca04a..72ea70f82eb8 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -52,7 +52,7 @@ Array > PadInferCorrectLayout( // split. // 1) Create a map from axis to param_width using old layout. - std::map> axis_pad_width; + std::map> axis_pad_width; int index_counter = 0; CHECK_EQ(new_in_layouts.size(), 1); CHECK_EQ(old_in_layouts.size(), 1); @@ -63,7 +63,7 @@ Array > PadInferCorrectLayout( } // 2) Create new pad width by walking over the new layout and using the map. - tvm::Array> new_pad_width; + tvm::Array> new_pad_width; for (auto iter_var : new_in_layouts[0]->axes) { const auto& new_layout_axis = LayoutAxis::Get(iter_var); auto axis_name = new_layout_axis.name(); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b8ee7e762acf..c9d824d70b83 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1223,7 +1223,7 @@ inline Tensor DynamicArange(const tvm::Tensor& start, tvm::DataType dtype, std::string name = "tensor", std::string tag = topi::kInjective) { - tvm::Expr num_elem = tvm::Var("num_elem"); + tvm::PrimExpr num_elem = tvm::Var("num_elem"); return tvm::compute({num_elem}, [&](const Array& indices) { return tvm::cast(dtype, start[0] + step[0] * indices[0]); }, name, tag); @@ -1237,7 +1237,7 @@ Array ArangeCompute(const Attrs& attrs, Tensor start = inputs[0]; Tensor stop = inputs[1]; Tensor step = inputs[2]; - Array empty = {0}; + Array empty = {0}; return { DynamicArange(start, stop, step, param->dtype) }; } diff --git a/src/relay/pass/infer_layout_util.h b/src/relay/pass/infer_layout_util.h index c5202b597991..b2cef6c12f70 100644 --- a/src/relay/pass/infer_layout_util.h +++ b/src/relay/pass/infer_layout_util.h @@ -44,7 +44,7 @@ namespace relay { * \return The adjusted Layout. */ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout, - const Array& old_shape) { + const Array& old_shape) { // For each subordinate axis // 1) Find the corresponding dual axis. // 2) Find the Index of this dual axis in old_layout. diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index fea463d24099..f5c65e57cddd 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -269,7 +269,7 @@ class SequentialNode : public PassNode { PassInfo PassInfoNode::make(int opt_level, std::string name, - tvm::Array required) { + tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -370,7 +370,7 @@ void SequentialNode::ResolveDependency(const Module& mod) { } // linearly scan the pass array to match pass_name -inline bool PassArrayContains(const Array& pass_array, +inline bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { auto* str_name = x.as(); @@ -428,7 +428,7 @@ Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfoNode::make(opt_level, name, required); return ModulePassNode::make(pass_func, pass_info); } @@ -437,7 +437,7 @@ Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfoNode::make(opt_level, name, required); return FunctionPassNode::make(pass_func, pass_info); } @@ -507,7 +507,7 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential") tvm::Array passes = args[0]; int opt_level = args[1]; std::string name = args[2]; - tvm::Array required = args[3]; + tvm::Array required = args[3]; PassInfo pass_info = PassInfoNode::make(opt_level, name, required); *ret = Sequential(passes, pass_info); }); @@ -533,8 +533,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext") auto pctx = PassContext::Create(); int opt_level = args[0]; int fallback_device = args[1]; - tvm::Array required = args[2]; - tvm::Array disabled = args[3]; + tvm::Array required = args[2]; + tvm::Array disabled = args[3]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; pctx->required_pass = std::move(required); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 1f47b20dbbe3..c62520a53e84 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -202,7 +202,7 @@ class TypeSolver::Unifier : public TypeFunctor { return ulhs; } - return tvm::Expr(); + return tvm::PrimExpr(); } Type VisitType_(const TensorTypeNode* op, const Type& tn) final { @@ -243,8 +243,8 @@ class TypeSolver::Unifier : public TypeFunctor { if (!dim.defined()) { // NB: We push an arbitrary dimension here so we can continue error propogation. shape.push_back(tt1->shape[i]); - tvm::Expr shape1 = tt1->shape[i]; - tvm::Expr shape2 = tt2->shape[i]; + tvm::PrimExpr shape1 = tt1->shape[i]; + tvm::PrimExpr shape2 = tt2->shape[i]; std::tuple tuple = std::make_tuple(i, shape1, shape2); mismatches.push_back(tuple); } else { diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 6579c3d26aa4..e01a47dc5f1b 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -50,7 +50,7 @@ bool DequantizeRel(const Array& types, CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - const Array oshape = data->shape; + const Array oshape = data->shape; // assign output type, output will always be float 32. reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32))); return true; diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index f6133b8736de..f53d2c5ee438 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -57,7 +57,7 @@ bool QuantizeRel(const Array& types, AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point - const Array oshape = data->shape; + const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || out_dtype == DataType::Int(32)) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index a8a919589839..2686965e7b62 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -190,7 +190,7 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point - const Array oshape = data->shape; + const Array oshape = data->shape; // assign output type auto out_dtype = requantize_attrs->out_dtype; CHECK(out_dtype == DataType::Int(8) || diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index e4b5cf1134d7..378a5e3728f4 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -94,7 +94,7 @@ static inline Expr Requantize(const Expr& data, const Array& input_sh attrs.operator->(), input_shape, out_dtype); } -static inline int64_t get_const_int(const tvm::Expr& x) { +static inline int64_t get_const_int(const tvm::PrimExpr& x) { auto* value_ptr = as_const_int(x); CHECK(value_ptr) << "Expr is not a constant int"; return value_ptr[0]; diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index 6c8df8bd9080..3e3292303e9d 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -33,13 +33,13 @@ class ElemWiseDetector : public ir::ExprVisitor { public: explicit ElemWiseDetector(Array axis) : axis_(axis) {} - void VisitExpr(const Expr& e) final { + void VisitExpr(const PrimExpr& e) final { if (!is_elem_wise_) return; ExprVisitor::VisitExpr(e); } void VisitExpr_(const CallNode* op) final { - Array axis = op->args; + Array axis = op->args; if (axis_.size() != axis.size()) { is_elem_wise_ = false; return; diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 82ee8ff68413..3cf30f4ea7e2 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -300,7 +300,7 @@ Array ScanGetBody(const Operation& scan_op) { return GetSubGraph(scan->update, inputs, false); } -Map ScanFixPointAnalysis(const Operation& scan_op) { +Map ScanFixPointAnalysis(const Operation& scan_op) { const ScanOpNode* scan = scan_op.as(); Array body = ScanGetBody(scan_op); @@ -377,7 +377,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } ReachGraph reach; - Map ret; + Map ret; std::unordered_set place_holder_ref; for (size_t i = 0; i < scan->state_placeholder.size(); ++i) { for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) { diff --git a/src/schedule/graph.h b/src/schedule/graph.h index d596335bbbc1..99ba6e38580a 100644 --- a/src/schedule/graph.h +++ b/src/schedule/graph.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -123,7 +123,7 @@ Array ScanGetBody(const Operation& scan_op); * \param scan The scan node. * \return Map of spatial_axis -> IntImm */ -Map ScanFixPointAnalysis(const Operation& scan); +Map ScanFixPointAnalysis(const Operation& scan); } // namespace schedule } // namespace tvm diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index d08b4bec5849..869e3051d39f 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -55,7 +55,7 @@ void PassDownDomain(const Stage& stage, std::unordered_map* p_state, arith::Analyzer* actx, bool allow_missing) { - auto ceil_div = [actx](Expr a, Expr b) { + auto ceil_div = [actx](PrimExpr a, PrimExpr b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); } @@ -118,7 +118,7 @@ void PassDownDomain(const Stage& stage, void PassUpIndex(const Stage& stage, const Map& dom_map, - std::unordered_map* p_state, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { @@ -128,10 +128,10 @@ void PassUpIndex(const Stage& stage, CHECK(allow_missing); continue; } - Expr outer = state.at(s->outer); - Expr inner = state.at(s->inner); - Expr factor = dom_map.at(s->inner)->extent; - Expr parent_min = dom_map.at(s->parent)->min; + PrimExpr outer = state.at(s->outer); + PrimExpr inner = state.at(s->inner); + PrimExpr factor = dom_map.at(s->inner)->extent; + PrimExpr parent_min = dom_map.at(s->parent)->min; state[s->parent] = inner + outer * factor; // add min if they exist if (!is_zero(parent_min)) { @@ -142,10 +142,10 @@ void PassUpIndex(const Stage& stage, CHECK(allow_missing); continue; } - Expr value = state.at(s->fused); - Expr factor = dom_map.at(s->inner)->extent; - Expr outer_min = dom_map.at(s->outer)->min; - Expr inner_min = dom_map.at(s->inner)->min; + PrimExpr value = state.at(s->fused); + PrimExpr factor = dom_map.at(s->inner)->extent; + PrimExpr outer_min = dom_map.at(s->outer)->min; + PrimExpr inner_min = dom_map.at(s->inner)->min; state[s->outer] = indexdiv(value, factor); state[s->inner] = indexmod(value, factor); // add min if they exist @@ -160,8 +160,8 @@ void PassUpIndex(const Stage& stage, CHECK(allow_missing); continue; } - Expr value = state.at(s->rebased); - Expr parent_min = dom_map.at(s->parent)->min; + PrimExpr value = state.at(s->rebased); + PrimExpr parent_min = dom_map.at(s->parent)->min; // add min if they exist if (!is_zero(parent_min)) { state[s->parent] = value + parent_min; @@ -177,7 +177,7 @@ void PassUpIndex(const Stage& stage, void PassDownIndex(const Stage& stage, const Map& dom_map, - std::unordered_map* p_state, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (IterVarRelation rel : stage->relations) { @@ -188,8 +188,8 @@ void PassDownIndex(const Stage& stage, } Range r = dom_map.at(s->inner); CHECK(is_zero(r->min)); - Expr parent = state.at(s->parent); - Expr factor = r->extent; + PrimExpr parent = state.at(s->parent); + PrimExpr factor = r->extent; state[s->outer] = indexdiv(parent, factor); state[s->inner] = indexmod(parent, factor); } else if (const FuseNode* s = rel.as()) { @@ -197,11 +197,11 @@ void PassDownIndex(const Stage& stage, CHECK(allow_missing); continue; } - Expr factor = dom_map.at(s->inner)->extent; - Expr outer_min = dom_map.at(s->outer)->min; - Expr inner_min = dom_map.at(s->inner)->min; - Expr inner = state.at(s->inner); - Expr outer = state.at(s->outer); + PrimExpr factor = dom_map.at(s->inner)->extent; + PrimExpr outer_min = dom_map.at(s->outer)->min; + PrimExpr inner_min = dom_map.at(s->inner)->min; + PrimExpr inner = state.at(s->inner); + PrimExpr outer = state.at(s->outer); CHECK(is_zero(outer_min)); CHECK(is_zero(inner_min)); state[s->fused] = outer * factor + inner; @@ -210,8 +210,8 @@ void PassDownIndex(const Stage& stage, CHECK(allow_missing); continue; } - Expr value = state.at(s->parent); - Expr parent_min = dom_map.at(s->parent)->min; + PrimExpr value = state.at(s->parent); + PrimExpr parent_min = dom_map.at(s->parent)->min; CHECK(is_zero(parent_min)); state[s->rebased] = value; } else if (const SingletonNode* s = rel.as()) { @@ -236,8 +236,8 @@ void PassUpDomain(const SplitNode* s, *parent = IntSet::range(dom_map.at(s->parent)); return; } - Expr factor = dom_map.at(s->inner)->extent; - Expr parent_min = dom_map.at(s->parent)->min; + PrimExpr factor = dom_map.at(s->inner)->extent; + PrimExpr parent_min = dom_map.at(s->parent)->min; CHECK(outer.defined()); CHECK(inner.defined()); CHECK(factor.defined()); @@ -260,21 +260,21 @@ void PassUpDomain(const FuseNode* s, *inner = IntSet::range(dom_map.at(s->inner)); return; } - Expr outer_min = dom_map.at(s->outer)->min; - Expr inner_min = dom_map.at(s->inner)->min; + PrimExpr outer_min = dom_map.at(s->outer)->min; + PrimExpr inner_min = dom_map.at(s->inner)->min; if (fused.is_single_point()) { - Expr value = fused.point_value(); - Expr factor = dom_map.at(s->inner)->extent; - Expr v_outer = indexdiv(value, factor); - Expr v_inner = indexmod(value, factor); + PrimExpr value = fused.point_value(); + PrimExpr factor = dom_map.at(s->inner)->extent; + PrimExpr v_outer = indexdiv(value, factor); + PrimExpr v_inner = indexmod(value, factor); if (!is_zero(outer_min)) v_outer = v_outer + outer_min; if (!is_zero(inner_min)) v_inner = v_inner + inner_min; *outer = IntSet::single_point(v_outer); *inner = IntSet::single_point(v_inner); } else { - Expr fused_extent = (fused.max() - fused.min() + 1); - Expr inner_extent = dom_map.at(s->inner)->extent; + PrimExpr fused_extent = (fused.max() - fused.min() + 1); + PrimExpr inner_extent = dom_map.at(s->inner)->extent; *outer = IntSet::interval( outer_min + indexdiv(fused.min(), inner_extent), outer_min + indexdiv(fused.max(), inner_extent)); @@ -305,7 +305,7 @@ void PassUpDomain(const RebaseNode* s, *parent = IntSet::range(dom_map.at(s->parent)); return; } - Expr parent_min = dom_map.at(s->parent)->min; + PrimExpr parent_min = dom_map.at(s->parent)->min; *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}}); } @@ -458,8 +458,8 @@ void PassUpBoundCheck(const Stage& s, bool inner = state.at(s->inner); if (dom_map.count(s->inner) && dom_map.count(s->outer)) { - Expr factor = dom_map.at(s->inner)->extent; - Expr step = dom_map.at(s->outer)->extent; + PrimExpr factor = dom_map.at(s->inner)->extent; + PrimExpr step = dom_map.at(s->outer)->extent; if (outer || inner) { state[s->parent] = true; } else { @@ -486,10 +486,10 @@ void PassUpBoundCheck(const Stage& s, } } -std::vector MakeBoundCheck( +std::vector MakeBoundCheck( const Stage& stage, const Map& dom_map, - const std::unordered_map& value_map, + const std::unordered_map& value_map, bool skip_ivar_domain, const std::unordered_set& skip_iter) { arith::Analyzer analyzer; @@ -500,7 +500,7 @@ std::vector MakeBoundCheck( } PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); - std::vector preds; + std::vector preds; std::unordered_map iset_dmap; // setup domain map for set analysis @@ -512,8 +512,8 @@ std::vector MakeBoundCheck( if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; if (bound_state.at(iv)) { Range dom = dom_map.at(iv); - Expr value = value_map.at(iv) - dom->min; - Expr vmax = EvalSet(value, iset_dmap).max(); + PrimExpr value = value_map.at(iv) - dom->min; + PrimExpr vmax = EvalSet(value, iset_dmap).max(); if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } @@ -524,10 +524,10 @@ std::vector MakeBoundCheck( Range dom = dom_map.at(iv); CHECK(iv->dom.defined()); if (!skip_ivar_domain && !iv->dom.same_as(dom)) { - Expr value = value_map.at(iv) - iv->dom->min; + PrimExpr value = value_map.at(iv) - iv->dom->min; IntSet s = EvalSet(value, iset_dmap); - Expr vmin = s.min(); - Expr vmax = s.max(); + PrimExpr vmin = s.min(); + PrimExpr vmax = s.max(); // The range of `value` resides in [vmin, vmax] if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) { preds.emplace_back(value >= 0); diff --git a/src/schedule/message_passing.h b/src/schedule/message_passing.h index f46f305b0f56..66615890a50c 100644 --- a/src/schedule/message_passing.h +++ b/src/schedule/message_passing.h @@ -62,7 +62,7 @@ void PassDownDomain( */ void PassUpIndex(const Stage& stage, const Map& dom_map, - std::unordered_map* p_state, + std::unordered_map* p_state, bool allow_missing = false); /*! @@ -76,7 +76,7 @@ void PassUpIndex(const Stage& stage, */ void PassDownIndex(const Stage& stage, const Map& dom_map, - std::unordered_map* p_state, + std::unordered_map* p_state, bool allow_missing = false); /*! @@ -120,11 +120,11 @@ void PassDownBitMaskOr(const Stage& stage, * \param skip_iter The set of variables to skip bound condition. * \return List of predicates that we need to check. */ -std::vector +std::vector MakeBoundCheck( const Stage& stage, const Map& dom_map, - const std::unordered_map& value_map, + const std::unordered_map& value_map, bool skip_ivar_domain, const std::unordered_set& skip_iter); diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index a6500caa33b9..3bad33811d78 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -45,20 +45,20 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { class VarReplacer : public ir::StmtExprMutator { public: explicit VarReplacer( - const std::unordered_map& vsub) + const std::unordered_map& vsub) : vsub_(vsub) {} - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; - return GetRef(op); + return GetRef(op); } ir::CommReducer MutateCommReducer(ir::CommReducer combiner) { // Replace free variables in combiner - auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) { + auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) { return this->VisitExpr(e); }); - auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) { + auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) { return this->VisitExpr(e); }); @@ -71,8 +71,8 @@ class VarReplacer : public ir::StmtExprMutator { } } - Expr VisitExpr_(const ir::ReduceNode* op) final { - Expr new_e = StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const ir::ReduceNode* op) final { + PrimExpr new_e = StmtExprMutator::VisitExpr_(op); const ir::ReduceNode* new_reduce = new_e.as(); ir::CommReducer new_combiner = MutateCommReducer(op->combiner); if (op->combiner.same_as(new_combiner)) { @@ -88,21 +88,21 @@ class VarReplacer : public ir::StmtExprMutator { } private: - const std::unordered_map& vsub_; + const std::unordered_map& vsub_; }; -Expr InjectPredicate(const Array& predicates, - Expr body) { +PrimExpr InjectPredicate(const Array& predicates, + PrimExpr body) { using ir::ReduceNode; using ir::SelectNode; if (predicates.size() == 0) return body; const ReduceNode* reduce = body.as(); if (reduce) { auto n = make_object(*reduce); - n->condition = n->condition && arith::ComputeReduce(predicates, Expr()); - return Expr(n); + n->condition = n->condition && arith::ComputeReduce(predicates, PrimExpr()); + return PrimExpr(n); } - return SelectNode::make(arith::ComputeReduce(predicates, Expr()), + return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), body, make_zero(body.dtype())); } @@ -153,7 +153,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, Stage s = operator[](tensor->op); Tensor sugar_tensor = s->op.output(tensor->value_index); Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array& i) { - return sugar_tensor(Array(i.begin(), i.end())); + return sugar_tensor(Array(i.begin(), i.end())); }, os.str()); vsub[sugar_tensor] = cache; @@ -193,9 +193,9 @@ void PrepareAxisMapping(Stage orig_stage, std::unordered_set* p_red_axis, Array* p_new_axis, std::unordered_map* p_dom_map, - std::unordered_map* p_vsub, - std::unordered_map* p_vsub2newvar, - std::vector* p_predicates) { + std::unordered_map* p_vsub, + std::unordered_map* p_vsub2newvar, + std::vector* p_predicates) { auto& red_axis = *p_red_axis; auto& new_axis = *p_new_axis; auto& dom_map = *p_dom_map; @@ -214,7 +214,7 @@ void PrepareAxisMapping(Stage orig_stage, schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true); { // The source->cache - std::unordered_map value_map; + std::unordered_map value_map; for (IterVar iv : orig_stage->leaf_iter_vars) { if (red_axis.count(iv)) continue; CHECK_EQ(iv->iter_type, kDataPar) @@ -305,15 +305,15 @@ Array CacheWriteWithReLayout(Schedule sch, Array new_axis; std::unordered_map dom_map; - std::unordered_map vsub; - std::unordered_map vsub2newvar; - std::vector predicates; + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); - Expr body; - Array body_list; + PrimExpr body; + Array body_list; const ir::ReduceNode* first_reduce = nullptr; for (auto cbody : compute->body) { body = VarReplacer(vsub)(cbody); @@ -340,10 +340,10 @@ Array CacheWriteWithReLayout(Schedule sch, body_list.push_back(body); } // The reader args - Array args; + Array args; { // cache->compute - std::unordered_map value_map; + std::unordered_map value_map; for (IterVar iv : compute->axis) { value_map[iv] = iv->var; } @@ -357,7 +357,7 @@ Array CacheWriteWithReLayout(Schedule sch, compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); - Array cache_expr_list; + Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); @@ -386,9 +386,9 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Array new_axis; std::unordered_map dom_map; - std::unordered_map vsub; - std::unordered_map vsub2newvar; - std::vector predicates; + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); @@ -404,15 +404,15 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, for (Region old_region : tensor_op->input_regions) { Region region; for (Range r : old_region) { - Expr min = VarReplacer(vsub2newvar)(r->min); - Expr extent = VarReplacer(vsub2newvar)(r->extent); + PrimExpr min = VarReplacer(vsub2newvar)(r->min); + PrimExpr extent = VarReplacer(vsub2newvar)(r->extent); region.push_back(Range::make_by_min_extent(min, extent)); } new_regions.push_back(region); } - Array new_scalar_inputs; - for (Expr old_input : tensor_op->scalar_inputs) { + Array new_scalar_inputs; + for (PrimExpr old_input : tensor_op->scalar_inputs) { new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } @@ -430,10 +430,10 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, } // The reader args - Array args; + Array args; { // cache->compute - std::unordered_map value_map; + std::unordered_map value_map; for (IterVar iv : compute_axis) { value_map[iv] = iv->var; } @@ -449,7 +449,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, } } - Array cache_expr_list; + Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); @@ -542,7 +542,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { void InjectInline(ScheduleNode* sch) { sch->InvalidateCache(); - std::vector > new_body(sch->stages.size()); + std::vector > new_body(sch->stages.size()); std::vector changed(sch->stages.size(), false); std::vector new_hybrid_body(sch->stages.size()); std::vector hybrid_changed(sch->stages.size(), false); @@ -552,7 +552,7 @@ void InjectInline(ScheduleNode* sch) { if (stage->attach_type == kInline) { stage->attach_type = kInlinedAlready; Array args; - Expr body; + PrimExpr body; { // setup args const ComputeOpNode* compute = stage->op.as(); @@ -583,7 +583,7 @@ void InjectInline(ScheduleNode* sch) { << "The Reduce inputs of ComputeOp should " << "have the same attribute except value_index"; } - Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]), + PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]), stage->op, args, body).as()->value; if (!new_value.same_as(new_body[j][0])) { changed[j] = true; @@ -594,12 +594,12 @@ void InjectInline(ScheduleNode* sch) { auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); - new_body[j].Set(k, Expr(n)); + new_body[j].Set(k, PrimExpr(n)); } } } else { for (size_t k = 0; k < new_body[j].size(); ++k) { - Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]), + PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]), stage->op, args, body).as()->value; if (!new_value.same_as(new_body[j][k])) { new_body[j].Set(k, new_value); @@ -706,7 +706,7 @@ Array Schedule::rfactor(const Tensor& tensor, arith::Analyzer analyzer; // Get the replace index std::unordered_map dom_map; - std::unordered_map value_map; + std::unordered_map value_map; for (IterVar iv : compute_op->reduce_axis) { if (touch_map.count(iv)) { dom_map[iv] = iv->dom; @@ -727,7 +727,7 @@ Array Schedule::rfactor(const Tensor& tensor, } } schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true); - std::vector predicates = schedule::MakeBoundCheck( + std::vector predicates = schedule::MakeBoundCheck( reduce_stage, dom_map, value_map, true, skip_bound_check); // Get the factored op node. @@ -761,16 +761,16 @@ Array Schedule::rfactor(const Tensor& tensor, const ReduceNode* reduce = compute_op->body[idx].as(); CHECK(reduce) << "Can only rfactor non-inline reductions"; predicates.push_back(reduce->condition); - Expr predicate = likely(arith::ComputeReduce(predicates, Expr())); + PrimExpr predicate = likely(arith::ComputeReduce(predicates, PrimExpr())); - std::unordered_map vsub; + std::unordered_map vsub; for (IterVar iv : compute_op->reduce_axis) { if (!touch_map.count(iv)) { n->reduce_axis.push_back(iv); } else { CHECK(value_map.count(iv)); - Expr index = value_map.at(iv); + PrimExpr index = value_map.at(iv); vsub[iv->var.get()] = index; } } @@ -785,12 +785,12 @@ Array Schedule::rfactor(const Tensor& tensor, } } VarReplacer replacer(vsub); - Array new_source = ir::UpdateArray(reduce->source, - [&replacer] (const Expr& e) { return replacer(e); }); + Array new_source = ir::UpdateArray(reduce->source, + [&replacer] (const PrimExpr& e) { return replacer(e); }); - Expr new_pred = replacer(predicate); + PrimExpr new_pred = replacer(predicate); - std::vector body; + std::vector body; for (size_t idx = 0; idx < reduce->source.size(); ++idx) { body.emplace_back(ReduceNode::make(reduce->combiner, new_source, @@ -798,7 +798,7 @@ Array Schedule::rfactor(const Tensor& tensor, new_pred, idx)); } - n->body = Array(body); + n->body = Array(body); // refresh relations, keep the un-touched relations. Array rels; for (IterVarRelation rel : reduce_stage->relations) { @@ -842,7 +842,7 @@ Array Schedule::rfactor(const Tensor& tensor, } Array repl_tensors = compute(old_tensors[0]->shape, [&](const Array& i) { - Array indices; + Array indices; const int idx_size = static_cast(i.size()); for (int idx = 0; idx < idx_size; ++idx) { if (factor_axis_pos == idx) { @@ -853,13 +853,13 @@ Array Schedule::rfactor(const Tensor& tensor, if (factor_axis_pos == idx_size) { indices.push_back(repl_red_axis->var); } - Array factor_exprs; + Array factor_exprs; for (int idx = 0; idx < size; ++idx) { factor_exprs.push_back(factor_tensors[idx](indices)); } - Array reductions; + Array reductions; Array axis = {repl_red_axis}; - Expr cond = const_true(); + PrimExpr cond = const_true(); for (int idx = 0; idx < size; ++idx) { reductions.push_back(ReduceNode::make(reduce->combiner, factor_exprs, axis, cond, idx)); diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index a53c1ae79c64..fe56b662dfcd 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -55,8 +55,8 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) void Split(StageNode* self, IterVar parent, - Expr factor, - Expr nparts, + PrimExpr factor, + PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. @@ -217,21 +217,21 @@ Stage& Stage::env_threads(Array threads) { return *this; } -Stage& Stage::set_store_predicate(Expr predicate) { +Stage& Stage::set_store_predicate(PrimExpr predicate) { StageNode* self = operator->(); self->store_predicate = predicate; return *this; } Stage& Stage::split( - IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, factor, Expr(), p_outer, p_inner); + IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) + Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); return *this; } Stage& Stage::split_by_nparts( - IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, Expr(), nparts, p_outer, p_inner); + IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) + Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); return *this; } @@ -332,7 +332,7 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) } Stage& Stage::tile(IterVar x_parent, IterVar y_parent, - Expr x_factor, Expr y_factor, + PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) { split(x_parent, x_factor, p_x_outer, p_x_inner); @@ -400,7 +400,7 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*) Stage& Stage::pragma(IterVar var, const std::string& pragma_type, - const Expr& pragma_value) { // NOLINT(*) + const PrimExpr& pragma_value) { // NOLINT(*) if (pragma_type == "unroll") { this->unroll(var); } else if (pragma_type == "vectorize") { @@ -415,7 +415,7 @@ Stage& Stage::pragma(IterVar var, return *this; } -Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) { +Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) { StageNode *self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -756,8 +756,8 @@ Schedule ScheduleNode::make(Array ops) { IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, - Expr factor, - Expr nparts) { + PrimExpr factor, + PrimExpr nparts) { auto n = make_object(); n->parent = parent; n->outer = outer; diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 38174dfdf50d..1eb595c852c0 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -277,13 +277,13 @@ class SchedulePostProc : public StmtExprMutator { } } - Expr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->call_type == CallNode::Halide) { TensorKey key{op->func, op->value_index}; auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - Expr ret = CallNode::make( + PrimExpr ret = CallNode::make( op->dtype, dst->op->name, op->args, op->call_type, dst->op, dst->value_index); return this->VisitExpr(ret); @@ -292,12 +292,12 @@ class SchedulePostProc : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } - Expr VisitExpr_(const VarNode* op) final { + PrimExpr VisitExpr_(const VarNode* op) final { auto it = var_value_.find(op); if (it != var_value_.end()) { return it->second; } else { - return GetRef(op); + return GetRef(op); } } @@ -343,9 +343,9 @@ class SchedulePostProc : public StmtExprMutator { replace_op_[src->op.get()] = repl_op; } // The thread extent scope. - std::unordered_map thread_extent_scope_; + std::unordered_map thread_extent_scope_; // The scan value - std::unordered_map var_value_; + std::unordered_map var_value_; // buffer replacement std::unordered_map replace_buffer_; // buffere realization to be replaced diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index a6010c357922..9a242578d456 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -28,7 +28,7 @@ namespace test { struct TestAttrs : public AttrsNode { int axis; std::string name; - Expr expr; + PrimExpr expr; double learning_rate; TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") { @@ -70,10 +70,10 @@ TEST(Attrs, Basic) { LOG(FATAL) << "bad"; } catch (const tvm::AttrError& e) { std::string what = e.what(); - CHECK(what.find("expr : Expr, default=1") != std::string::npos); + CHECK(what.find("expr : PrimExpr, default=1") != std::string::npos); CHECK(what.find("axisx") != std::string::npos); } - n->InitBySeq("learning_rate", Expr(1), "expr", 128, "name", "xx"); + n->InitBySeq("learning_rate", PrimExpr(1), "expr", 128, "name", "xx"); CHECK_EQ(n->learning_rate, 1.0); n->InitBySeq("name", "xxx", "expr", 128); @@ -84,7 +84,7 @@ TEST(Attrs, Basic) { std::ostringstream os; n->PrintDocString(os); LOG(INFO) << "docstring\n"<< os.str(); - CHECK(os.str().find("expr : Expr, default=1") != std::string::npos); + CHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos); } diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 699865754b13..15bdd474f803 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -31,13 +31,13 @@ TEST(BuildModule, Basic) { using namespace tvm; auto n = var("n"); - Array shape; + Array shape; shape.push_back(n); auto A = placeholder(shape, DataType::Float(32), "A"); auto B = placeholder(shape, DataType::Float(32), "B"); - auto C = compute(A->shape, [&A, &B](Expr i) { + auto C = compute(A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "C"); @@ -88,18 +88,18 @@ TEST(BuildModule, Heterogeneous) { // The shape of input tensors. const int n = 4; - Array shape{n}; + Array shape{n}; auto A = placeholder(shape, DataType::Float(32), "A"); auto B = placeholder(shape, DataType::Float(32), "B"); auto C = placeholder(shape, DataType::Float(32), "C"); - auto elemwise_add = compute(A->shape, [&A, &B](Expr i) { + auto elemwise_add = compute(A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add"); auto copy = placeholder(shape, DataType::Float(32), "__copy"); - auto elemwise_sub = compute(C->shape, [©, &C](Expr i) { + auto elemwise_sub = compute(C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub"); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 3d7c35560161..d5d8aae9f687 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -144,11 +144,11 @@ TEST(InplaceArrayBase, ExceptionSafety) { ASSERT_EXIT(correct_init(), ::testing::ExitedWithCode(0), ""); } -TEST(Array, Expr) { +TEST(Array, PrimExpr) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); - Array list{x, z, z}; + Array list{x, z, z}; LOG(INFO) << list.size(); LOG(INFO) << list[0]; LOG(INFO) << list[1]; @@ -158,7 +158,7 @@ TEST(Array, Mutate) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); - Array list{x, z, z}; + Array list{x, z, z}; auto list2 = list; list.Set(1, x); CHECK(list[1].same_as(x)); @@ -167,8 +167,8 @@ TEST(Array, Mutate) { TEST(Array, Iterator) { using namespace tvm; - Array array{1, 2, 3}; - std::vector vector(array.begin(), array.end()); + Array array{1, 2, 3}; + std::vector vector(array.begin(), array.end()); CHECK(vector[1].as()->value == 2); } @@ -177,7 +177,7 @@ TEST(Map, Expr) { Var x("x"); auto z = max(x + 1 + 2, 100); auto zz = z + 1; - Map dict{{x, z}, {z, 2}}; + Map dict{{x, z}, {z, 2}}; CHECK(dict.size() == 2); CHECK(dict[x].same_as(z)); CHECK(dict.count(z)); @@ -188,7 +188,7 @@ TEST(StrMap, Expr) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); - Map dict{{"x", z}, {"z", 2}}; + Map dict{{"x", z}, {"z", 2}}; CHECK(dict.size() == 2); CHECK(dict["x"].same_as(z)); } @@ -197,7 +197,7 @@ TEST(Map, Mutate) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); - Map dict{{x, z}, {z, 2}}; + Map dict{{x, z}, {z, 2}}; auto zz = z + 1; CHECK(dict[x].same_as(z)); dict.Set(x, zz); @@ -218,9 +218,9 @@ TEST(Map, Mutate) { TEST(Map, Iterator) { using namespace tvm; - Expr a = 1, b = 2; - Map map1{{a, b}}; - std::unordered_map + PrimExpr a = 1, b = 2; + Map map1{{a, b}}; + std::unordered_map map2(map1.begin(), map1.end()); CHECK(map2[a].as()->value == 2); } diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 4b6915f7de93..d9b620063f56 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -26,7 +26,7 @@ TEST(Expr, Basic) { Var x("x"); auto z = max(x + 1 + 2, 100); ObjectRef tmp = z; - Expr zz = Downcast(tmp); + PrimExpr zz = Downcast(tmp); std::ostringstream os; os << z; CHECK(zz.same_as(z)); @@ -37,7 +37,7 @@ TEST(Expr, Basic) { TEST(ExprNodeRef, Basic) { using namespace tvm; Var x("x"); - Expr z = max(x + 1 + 2, 100); + PrimExpr z = max(x + 1 + 2, 100); const ir::MaxNode* op = z.as(); CHECK(GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 23a81b9f30ca..178f582b94f1 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -61,7 +61,7 @@ TEST(IRF, ExprTransform) { auto z = x + 1; class MyExprFunctor - : public ir::ExprFunctor { + : public ir::ExprFunctor { public: int VisitExpr_(const VarNode* op, int b) final { return b; @@ -90,7 +90,7 @@ TEST(IRF, ExprVisit) { auto z = x + 1; class MyVisitor - : public ir::ExprFunctor, + : public ir::ExprFunctor, public ir::StmtFunctor { public: int count = 0; @@ -152,13 +152,13 @@ TEST(IRF, StmtMutator) { protected: // implementation - Expr VisitExpr_(const AddNode* op) final { + PrimExpr VisitExpr_(const AddNode* op) final { return op->a; } Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); } - Expr VisitExpr(const Expr& expr) final { + PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); } }; diff --git a/tests/cpp/ir_ssa_test.cc b/tests/cpp/ir_ssa_test.cc index 47cd0000a5ee..d1316dec7121 100644 --- a/tests/cpp/ir_ssa_test.cc +++ b/tests/cpp/ir_ssa_test.cc @@ -26,7 +26,7 @@ TEST(IRSSA, Convert) { using namespace tvm; using namespace tvm::ir; Var x("x"), y; - Expr let = LetNode::make(x, 1, x + 1); + PrimExpr let = LetNode::make(x, 1, x + 1); auto z = EvaluateNode::make(let + let); CHECK(!ir::VerifySSA(z)); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 24ed6d8c3a18..b0e5f2450c4b 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -130,7 +130,7 @@ TEST(PackedFunc, Expr) { using namespace tvm::runtime; // automatic conversion of int to expr PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { - Expr x = args[0]; + PrimExpr x = args[0]; *rv = x.as()->value + 1; }); int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { @@ -218,7 +218,7 @@ TEST(PackedFunc, ObjectConversion) { // Check convert back CHECK(rv.operator NDArray().same_as(x)); CHECK(rv.operator ObjectRef().same_as(x)); - CHECK(!rv.IsObjectRef()); + CHECK(!rv.IsObjectRef()); auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args[0].type_code(), kNDArrayContainer); @@ -228,7 +228,7 @@ TEST(PackedFunc, ObjectConversion) { CHECK(args[1].operator NDArray().get() == nullptr); CHECK(args[1].operator Module().get() == nullptr); CHECK(args[1].operator Array().get() == nullptr); - CHECK(!args[0].IsObjectRef()); + CHECK(!args[0].IsObjectRef()); }); pf1(x, ObjectRef()); pf1(ObjectRef(x), NDArray()); @@ -254,7 +254,7 @@ TEST(PackedFunc, ObjectConversion) { CHECK(args[1].operator ObjectRef().get() == nullptr); CHECK(args[1].operator NDArray().get() == nullptr); CHECK(args[1].operator Module().get() == nullptr); - CHECK(!args[0].IsObjectRef()); + CHECK(!args[0].IsObjectRef()); }); pf2(m, ObjectRef()); pf2(ObjectRef(m), Module()); diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 2b0345422da9..5392eaeac1e8 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -24,7 +24,7 @@ TEST(Pattern, Basic) { using namespace tvm; using namespace tvm::arith; Var x("x"), y("y"), z("z"); - arith::PVar px, py, pz; + arith::PVar px, py, pz; arith::PVar pt; arith::PVar planes; @@ -49,7 +49,7 @@ TEST(Pattern, Basic) { CHECK((px + min(py, px)).Match(z + min(y, z))); CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); - CHECK((px - floormod(py, px * PConst(2))).Match(x - floormod(2, x * 2))); + CHECK((px - floormod(py, px * PConst(2))).Match(x - floormod(2, x * 2))); // logicals CHECK((px == pz).Match(x == 1)); @@ -111,10 +111,10 @@ TEST(Pattern, Basic) { } // ramp pattern { - CHECK(ramp(px, PConst(1), planes).Match( + CHECK(ramp(px, PConst(1), planes).Match( ir::RampNode::make(x, 1, 10))); CHECK(planes.Eval() == 10); - CHECK(!ramp(px, PConst(1), planes).Match( + CHECK(!ramp(px, PConst(1), planes).Match( ir::RampNode::make(x, 2, 10))); } // broadcast pattern diff --git a/tests/cpp/simple_passes_test.cc b/tests/cpp/simple_passes_test.cc index 19c851b3ecda..6333f1549bec 100644 --- a/tests/cpp/simple_passes_test.cc +++ b/tests/cpp/simple_passes_test.cc @@ -25,7 +25,7 @@ TEST(SimplePasses, HasSideEffect) { using namespace tvm; auto n = var("n"); - Array shape; + Array shape; shape.push_back(n); auto A = placeholder(shape, DataType::Float(32), "A"); diff --git a/tests/python/unittest/test_lang_tensor_overload_op.py b/tests/python/unittest/test_lang_tensor_overload_op.py index 1273fb7a1990..16d67715da45 100644 --- a/tests/python/unittest/test_lang_tensor_overload_op.py +++ b/tests/python/unittest/test_lang_tensor_overload_op.py @@ -29,8 +29,8 @@ def test_operator_type_and_tags(): B1 = B[0] B2 = B[0,0] - assert isinstance(k + n, tvm.expr.Expr) - assert isinstance(n + n, tvm.expr.Expr) + assert isinstance(k + n, tvm.expr.PrimExpr) + assert isinstance(n + n, tvm.expr.PrimExpr) assert isinstance(k + A, tvm.tensor.Tensor) assert isinstance(A + k, tvm.tensor.Tensor) assert isinstance(n + A, tvm.tensor.Tensor) @@ -53,11 +53,11 @@ def test_operator_type_and_tags(): assert (B + A).op.tag == topi.tag.BROADCAST assert (B + B).op.tag == topi.tag.BROADCAST - assert isinstance(k + B2, tvm.expr.Expr) - assert isinstance(B2 + k, tvm.expr.Expr) - assert isinstance(n + B2, tvm.expr.Expr) - assert isinstance(B2 + n, tvm.expr.Expr) - assert isinstance(B2 + B2, tvm.expr.Expr) + assert isinstance(k + B2, tvm.expr.PrimExpr) + assert isinstance(B2 + k, tvm.expr.PrimExpr) + assert isinstance(n + B2, tvm.expr.PrimExpr) + assert isinstance(B2 + n, tvm.expr.PrimExpr) + assert isinstance(B2 + B2, tvm.expr.PrimExpr) assert isinstance(B2 + A, tvm.tensor.Tensor) assert isinstance(A + B2, tvm.tensor.Tensor) assert isinstance(B2 + B, tvm.tensor.Tensor) diff --git a/tests/python/unittest/test_pass_lower_intrin.py b/tests/python/unittest/test_pass_lower_intrin.py index d2d106df001e..02f8118d56fc 100644 --- a/tests/python/unittest/test_pass_lower_intrin.py +++ b/tests/python/unittest/test_pass_lower_intrin.py @@ -19,7 +19,7 @@ def lower_intrin(stmt): """wrapper to call transformation in stmt""" - lower_expr = isinstance(stmt, tvm.expr.Expr) + lower_expr = isinstance(stmt, tvm.expr.PrimExpr) stmt = tvm.stmt.Evaluate(stmt) if lower_expr else stmt stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass._LowerIntrinStmt(stmt, "llvm") diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index 542d43e6b869..ce16e23bf4fa 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -44,7 +44,7 @@ namespace topi { * \return A Tensor whose op member is a broadcast operation */ inline tvm::Tensor broadcast_to(const tvm::Tensor& t, - const tvm::Array& output_shape, + const tvm::Array& output_shape, std::string name = "T_broadcast_to", std::string tag = kBroadcast) { CHECK_GE(output_shape.size(), t->shape.size()) @@ -59,38 +59,38 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t, return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; return tvm::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), l, name, tag); } #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ - inline tvm::Expr Name(const tvm::Expr& a, \ - const tvm::Expr& b) { \ + inline tvm::PrimExpr Name(const tvm::PrimExpr& a, \ + const tvm::PrimExpr& b) { \ ComputeRule; \ } \ inline tvm::Tensor Name(const tvm::Tensor& A, \ const tvm::Tensor& B, \ std::string name = "T_" #Name, \ std::string tag = kBroadcast) { \ - auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ return detail::WithBroadcast(l, A, B, name, tag); \ } \ inline tvm::Tensor Name(const tvm::Tensor& A, \ - const tvm::Expr& B, \ + const tvm::PrimExpr& B, \ std::string name = "T_" #Name, \ std::string tag = kElementWise) { \ - auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ return compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \ return l(A(i), B); \ }, name, tag); \ } \ - inline tvm::Tensor Name(const tvm::Expr& A, \ + inline tvm::Tensor Name(const tvm::PrimExpr& A, \ const tvm::Tensor& B, \ std::string name = "T_" #Name, \ std::string tag = kElementWise) { \ - auto l = [&](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \ + auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ return compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \ return l(A, B(i)); \ }, name, tag); \ @@ -102,12 +102,12 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t, const tvm::Tensor& B) { \ return topi::OpName(A, B); \ } \ - inline tvm::Tensor Name(const tvm::Expr& A, \ + inline tvm::Tensor Name(const tvm::PrimExpr& A, \ const tvm::Tensor& B) { \ return topi::OpName(A, B); \ } \ inline tvm::Tensor Name(const tvm::Tensor& A, \ - const tvm::Expr& B) { \ + const tvm::PrimExpr& B) { \ return topi::OpName(A, B); \ } diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index 4dce9a0bc5d1..c64490330f42 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -52,7 +52,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - Expr("tvm.contrib.cublas.matmul"), + PrimExpr("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), @@ -62,7 +62,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, } /*! -* \brief Create an op that multiplies batch matrices +* \brief Create an op that multiplies batch matrices * lhs and rhs with cuBLAS * * \param lhs The left matrix operand @@ -84,7 +84,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - Expr("tvm.contrib.cublas.batch_matmul"), + PrimExpr("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index 6eb6dbc5034a..a61499dc488f 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -51,7 +51,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - Expr("tvm.contrib.rocblas.matmul"), + PrimExpr("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), diff --git a/topi/include/topi/cuda/dense.h b/topi/include/topi/cuda/dense.h index 3bd3e4aba744..781258afa982 100644 --- a/topi/include/topi/cuda/dense.h +++ b/topi/include/topi/cuda/dense.h @@ -119,8 +119,8 @@ inline Schedule schedule_dense(const Target &target, const Array& outs) auto thread_x = tvm::thread_axis(Range(), "threadIdx.x"); s[dense].bind(tx, thread_x); s[dense_f].compute_at(s[dense], tx); - s[dense].set_store_predicate(static_cast(thread_x) == 0); - s[out].set_store_predicate(static_cast(thread_x) == 0); + s[dense].set_store_predicate(static_cast(thread_x) == 0); + s[out].set_store_predicate(static_cast(thread_x) == 0); }; std::function traverse; diff --git a/topi/include/topi/cuda/reduction.h b/topi/include/topi/cuda/reduction.h index 3166d0836247..a82b36306617 100644 --- a/topi/include/topi/cuda/reduction.h +++ b/topi/include/topi/cuda/reduction.h @@ -125,7 +125,7 @@ Schedule ScheduleReduce(const Target& target, } } - stage_real.set_store_predicate(static_cast(thread_x) == 0); + stage_real.set_store_predicate(static_cast(thread_x) == 0); return sch; } diff --git a/topi/include/topi/detail/broadcast.h b/topi/include/topi/detail/broadcast.h index 8c5068a2f35d..2d326e7e1e07 100644 --- a/topi/include/topi/detail/broadcast.h +++ b/topi/include/topi/detail/broadcast.h @@ -37,18 +37,18 @@ namespace topi { namespace detail { struct BroadcastHelper { - std::deque common_shape; + std::deque common_shape; std::deque all_vars; std::deque vars1; std::deque vars2; }; -inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, - const tvm::Array& shape2) { +inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, + const tvm::Array& shape2) { BroadcastHelper bh; int s1_size = shape1.size(); int s2_size = shape2.size(); - tvm::Expr one(1); + tvm::PrimExpr one(1); int i; for (i = 1; i <= std::min(s1_size, s2_size); ++i) { // TODO(@icemelon9): Need to revisit this part @@ -81,9 +81,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } else { CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " << shape2[s2_size - i] << " in: " - << tvm::Array(shape1.begin(), shape1.end()) + << tvm::Array(shape1.begin(), shape1.end()) << " and " - << tvm::Array(shape2.begin(), shape2.end()); + << tvm::Array(shape2.begin(), shape2.end()); } } // Remaining dimensions whether on shape1 or shape2 can always be completed @@ -98,12 +98,12 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, return bh; } -inline tvm::Array InputIndexFromBroadcast( +inline tvm::Array InputIndexFromBroadcast( const tvm::Array& ovars, const tvm::Tensor& T, const std::deque& my_vars, const std::deque& all_vars) { - tvm::Array ivars; + tvm::Array ivars; CHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. size_t expected_dims = T->shape.size(); @@ -138,7 +138,7 @@ inline tvm::Tensor WithBroadcast(FBinaryExpr op, B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; return tvm::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), l, name, tag); diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 00db1fc38705..43ac3a29cd7c 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -41,7 +41,7 @@ using namespace tvm; * * \return true if the given expr is a constant int or uint, false otherwise. */ -inline bool IsConstInt(Expr expr) { +inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance() || expr->IsInstance(); @@ -55,7 +55,7 @@ inline bool IsConstInt(Expr expr) { * * \return The integer value. */ -inline int64_t GetConstInt(Expr expr) { +inline int64_t GetConstInt(PrimExpr expr) { if (expr->IsInstance()) { return expr.as()->value; } @@ -75,11 +75,13 @@ inline int64_t GetConstInt(Expr expr) { * * \return A vector of the integer values */ -inline std::vector GetConstIntValues(Array exprs, const std::string& var_name) { +inline std::vector GetConstIntValues( + Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { - CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers"; + CHECK(IsConstInt(expr)) << "All elements of " + << var_name << " must be constant integers"; result.push_back(GetConstInt(expr)); } return result; @@ -94,7 +96,8 @@ inline std::vector GetConstIntValues(Array exprs, const std::string& * * \return A vector of the int64_t values */ -inline std::vector GetConstInt64Values(Array exprs, const std::string& var_name) { +inline std::vector GetConstInt64Values( + Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { @@ -113,10 +116,10 @@ inline std::vector GetConstInt64Values(Array exprs, const std::st * * \return result True if both expressions are equal, else false */ -inline bool EqualCheck(Expr lhs, Expr rhs) { +inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { bool result = tvm::ir::Equal(lhs, rhs); if (!result) { - Expr zero(0); + PrimExpr zero(0); result = tvm::ir::Equal(tvm::ir::CanonicalSimplify(lhs-rhs), zero); } return result; diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 643b44bb62aa..8bdda802ad7a 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -42,12 +42,12 @@ using namespace tvm; * * \return The Buffer object */ -inline Buffer DeclExternBuffer(Array shape, +inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); - auto elem_offset = Expr(); - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", + auto elem_offset = PrimExpr(); + return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, kDefault); } @@ -56,7 +56,7 @@ inline Buffer DeclExternBuffer(Array shape, * function. The function expects two arguments: an array of Buffers holding the input * tensor values, and a pre-allocated array of Buffers to be filled with the outputs. */ -using FExtern = std::function, Array)>; +using FExtern = std::function, Array)>; /*! * \brief Create tensors representing the result of invoking an external function. @@ -75,7 +75,7 @@ using FExtern = std::function, Array)>; * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding * element of out_types. */ -inline Array make_extern(const Array< Array >& out_shapes, +inline Array make_extern(const Array< Array >& out_shapes, const std::vector& out_types, const Array& inputs, FExtern fextern, @@ -116,18 +116,18 @@ inline Array make_extern(const Array< Array >& out_shapes, * * \return An expression representing the pack operation */ -inline Expr pack_buffer(Buffer buf) { +inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; auto shape = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape, buf->shape, tvm::ir::CallNode::CallType::Intrinsic); - Expr strides; + PrimExpr strides; if (buf->strides.size() > 0) { strides = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape, buf->shape, tvm::ir::CallNode::CallType::Intrinsic); } else { strides = 0; } - Array pack_args{ + Array pack_args{ buf->data, shape, strides, @@ -148,7 +148,7 @@ inline Expr pack_buffer(Buffer buf) { * * \return An expression representing the invocation */ -inline Expr call_packed(Array args) { +inline PrimExpr call_packed(Array args) { return tvm::ir::CallNode::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed, args, tvm::ir::CallNode::CallType::Intrinsic); } diff --git a/topi/include/topi/detail/pad_utils.h b/topi/include/topi/detail/pad_utils.h index 50e0f9532d1f..ec757e9954fe 100644 --- a/topi/include/topi/detail/pad_utils.h +++ b/topi/include/topi/detail/pad_utils.h @@ -42,7 +42,7 @@ using namespace tvm; * \return An array of 4 elements, representing padding sizes for * each individual side. The array is in the order { top, left, bottom, right } */ -inline Array GetPadTuple(Expr pad_h, Expr pad_w) { +inline Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { pad_h *= 2; pad_w *= 2; diff --git a/topi/include/topi/detail/ravel_unravel.h b/topi/include/topi/detail/ravel_unravel.h index 6cd4707a2477..5526a7dac7c2 100644 --- a/topi/include/topi/detail/ravel_unravel.h +++ b/topi/include/topi/detail/ravel_unravel.h @@ -41,10 +41,10 @@ using namespace tvm; * * \return The index after flattening */ -inline Expr RavelIndex(Array indices, Array shape) { +inline PrimExpr RavelIndex(Array indices, Array shape) { CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; CHECK_GT(indices.size(), 0) << "indices must not be empty"; - Expr idx; + PrimExpr idx; for (size_t i = 0; i < indices.size(); ++i) { if (i == 0) { idx = indices[i]; @@ -63,8 +63,8 @@ inline Expr RavelIndex(Array indices, Array shape) { * * \return The coordinate corresponding to the 1D index */ -inline Array UnravelIndex(Expr idx, Array shape) { - std::vector indices; +inline Array UnravelIndex(PrimExpr idx, Array shape) { + std::vector indices; for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { indices.push_back(indexmod(idx, shape[i])); diff --git a/topi/include/topi/detail/tensor_utils.h b/topi/include/topi/detail/tensor_utils.h index fe23836591e4..e52452e89d33 100644 --- a/topi/include/topi/detail/tensor_utils.h +++ b/topi/include/topi/detail/tensor_utils.h @@ -36,7 +36,7 @@ using namespace tvm; * * \return True if the input shape is empty. */ -inline bool is_empty_shape(const Array& x) { +inline bool is_empty_shape(const Array& x) { bool is_empty = false; for (const auto& dim : x) { if (auto int_dim = dim.as()) { diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index dec94f351d96..df7cff064383 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -191,9 +191,9 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) { return compute(x->shape, [&](const Array& i) { - Expr zero = make_zero(x->dtype); - Expr one = make_const(x->dtype, 1); - Expr minus_one = make_const(x->dtype, -1); + PrimExpr zero = make_zero(x->dtype); + PrimExpr one = make_const(x->dtype, 1); + PrimExpr minus_one = make_const(x->dtype, -1); auto s1 = tvm::ir::SelectNode::make((x(i) < zero), minus_one, zero); auto s2 = tvm::ir::SelectNode::make((x(i) > zero), one, s1); return s2; @@ -213,7 +213,7 @@ inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) { return compute(x->shape, [&](const Array& i) { - Expr one = make_const(x->dtype, 1); + PrimExpr one = make_const(x->dtype, 1); return one/tvm::sqrt(x(i)); }, name, tag); } @@ -231,8 +231,8 @@ inline Tensor rsqrt(const Tensor& x, * \return A Tensor whose op member is the clip operation */ inline Tensor clip(const Tensor& x, - const Expr& a_min, - const Expr& a_max, + const PrimExpr& a_min, + const PrimExpr& a_max, std::string name = "T_clip", std::string tag = kElementWise) { return compute(x->shape, [&](const Array& i) { @@ -325,12 +325,12 @@ inline Tensor elemwise_sum(const Array& xs, * * \return A Tensor whose op member is the full operation */ -inline Tensor full(const Array& shape, +inline Tensor full(const Array& shape, DataType dtype, - const Expr fill_value, + const PrimExpr fill_value, std::string name = "T_full", std::string tag = kElementWise) { - Expr ev = cast(dtype, fill_value); + PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { LOG(ERROR) << "Can't cast fill_value to " << dtype; } @@ -351,10 +351,10 @@ inline Tensor full(const Array& shape, * \return A Tensor whose op memeber is the full_like operation */ inline Tensor full_like(const Tensor& x, - const Expr fill_value, + const PrimExpr fill_value, std::string name = "T_full_like", std::string tag = kElementWise) { - Expr ev = cast(x->dtype, fill_value); + PrimExpr ev = cast(x->dtype, fill_value); return compute(x->shape, [&](const Array& i) { return ev; }, name, tag); diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index 5f58562ddd81..f169ec946636 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -50,8 +50,8 @@ using namespace tvm; * * \return The interpolated value in the given index. */ -inline Expr bilinear_sample_nchw(const Tensor& input, const Array& indices, - const Expr max_y, const Expr max_x) { +inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& indices, + const PrimExpr max_y, const PrimExpr max_x) { auto in_y = indices[2]; auto yf = tvm::floor(in_y); auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y)); @@ -91,11 +91,11 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array& indices * \return A Tensor resized to given shape */ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input, - const Array& shape, + const Array& shape, bool align_corners = false, std::string name = "tensor", std::string tag = kInjective) { - Array out_shape; + Array out_shape; out_shape.push_back(input->shape[0]); out_shape.push_back(cast(DataType::Int(32), shape[0])); out_shape.push_back(cast(DataType::Int(32), shape[1])); @@ -103,7 +103,7 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input, return compute( out_shape, [&](const Array& indices) { - Array idx; + Array idx; idx.push_back(indices[0]); idx.push_back(indices[1] * input->shape[1] / shape[0]); idx.push_back(indices[2] * input->shape[2] / shape[1]); @@ -125,11 +125,11 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, - const Array& shape, + const Array& shape, bool align_corners = false, std::string name = "tensor", std::string tag = kInjective) { - Array out_shape; + Array out_shape; out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[1]); out_shape.push_back(cast(DataType::Int(32), shape[0])); @@ -137,7 +137,7 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, return compute( out_shape, [&](const Array& indices) { - Array idx; + Array idx; idx.push_back(indices[0]); idx.push_back(indices[1]); idx.push_back(indices[2] * input->shape[2] / shape[0]); @@ -159,11 +159,11 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, - const Array& shape, + const Array& shape, bool align_corners = false, std::string name = "tensor", std::string tag = kInjective) { - Array out_shape; + Array out_shape; out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[1]); out_shape.push_back(cast(DataType::Int(32), shape[0])); @@ -172,7 +172,7 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, return compute( out_shape, [&](const Array& indices) { - Array idx; + Array idx; idx.push_back(indices[0]); idx.push_back(indices[1]); idx.push_back(indices[2] * input->shape[2] / shape[0]); @@ -196,7 +196,7 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize_nearest_neighbor(const Tensor& input, - const Array& shape, + const Array& shape, std::string layout = "NCHW", bool align_corners = false, std::string name = "tensor", @@ -227,25 +227,25 @@ inline Tensor resize_nearest_neighbor(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize_bilinear_nhwc(const Tensor& input, - const Array& shape, + const Array& shape, bool align_corners = false, std::string name = "tensor", std::string tag = kInjective) { - Array out_shape; + Array out_shape; out_shape.push_back(input->shape[0]); out_shape.push_back(cast(DataType::Int(32), shape[0])); out_shape.push_back(cast(DataType::Int(32), shape[1])); out_shape.push_back(input->shape[3]); - Expr cone = make_const(DataType::Int(32), 1); + PrimExpr cone = make_const(DataType::Int(32), 1); auto in_height = as_const_int(input->shape[1]); auto in_width = as_const_int(input->shape[2]); auto out_height = as_const_int(shape[0]); auto out_width = as_const_int(shape[1]); - Expr y_ratio; - Expr x_ratio; + PrimExpr y_ratio; + PrimExpr x_ratio; if (!align_corners) { y_ratio = make_const(DataType::Float(32), (static_cast(*in_height) / @@ -259,8 +259,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, static_cast(*out_width - 1))); } - Expr other_y = tvm::ir::Simplify(input->shape[1] - cone); - Expr other_x = tvm::ir::Simplify(input->shape[2] - cone); + PrimExpr other_y = tvm::ir::Simplify(input->shape[1] - cone); + PrimExpr other_x = tvm::ir::Simplify(input->shape[2] - cone); return compute( out_shape, [&](const Array& indices) { @@ -304,25 +304,25 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize_bilinear_nchw(const Tensor& input, - const Array& shape, + const Array& shape, bool align_corners = false, std::string name = "tensor", std::string tag = kInjective) { - Array out_shape; + Array out_shape; out_shape.push_back(input->shape[0]); out_shape.push_back(input->shape[1]); out_shape.push_back(cast(DataType::Int(32), shape[0])); out_shape.push_back(cast(DataType::Int(32), shape[1])); - Expr cone = make_const(DataType::Int(32), 1); + PrimExpr cone = make_const(DataType::Int(32), 1); auto in_height = as_const_int(input->shape[2]); auto in_width = as_const_int(input->shape[3]); auto out_height = as_const_int(shape[0]); auto out_width = as_const_int(shape[1]); - Expr y_ratio; - Expr x_ratio; + PrimExpr y_ratio; + PrimExpr x_ratio; if (!align_corners) { y_ratio = make_const(DataType::Float(32), (static_cast(*in_height) / @@ -336,8 +336,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, static_cast(*out_width - 1))); } - Expr other_y = tvm::ir::Simplify(input->shape[2] - cone); - Expr other_x = tvm::ir::Simplify(input->shape[3] - cone); + PrimExpr other_y = tvm::ir::Simplify(input->shape[2] - cone); + PrimExpr other_x = tvm::ir::Simplify(input->shape[3] - cone); return compute( out_shape, [&](const Array& indices) { @@ -360,7 +360,7 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize_bilinear(const Tensor& input, - const Array& shape, + const Array& shape, std::string layout = "NCHW", bool align_corners = false, std::string name = "tensor", @@ -390,7 +390,7 @@ inline Tensor resize_bilinear(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize(const Tensor& input, - const Array& shape, + const Array& shape, std::string layout = "NCHW", bool align_corners = false, std::string mode = "BILINEAR", diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 5920c0b92061..3f65c75a02bb 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -39,9 +39,9 @@ using namespace tvm; namespace detail { template -tvm::Expr Map(const tvm::Array& exprs, T op) { +tvm::PrimExpr Map(const tvm::Array& exprs, T op) { CHECK_GE(exprs.size(), 1); - tvm::Expr res = exprs[0]; + tvm::PrimExpr res = exprs[0]; for (size_t i = 1; i < exprs.size(); ++i) { res = op(res, exprs[i]); } @@ -172,9 +172,9 @@ inline tvm::Tensor prelu(const tvm::Tensor &x, * */ inline tvm::Tensor pad(const tvm::Tensor& t, - const tvm::Array& pad_before, - tvm::Array pad_after = tvm::Array(), - Expr pad_value = Expr(), + const tvm::Array& pad_before, + tvm::Array pad_after = tvm::Array(), + PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", std::string tag = kElementWise, std::string pad_mode = "constant") { @@ -185,9 +185,9 @@ inline tvm::Tensor pad(const tvm::Tensor& t, } CHECK_GE(pad_before.size(), 1); CHECK_EQ(pad_before.size(), pad_after.size()); - tvm::Array output_shape; - tvm::Array pad_before_int32; - tvm::Array pad_after_int32; + tvm::Array output_shape; + tvm::Array pad_before_int32; + tvm::Array pad_after_int32; for (const auto &ele : pad_before) { pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } @@ -207,9 +207,9 @@ inline tvm::Tensor pad(const tvm::Tensor& t, pad_value = tvm::make_const(t->dtype, 0); } auto l = [&](tvm::Array ovars) { - tvm::Array indices; - tvm::Array sel; - tvm::Array pad_idx; + tvm::Array indices; + tvm::Array sel; + tvm::Array pad_idx; for (size_t i = 0; i < t->shape.size(); ++i) { if (i >= pad_before_int32.size()) { indices.push_back(ovars[i]); @@ -286,7 +286,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I, CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::Array output_shape{ I->shape[0], // B W->shape[0], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H @@ -297,7 +297,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I, auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I - : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); + : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { return tvm::sum( T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), @@ -337,7 +337,7 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I, CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::Array output_shape{ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W I->shape[2], // B @@ -389,7 +389,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, auto pH = I->shape[2]; auto pW = I->shape[3]; auto pCM = W->shape[1]; // channel_multiplier - tvm::Array output_shape{ + tvm::Array output_shape{ I->shape[0], // B W->shape[1], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H @@ -400,7 +400,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I - : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); + : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) * W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), @@ -422,7 +422,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, auto pH = I->shape[1]; auto pW = I->shape[2]; auto pCM = W->shape[1]; // channel_multiplier - tvm::Array output_shape{ + tvm::Array output_shape{ I->shape[0], // B indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W @@ -433,7 +433,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I - : pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)}); + : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) { return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) * W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), @@ -474,7 +474,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, CHECK_EQ(5, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::Array output_shape{ I->shape[0], // B I->shape[1], // G W->shape[2], // O @@ -487,7 +487,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, auto T = (pad_h == 0 && pad_w == 0) ? I - : pad(I, {tvm::Expr(0), tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); + : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::Array args) { tvm::Var b = args[0]; tvm::Var g = args[1]; diff --git a/topi/include/topi/nn/bnn.h b/topi/include/topi/nn/bnn.h index 6b79db3b5c79..e2af3ae61518 100644 --- a/topi/include/topi/nn/bnn.h +++ b/topi/include/topi/nn/bnn.h @@ -55,7 +55,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data, << "binarize_pack: axis size must be a multiple of 32"; auto n = ishape.size(); - Array oshape; + Array oshape; for (size_t i = 0; i < n; ++i) { oshape.push_back(i == static_cast(axis) ? tvm::ir::Simplify(indexdiv(ishape[i], 32)) : @@ -65,15 +65,15 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data, return tvm::compute( oshape, [&](const Array& indices) { - Array start_idx; + Array start_idx; for (size_t i = 0; i < n; ++i) { start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 : - static_cast(indices[i])); + static_cast(indices[i])); } auto packed = make_const(DataType::UInt(32), 0); for (size_t j = 0; j < 32; ++j) { - Array idx; + Array idx; for (size_t i = 0; i < n; ++i) { idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) : diff --git a/topi/include/topi/nn/dilate.h b/topi/include/topi/nn/dilate.h index 1dc2c8d53948..334b17054c81 100644 --- a/topi/include/topi/nn/dilate.h +++ b/topi/include/topi/nn/dilate.h @@ -42,10 +42,10 @@ using namespace tvm; * * \return The logical conjunction expression */ -Expr all(Array args) { +PrimExpr all(Array args) { CHECK_GT(args.size(), 0) << "all requires at least one argument"; - Expr ret = args[0]; + PrimExpr ret = args[0]; for (size_t i = 1; i < args.size(); ++i) { ret = ret && args[i]; } @@ -65,7 +65,7 @@ Expr all(Array args) { * \return The output tensor. */ inline Tensor dilate(const Tensor& x, - Array strides, + Array strides, std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); @@ -73,7 +73,7 @@ inline Tensor dilate(const Tensor& x, << "strides size (" << strides.size() << ") must match dimension of x (" << n << ")"; - Array out_shape; + Array out_shape; for (size_t i = 0; i < n; ++i) { out_shape.push_back(tvm::ir::Simplify( (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); @@ -82,8 +82,8 @@ inline Tensor dilate(const Tensor& x, return tvm::compute( out_shape, [&](const Array& indices) { - Array not_zero; - Array index_tuple; + Array not_zero; + Array index_tuple; for (size_t i = 0; i < n; ++i) { if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { index_tuple.push_back(indices[i]); diff --git a/topi/include/topi/nn/flatten.h b/topi/include/topi/nn/flatten.h index e331b978c172..6b542f7c2afe 100644 --- a/topi/include/topi/nn/flatten.h +++ b/topi/include/topi/nn/flatten.h @@ -51,14 +51,14 @@ inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) { auto ishape = x->shape; - Expr dim = 1; + PrimExpr dim = 1; for (size_t i = 1; i < ishape.size(); ++i) { dim = dim * ishape[i]; } - Array oshape({ ishape[0], dim }); + Array oshape({ ishape[0], dim }); - std::vector extra_shape; + std::vector extra_shape; for (size_t i = 1; i < ishape.size(); ++i) { extra_shape.push_back(ishape[i]); } @@ -66,8 +66,8 @@ inline Tensor flatten(const Tensor& x, return tvm::compute( oshape, [&](Var i, Var j) { - Expr idx = j; - std::vector index; + PrimExpr idx = j; + std::vector index; for (auto s : extra_shape) { index.push_back(indexmod(idx, s)); idx = indexdiv(idx, s); diff --git a/topi/include/topi/nn/local_response_norm.h b/topi/include/topi/nn/local_response_norm.h index 2490b37dbe21..0cce997c2009 100644 --- a/topi/include/topi/nn/local_response_norm.h +++ b/topi/include/topi/nn/local_response_norm.h @@ -59,10 +59,10 @@ inline Tensor lrn(const Tensor& data, CHECK_EQ(size % 2, 1) << "size should be odd number"; CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; auto input_shape = data->shape; - Array pad_before{ 0, 0, 0, 0}; - Array pad_after{ 0, 0, 0, 0}; - pad_before.Set(axis, static_cast(size/2)); - pad_after.Set(axis, static_cast(size/2)); + Array pad_before{ 0, 0, 0, 0}; + Array pad_after{ 0, 0, 0, 0}; + pad_before.Set(axis, static_cast(size/2)); + pad_after.Set(axis, static_cast(size/2)); auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); auto rxs = tvm::reduce_axis(Range(0, size), "rxs"); Tensor sqr_sum; diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 35bcd53a65f1..a074ee1f6ef9 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -61,9 +61,9 @@ enum PoolType : int { * \return The output tensor in same layout order */ inline Tensor pool_impl(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, @@ -94,11 +94,11 @@ inline Tensor pool_impl(const Tensor& x, pad_right += stride_width - 1; } - Array pad_before(std::vector(x->shape.size(), 0)); + Array pad_before(std::vector(x->shape.size(), 0)); pad_before.Set(height_axis, pad_top); pad_before.Set(width_axis, pad_left); - Array pad_after(std::vector(x->shape.size(), 0)); + Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); @@ -110,7 +110,7 @@ inline Tensor pool_impl(const Tensor& x, auto dheight = tvm::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::reduce_axis(Range(0, kernel_width)); - Array out_shape = x->shape; + Array out_shape = x->shape; out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); @@ -125,7 +125,7 @@ inline Tensor pool_impl(const Tensor& x, auto temp = do_pad ? pad( x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; return tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); indices.Set(height_axis, output[height_axis] * stride_height + dheight); indices.Set(width_axis, output[width_axis] * stride_width + dwidth); @@ -138,7 +138,7 @@ inline Tensor pool_impl(const Tensor& x, // TVM compute for summing the pooling window. auto pool_sum = tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); indices.Set(height_axis, output[height_axis] * stride_height + dheight); indices.Set(width_axis, output[width_axis] * stride_width + dwidth); @@ -148,18 +148,18 @@ inline Tensor pool_impl(const Tensor& x, // TVM compute for dividing the reduced window sum by kernel size. return tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); if (count_include_pad) { return div(pool_sum(indices), (kernel_height * kernel_width)); } else { - Expr h_start = output[height_axis] * stride_height - pad_top; - Expr w_start = output[width_axis] * stride_width - pad_left; - Expr h_end = ir::MinNode::make(h_start + kernel_height, height); - Expr w_end = ir::MinNode::make(w_start + kernel_width, width); + PrimExpr h_start = output[height_axis] * stride_height - pad_top; + PrimExpr w_start = output[width_axis] * stride_width - pad_left; + PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width); h_start = ir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); w_start = ir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); - Expr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start), + PrimExpr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start), make_const(DataType::DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } @@ -170,9 +170,12 @@ inline Tensor pool_impl(const Tensor& x, } } -inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, - const Array& kernel_size, const Array& stride_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, +inline Tensor pool_grad_impl(const Tensor& out_grad, + const Tensor& x, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, + PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad) { CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; @@ -201,11 +204,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, pad_right += stride_width - 1; } - Array pad_before(std::vector(x->shape.size(), 0)); + Array pad_before(std::vector(x->shape.size(), 0)); pad_before.Set(height_axis, pad_top); pad_before.Set(width_axis, pad_left); - Array pad_after(std::vector(x->shape.size(), 0)); + Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); @@ -217,7 +220,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto dheight = tvm::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::reduce_axis(Range(0, kernel_width)); - Array out_shape = x->shape; + Array out_shape = x->shape; out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); @@ -229,7 +232,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { - Array ravel_shape{x->shape.begin(), x->shape.end()}; + Array ravel_shape{x->shape.begin(), x->shape.end()}; ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); @@ -243,7 +246,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto mp_argmax = tvm::compute(out_shape, [&](const Array& inds) { - Array window_inds{inds.begin(), inds.end()}; + Array window_inds{inds.begin(), inds.end()}; window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); auto idx = detail::RavelIndex(window_inds, ravel_shape); @@ -256,19 +259,19 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, return tvm::compute( x->shape, [&](const Array& inds) { - Array pad_inds {inds.begin(), inds.end()}; + Array pad_inds {inds.begin(), inds.end()}; pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left); auto idx = detail::RavelIndex(pad_inds, ravel_shape); - Array out_idx {inds.begin(), inds.end()}; + Array out_idx {inds.begin(), inds.end()}; out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); - Expr out_idx_lower_h = ir::SelectNode::make( + PrimExpr out_idx_lower_h = ir::SelectNode::make( pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); - Expr out_idx_lower_w = ir::SelectNode::make( + PrimExpr out_idx_lower_w = ir::SelectNode::make( pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); @@ -287,29 +290,29 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, return tvm::compute( x->shape, [&](const Array& inds) { - Expr pad_h_idx = inds[height_axis] + pad_top; - Expr pad_w_idx = inds[width_axis] + pad_left; + PrimExpr pad_h_idx = inds[height_axis] + pad_top; + PrimExpr pad_w_idx = inds[width_axis] + pad_left; // output indices whose pooling windows cover current input element (can be out-of-bound) - Array out_idx{inds.begin(), inds.end()}; + Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); - Expr out_idx_lower_h = ir::SelectNode::make( + PrimExpr out_idx_lower_h = ir::SelectNode::make( pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), (pad_h_idx - kernel_height) / stride_height + 1); - Expr out_idx_lower_w = ir::SelectNode::make( + PrimExpr out_idx_lower_w = ir::SelectNode::make( pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), (pad_w_idx - kernel_width) / stride_width + 1); - Expr divide_factor; // number of pooled elements + PrimExpr divide_factor; // number of pooled elements if (count_include_pad) { divide_factor = kernel_height * kernel_width; } else { - Expr h_start = out_idx[height_axis] * stride_height - pad_top; - Expr w_start = out_idx[width_axis] * stride_width - pad_left; - Expr h_end = ir::MinNode::make(h_start + kernel_height, height); - Expr w_end = ir::MinNode::make(w_start + kernel_width, width); + PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top; + PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left; + PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width); h_start = ir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); w_start = ir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); divide_factor = @@ -412,9 +415,9 @@ inline bool find_width(const std::string& layout, * \return The output tensor in the same layout */ inline Tensor pool(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", @@ -457,8 +460,8 @@ inline Tensor pool(const Tensor& x, * * \return The output tensor in the same layout */ -inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, +inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; @@ -467,16 +470,16 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& output_size, + const Array& output_size, PoolType pool_type, const size_t height_axis, const size_t width_axis) { @@ -504,13 +507,13 @@ inline Tensor adaptive_pool_impl(const Tensor& x, auto out_height = cast(DataType::Int(32), output_size[0]); auto out_width = cast(DataType::Int(32), output_size[1]); - Array out_shape = x->shape; + Array out_shape = x->shape; out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); if (pool_type == kMaxPool) { return tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); auto i_start_h = start_index(output[height_axis], out_height, height); auto i_end_h = end_index(output[height_axis], out_height, height); @@ -524,7 +527,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, }, "tensor", "adaptive_pool_max"); } else if (pool_type == kAvgPool) { auto pool_sum = tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); auto i_start_h = start_index(output[height_axis], out_height, height); auto i_end_h = end_index(output[height_axis], out_height, height); @@ -540,7 +543,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, }, "tensor", "adaptive_pool_sum"); return tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); auto i_start_h = start_index(output[height_axis], out_height, height); auto i_end_h = end_index(output[height_axis], out_height, height); @@ -583,7 +586,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, * \return The output tensor in same layout order */ inline Tensor adaptive_pool(const Tensor& x, - const Array& output_size, + const Array& output_size, PoolType pool_type, const std::string& layout = "NCHW") { int height_axis = -1, width_axis = -1; @@ -620,7 +623,7 @@ inline Tensor adaptive_pool(const Tensor& x, inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") { - return adaptive_pool(x, Array{1, 1}, pool_type, layout); + return adaptive_pool(x, Array{1, 1}, pool_type, layout); } /*! @@ -639,9 +642,9 @@ inline Tensor global_pool(const Tensor& x, * \return The output tensor in same layout order */ inline Tensor pool_impl_nd(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, const std::vector& axis, @@ -654,13 +657,13 @@ inline Tensor pool_impl_nd(const Tensor& x, CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; Array daxis; - std::vector kernel(k_size); - std::vector stride(k_size); - std::vector pad_head(k_size); - std::vector pad_tail(k_size); - Array pad_before(std::vector(x_size, 0)); - Array pad_after(std::vector(x_size, 0)); - Array out_shape = x->shape; + std::vector kernel(k_size); + std::vector stride(k_size); + std::vector pad_head(k_size); + std::vector pad_tail(k_size); + Array pad_before(std::vector(x_size, 0)); + Array pad_after(std::vector(x_size, 0)); + Array out_shape = x->shape; bool do_pad = false; for (int i = 0; i < k_size; i++) { @@ -694,7 +697,7 @@ inline Tensor pool_impl_nd(const Tensor& x, auto temp = do_pad ? pad( x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; return tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { @@ -711,7 +714,7 @@ inline Tensor pool_impl_nd(const Tensor& x, // TVM compute for summing the pooling window. auto pool_sum = tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { @@ -724,7 +727,7 @@ inline Tensor pool_impl_nd(const Tensor& x, // TVM compute for dividing the reduced window sum by kernel size. return tvm::compute(out_shape, [&](const Array& output) { - Array indices; + Array indices; for (const Var& var : output) indices.push_back(var); if (count_include_pad) { auto kernel_size = make_const(DataType::Int(32), 1); @@ -733,8 +736,8 @@ inline Tensor pool_impl_nd(const Tensor& x, } return div(pool_sum(indices), kernel_size); } else { - std::vector start(k_size); - std::vector end(k_size); + std::vector start(k_size); + std::vector end(k_size); auto kernel_size = make_const(DataType::Int(32), 1); for (int i = 0; i < k_size; i++) { int ii = axis[i]; @@ -744,7 +747,7 @@ inline Tensor pool_impl_nd(const Tensor& x, kernel_size *= (end[i] - start[i]); } - Expr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); + PrimExpr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } }, "tensor", kElementWise); @@ -784,9 +787,9 @@ inline Tensor pool_impl_nd(const Tensor& x, * \return The output tensor in the same layout */ inline Tensor pool1d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", @@ -830,9 +833,9 @@ inline Tensor pool1d(const Tensor& x, * \return The output tensor in the same layout */ inline Tensor pool3d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", diff --git a/topi/include/topi/nn/softmax.h b/topi/include/topi/nn/softmax.h index c3124bbe6f58..58ecc956964d 100644 --- a/topi/include/topi/nn/softmax.h +++ b/topi/include/topi/nn/softmax.h @@ -66,7 +66,7 @@ inline Tensor softmax(const Tensor &x, auto insert_reduce_index = [axis, ndim](const Array &indices, const IterVar &reduce_index) { - Array eval_range; + Array eval_range; int arg_counter = 0; for (size_t i = 0; i < ndim; ++i) { if (static_cast(i) == axis) @@ -78,7 +78,7 @@ inline Tensor softmax(const Tensor &x, }; auto get_non_reduce_indices = [axis, ndim](const Array &indices) { - Array non_reduce_indices; + Array non_reduce_indices; for (size_t i = 0; i < ndim; ++i) { if (static_cast(i) != axis) non_reduce_indices.push_back(indices[i]); @@ -135,8 +135,8 @@ inline Tensor log_softmax(const Tensor& x, std::string tag = "log_softmax_output") { CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input"; - Expr m = x->shape[0]; - Expr n = x->shape[1]; + PrimExpr m = x->shape[0]; + PrimExpr n = x->shape[1]; auto k = tvm::reduce_axis(Range(0, n), "k"); auto max_elem = tvm::compute( diff --git a/topi/include/topi/nn/upsampling.h b/topi/include/topi/nn/upsampling.h index 3ca08549d62b..b6230c7c017c 100644 --- a/topi/include/topi/nn/upsampling.h +++ b/topi/include/topi/nn/upsampling.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -49,7 +49,7 @@ using namespace topi::image; * \return A Tensor upsampled to given shape */ inline Tensor upsampling(const Tensor& input, - const Array shape, + const Array shape, std::string layout = "NCHW", std::string mode = "NEAREST_NEIGHBOR", std::string name = "tensor", diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index 2d3d7d352fb7..ac843b1a7077 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -43,11 +43,11 @@ namespace topi { using namespace tvm; /*! \brief The operation to use for CommReduce */ -using FReduce = std::function& axis)>; +using FReduce = std::function& axis)>; /*! \brief The operation to use for CommReduceIdx */ using FCommReduce = std::function< - Array(Array exprs, const Array& axis, Expr* condition)>; + Array(Array exprs, const Array& axis, PrimExpr* condition)>; /*! * \brief Convert a reduction axis which could be empty or have negative @@ -97,12 +97,12 @@ inline Array MakeReduceAxes(const std::vector& real_axis, const Te } /*! \brief Calculate the target shape for a reduce op */ -inline Array MakeReduceTargetShape(const std::vector& real_axis, +inline Array MakeReduceTargetShape(const std::vector& real_axis, const Tensor& data, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); - Array target_shape; + Array target_shape; if (keepdims) { for (size_t i = 0; i < ndim; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { @@ -140,12 +140,12 @@ inline Array MakeReduceTargetShape(const std::vector& real_axis, */ inline Tensor DoCommReduce(const Tensor& data, FReduce func, - const Array& target_shape, + const Array& target_shape, const std::vector& reduce_axes, const std::vector& squeeze_axes) { auto r_axes = MakeReduceAxes(reduce_axes, data); auto compute = [&](const Array& indices) { - Array eval_range; + Array eval_range; Array eval_indices; int arg_counter = 0; int red_counter = 0; @@ -222,8 +222,8 @@ inline Tensor CommReduceIdx(const Tensor& data, auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data] (const Array& indices) { - Array eval_range; - Array eval_indices; + Array eval_range; + Array eval_indices; int arg_counter = 0; int red_counter = 0; @@ -243,7 +243,7 @@ inline Tensor CommReduceIdx(const Tensor& data, } } - Array ravel_shape; + Array ravel_shape; for (auto i : real_axis) { ravel_shape.push_back(data->shape[i]); } @@ -263,10 +263,10 @@ inline Tensor CommReduceIdx(const Tensor& data, } /*! \brief A combiner function for a reduction */ -using FCombine = std::function(Array lhs, Array rhs)>; +using FCombine = std::function(Array lhs, Array rhs)>; /*! \brief An initializer function for a reduction */ -using FIdentity = std::function(std::vector types)>; +using FIdentity = std::function(std::vector types)>; /*! * \brief Create a commutative reducer for a reduction @@ -281,7 +281,7 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name = "reduce") { return [fcombine, fidentity, name] - (Array exprs, const Array& axis, Expr* condition) { + (Array exprs, const Array& axis, PrimExpr* condition) { Array lhs, rhs; std::vector dtypes; @@ -297,7 +297,7 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, auto cond = condition != nullptr ? *condition : tvm::const_true(); auto combiner = tvm::ir::CommReducerNode::make(lhs, rhs, result, id_elem); - Array outputs; + Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { outputs.push_back( tvm::ir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); @@ -307,17 +307,17 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, } /*! \brief Wrap tvm::min to ensure we get the correct overload */ -inline Expr MinOp(Expr source, Array axis) { +inline PrimExpr MinOp(PrimExpr source, Array axis) { return tvm::min(source, axis); } /*! \brief Wrap tvm::max to ensure we get the correct overload */ -inline Expr MaxOp(Expr source, Array axis) { +inline PrimExpr MaxOp(PrimExpr source, Array axis) { return tvm::max(source, axis); // NOLINT(*) } /*! \brief Wrap tvm::prod to ensure we get the correct overload */ -inline Expr ProdOp(Expr source, Array axis) { +inline PrimExpr ProdOp(PrimExpr source, Array axis) { return tvm::prod(source, axis); // NOLINT(*) } @@ -341,7 +341,7 @@ inline Tensor sum(const Tensor& data, return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); } -inline Tensor collapse_sum(const Tensor& data, Array target_shape) { +inline Tensor collapse_sum(const Tensor& data, Array target_shape) { CHECK_GE(data->shape.size(), target_shape.size()); auto ishape = detail::GetConstIntValues(data->shape, "ishape"); auto oshape = detail::GetConstIntValues(target_shape, "oshape"); @@ -472,13 +472,13 @@ inline Tensor argmin(const Tensor& data, bool keepdims = false, bool atleast1d = false) { auto fcombine = [](Array lhs, Array rhs) { - Array result; + Array result; result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { - Array result; + Array result; result.push_back(tvm::make_const(types[0], -1)); // idx result.push_back(tvm::max_value(types[1])); // val return result; @@ -489,13 +489,13 @@ inline Tensor argmin(const Tensor& data, inline FCommReduce MakeArgmaxReducer() { auto fcombine = [](Array lhs, Array rhs) { - Array result; + Array result; result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { - Array result; + Array result; result.push_back(tvm::make_const(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val return result; diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 00106c10af91..66e2773ded7e 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -72,7 +72,7 @@ inline Tensor expand_dims(const Tensor& x, // Calculate offset from last dimension axis = ndim + axis + 1; } - Array new_shape; + Array new_shape; for (size_t i = 0; i < static_cast(axis); ++i) { new_shape.push_back(x->shape[i]); } @@ -85,7 +85,7 @@ inline Tensor expand_dims(const Tensor& x, return compute( new_shape, [&](const Array& indices) { - Array idx; + Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -118,7 +118,7 @@ inline Tensor transpose(const Tensor& x, } } - Array new_shape; + Array new_shape; for (size_t i = 0; i < axes.size(); ++i) { int axis = static_cast(axes[i]->value); int new_axis = axis; @@ -140,7 +140,7 @@ inline Tensor transpose(const Tensor& x, return compute( new_shape, [&](const Array& indices) { - std::vector idx; + std::vector idx; for (size_t i = 0; i < axes.size(); ++i) { idx.push_back(1); } @@ -181,7 +181,7 @@ inline Tensor flip(const Tensor& x, // Reverse the Input Tensor in the axis specified return compute( x->shape, [&](const Array& indices) { - Array real_indices; + Array real_indices; for (size_t i = 0; i < src_tensor_dim; ++i) { if (i == static_cast(axis)) { real_indices.push_back(x->shape[i] - indices[i] - 1); @@ -204,11 +204,11 @@ inline Tensor flip(const Tensor& x, * \return A Tensor whose op member is the reshape operation */ inline Tensor reshape(const Tensor& x, - Array newshape, + Array newshape, std::string name = "T_reshape", std::string tag = kInjective) { auto x_shape = x->shape; - Array target_shape; + Array target_shape; for (const auto &ele : newshape) { if (ele.as()) { @@ -226,7 +226,7 @@ inline Tensor reshape(const Tensor& x, return compute( target_shape, [&](const Array& indices) { return x(UnravelIndex( - RavelIndex(Array{indices.begin(), indices.end()}, target_shape), + RavelIndex(Array{indices.begin(), indices.end()}, target_shape), x_shape)); }, name, tag); } @@ -272,7 +272,7 @@ inline Tensor squeeze(const Tensor& x, std::unordered_set axis_set(axis_val.begin(), axis_val.end()); - Array out_shape; + Array out_shape; for (size_t i = 0; i < ndim; ++i) { if (axis_set.count(static_cast(i)) == 0) { out_shape.push_back(x->shape[i]); @@ -284,7 +284,7 @@ inline Tensor squeeze(const Tensor& x, return compute( out_shape, [&](const Array& indices) { - Array real_indices; + Array real_indices; int flag = 0; for (size_t i = 0; i < ndim; ++i) { if (axis_set.count(static_cast(i)) == 0) { @@ -323,17 +323,17 @@ inline Tensor concatenate(const Array& inputs, CHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; - Array axis_sizes; + Array axis_sizes; for (auto t : inputs) { axis_sizes.push_back(t->shape[axis]); } - Expr join_size = axis_sizes[0]; + PrimExpr join_size = axis_sizes[0]; for (size_t i = 1; i < axis_sizes.size(); ++i) { join_size += axis_sizes[i]; } join_size = tvm::ir::Simplify(join_size); - Array out_shape; + Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); } @@ -345,7 +345,7 @@ inline Tensor concatenate(const Array& inputs, for (size_t i = 0; i < inputs.size() - 1; ++i) { ind -= axis_sizes[i]; - Array idx; + Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -388,7 +388,7 @@ inline Tensor stack(const Array& inputs, "axis out of bounds"; const int stack_size = static_cast(inputs.size()); - Array out_shape; + Array out_shape; for (size_t i = 0; i < static_cast(axis); ++i) out_shape.push_back(inputs[0]->shape[i]); out_shape.push_back(stack_size); @@ -397,7 +397,7 @@ inline Tensor stack(const Array& inputs, return compute( out_shape, [&](const Array& indices) { - Array idx; + Array idx; for (size_t i = 0; i < indices.size(); ++i) if (i != static_cast(axis)) idx.push_back(indices[i]); @@ -445,7 +445,7 @@ inline Array split(const Tensor& x, begin_ids.push_back(val); } - Array< Array > out_shapes; + Array< Array > out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { int out_axis_size; if (i == begin_ids.size() - 1) { @@ -454,7 +454,7 @@ inline Array split(const Tensor& x, out_axis_size = begin_ids[i + 1] - begin_ids[i]; } - Array shape; + Array shape; for (size_t i = 0; i < static_cast(axis); ++i) { shape.push_back(x->shape[i]); } @@ -472,7 +472,7 @@ inline Array split(const Tensor& x, compute( out_shapes[i], [&](const Array& indices) { auto begin = begin_ids[i]; - Array real_indices; + Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(indices[j]); } @@ -547,9 +547,9 @@ inline Tensor strided_slice(const Tensor& x, end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } // Compute - Array out_shape; - Array begin_expr; - Array strides_expr; + Array out_shape; + Array begin_expr; + Array strides_expr; for (size_t i = 0; i < src_tensor_dim; ++i) { int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; @@ -581,7 +581,7 @@ inline Tensor strided_slice(const Tensor& x, return compute( out_shape, [&](const Array& indices) { - Array real_indices; + Array real_indices; for (size_t i = 0; i < src_tensor_dim; ++i) { real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); } @@ -647,9 +647,9 @@ inline Tensor take(const Tensor& a, std::string mode = "clip", std::string name = "T_take", std::string tag = kInjective) { - Array a_shape = a->shape; - Array out_shape = indices->shape; - Expr a_size = 1; + Array a_shape = a->shape; + Array out_shape = indices->shape; + PrimExpr a_size = 1; for (size_t i = 0; i < a_shape.size(); ++i) { a_size = a_size * a_shape[i]; } @@ -699,15 +699,16 @@ inline Tensor sequence_mask(const Tensor& data, CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)."; auto length_dim = data->shape[axis]; auto batch_dim = data->shape[1 - axis]; - Array out_shape = data->shape; + Array out_shape = data->shape; Tensor out = compute( out_shape, [&](const Array& out_index) { - Array len_index; + Array len_index; auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); - Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::make_const(data->dtype, mask_value), data(out_index)); + PrimExpr ret = tvm::if_then_else( + tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), + tvm::make_const(data->dtype, mask_value), data(out_index)); return ret; }, name, tag); return out; @@ -740,7 +741,7 @@ inline Tensor take(const Tensor& a, auto axis_dim = a->shape[axis]; int indices_len = static_cast(indices->shape.size()); - Array out_shape; + Array out_shape; for (size_t i = 0; i < a->shape.size(); ++i) { if (axis == static_cast(i)) { for (size_t j = 0; j < indices->shape.size(); ++j) { @@ -753,11 +754,11 @@ inline Tensor take(const Tensor& a, if (mode == "clip") { return compute( out_shape, [&](const Array& out_index) { - Array indices_position; + Array indices_position; for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -774,11 +775,11 @@ inline Tensor take(const Tensor& a, "Make sure input indices are in bound"; return compute( out_shape, [&](const Array& out_index) { - Array indices_position; + Array indices_position; for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -791,11 +792,11 @@ inline Tensor take(const Tensor& a, } else { // mode == "wrap" return compute( out_shape, [&](const Array& out_index) { - Array indices_position; + Array indices_position; for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -830,7 +831,7 @@ inline Tensor where(const Tensor& condition, << x->shape.size() << " vs " << y->shape.size(); CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " << y->dtype; - Array oshape = x->shape; + Array oshape = x->shape; Tensor out; if (condition->shape.size() != 1) { @@ -848,7 +849,7 @@ inline Tensor where(const Tensor& condition, << condition->shape[0] << " vs " << x->shape[0]; out = compute( oshape, [&](const Array& indices) { - Array condition_idx{indices[0]}; + Array condition_idx{indices[0]}; return tvm::ir::SelectNode::make(condition(condition_idx) != 0, x(indices), y(indices)); }, name, tag); @@ -885,7 +886,7 @@ inline Tensor repeat(const Tensor& x, // Calculate offset from last dimension axis += ndim; } - Array new_shape; + Array new_shape; for (size_t i = 0; i < static_cast(axis); ++i) { new_shape.push_back(x->shape[i]); } @@ -896,7 +897,7 @@ inline Tensor repeat(const Tensor& x, return compute( new_shape, [&](const Array& indices) { - Array idx; + Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -925,9 +926,9 @@ inline Tensor tile(const Tensor& x, size_t ndim = x->shape.size(); size_t rdim = reps.size(); size_t tdim = (ndim > rdim) ? ndim : rdim; - Array data_shape; - Array reps_shape; - Array new_shape; + Array data_shape; + Array reps_shape; + Array new_shape; if (ndim == rdim) { for (size_t i = 0; i < ndim; ++i) { data_shape.push_back(x->shape[i]); @@ -958,7 +959,7 @@ inline Tensor tile(const Tensor& x, } else { return compute( new_shape, [&](const Array& indices) { - Array idx; + Array idx; if (ndim >= rdim) { for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i])); @@ -991,7 +992,7 @@ inline Tensor gather_nd(const Tensor& data, size_t indices_dim0 = static_cast(GetConstInt(indices->shape[0])); CHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more " << "than dimensions of data tensor"; - Array out_shape; + Array out_shape; for (size_t i = 1; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } @@ -1003,12 +1004,12 @@ inline Tensor gather_nd(const Tensor& data, } return compute( out_shape, [&](const Array& out_index) { - Array indices_position; + Array indices_position; indices_position.push_back(0); for (size_t i = 0; i < ndim_i - 1; ++i) { indices_position.push_back(out_index[i]); } - Array real_indices; + Array real_indices; for (size_t i = 0; i < indices_dim0; ++i) { indices_position.Set(0, make_const(DataType::Int(32), i)); if (indices->dtype.is_int()) { @@ -1046,7 +1047,7 @@ inline tvm::Tensor matmul(const tvm::Tensor& A, bool trans_b = false, std::string name = "T_matmul", std::string tag = kMatMul) { - tvm::Array output_shape{A->shape[trans_a ? 1 : 0], + tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); auto l = [&](tvm::Var i, tvm::Var j) { @@ -1075,7 +1076,7 @@ inline Tensor tensordot(const Tensor& A, CHECK_GE(A->shape.size(), axes); CHECK_GE(B->shape.size(), axes); - Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); + Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); @@ -1086,13 +1087,13 @@ inline Tensor tensordot(const Tensor& A, auto func = [&A, &B, &iter_vars, axes] (const Array& input_indices) { - Array A_indices( + Array A_indices( input_indices.begin(), input_indices.begin() + (A->shape.size() - axes)); for (auto& v : iter_vars) A_indices.push_back(v); - Array B_indices; + Array B_indices; for (auto& v : iter_vars) B_indices.push_back(v); @@ -1124,8 +1125,8 @@ inline Tensor tensordot(const Tensor& A, */ inline Tensor tensordot(const Tensor& A, const tvm::Tensor& B, - Array A_axes, - Array B_axes, + Array A_axes, + Array B_axes, std::string name = "T_tensordot", std::string tag = kMatMul) { CHECK_EQ(A_axes.size(), B_axes.size()); @@ -1133,7 +1134,7 @@ inline Tensor tensordot(const Tensor& A, auto A_axes_val = GetConstIntValues(A_axes, "A_axes"); auto B_axes_val = GetConstIntValues(B_axes, "B_axes"); - Array output_shape; + Array output_shape; for (unsigned i = 0; i < A->shape.size(); ++i) if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end()) output_shape.push_back(A->shape[i]); @@ -1149,7 +1150,7 @@ inline Tensor tensordot(const Tensor& A, [&A, &B, &iter_vars, A_axes_val, B_axes_val] (const Array& input_indices) { int idx_input = 0; - Array A_indices; + Array A_indices; for (unsigned i = 0; i < A->shape.size(); ++i) { auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); if (axes_pos == A_axes_val.end()) @@ -1158,7 +1159,7 @@ inline Tensor tensordot(const Tensor& A, A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); } - Array B_indices; + Array B_indices; for (unsigned i = 0; i < B->shape.size(); ++i) { auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); if (axes_pos == B_axes_val.end()) @@ -1171,15 +1172,15 @@ inline Tensor tensordot(const Tensor& A, return compute(output_shape, func, name, tag); } -inline Tensor arange(const Expr& start, - const Expr& stop, - const Expr& step, +inline Tensor arange(const PrimExpr& start, + const PrimExpr& stop, + const PrimExpr& step, DataType dtype, std::string name = "T_arange", std::string tag = kInjective) { - Expr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil( + PrimExpr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil( tvm::cast(tvm::DataType::Float(32), stop - start) / step)); - Array shape; + Array shape; return compute({num_elem}, [&](const Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, tag); @@ -1213,12 +1214,12 @@ inline Tensor layout_transform(const Tensor& src, CHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; - Array dst_shape = layout_converter.ForwardShape(src->shape); + Array dst_shape = layout_converter.ForwardShape(src->shape); return compute( dst_shape, [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); return src(src_indices); }, name, tag); } @@ -1236,10 +1237,10 @@ inline Tensor shape(const Tensor& src, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); - Array out_shape{ndim}; + Array out_shape{ndim}; return compute(out_shape, [&](const Array& indices) { auto idx = indices[0]; - Expr ret = 0; + PrimExpr ret = 0; for (int i = 0; i < ndim; ++i) { ret = tvm::if_then_else(idx == i, src->shape[i], ret); } @@ -1260,9 +1261,9 @@ inline Tensor ndarray_size(const Tensor& src, const std::string& name = "ndarray_size", const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); - Array out_ndarray_size = {1}; + Array out_ndarray_size = {1}; return compute(out_ndarray_size, [&](const Array& indices) { - Expr ret = 1; + PrimExpr ret = 1; for (int i = 0; i < ndim; ++i) { ret *= src->shape[i]; } @@ -1284,14 +1285,14 @@ inline Tensor ndarray_size(const Tensor& src, * \return one-hot tensor. */ inline Tensor one_hot(const Tensor& indices, - const Expr on_value, - const Expr off_value, + const PrimExpr on_value, + const PrimExpr off_value, int depth, int axis, const DataType& dtype, const std::string name = "T_one_hot", const std::string tag = kInjective) { - Array oshape; + Array oshape; int ndim = indices->shape.size() + 1; int indices_index = 0; int true_axis = (axis == -1) ? indices->shape.size() : axis; @@ -1303,8 +1304,8 @@ inline Tensor one_hot(const Tensor& indices, } } - Expr on_value_cast = cast(dtype, on_value); - Expr off_value_cast = cast(dtype, off_value); + PrimExpr on_value_cast = cast(dtype, on_value); + PrimExpr off_value_cast = cast(dtype, off_value); return compute(oshape, [&](const Array& iter_vars) { Array indices_indices; for (size_t i = 0; i < iter_vars.size(); i++) { diff --git a/topi/include/topi/vision/reorg.h b/topi/include/topi/vision/reorg.h index 08dd13605b40..df3fadef7d75 100644 --- a/topi/include/topi/vision/reorg.h +++ b/topi/include/topi/vision/reorg.h @@ -74,7 +74,7 @@ inline Tensor reorg(const Tensor &data, int out_h = h_in / stride; int out_w = w_in / stride; - Array out_shape = {batch, out_c, out_h, out_w}; + Array out_shape = {batch, out_c, out_h, out_w}; return reshape(out, out_shape); } } // namespace vision diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py index acffcb533d39..13f8e720288b 100644 --- a/topi/python/topi/nn/pad.py +++ b/topi/python/topi/nn/pad.py @@ -57,7 +57,7 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): out_shape = tuple( tvm.ir_pass.Simplify( (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n)) - pad_value = (pad_value if isinstance(pad_value, tvm.expr.Expr) + pad_value = (pad_value if isinstance(pad_value, tvm.expr.PrimExpr) else tvm.const(pad_value, data.dtype)) def _pad(*indices): not_zero = [] diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 079dda5d0b0e..8f32a297d719 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -198,7 +198,7 @@ def simplify(expr): out : Expr or int The simplified output """ - return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr + return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.PrimExpr) else expr def ravel_index(indices, shape): diff --git a/topi/src/topi.cc b/topi/src/topi.cc index eeb1249ee7f2..e9c9bc07b3cc 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -112,14 +112,14 @@ TVM_REGISTER_GLOBAL("topi.TEST_create_target") *rv = Op(args[0].operator tvm::Tensor(), \ args[1].operator tvm::Tensor()); \ } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::Expr(), \ + *rv = Op(args[0].operator tvm::PrimExpr(), \ args[1].operator tvm::Tensor()); \ } else if (lhs_is_tensor && !rhs_is_tensor) { \ *rv = Op(args[0].operator tvm::Tensor(), \ - args[1].operator tvm::Expr()); \ + args[1].operator tvm::PrimExpr()); \ } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::Expr(), \ - args[1].operator tvm::Expr()); \ + *rv = Op(args[0].operator tvm::PrimExpr(), \ + args[1].operator tvm::PrimExpr()); \ } \ }); \ @@ -433,7 +433,7 @@ TVM_REGISTER_GLOBAL("topi.tensordot") } else if (args.size() == 3) { *rv = tensordot(args[0], args[1], args[2]); } else { - Array axes = args[3]; + Array axes = args[3]; *rv = tensordot(args[0], args[1], args[2], axes); } }); diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index 8cf5efe9457b..4361f8fb675c 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -58,8 +58,8 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, B = (tvm.var("B", dtype=dtype) if rhs_shape is None else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype)) C = ftopi(A, B) - if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr): - assert(isinstance(C, tvm.expr.Expr)) + if isinstance(A, tvm.expr.PrimExpr) and isinstance(B, tvm.expr.PrimExpr): + assert(isinstance(C, tvm.expr.PrimExpr)) return def gen_operand(shape, low, high, ctx): @@ -241,8 +241,8 @@ def test_apply( # Build the logic and compile the function A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype) B = func(A) - if isinstance(A, tvm.expr.Expr): - assert (isinstance(B, tvm.expr.Expr)) + if isinstance(A, tvm.expr.PrimExpr): + assert (isinstance(B, tvm.expr.PrimExpr)) return def check_device(device): @@ -283,8 +283,8 @@ def test_apply( A = (tvm.var("A", dtype=dtype)) B = (tvm.var("B", dtype=dtype)) C = func(A, B) - if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr): - assert (isinstance(C, tvm.expr.Expr)) + if isinstance(A, tvm.expr.PrimExpr) and isinstance(B, tvm.expr.PrimExpr): + assert (isinstance(C, tvm.expr.PrimExpr)) return def check_device(device):