Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 1: Add cache_read/cache_write steps (apac…
Browse files Browse the repository at this point in the history
…he#6107)

* Add cache_read/cache_write step

* Update

* Update

* Update

* Update state->current_compute_dag to Optional

* Update

* Update doc

* Update

* Update

* Doc update

* Update
  • Loading branch information
jcf94 authored and Trevor Morris committed Sep 2, 2020
1 parent 87c874e commit 02a14a0
Show file tree
Hide file tree
Showing 10 changed files with 1,043 additions and 113 deletions.
11 changes: 11 additions & 0 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,17 @@ class ComputeDAG : public ObjectRef {
*/
State InferBound(const State& state) const;

/*!
* \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
* ComputeDAG may not be up-to-date. This function replays the given transform steps from the
* initial state and returns an up-to-date ComputeDAG.
* \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up
* the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage
* structure.
* \return The up-to-date ComputeDAG.
*/
ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;

TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};
Expand Down
82 changes: 58 additions & 24 deletions include/tvm/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,18 @@ class AttachMap : public ObjectRef {
public:
/*!
* \brief Process the stage/iterator mapping after compute at.
* \param stage_id The index of the stage to be compute at.
* \param stage_id The index of the stage to be computed at.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter_id The index of iterator in target stage that this step will compute at to.
*/
void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);

/*!
* \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
* \param stage_id The index of the stage to be compute at.
* \param stage_id The index of the stage to be computed at.
*/
void DeleteStage(int stage_id);

/*!
* \brief Find the relations of original iterators in AttachMap, and update them with the new
* iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
Expand All @@ -201,6 +203,17 @@ class AttachMap : public ObjectRef {
void UpdateIters(const std::vector<IterKey>& original_iters,
const std::vector<IterKey>& new_iters);

/*!
* \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
* to stage indexes that are larger than the start_id. Used for steps that insert new stages to
* ComputeDAG(e.g. CacheRead/CacheWrite step).
* \param start_id The index threshold, stage indexes in AttachMap which are larger than this
* will be applied the extra offset.
* \param offset The index offset to be added to the stage index.
* \return The updated AttachMap after applying stage index offset.
*/
AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;

TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);

Expand Down Expand Up @@ -231,6 +244,12 @@ class StateNode : public Object {
* operation.
*/
AttachMap attach_map;
/*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means
* no modification to the original ComputeDAG.
* Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the
* ComputeDAG, the stored value is the up-to-date ComputeDAG for this state.
*/
Optional<ObjectRef> current_compute_dag;
/*!
* \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
* tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
Expand All @@ -245,15 +264,6 @@ class StateNode : public Object {

static constexpr const char* _type_key = "auto_scheduler.State";
TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);

private:
/*!
* \brief The up-to-date ComputeDAG of this state, used for some steps that may change the
* stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added
* later).
* The default value is an empty ObjectRef. (means no modification to the original DAG)
*/
ObjectRef current_compute_dag;
};

/*!
Expand Down Expand Up @@ -290,7 +300,7 @@ class State : public ObjectRef {
/********** Step APIs working on single stage **********/

/*!
* \brief Schedule primitive corresponds to te.bind.
* \brief Schedule primitive corresponds to `te::Stage::bind`.
* \param stage_id The index of the stage to be binded.
* \param it The iterator to be binded.
* \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
Expand All @@ -299,14 +309,14 @@ class State : public ObjectRef {
*/
TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
/*!
* \brief Schedule primitive corresponds to te.parallel.
* \brief Schedule primitive corresponds to `te::Stage::parallel`.
* \param stage_id The index of the stage to be paralleled.
* \param it The iterator to be paralleled.
* \return The iterator result after parallel.
*/
TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
/*!
* \brief Schedule primitive corresponds to te.unroll.
* \brief Schedule primitive corresponds to `te::Stage::unroll`.
* \param stage_id The index of the stage to be unrolled.
* \param it The iterator to be unrolled.
* \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
Expand All @@ -315,14 +325,14 @@ class State : public ObjectRef {
*/
TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
/*!
* \brief Schedule primitive corresponds to te.vectorize.
* \brief Schedule primitive corresponds to `te::Stage::vectorize`.
* \param stage_id The index of the stage to be vectorized.
* \param it The iterator to be vectorized.
* \return The iterator result after vectorize.
*/
TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
/*!
* \brief Schedule primitive corresponds to te.fuse.
* \brief Schedule primitive corresponds to `te::Stage::fuse`.
* \param stage_id The index of the stage to be fused.
* \param iters The iterators to be fused.
* \return The iterator result after fuse.
Expand All @@ -331,13 +341,13 @@ class State : public ObjectRef {
*/
TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
* \brief Schedule primitive corresponds to te.reorder.
* \brief Schedule primitive corresponds to `te::Stage::reorder`.
* \param stage_id The index of the stage to be reordered.
* \param order The expected iterator order.
*/
TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
/*!
* \brief Schedule primitive corresponds to te.split.
* \brief Schedule primitive corresponds to `te::Stage::split`.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
Expand All @@ -353,8 +363,8 @@ class State : public ObjectRef {
/********** Step APIs working on multiple stages **********/

/*!
* \brief Schedule primitive corresponds to te.compute_at.
* \param stage_id The index of the stage to be reordered.
* \brief Schedule primitive corresponds to `te::Stage::compute_at`.
* \param stage_id The index of the stage to be computed at.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter The iterator in target stage that this step will compute at to.
* \note After compute_at, we need careful dependency analysis to compute the accurate bound
Expand All @@ -364,20 +374,44 @@ class State : public ObjectRef {
*/
TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
/*!
* \brief Schedule primitive corresponds to te.compute_inline.
* \param stage_id The index of the stage to be reordered.
* \brief Schedule primitive corresponds to `te::Stage::compute_inline`.
* \param stage_id The index of the stage to be marked compute inlined.
*/
TVM_DLL void compute_inline(int stage_id);
/*!
* \brief Schedule primitive corresponds to te.compute_root.
* \param stage_id The index of the stage to be reordered.
* \brief Schedule primitive corresponds to `te::Stage::compute_root`.
* \param stage_id The index of the stage to be marked compute at root.
* \note After compute_root, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
* Call ComputeDAG::InferBound on the updated state to get the complete bound information.
*/
TVM_DLL void compute_root(int stage_id);

/********** Step APIs adding new stages **********/

/*!
* \brief Schedule primitive corresponds to `te::Schedule::cache_read`.
* \param stage_id The index of the stage to be cache read.
* \param scope_name The scope name of the newly added read stage.
* \param reader_stage_ids The indices of read stages.
* \param dag The original ComputeDAG of this state.
* \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
* \param stage_id The index of the stage to be cache write.
* \param scope_name The scope name of the newly added compute stage.
* \param dag The original ComputeDAG of this state.
* \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
* This step will cache write all output tensors of the target stage.
*/
int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);

TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
};
Expand Down
Loading

0 comments on commit 02a14a0

Please sign in to comment.