Skip to content

Commit

Permalink
[REFACTOR][IR] tvm::Expr -> PrimExpr(Primitive Expr) (apache#4669)
Browse files Browse the repository at this point in the history
* [REFACTOR][IR] tvm::Expr -> PrimExpr(Primitive Expr)

As part of unified IR, we will need to unify relay::Expr
and the current tvm::Expr under the same base type.

From the techinical point of view. tvm::Expr is a "primitive"
expression that only contains POD types and handles and does
not do life-cycle management.

This PR renames Expr->PrimExpr to clarify that.
We will send a subsequent PR to introduce the base expr class.

* Remove legacy VarExpr and ExprHash/Equal
  • Loading branch information
tqchen authored and alexwong committed Feb 28, 2020
1 parent 444ee5e commit bb4554f
Show file tree
Hide file tree
Showing 207 changed files with 2,670 additions and 2,650 deletions.
62 changes: 31 additions & 31 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ConstIntBoundAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const Expr& expr);
ConstIntBound operator()(const PrimExpr& expr);

/*!
* \brief Update constant int bound information of var.
Expand Down Expand Up @@ -136,7 +136,7 @@ class ConstIntBoundAnalyzer {
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Expand Down Expand Up @@ -192,7 +192,7 @@ class ModularSetAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ModularSet operator()(const Expr& expr);
ModularSet operator()(const PrimExpr& expr);
/*!
* \brief Update constant int bound information of var.
*
Expand All @@ -215,7 +215,7 @@ class ModularSetAnalyzer {
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Expand All @@ -232,7 +232,7 @@ class RewriteSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
PrimExpr operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
Expand All @@ -242,10 +242,10 @@ class RewriteSimplifier {
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
const PrimExpr& new_expr,
bool override = false);

std::function<void()> EnterConstraint(const Expr& constraint);
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
Expand All @@ -268,7 +268,7 @@ class CanonicalSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
PrimExpr operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
Expand All @@ -278,7 +278,7 @@ class CanonicalSimplifier {
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
const PrimExpr& new_expr,
bool override = false);

private:
Expand Down Expand Up @@ -316,7 +316,7 @@ class ConstraintContext {
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, Expr constraint)
ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
// enter the scope.
void EnterWithScope();
Expand All @@ -325,7 +325,7 @@ class ConstraintContext {
/*! \brief The analyzer */
Analyzer* analyzer_;
/*! \brief The constraint */
Expr constraint_;
PrimExpr constraint_;
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};
Expand Down Expand Up @@ -375,9 +375,9 @@ class IntSet : public ObjectRef {
*/
Range cover_range(Range max_range) const;
/*! \return Lower bound of the set */
Expr min() const;
PrimExpr min() const;
/*! \return upper bound of the set */
Expr max() const;
PrimExpr max() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
/*! \return Whether the set represent everything */
Expand All @@ -398,7 +398,7 @@ class IntSet : public ObjectRef {
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr point_value() const;
PrimExpr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
Expand All @@ -415,13 +415,13 @@ class IntSet : public ObjectRef {
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet single_point(Expr point);
static IntSet single_point(PrimExpr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(Expr vec);
static IntSet vector(PrimExpr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
Expand All @@ -434,7 +434,7 @@ class IntSet : public ObjectRef {
* \param max The maximum value of the interval.
* \return constructed set.
*/
static IntSet interval(Expr min, Expr max);
static IntSet interval(PrimExpr min, PrimExpr max);
};

/*!
Expand All @@ -450,7 +450,7 @@ class IntSetAnalyzer {
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);

private:
friend class Analyzer;
Expand Down Expand Up @@ -499,7 +499,7 @@ class Analyzer {
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
void Bind(const Var& var, const PrimExpr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
Expand All @@ -509,7 +509,7 @@ class Analyzer {
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
void Bind(const Var& var, const Range& range);
/*!
* \brief Whether can we prove expr >= val.
Expand All @@ -522,7 +522,7 @@ class Analyzer {
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove condition.
*
Expand All @@ -531,7 +531,7 @@ class Analyzer {
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProve(const Expr& cond);
bool CanProve(const PrimExpr& cond);
/*!
* \brief Simplify expr.
*
Expand All @@ -540,7 +540,7 @@ class Analyzer {
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
Expr Simplify(const Expr& expr);
PrimExpr Simplify(const PrimExpr& expr);
};

//-----------------------------------------------
Expand All @@ -554,7 +554,7 @@ class Analyzer {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
IntSet EvalSet(PrimExpr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
Expand All @@ -563,7 +563,7 @@ IntSet EvalSet(Expr e,
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
IntSet EvalSet(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
Expand Down Expand Up @@ -598,7 +598,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand All @@ -608,7 +608,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
* \return the map from the expression to its possible value.
*/
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
Expand Down Expand Up @@ -640,7 +640,7 @@ IntSet Intersect(const Array<IntSet>& sets);
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
/*!
Expand All @@ -653,7 +653,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map);

Expand All @@ -676,7 +676,7 @@ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e,
Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
const Array<Var>& vars);

/*!
Expand All @@ -687,7 +687,7 @@ Array<Expr> DetectLinearEquation(const Expr& e,
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<Expr> DetectClipBound(const Expr& e,
Array<PrimExpr> DetectClipBound(const PrimExpr& e,
const Array<Var>& vars);

// implementation
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
} else {
Expr expr = val;
PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
Expand All @@ -502,7 +502,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
*ptr = val.operator std::string();
} else {
Expr expr = val;
PrimExpr expr = val;
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
Expand All @@ -517,7 +517,7 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
} else {
Expr expr = val;
PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,30 +66,30 @@ class Buffer : public ObjectRef {
* If stride is not needed in the slice, it won't be presented
* \return the result buffer.
*/
TVM_DLL Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
/*!
* \brief Get access ptr to the entire buffer.
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
TVM_DLL Expr access_ptr(int access_mask,
TVM_DLL PrimExpr access_ptr(int access_mask,
DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
Expr offset = make_const(DataType::Int(32), 0)) const;
PrimExpr offset = make_const(DataType::Int(32), 0)) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
TVM_DLL Expr vload(Array<Expr> begin, DataType dtype) const;
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
*/
TVM_DLL Stmt vstore(Array<Expr> begin, Expr value) const;
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand All @@ -112,14 +112,14 @@ class BufferNode : public Object {
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The shape of the buffer */
Array<Expr> shape;
Array<PrimExpr> shape;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
Array<PrimExpr> strides;
/*! \brief The offset in terms of number of dtype elements (including lanes) */
Expr elem_offset;
PrimExpr elem_offset;
// Meta data
/*! \brief optional name of the buffer */
std::string name;
Expand Down Expand Up @@ -159,9 +159,9 @@ class BufferNode : public Object {
// A default value will be picked.
TVM_DLL static Buffer make(Var ptr,
DataType dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr elem_offset,
Array<PrimExpr> shape,
Array<PrimExpr> strides,
PrimExpr elem_offset,
std::string name,
std::string scope,
int data_alignment,
Expand All @@ -184,7 +184,7 @@ inline const BufferNode* Buffer::operator->() const {
* \return The created buffer.
* \sa BufferNode::make for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<Expr> shape,
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape,
DataType dtype = DataType::Float(32),
std::string name = "buffer");
} // namespace tvm
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class TargetNode : public Object {
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<Expr> keys_array;
Array<PrimExpr> keys_array;
/*! \brief Options for this target */
Array<Expr> options_array;
Array<PrimExpr> options_array;
/*! \brief Collection of imported libs */
Array<Expr> libs_array;
Array<PrimExpr> libs_array;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
Expand Down
Loading

0 comments on commit bb4554f

Please sign in to comment.