Skip to content

Commit

Permalink
[TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes.
Browse files Browse the repository at this point in the history
This is needed to migrate passes before StorageFlatten to the PassManager.
te::Tensor is an useful object for tensor expression.
However, making it part of the TIR makes the abstraction less clean.

This PR is a first step to remove this dependency. In particular,
we will use Buffer in all the places where the te::Tensor is being used.
We can not use IRModule of PrimFunc to represent TIR at any point of the
optimizations.

- The TIR passes will no longer need to consider Provide/HalideCall and Realize
- Instead, we will use the BufferLoad/Store and BufferRealize.

Buffer will serve as the only abstraction for the TIR data models to represent
the intermediate storage among computations.
We still keep Realize/HalideCall and Provide as TIR nodes for now to make the change minimum.
Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize the temporary IR
generated by TE(which contains these nodes) to the TIR.

The TIR optimizations are now mostly migrated to to the pass manager.
  • Loading branch information
tqchen committed Apr 18, 2020
1 parent 3264895 commit c6d7c3c
Show file tree
Hide file tree
Showing 40 changed files with 932 additions and 421 deletions.
14 changes: 7 additions & 7 deletions include/tvm/arith/bound.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
* \param body The given statement.
* \param tensor The name of the calls or provides.
* \param consider_calls If calls (read) are considered.
* \param consider_provides If provides (write) are considered.
* \param buffer The buffer to check the access info.
* \param consider_loads If loads are considered.
* \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body,
const te::Tensor &tensor,
bool consider_calls,
bool consider_provides);
Domain DomainTouched(const Stmt& body,
const tir::Buffer& buffer,
bool consider_loads,
bool consider_stores);

} // namespace arith
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ObjAllocatorBase {
static_assert(std::is_base_of<Object, T>::value,
"make can only be used to create Object");
T* ptr = Handler::New(static_cast<Derived*>(this),
std::forward<Args>(args)...);
std::forward<Args>(args)...);
ptr->type_index_ = T::RuntimeTypeIndex();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/te/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#define TVM_TE_SCHEDULE_PASS_H_

#include <tvm/te/schedule.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace te {
Expand All @@ -54,6 +55,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);

/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
*
* Perform this translation before running any TIR optimizations.
*
* List of actions taken by the function:
* - Remove occurences of te::Tensor, te::Operation in the IR
* and replace them by corresponding IR nodes via tir::Buffer.
* - Add annotation of extern buffers using the buffer_map field
* in the PrimFunc type.
*
* \param arg_list Array of Tensor/Var/Buffer arguments to the function.
* \param body The body of the function.
* \param bindings potential Tensor to Buffer bindings for the Tensors in the body.
*/
PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> bindings);

/*!
* \brief To automatically inline the element-wise operations.
*
Expand Down
15 changes: 12 additions & 3 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,10 @@ class CallNode : public PrimExprNode {
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*! \brief Halide-style call, evaluates func(args). */
/*!
* \brief Halide-style call, evaluates func(args).
* \note Deprecated, move to BufferLoad in the future.
*/
Halide = 3,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
Expand All @@ -707,9 +710,15 @@ class CallNode : public PrimExprNode {
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;
/*! \brief The function to be called. */
/*!
* \brief The function to be called.
* \note Deprecated, move to BufferLoad in the future.
*/
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
/*!
* \brief The output value index if func's value is a tuple.
* \note Deprecated, move to BufferLoad in the future.
*/
int value_index{0};

void VisitAttrs(AttrVisitor* v) {
Expand Down
23 changes: 0 additions & 23 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
PrimExpr body);

/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<te::Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);

/*!
* \brief Try to modify the AST to support TensorCore
*
Expand All @@ -202,13 +186,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
*/
bool VerifyCompactBuffer(Stmt stmt);

/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
*/
Stmt InjectPrefetch(Stmt stmt);

/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
Expand Down
119 changes: 98 additions & 21 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ class StoreNode : public StmtNode {
* \endcode
* \sa BufferLoad
*/
class BufferStore;
class BufferStoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Expand Down Expand Up @@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
};

