Skip to content

Commit

Permalink
[MetaSchedule] Refactor MultiLevelTiling state to allow subclassing (a…
Browse files Browse the repository at this point in the history
…pache#11931)

This PR made `State` in `MultiLevelTiling` inherit `Object`, to allow future subclassing of `State`. Making `State` an `Object` allows instances of `State` and its subclasses to be stored in `std::vector<State>`.
  • Loading branch information
vinx13 authored and blackkker committed Jul 7, 2022
1 parent df71540 commit f86b0c3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 37 deletions.
70 changes: 39 additions & 31 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ using tir::IterVarType;
using tir::LoopRV;
using tir::Schedule;

State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
ObjectPtr<StateNode> node = make_object<StateNode>();
node->sch = std::move(sch);
node->block_rv = std::move(block_rv);
node->tiles = std::move(tiles);
data_ = std::move(node);
}

State StateNode::Copy() const {
ObjectPtr<StateNode> node = make_object<StateNode>(*this);
node->sch = sch->Copy();
return State(node);
}

// Do nothing; Inherited from ScheduleRuleNode
void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
Expand All @@ -82,15 +96,15 @@ Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV&

Array<Schedule> results;
for (auto&& state : ApplySubRules({State(sch, block_rv)})) {
results.push_back(std::move(state.sch));
results.push_back(std::move(state->sch));
}
return results;
}

std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states) {
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); });
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
return states;
}

Expand All @@ -102,53 +116,49 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
std::vector<int> levels = config.levels;
ReuseType req = config.req;
if (Optional<Array<Integer>> ann = tir::GetAnn<Array<Integer>>(
state.sch->GetSRef(state.block_rv), "meta_schedule.write_cache_level")) {
state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) {
req = ReuseType::kMustReuse;
levels = std::vector<int>(ann.value().begin(), ann.value().end());
}
std::vector<State> results;
if (req == ReuseType::kMayReuse) {
// Case 1. If the write cache is already there, we don't need to add another.
Array<BlockRV> consumer_rvs = state.sch->GetConsumers(state.block_rv);
if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) {
Array<BlockRV> consumer_rvs = state->sch->GetConsumers(state->block_rv);
if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) {
for (int level : levels) {
State new_state = state;
new_state.sch = state.sch->Copy();
new_state.sch->Seed(state.sch->ForkSeed());
const LoopRV& loop_rv = new_state.tiles[level - 1].back();
new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
State new_state = state->Copy();
const LoopRV& loop_rv = new_state->tiles[level - 1].back();
new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
results.push_back(std::move(new_state));
}
results.push_back(state);
return results;
} else {
// Case 2. No write cache is added
State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv);
new_state.sch->Seed(state.sch->ForkSeed());
State new_state = state->Copy();
results.emplace_back(std::move(new_state));
}
}

// Case 3. Add one write cache
BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0,
/*storage_scope=*/config.scope);
BlockRV write_cache =
state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
/*storage_scope=*/config.scope);
for (int level : levels) {
State new_state = state;
new_state.sch = state.sch->Copy();
new_state.sch->Seed(state.sch->ForkSeed());
const LoopRV& loop_rv = new_state.tiles[level - 1].back();
new_state.sch->ReverseComputeAt(write_cache, loop_rv, true);
State new_state = state->Copy();
const LoopRV& loop_rv = new_state->tiles[level - 1].back();
new_state->sch->ReverseComputeAt(write_cache, loop_rv, true);
results.push_back(std::move(new_state));
}
return results;
}

std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
Schedule& sch = state.sch;
const BlockRV& block_rv = state.block_rv;
Schedule& sch = state->sch;
const BlockRV& block_rv = state->block_rv;
// Step 1. Assuming trivial binding, pair the loops and their iter-var-types
Array<LoopRV> loops = sch->GetLoops(block_rv);
std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv));
std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv));
ICHECK_EQ(loops.size(), iter_types.size());
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
Expand Down Expand Up @@ -192,7 +202,7 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
sch->Bind(fused, tile_binds[i]);
tiles[i] = {fused};
}
state.tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
state->tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
if (this->thread_warp_size_ != -1) {
int64_t low_inclusive = 1;
int64_t high_inclusive = this->max_threads_per_block_;
Expand All @@ -213,13 +223,13 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
return {std::move(state)};
}
ICHECK(config.req != ReuseType::kMayReuse);
const BlockRV& block_rv = state.block_rv;
const BlockRV& block_rv = state->block_rv;
std::vector<State> results;
results.reserve(config.levels.size());
for (int level : config.levels) {
Schedule sch = state.sch->Copy();
sch->Seed(state.sch->ForkSeed());
const LoopRV& loop_rv = state.tiles[level - 1].back();
State new_state = state->Copy();
Schedule& sch = new_state->sch;
const LoopRV& loop_rv = state->tiles[level - 1].back();
// Enumerate all buffers that are read but not written
std::vector<int> read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv));
for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) {
Expand All @@ -246,8 +256,6 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
vector_load_len);
}
}
State new_state = state;
new_state.sch = sch;
results.push_back(std::move(new_state));
}
return results;
Expand Down
25 changes: 20 additions & 5 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,35 @@ struct ReuseConfig {
}
};

// Forware declaration
class State;

/*! \brief The state of auto scheduling for the multi-level tiling rule */
struct State {
class StateNode : public Object {
public:
/*! \brief The schedule to date */
tir::Schedule sch;
/*! \brief The block to be tiled */
tir::BlockRV block_rv;
/*! \brief The loop tiles */
Array<Array<tir::LoopRV>> tiles;

/*!
* \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that
* produce multiple states should use this method to create new states.
*/
virtual State Copy() const;

static constexpr const char* _type_key = "meta_schedule.State";
TVM_DECLARE_BASE_OBJECT_INFO(StateNode, Object);
};

/*! \brief Managed reference to StateNode */
class State : public ObjectRef {
public:
/*! \brief Default constructor */
explicit State(tir::Schedule sch, tir::BlockRV block_rv,
Optional<tir::BlockRV> write_cache = NullOpt, bool write_cache_is_added = false,
Array<Array<tir::LoopRV>> tiles = {})
: sch(sch), block_rv(block_rv), tiles(tiles) {}
explicit State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
};

/*!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
// tile the outerloops.
virtual std::vector<State> ApplySubRules(std::vector<State> states) {
states = SubRule(std::move(states), [&](State state) {
state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name);
state->block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name);
return std::vector<State>(1, state);
});
return MultiLevelTilingNode::ApplySubRules(states);
Expand Down

0 comments on commit f86b0c3

Please sign in to comment.