Skip to content

Commit

Permalink
[REFACTOR][IR] Add Node suffix to low-level IR nodes (apache#4649)
Browse files Browse the repository at this point in the history
* [REFACTOR][IR] Variable -> VarNode

* [REFACTOR][IR] Add/Sub/Mul/Div -> AddNode/SubNode etc.

* [REFACTOR][IR] Min/Max/FloorDiv/FloorMod -> MinNode/MaxNode etc.

* [REFACTOR][IR] EQ/NE/LT/LE/GT/GE/Select -> EQNode/NENode etc.

* [REFACTOR][IR] Add Node suffix to Select/Call/Load/Ramp/Shuffle/Let

* [REFACTOR][IR] Add node suffix to IntImm/UIntImm/FloatImm/StringImm

* [REFACTOR][IR] Add Node suffix to Any, AttrStmt, AssertStmt

* [REFACTOR][IR] Add Node suffix to Store/Provide/Allocate/Free

* [REFACTOR][IR] Add Node suffix to ProducerConsumer

* Fix lint

* style updates, test fixes
  • Loading branch information
tqchen authored and zhiics committed Mar 2, 2020
1 parent 17a985c commit fbd1dde
Show file tree
Hide file tree
Showing 194 changed files with 3,961 additions and 3,919 deletions.
12 changes: 6 additions & 6 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ IntSet EvalSet(Expr e,
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
* \brief Find an symbolic integer set that contains is union over
Expand All @@ -586,7 +586,7 @@ IntSet EvalSet(Range r,
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(IntSet s,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
Expand All @@ -595,7 +595,7 @@ IntSet EvalSet(IntSet s,
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
Expand All @@ -609,7 +609,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
*/
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
* \brief Create an union set of all sets
Expand Down Expand Up @@ -654,8 +654,8 @@ IntSet DeduceBound(Expr v, Expr cond,
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map);

/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
} else {
Expr expr = val;
CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
Expand All @@ -503,7 +503,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
*ptr = val.operator std::string();
} else {
Expr expr = val;
const ir::StringImm* op = expr.as<ir::StringImm>();
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
}
Expand All @@ -519,11 +519,11 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
} else {
Expr expr = val;
CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
Expand Down
22 changes: 11 additions & 11 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Var;
* - Let
* - LetStmt
*/
class Variable : public ExprNode {
class VarNode : public ExprNode {
public:
/*!
* \brief The hint to the variable name.
Expand All @@ -118,7 +118,7 @@ class Variable : public ExprNode {
}

static constexpr const char* _type_key = "Variable";
TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};

/*! \brief a named variable in TVM */
Expand All @@ -139,18 +139,18 @@ class Var : public Expr {
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* operator->() const {
const VarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* get() const {
return static_cast<const Variable*>(data_.get());
const VarNode* get() const {
return static_cast<const VarNode*>(data_.get());
}
/*! \brief type indicate the container type */
using ContainerType = Variable;
using ContainerType = VarNode;
};

// Backward compatibility, will be removed later.
Expand All @@ -161,7 +161,7 @@ using ExprEqual = ObjectEqual;

class Integer;
/*! \brief ExprNode: constant integer. */
class IntImm : public ExprNode {
class IntImmNode : public ExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
Expand All @@ -174,7 +174,7 @@ class IntImm : public ExprNode {
TVM_DLL static Integer make(DataType t, int64_t value);

static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode);
};

/*!
Expand Down Expand Up @@ -206,8 +206,8 @@ class Integer : public Expr {
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImm* operator->() const {
return static_cast<const IntImm*>(get());
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*!
* \brief convert to int64_t
Expand All @@ -218,7 +218,7 @@ class Integer : public Expr {
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImm;
using ContainerType = IntImmNode;
};

/*! \brief range over one dimension */
Expand Down
40 changes: 20 additions & 20 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ inline Expr const_false(int lanes = 1) {
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::IntImm* op = x.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = x.as<ir::IntImmNode>()) {
return &(op->value);
} else {
return nullptr;
Expand All @@ -90,7 +90,7 @@ inline const int64_t* as_const_int(const Expr& x) {
*/
inline const uint64_t* as_const_uint(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImm* op = x.as<ir::UIntImm>()) {
if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
return &(op->value);
} else {
return nullptr;
Expand Down Expand Up @@ -600,7 +600,7 @@ TVM_DLL Expr trunc(Expr x);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \
return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \
} \

TVM_DECLARE_INTRIN_UNARY(exp);
Expand All @@ -617,45 +617,45 @@ TVM_DECLARE_INTRIN_UNARY(atan);

// Implementation details after this
inline bool is_const(const Expr& x) {
if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
return true;
} else if (const auto* op = x.as<ir::Broadcast>()) {
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value;
if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) {
if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
return true;
}
}
return false;
}

inline bool is_positive_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0;
} else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) {
} else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
return op->value > 0;
} else {
return false;
}
}

inline bool is_negative_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value < 0;
} else {
return false;
}
}

inline bool is_const_int(const Expr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImm>()) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImm>()) {
} else if (const auto* op = x.as<ir::UIntImmNode>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::Broadcast>()) {
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value;
if (const auto* opv = val.as<ir::IntImm>()) {
if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImm>()) {
} else if (const auto* opv = val.as<ir::UIntImmNode>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
Expand All @@ -664,7 +664,7 @@ inline bool is_const_int(const Expr& x, int64_t value) {

inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) {
if (const auto* op = stmt.as<ir::EvaluateNode>()) {
return is_const(op->value);
}
if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
Expand All @@ -675,15 +675,15 @@ inline bool is_no_op(const Stmt& stmt) {

template<typename ValueType>
inline Expr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
return ir::FloatImm::make(t, static_cast<double>(value));
return ir::FloatImmNode::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t;
return Expr();
}
Expand All @@ -693,7 +693,7 @@ inline Expr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
return ir::Broadcast::make(
return ir::BroadcastNode::make(
MakeConstScalar(t.element_of(), value), t.lanes());
}
}
Expand Down
Loading

0 comments on commit fbd1dde

Please sign in to comment.