/*!
* \brief Managed reference to BufferStoreNode.
* \sa BufferStoreNode
*/
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer,
Expand All @@ -289,8 +292,80 @@ class BufferStore : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
};

/*!
* \brief Annotate the region where the buffer need to
* be read and write in the body.
* We only need to allocate the space for the corresponding region.
*
* \note There should be at most one BufferRealize for each buffer.
* BufferRealize is not necessary for external buffers,
* since they are assumed to be fully allocated.
*
* \sa BufferLoad, BufferStore
*/
class BufferRealizeNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Buffer buffer;
/*! \brief Bounds to be realized */
Array<Range> bounds;
/*! \brief Only realize if condition holds. */
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
}

bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const {
return
equal(buffer, other->buffer) &&
equal(bounds, other->bounds) &&
equal(condition, other->condition) &&
equal(body, other->body);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
}

BufferRealizeNode() = default;
BufferRealizeNode(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body)
: buffer(buffer), bounds(bounds),
condition(condition), body(body) {}

static constexpr const char* _type_key = "BufferRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
};

/*!
* \brief Managed reference to BufferRealizeNode.
* \sa BufferRealizeNode
*/
class BufferRealize : public Stmt {
public:
TVM_DLL explicit BufferRealize(Buffer buffer,
Array<Range> bounds,
PrimExpr condition,
Stmt body);

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode);
};

/*!
* \brief Store value into mult-dimensional array defined by func.
*
* \note Deprecated, move to BufferStore in the future.
*/
class ProvideNode : public StmtNode {
public:
Expand Down Expand Up @@ -430,6 +505,8 @@ class FreeNode : public StmtNode {
/*!
* \brief Annotate the bounds where func need to be written and read in body.
* We will need to allocate space for the corresponding regions.
*
* \note Deprecated, move to BufferRealize in the future.
*/
class RealizeNode : public StmtNode {
public:
Expand Down Expand Up @@ -747,50 +824,50 @@ class ForNode : public StmtNode {
};

/*!
* \brief A prefetch hint of func.
* \brief A prefetch hint for abuffer
*/
class PrefetchNode : public StmtNode {
public:
/*! \brief The function to be prefetched. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
DataType dtype;
Buffer buffer;
/*! \brief Bounds to be prefetched. */
Region bounds;
Array<Range> bounds;

void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("dtype", &dtype);
v->Visit("buffer", &buffer);
v->Visit("bounds", &bounds);
}

bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(buffer, other->buffer) &&
equal(bounds, other->bounds);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(value_index);
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(bounds);
}

TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType dtype,
Region bounds);
PrefetchNode() = default;
PrefetchNode(Buffer buffer, Array<Range> bounds)
: buffer(buffer), bounds(bounds) {}

static constexpr const char* _type_key = "Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
};

/*!
* \brief Managed reference to PrefetchNode.
* \sa PrefetchNode
*/
class Prefetch : public Stmt {
public:
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds);

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};

/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -121,6 +122,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
return vtable;
}
};
Expand Down Expand Up @@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProvideNode* op) override;
Expand Down Expand Up @@ -248,6 +252,7 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override;
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const std::string& name,
const tvm::Array<runtime::String>& required);


/*!
* \brief Inject prefetch instructions into stmt.
*
* \return The pass.
*/
TVM_DLL Pass InjectPrefetch();

// TODO(tvm-team): consolidate configs to the PassContext
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
*
* \return The Pass
*/
TVM_DLL Pass StorageFlatten(int cache_line_size,
bool create_bound_attribute = false);

/*!
* \brief Inject copy intrinsics with optional pad.
*
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/autotvm/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def ana_lower(sch, args,
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True)
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt)
func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
mod = tvm.IRModule.from_expr(func._move())
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
mod = tvm.tir.transform.Simplify()(mod._move())
assert simple_mode
return stmt
return mod["main"].body

try:
_get_buffer_curve_sample_flatten = tvm._ffi.get_global_func(
Expand Down
Loading

0 comments on commit c6d7c3c

Please sign in to comment.