diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 693a668a158e..ab041cf4a43d 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -83,6 +83,24 @@ class State: ----- This is a wrapper class of StateObject to deal with copy-on-write property """ + + # Static trans table for thread bind + # This is used to transform the annotation name to C++ enum + ANNOTATION_TRANS_TABLE = { + "none": 0, + "unroll": 1, + "vectorize": 2, + "parallel": 3, + "vthread": 4, + "blockIdx.x": 5, + "threadIdx.x": 6, + "blockIdx.y": 7, + "threadIdx.y": 8, + "blockIdx.z": 9, + "threadIdx.z": 10, + "tensorize": 11 + } + def __init__(self, state_object, dag): self.state_object = state_object self.compute_dag = dag @@ -108,20 +126,140 @@ def stage_ops(self): """ return [stage.op for stage in self.stages] + def bind(self, stage, iterator, thread_name): + """ Schedule primitive corresponds to te.bind. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be binded, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to be binded. + thread_name : str + The thread type to be binded. Candidates: + - vthread + - blockIdx.x + - threadIdx.x + - blockIdx.y + - threadIdx.y + - blockIdx.z + - threadIdx.z + + Returns + ------- + res_it : Iterator + The binded Iterator. + """ + if not thread_name in State.ANNOTATION_TRANS_TABLE.keys(): + raise ValueError("Invalid thread_name: ", thread_name) + + self.state_object, res = _ffi_api.StateBind(self.state_object, + self._resolve_stage_id(stage), iterator, + State.ANNOTATION_TRANS_TABLE[thread_name]) + return res + + def parallel(self, stage, iterator): + """ Schedule primitive corresponds to te.parallel. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be paralleled, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to be paralleled. + + Returns + ------- + res_it : Iterator + The paralleled Iterator. + """ + self.state_object, res = _ffi_api.StateParallel(self.state_object, + self._resolve_stage_id(stage), iterator) + return res + + def unroll(self, stage, iterator, max_unroll=None): + """ Schedule primitive corresponds to te.unroll. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be unrolled, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to be unrolled. + max_unroll : Optional[int] + The max unroll limit. Iterator with extent larger than this limit will be skipped. + + Returns + ------- + res_it : Iterator + The unrolled Iterator. + """ + self.state_object, res = _ffi_api.StateUnroll(self.state_object, + self._resolve_stage_id(stage), iterator, + max_unroll if max_unroll else -1) + return res + + def vectorize(self, stage, iterator): + """ Schedule primitive corresponds to te.vectorize. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be vectorized, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to be vectorized. + + Returns + ------- + res_it : Iterator + The vectorized Iterator. + """ + self.state_object, res = _ffi_api.StateVectorize(self.state_object, + self._resolve_stage_id(stage), iterator) + return res + + def fuse(self, stage, iters): + """ Schedule primitive corresponds to te.fuse. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be fused, which can be specified by the integer index, Operation, + or output tensor of the stage. + iters : List[Iterator] + The iterators to be fused. + + Returns + ------- + res_it : Iterator + The fused Iterator. + + Notes + ----- + If the iterators to be fused have stages attached at them(by compute_at), the fused + result will become the new attach point. + """ + self.state_object, res = _ffi_api.StateFuse(self.state_object, + self._resolve_stage_id(stage), iters) + return res + def reorder(self, stage, order): """ Schedule primitive corresponds to te.reorder. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be reordered, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be reordered, which can be specified by the integer index, Operation, + or output tensor of the stage. order : List[Iterator] Iterators in the expected order. """ - stage_id = self._resolve_stage_id(stage) - - self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) + self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage), + order) def split(self, stage, iterator, lengths, inner_to_outer=True): """ Schedule primitive corresponds to te.split. @@ -132,8 +270,8 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be split, can be a Stage order index, Stage operation or stage - output tensor. + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to be split. lengths: List[int] @@ -144,34 +282,74 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): Returns ------- res_its : List[Iterator] - The splitted new Iterators - """ - stage_id = self._resolve_stage_id(stage) + The splitted new Iterators. - self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, - inner_to_outer) + Notes + ----- + If we do split on an iterator which has stages attached at it(by compute_at), the inner + most iterator of split results will become the new attach point. + """ + self.state_object, res = _ffi_api.StateSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, lengths, inner_to_outer) return res - def fuse(self, stage, iters): - """ Schedule primitive corresponds to te.fuse. + def compute_at(self, stage, target_stage, target_iter): + """ Schedule primitive corresponds to te.compute_at. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be fused, can be a Stage order index, Stage operation or stage - output tensor. - iters : List[Iterator] - The iterators to be fused + The Stage to be compute at, which can be specified by the integer index, Operation, + or output tensor of the stage. + target_stage : Union[int, Operation, Tensor] + The target stage of compute_at, which can be specified by the integer index, Operation, + or output tensor of the stage. + target_iter : Iterator + The target Iterator of compute_at. + + Notes + ----- + After compute_at, we need careful dependency analysis to compute the accurate bound + information. However, it is relatively expensive and complicated, so we just fill "None" + as bound for the newly created iterators. + Call ComputeDAG::InferBound on the returned state to get the complete bound information. + """ + self.state_object = _ffi_api.StateComputeAt(self.state_object, + self._resolve_stage_id(stage), + self._resolve_stage_id(target_stage), + target_iter) - Returns - ------- - res_it : Iterator - The fused Iterator + def compute_inline(self, stage): + """ Schedule primitive corresponds to te.compute_inline. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be compute inlined, which can be specified by the integer index, Operation, + or output tensor of the stage. """ - stage_id = self._resolve_stage_id(stage) + self.state_object = _ffi_api.StateComputeInline(self.state_object, + self._resolve_stage_id(stage)) - self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) - return res + def compute_root(self, stage): + """ Schedule primitive corresponds to te.compute_root. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be compute root, which can be specified by the integer index, Operation, + or output tensor of the stage. + + Notes + ----- + After compute_root, we need careful dependency analysis to compute the accurate bound + information. However, it is relatively expensive and complicated, so we just fill "None" + as bound for the newly created iterators. + Call ComputeDAG::InferBound on the returned state to get the complete bound information. + """ + self.state_object = _ffi_api.StateComputeRoot(self.state_object, + self._resolve_stage_id(stage)) def copy(self): """ Do deep copy of this State. """ diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index a7abcb8a7ebf..d81dff66d402 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -270,19 +270,9 @@ std::pair> ComputeDAG::ApplySteps( } // Apply the history steps to TVM schedule + // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { - // Call each step's ApplyToSchedule method - // Note: some steps have extra parameters that must be passed and they may need different - // return value, so the ApplyToSchedule is not able to be merged to single interface - if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else { - LOG(FATAL) << "Invalid Step"; - } + StepApplyToSchedule(step, stages, stage_to_axes); } return std::make_pair(schedule, operator->()->tensors); @@ -326,15 +316,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); - } else if (auto ps = step.as()) { - ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes); - } else { - LOG(FATAL) << "Invalid Step"; - } + ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes); } return ss.str(); @@ -352,7 +334,7 @@ State ComputeDAG::InferBound(const State& state) const { ret_state = operator->()->init_state; pstate = ret_state.CopyOnWrite(); pstate->transform_steps = state->transform_steps; - ret_state.DoSteps(*this); + ret_state.ApplySteps(*this); } else { ret_state = state; pstate = ret_state.CopyOnWrite(); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 1bfcb9ebc58a..bfe547864ed1 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -90,36 +90,122 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array& iters, data_ = std::move(node); } +/********** AttachMap **********/ +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // Delete the current entry of this stage + DeleteStageEntry(pnode, stage_id); + + // Store the new stage/iterator relations to map + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = iter_key; + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + // Delete the original stage entry + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::UpdateIters(const std::vector& original_iters, + const std::vector& new_iters) { + CHECK_EQ(original_iters.size(), new_iters.size()); + AttachMapNode* pnode = CopyOnWrite(); + for (size_t i = 0; i < original_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(original_iters[i]); + // We get > from this map + if (entry == pnode->iter_to_attached_stages.end()) { + // Skip if this iterator does not have any attach relations + continue; + } + + // Update the attaching target of an stage to the new iter in `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // Remove the original iterator relation from `iter_to_attached_stages` and add the new + // iterator to it + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + // We get from this map + if (old_entry != pnode->stage_to_attach_iter.end()) { + // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have + // any attatched stage, delete this iterm too + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + // We get > from this map + FindAndDeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // Delete the stage in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + /********** State **********/ State::State(const Array& ops) { auto node = make_object(); for (const auto& op : ops) { node->stages.push_back(Stage(op)); } + node->attach_map = AttachMap(make_object()); node->concrete = true; data_ = std::move(node); } /********** Schedule primitives apis for state **********/ -void State::reorder(int stage_id, const Array& order) { +Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; - CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " - << "should be specified"; - Array after_ids; - GetIndices(stage->iters, order, &after_ids); - ReorderStep step = ReorderStep(stage_id, after_ids); + if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) { + LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, " + << "kThreadX, kThreadY, kBlockZ, kThreadZ"; + } + AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); - DoReorderStep(step); + return step->ApplyToState(this); } -Array State::split(int stage_id, const Iterator& it, - const Array>& lengths, bool inner_to_outer) { +Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - SplitStep step = - SplitStep(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + +Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { + const Stage& stage = operator->()->stages[stage_id]; + + // Don't unroll if the extent is larger than max_unroll + if (max_unroll != -1 && it->range.defined()) { + if (auto imm = it->range->extent.as()) { + if (imm->value > max_unroll) { + return it; + } + } + } + + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + +Iterator State::vectorize(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = + AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize); CopyOnWrite()->transform_steps.push_back(step); - return DoSplitStep(step); + return step->ApplyToState(this); } Iterator State::fuse(int stage_id, const Array& iters) { @@ -128,174 +214,59 @@ Iterator State::fuse(int stage_id, const Array& iters) { GetIndices(stage->iters, iters, &indices); FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); - return DoFuseStep(step); + return step->ApplyToState(this); } -/********** Step implementations for state **********/ -void State::DoReorderStep(const ReorderStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - Array iters; - for (auto x : step->after_ids) { - iters.push_back(stage->iters[x]); - } - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(step->stage_id, - Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); +void State::reorder(int stage_id, const Array& order) { + const Stage& stage = operator->()->stages[stage_id]; + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + << "should be specified"; + Array after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + step->ApplyToState(this); } -// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -Array State::DoSplitStepCommon(int stage_id, int iter_id, - const Array>& lengths, - bool inner_to_outer) { +Array State::split(int stage_id, const Iterator& it, + const Array>& lengths, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; - const Iterator& it = stage->iters[iter_id]; - bool concrete = true; - - Optional tosplit_min, tosplit_extent; - if (it->range.defined()) { - tosplit_min = it->range->min; - tosplit_extent = it->range->extent; - } else { - tosplit_min = NullOpt; - tosplit_extent = NullOpt; - } - - Array outs; - for (size_t i = 0; i < lengths.size(); ++i) { - Optional l; - String name; - if (inner_to_outer) { - l = lengths[lengths.size() - i - 1]; - name = it->name + "." + std::to_string(lengths.size() - i); - } else { - l = lengths[i]; - name = it->name + "." + std::to_string(i); - } - Iterator res; - if (l && tosplit_min && tosplit_extent) { - res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, - IteratorAnnotation::kNone); - tosplit_min = Integer(0); - tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); - } else { - res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); - tosplit_min = NullOpt; - tosplit_extent = NullOpt; - concrete = false; - } - outs.push_back(std::move(res)); - } - - Range range; - if (tosplit_min && tosplit_extent) { - range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); - } - if (inner_to_outer) { - outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone)); - // Reverse the Iterator array - Array temp(outs.rbegin(), outs.rend()); - outs = std::move(temp); - } else { - outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, - IteratorAnnotation::kNone)); - } - - Array new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); - new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); - - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, - Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); - pstate->concrete &= concrete; - - return outs; + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); } -Array State::DoSplitStep(const SplitStep& step) { - return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer); +void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { + const Stage& target_stage = operator->()->stages[target_stage_id]; + ComputeAtStep step = + ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); + CopyOnWrite()->transform_steps.push_back(step); + step->ApplyToState(this); } -Iterator State::DoFuseStep(const FuseStep& step) { - int stage_id = step->stage_id; - const Stage& stage = operator->()->stages[stage_id]; - - String new_name; - PrimExpr new_extent = 1; - IteratorKind new_iter_kind = IteratorKind::kSpecial; - - for (size_t i = 0; i < step->fused_ids.size(); ++i) { - if (i > 0) { - CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); - } - - const Iterator& it = stage->iters[step->fused_ids[i]]; - new_name = new_name + it->name + "@"; - - if (it->range.defined() && new_extent.defined()) { - new_extent = new_extent * it->range->extent; - } else { - new_extent = PrimExpr(); - } - - if (i == 0) { - new_iter_kind = it->iter_kind; - } else { - if (new_iter_kind != it->iter_kind) { - new_iter_kind = IteratorKind::kMixed; - } - } - } +void State::compute_inline(int stage_id) { + ComputeInlineStep step = ComputeInlineStep(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + step->ApplyToState(this); +} - Range range; - if (new_extent.defined()) { - range = Range::FromMinExtent(0, new_extent); - } - Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); - Array new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + step->fused_ids.front()); - new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, - stage->iters.end()); - - StateNode* pstate = CopyOnWrite(); - pstate->stages.Set(stage_id, - Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); - - return new_it; +void State::compute_root(int stage_id) { + ComputeRootStep step = ComputeRootStep(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + step->ApplyToState(this); } -void State::DoSteps(const ComputeDAG& dag) { +void State::ApplySteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; + // Call each step's ApplyToState method for (const auto& step : operator->()->transform_steps) { - if (auto ps = step.as()) { - DoReorderStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoSplitStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoFuseStep(GetRef(ps)); - } else { - LOG(FATAL) << "Invalid step: " << step; - } + StepApplyToState(step, this, dag); } } -static const char* IteratorAnnotationString[] = { - "for", // kNone = 0 - "unroll", // kUnroll = 1 - "vectorize", // kVectorize = 2 - "parallel", // kParallel = 3 - "vthread", // kVThread = 4 - "gpu.blockIdx.x", // kBlockX = 5 - "gpu.threadIdx.x", // kThreadX = 6 - "gpu.blockIdx.y", // kBlockY = 7 - "gpu.threadIdx.y", // kThreadY = 8 - "tensorize" // kTensorized = 9 -}; - // Print stage to ostream void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent, bool delete_trivial_loop) { @@ -332,6 +303,17 @@ void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_ indent += 2; } + + if (state.defined()) { + IterKey iter_key(stage_id, i); + auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); + if (pair != state->attach_map->iter_to_attached_stages.end()) { + // Print the attached stage + for (const auto& attach_stage_id : pair->second) { + PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); + } + } + } } for (size_t j = 0; j < base_indent + indent; ++j) { @@ -386,6 +368,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /********** State interface API for ffi **********/ +TVM_REGISTER_GLOBAL("auto_scheduler.StateBind") + .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) { + const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel") + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll") + .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); + return Array{state, res}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") + .set_body_typed([](State state, int stage_id, const Array& iters) { + const auto& res = state.fuse(stage_id, iters); + return Array{state, res}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") .set_body_typed([](State state, int stage_id, const Array& order) { state.reorder(stage_id, order); @@ -399,10 +411,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") - .set_body_typed([](State state, int stage_id, const Array& iters) { - const auto& res = state.fuse(stage_id, iters); - return Array{state, res}; +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") + .set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline") + .set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot") + .set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; }); TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 04e5304b6943..4d6477b92b0f 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -51,6 +51,9 @@ #include #include +#include +#include +#include #include "transform_step.h" @@ -79,84 +82,6 @@ enum class ComputeAtKind : int { kIter = 2, }; -/*! \brief The type of an iterator. */ -enum class IteratorKind : int { - /*! \brief Spatial iterator. */ - kSpatial = 0, - /*! \brief Reduction iterator. */ - kReduction = 1, - /*! \brief Fused spatial and reduction iterator. */ - kMixed = 2, - /*! \brief Special iterator. (e.g. virtual root iterator) */ - kSpecial = 3 -}; - -/*! \brief The type of an iterator's annotation. */ -enum class IteratorAnnotation : int { - /*! \brief This iterator has no annotation. */ - kNone = 0, - /*! \brief This iterator has been unrolled. */ - kUnroll = 1, - /*! \brief This iterator has been vectorized. */ - kVectorize = 2, - /*! \brief This iterator has been paralleld. */ - kParallel = 3, - /*! \brief This iterator has been bind to vthread. */ - kVThread = 4, - /*! \brief This iterator has been bind to blockIdx.x. */ - kBlockX = 5, - /*! \brief This iterator has been bind to threadIdx.x. */ - kThreadX = 6, - /*! \brief This iterator has been bind to blockIdx.y. */ - kBlockY = 7, - /*! \brief This iterator has been bind to threadIdx.y. */ - kThreadY = 8, - /*! \brief This iterator has been mapped with a tensorize intrinsic. */ - kTensorized = 9 -}; - -/*! - * \brief A for loop iterator - * Similar to tvm::IterVar in `include/tvm/tir/expr.h` - */ -class IteratorNode : public Object { - public: - /*! \brief The name of this iterator. */ - String name; - /*! \brief The range of this iterator. */ - Range range; - /*! \brief The iterator type of this iterator. */ - IteratorKind iter_kind; - /*! \brief The annotation type of this iterator. */ - IteratorAnnotation annotation; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("range", &range); - } - - static constexpr const char* _type_key = "auto_scheduler.Iterator"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); -}; - -/*! - * \brief Managed reference to IteratorNode. - * \sa IteratorNode - */ -class Iterator : public ObjectRef { - public: - /*! - * \brief The constructor. - * \param name The name of this iterator. - * \param range The range of this iterator. - * \param iter_kind The iterator type of this iterator. - * \param annotation The annotation type of this iterator. - */ - Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation); - - TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); -}; - /*! \brief Stage-level attributes. */ struct StageAttributes { /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */ @@ -167,16 +92,16 @@ struct StageAttributes { /*! * \brief A op stage in the compute declaration. - * Similar to te::Stage in `include/schedule.h`. + * Similar to te::Stage in `include/tvm/te/schedule.h`. */ class StageNode : public Object { public: /*! \brief The operator of this stage */ te::Operation op; - /*! \brief The type of this stage. */ - StageKind op_type; /*! \brief The iterators in this stage. */ Array iters; + /*! \brief The type of this stage. */ + StageKind op_type; /*! \brief The compute location of this stage. */ ComputeAtKind compute_at; /*! \brief Other stage-level attributes. */ @@ -185,6 +110,8 @@ class StageNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); v->Visit("iters", &iters); + v->Visit("op_type", &op_type); + v->Visit("compute_at", &compute_at); } static constexpr const char* _type_key = "auto_scheduler.Stage"; @@ -217,6 +144,70 @@ class Stage : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); }; +/*! \brief Use stage_id to represent a stage. */ +using StageKey = int; +/*! \brief Use stage_id and iter_id to represent a iterator. */ +using IterKey = std::pair; + +/*! + * \brief stores the compute_at relation between stages + * This stores a bi-directional mapping from stages and iter: + * 1. Stage to its attached iterator + * 2. Iterator to the stage attached to it + * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages + * to query the relations + */ +class AttachMapNode : public Object { + public: + /*! \brief A Map to store the mapping of stage to its attached iterator. */ + std::unordered_map stage_to_attach_iter; + /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ + std::unordered_map> iter_to_attached_stages; + + static constexpr const char* _type_key = "auto_scheduler.AttachMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); +}; + +/*! + * \brief Managed reference to AttachMapNode. + * \sa AttachMapNode + */ +class AttachMap : public ObjectRef { + public: + /*! + * \brief Process the stage/iterator mapping after compute at. + * \param stage_id The index of the stage to be compute at. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter_id The index of iterator in target stage that this step will compute at to. + */ + void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + /*! + * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage. + * \param stage_id The index of the stage to be compute at. + */ + void DeleteStage(int stage_id); + /*! + * \brief Find the relations of original iterators in AttachMap, and update them with the new + * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated. + * \param original_iters The original IterKey. + * \param new_iters The new IterKey to update. + */ + void UpdateIters(const std::vector& original_iters, + const std::vector& new_iters); + + TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); + + private: + /*! + * \brief To delete the entry of a specific stage. This will remove the items related to this + * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map. + * \param pnode A mutable pointer to AttachMapNode. + * \param stage_id The index of stage that will be removed from the map. + */ + static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); +}; + /*! * \brief A state in the search process. * It consists of the current loop structure and a list of transformation steps used to construct @@ -229,6 +220,11 @@ class StateNode : public Object { Array stages; /*! \brief History transformation steps. */ Array transform_steps; + /*! + * \brief The attach relations of stages and iterators. This is used to track the compute at + * operation. + */ + AttachMap attach_map; /*! * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. @@ -275,16 +271,59 @@ class State : public ObjectRef { String ToStr(bool delete_trivial_loop = true) const; /*! - * \brief General do step functions with a runtime dynamic dispatcher. This will re-apply all the - * transform steps with the initial state. + * \brief General call step functions with a runtime dynamic dispatcher. This will re-apply all + * the transform steps from the initial state. * \param dag The original ComputeDAG of this state. - * \note This is different from the class member `current_compute_dag`, for some transform step - * may change the op stage structure of the ComputeDAG. + * \note The input `dag` is different from the class member `current_compute_dag`. + * This function takes the initial ComputeDAG as input to replay all the history. While the + * `current_compute_dag` is used to track the current stage status, for some transform step may + * change the op stage structure. */ - void DoSteps(const ComputeDAG& dag); + void ApplySteps(const ComputeDAG& dag); - /* Step APIs for State. */ + /********** Step APIs working on single stage **********/ + /*! + * \brief Schedule primitive corresponds to te.bind. + * \param stage_id The index of the stage to be binded. + * \param it The iterator to be binded. + * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as + * this input. + * \return The iterator result after binded. + */ + Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); + /*! + * \brief Schedule primitive corresponds to te.parallel. + * \param stage_id The index of the stage to be paralleled. + * \param it The iterator to be paralleled. + * \return The iterator result after parallel. + */ + Iterator parallel(int stage_id, const Iterator& it); + /*! + * \brief Schedule primitive corresponds to te.unroll. + * \param stage_id The index of the stage to be unrolled. + * \param it The iterator to be unrolled. + * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be + * skipped. + * \return The iterator result after unrolled. + */ + Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); + /*! + * \brief Schedule primitive corresponds to te.vectorize. + * \param stage_id The index of the stage to be vectorized. + * \param it The iterator to be vectorized. + * \return The iterator result after vectorize. + */ + Iterator vectorize(int stage_id, const Iterator& it); + /*! + * \brief Schedule primitive corresponds to te.fuse. + * \param stage_id The index of the stage to be fused. + * \param iters The iterators to be fused. + * \return The iterator result after fuse. + * \note If the iterators to be fused have stages attached at them(by compute_at), the fused + * result will become the new attach point. + */ + Iterator fuse(int stage_id, const Array& iters); /*! * \brief Schedule primitive corresponds to te.reorder. * \param stage_id The index of the stage to be reordered. @@ -294,57 +333,46 @@ class State : public ObjectRef { /*! * \brief Schedule primitive corresponds to te.split. * \param stage_id The index of the stage to be split. - * \param it The iterator the be split. + * \param it The iterator to be split. * \param lengths The multiple split factors. Can be None to be filled by search policy. * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner. * \return The iterator results after split. + * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner + * most iterator of split results will become the new attach point. */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); - /*! - * \brief Schedule primitive corresponds to te.fuse. - * \param stage_id The index of the stage to be fused. - * \param iters The iterators to be fused. - * \return The iterator result after fuse. - */ - Iterator fuse(int stage_id, const Array& iters); - TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); - - private: - /* Do transform steps - * Note: The following functions only change loop state but do not change transform_history. - * We separate these functions out, so you can call them for replay easily given history steps */ + /********** Step APIs working on multiple stages **********/ /*! - * \brief Apply reorder step to current state. - * \param step A ReorderStep. + * \brief Schedule primitive corresponds to te.compute_at. + * \param stage_id The index of the stage to be reordered. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter The iterator in target stage that this step will compute at to. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - void DoReorderStep(const ReorderStep& step); + void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! - * \brief Apply split step to current state. - * \param step A SplitStep. - * \return The iterator results after split. + * \brief Schedule primitive corresponds to te.compute_inline. + * \param stage_id The index of the stage to be reordered. */ - Array DoSplitStep(const SplitStep& step); + void compute_inline(int stage_id); /*! - * \brief Apply fuse step to current state. - * \param step A FuseStep. - * \return The iterator result after fuse. + * \brief Schedule primitive corresponds to te.compute_root. + * \param stage_id The index of the stage to be reordered. + * \note After compute_root, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - Iterator DoFuseStep(const FuseStep& step); + void compute_root(int stage_id); - /*! - * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later). - * \param stage_id The index of the stage to be split. - * \param iter_id The index of the iterator to be split. - * \param lengths The multiple split factors. - * \param inner_to_outer The split direction. - * \return The iterator results after split. - */ - Array DoSplitStepCommon(int stage_id, int iter_id, - const Array>& lengths, bool inner_to_outer); + TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); }; } // namespace auto_scheduler diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index f6f882edb5a2..39f9ad86c958 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -42,25 +42,6 @@ namespace dmlc { namespace json { -inline std::vector IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) { - std::vector out; - for (const auto& x : data) { - CHECK(x.defined()); - out.push_back(x); - } - return out; -} - -inline std::vector IntArrayToVector( - const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) { - std::vector out; - for (const auto& x : data) { - CHECK(x); - out.push_back(x.value()); - } - return out; -} - template <> struct Handler<::tvm::Array<::tvm::auto_scheduler::Stage>> { inline static void Write(dmlc::JSONWriter* writer, @@ -82,28 +63,10 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::auto_scheduler::Step>& data) { writer->BeginArray(false); - for (size_t i = 0; i < data.size(); ++i) { + for (const auto& step : data) { writer->WriteArraySeperator(); writer->BeginArray(false); - if (auto ps = data[i].as<::tvm::auto_scheduler::ReorderStepNode>()) { - writer->WriteArrayItem(std::string("RE")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(IntArrayToVector(ps->after_ids)); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::SplitStepNode>()) { - writer->WriteArrayItem(std::string("SP")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->extent ? ::tvm::auto_scheduler::GetIntImm(ps->extent.value()) - : 0); - writer->WriteArrayItem(IntArrayToVector(ps->lengths)); - writer->WriteArrayItem(static_cast(ps->inner_to_outer)); - } else if (auto ps = data[i].as<::tvm::auto_scheduler::FuseStepNode>()) { - writer->WriteArrayItem(std::string("FU")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(IntArrayToVector(ps->fused_ids)); - } else { - LOG(FATAL) << "Invalid step: " << data[i]; - } + step->WriteToRecord(writer); writer->EndArray(); } writer->EndArray(); @@ -111,67 +74,12 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> { inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::auto_scheduler::Step>* data) { - std::vector int_list; - bool s, inner_to_outer; - std::string name, scope_name, pragma_type, ti_func_name; - int stage_id, iter_id, extent; - + bool s; reader->BeginArray(); data->clear(); while (reader->NextArrayItem()) { reader->BeginArray(); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&name); - if (name == "RE") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&int_list); - ::tvm::Array<::tvm::Integer> after_ids; - for (const auto& i : int_list) { - after_ids.push_back(i); - } - data->push_back(::tvm::auto_scheduler::ReorderStep(stage_id, after_ids)); - } else if (name == "SP") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&extent); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&int_list); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&inner_to_outer); - ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths; - for (const auto& i : int_list) { - lengths.push_back(::tvm::Integer(i)); - } - data->push_back(::tvm::auto_scheduler::SplitStep( - stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer)); - } else if (name == "FU") { - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); - CHECK(s); - reader->Read(&int_list); - ::tvm::Array<::tvm::Integer> fused_ids; - for (const auto& i : int_list) { - fused_ids.push_back(i); - } - data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids)); - } else { - LOG(FATAL) << "Invalid step format"; - } + data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader)); s = reader->NextArrayItem(); CHECK(!s); } @@ -187,8 +95,8 @@ struct Handler<::tvm::auto_scheduler::StateNode> { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) { - reader->BeginArray(); bool s; + reader->BeginArray(); s = reader->NextArrayItem(); CHECK(s); reader->Read(&data->stages); @@ -210,18 +118,17 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { - std::string target_str; bool s; - + std::string str_value; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(s); - reader->Read(&target_str); - data->workload_key = std::move(target_str); + reader->Read(&str_value); + data->workload_key = std::move(str_value); s = reader->NextArrayItem(); CHECK(s); - reader->Read(&target_str); - data->target = ::tvm::Target::Create(target_str); + reader->Read(&str_value); + data->target = ::tvm::Target::Create(str_value); s = reader->NextArrayItem(); CHECK(!s); } @@ -237,11 +144,11 @@ struct Handler<::tvm::auto_scheduler::MeasureInputNode> { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureInputNode* data) { - bool s; auto task_node = ::tvm::make_object<::tvm::auto_scheduler::SearchTaskNode>(); auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>(); state_node->concrete = true; + bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(s); @@ -277,15 +184,14 @@ struct Handler<::tvm::auto_scheduler::MeasureResultNode> { } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureResultNode* data) { + std::vector double_list; bool s; - std::vector tmp; - reader->BeginArray(); s = reader->NextArrayItem(); CHECK(s); - reader->Read(&tmp); + reader->Read(&double_list); data->costs.clear(); - for (const auto& i : tmp) { + for (const auto& i : double_list) { data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); } s = reader->NextArrayItem(); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 90b4db838fef..6c672a5215f2 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -28,7 +28,9 @@ #include #include +#include #include +#include #include "loop_state.h" #include "utils.h" @@ -36,6 +38,404 @@ namespace tvm { namespace auto_scheduler { +const char* IteratorAnnotationString[] = { + "for", // kNone = 0 + "unroll", // kUnroll = 1 + "vectorize", // kVectorize = 2 + "parallel", // kParallel = 3 + "vthread", // kVThread = 4 + "blockIdx.x", // kBlockX = 5 + "threadIdx.x", // kThreadX = 6 + "blockIdx.y", // kBlockY = 7 + "threadIdx.y", // kThreadY = 8 + "blockIdx.z", // kBlockZ = 9 + "threadIdx.z", // kThreadZ = 10 + "tensorize" // kTensorized = 11 +}; + +Step StepReadFromRecord(dmlc::JSONReader* reader) { + std::string name; + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&name); + if (name == AnnotationStepNode::record_prefix_str) { + return AnnotationStep(reader); + } else if (name == FuseStepNode::record_prefix_str) { + return FuseStep(reader); + } else if (name == ReorderStepNode::record_prefix_str) { + return ReorderStep(reader); + } else if (name == SplitStepNode::record_prefix_str) { + return SplitStep(reader); + } else if (name == ComputeAtStepNode::record_prefix_str) { + return ComputeAtStep(reader); + } else if (name == ComputeInlineStepNode::record_prefix_str) { + return ComputeInlineStep(reader); + } else if (name == ComputeRootStepNode::record_prefix_str) { + return ComputeRootStep(reader); + } else { + LOG(FATAL) << "Invalid step format: " << name; + } + return Step(); +} + +void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { + if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else { + LOG(FATAL) << "Invalid step: " << step; + } +} + +void StepApplyToSchedule(const Step& step, Array* stages, + StageToAxesMap* stage_to_axes) { + if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step: " << step; + } +} + +String StepPrintAsPythonAPI(const Step& step, Array* stages, + StageToAxesMap* stage_to_axes) { + if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step: " << step; + } + return ""; +} + +/********** Primitives working on single stage **********/ + +/********** Annotation **********/ +AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->annotation = ann; + data_ = std::move(node); +} + +AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + int int_val; + reader->Read(&int_val); + node->annotation = IteratorAnnotation(int_val); + data_ = std::move(node); +} + +void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(static_cast(annotation)); +} + +Iterator AnnotationStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + Iterator it = stage->iters[iter_id]; + + CHECK(it->annotation == IteratorAnnotation::kNone); + Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters.Set(iter_id, new_it); + state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage)); + return new_it; +} + +void AnnotationStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + + switch (annotation) { + case IteratorAnnotation::kUnroll: + stage.unroll(axes[iter_id]); + break; + case IteratorAnnotation::kVectorize: + stage.vectorize(axes[iter_id]); + break; + case IteratorAnnotation::kParallel: + stage.parallel(axes[iter_id]); + break; + case IteratorAnnotation::kVThread: + case IteratorAnnotation::kBlockX: + case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kBlockZ: + case IteratorAnnotation::kThreadX: + case IteratorAnnotation::kThreadY: + case IteratorAnnotation::kThreadZ: + stage.bind(axes[iter_id], + te::thread_axis(Range(), IteratorAnnotationString[static_cast(annotation)])); + break; + case IteratorAnnotation::kNone: + break; + default: + LOG(FATAL) << "Invalid Annotation " << static_cast(annotation); + break; + } + + stages->Set(stage_id, std::move(stage)); +} + +String AnnotationStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& iter = (*stage_to_axes)[stage][iter_id]; + + ss << "s[" << CleanName(stage->op->name) << "]."; + switch (annotation) { + case IteratorAnnotation::kUnroll: + ss << "unroll("; + break; + case IteratorAnnotation::kVectorize: + ss << "vectorize("; + break; + case IteratorAnnotation::kParallel: + ss << "parallel("; + break; + case IteratorAnnotation::kVThread: + case IteratorAnnotation::kBlockX: + case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kBlockZ: + case IteratorAnnotation::kThreadX: + case IteratorAnnotation::kThreadY: + case IteratorAnnotation::kThreadZ: + ss << "bind("; + break; + case IteratorAnnotation::kNone: + break; + default: + LOG(FATAL) << "Invalid annotation " << static_cast(annotation); + break; + } + ss << CleanName(iter->var->name_hint); + switch (annotation) { + case IteratorAnnotation::kVThread: + case IteratorAnnotation::kBlockX: + case IteratorAnnotation::kBlockY: + case IteratorAnnotation::kBlockZ: + case IteratorAnnotation::kThreadX: + case IteratorAnnotation::kThreadY: + case IteratorAnnotation::kThreadZ: + ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast(annotation)] + << "\")"; + break; + default: + break; + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Fuse **********/ +FuseStep::FuseStep(int stage_id, const Array& fused_ids) { + auto node = make_object(); + node->stage_id = stage_id; + for (const auto& x : fused_ids) { + CHECK(x->IsInstance()); + } + node->fused_ids = fused_ids; + data_ = std::move(node); +} + +FuseStep::FuseStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + std::vector int_list; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> fused_ids; + for (const auto& i : int_list) { + fused_ids.push_back(i); + } + node->fused_ids = fused_ids; + data_ = std::move(node); +} + +void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(IntArrayToVector(fused_ids)); +} + +Iterator FuseStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + size_t old_iter_size = static_cast(stage->iters.size()); + + String new_name; + PrimExpr new_extent = 1; + IteratorKind new_iter_kind = IteratorKind::kSpecial; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1); + } + + if (i != fused_ids.size() - 1) { + const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) != + iter_to_attached_stage.end()) { + LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " + << "stages. State before fusion:\n" + << (*state); + } + } + + const Iterator& it = stage->iters[fused_ids[i]]; + new_name = new_name + it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_kind = it->iter_kind; + } else { + if (new_iter_kind != it->iter_kind) { + new_iter_kind = IteratorKind::kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::FromMinExtent(0, new_extent); + } + Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone); + Array new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + + // Two vectors are used to represent the iterator relation before and after fuse + // The original iterators in AttachMap will be updated with the new iterators + std::vector from_iters; + std::vector to_iters; + const size_t begin_id = fused_ids.front(), end_id = fused_ids.back(); + for (size_t i = 0; i < old_iter_size; ++i) { + if (i <= begin_id) { + continue; + } else if (i > end_id) { + // move forward + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i - end_id + begin_id); + } else { + // move to the fused id + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, begin_id); + } + } + pstate->attach_map.UpdateIters(from_iters, to_iters); + + return new_it; +} + +IterVar FuseStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + auto stage = (*stages)[stage_id]; + const Array& axes = stage_to_axes->at(stage); + + Array to_fuse; + for (const auto& i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + + Array new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); + + stage_to_axes->Set(stage, std::move(new_axes)); + stages->Set(stage_id, std::move(stage)); + return fused_axis; +} + +String FuseStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse(" + << to_fuse.str() << ")\n"; + + return ss.str(); +} + /********** Reorder **********/ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); @@ -47,6 +447,41 @@ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { data_ = std::move(node); } +ReorderStep::ReorderStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + std::vector int_list; + reader->Read(&int_list); + ::tvm::Array<::tvm::Integer> after_ids; + for (const auto& i : int_list) { + after_ids.push_back(i); + } + node->after_ids = after_ids; + data_ = std::move(node); +} + +void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(IntArrayToVector(after_ids)); +} + +void ReorderStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + Array iters; + for (auto x : after_ids) { + iters.push_back(stage->iters[x]); + } + state->CopyOnWrite()->stages.Set( + stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); +} + void ReorderStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { auto stage = (*stages)[stage_id]; @@ -83,6 +518,86 @@ String ReorderStepNode::PrintAsPythonAPI(Array* stages, } /********** Split **********/ +// common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep +Array ApplySplitToState(State* state, int stage_id, int iter_id, + const Array>& lengths, bool inner_to_outer) { + const Stage& stage = (*state)->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + size_t old_iter_size = stage->iters.size(); + bool concrete = true; + + Optional tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = NullOpt; + tosplit_extent = NullOpt; + } + + Array outs; + for (size_t i = 0; i < lengths.size(); ++i) { + Optional l; + String name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l && tosplit_min && tosplit_extent) { + res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, + IteratorAnnotation::kNone); + tosplit_min = Integer(0); + tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); + } else { + res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); + tosplit_min = NullOpt; + tosplit_extent = NullOpt; + concrete = false; + } + outs.push_back(std::move(res)); + } + + Range range; + if (tosplit_min && tosplit_extent) { + range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); + } + if (inner_to_outer) { + outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone)); + // Reverse the Iterator array + Array temp(outs.rbegin(), outs.rend()); + outs = std::move(temp); + } else { + outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, + IteratorAnnotation::kNone)); + } + + Array new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), outs.begin(), outs.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, + Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); + pstate->concrete &= concrete; + + // Two vectors are used to represent the iterator relation before and after split + // The original iterators in AttachMap will be updated with the new iterators + std::vector from_iters; + std::vector to_iters; + for (size_t i = iter_id; i < old_iter_size; ++i) { + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i + lengths.size()); + } + pstate->attach_map.UpdateIters(from_iters, to_iters); + + return outs; +} + Array ApplySplitToSchedule(Array* stages, StageToAxesMap* stage_to_axes, int stage_id, int iter_id, const Array>& lengths, bool inner_to_outer) { @@ -171,6 +686,51 @@ SplitStep::SplitStep(int stage_id, int iter_id, Optional extent, data_ = std::move(node); } +SplitStep::SplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + int int_val; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&int_val); + if (int_val) { + node->extent = Integer(int_val); + } + s = reader->NextArrayItem(); + CHECK(s); + std::vector int_list; + reader->Read(&int_list); + ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths; + for (const auto& i : int_list) { + lengths.push_back(::tvm::Integer(i)); + } + node->lengths = lengths; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->inner_to_outer); + data_ = std::move(node); +} + +void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0); + writer->WriteArrayItem(IntArrayToVector(lengths)); + writer->WriteArrayItem(static_cast(inner_to_outer)); +} + +Array SplitStepNode::ApplyToState(State* state) const { + return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer); +} + Array SplitStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const { return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); @@ -181,57 +741,185 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -/********** Fuse **********/ -FuseStep::FuseStep(int stage_id, const Array& fused_ids) { - auto node = make_object(); +/********** Primitives working on multiple stages **********/ + +/********** Compute At **********/ +ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { + auto node = make_object(); node->stage_id = stage_id; - for (const auto& x : fused_ids) { - CHECK(x->IsInstance()); - } - node->fused_ids = fused_ids; + node->target_stage_id = target_stage_id; + node->target_iter_id = target_iter_id; data_ = std::move(node); } -IterVar FuseStepNode::ApplyToSchedule(Array* stages, - StageToAxesMap* stage_to_axes) const { - auto stage = (*stages)[stage_id]; - const Array& axes = stage_to_axes->at(stage); +ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->target_stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->target_iter_id); + data_ = std::move(node); +} - Array to_fuse; - for (const auto& i : fused_ids) { - to_fuse.push_back(axes[i]); +void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(target_stage_id); + writer->WriteArrayItem(target_iter_id); +} +void ComputeAtStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + + // Remove the bound information of each iterator since they may not be accurate after + // compute at + Array new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); } - IterVar fused_axis; - stage.fuse(to_fuse, &fused_axis); - Array new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()); - new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kIter, stage->attrs)); + // Update attach map + pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id); +} + +void ComputeAtStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id]; + stage.compute_at(target_stage, target_axis); - stage_to_axes->Set(stage, std::move(new_axes)); stages->Set(stage_id, std::move(stage)); - return fused_axis; } -String FuseStepNode::PrintAsPythonAPI(Array* stages, - StageToAxesMap* stage_to_axes) const { +String ComputeAtStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; const auto& stage = (*stages)[stage_id]; - std::stringstream to_fuse; + const auto& target_stage = (*stages)[target_stage_id]; + ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name) + << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} - for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint); - if (i != fused_ids.size() - 1) { - to_fuse << ", "; - } +/********** Compute Inline **********/ +ComputeInlineStep::ComputeInlineStep(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + data_ = std::move(node); +} + +ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + data_ = std::move(node); +} + +void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); +} + +void ComputeInlineStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + + // Check the validity of compute_inline + for (size_t i = 0; i < stage->iters.size(); ++i) { + CHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0) + << "Invalid compute_inline: There are some other stages that are attached to the " + << "target stage"; } + StateNode* pstate = state->CopyOnWrite(); + auto new_stage = pstate->stages[stage_id]; + new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; + pstate->stages.Set(stage_id, std::move(new_stage)); + // Update attach map + pstate->attach_map.DeleteStage(stage_id); +} + +void ComputeInlineStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + auto stage = (*stages)[stage_id]; + stage.compute_inline(); + stages->Set(stage_id, std::move(stage)); +} + +String ComputeInlineStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { std::stringstream ss; - const auto& fused = ApplyToSchedule(stages, stage_to_axes); + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} - ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse(" - << to_fuse.str() << ")\n"; +/********** Compute Root **********/ +ComputeRootStep::ComputeRootStep(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + data_ = std::move(node); +} +ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + data_ = std::move(node); +} + +void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); +} + +void ComputeRootStepNode::ApplyToState(State* state) const { + const Stage& stage = (*state)->stages[stage_id]; + + // Remove the bound information of each iterator since they may not be accurate after + // compute root + Array new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation)); + } + + StateNode* pstate = state->CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + ComputeAtKind::kRoot, stage->attrs)); + // Update attach map + pstate->attach_map.DeleteStage(stage_id); +} + +void ComputeRootStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + auto stage = (*stages)[stage_id]; + stage.compute_root(); + stages->Set(stage_id, std::move(stage)); +} + +String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; + ApplyToSchedule(stages, stage_to_axes); return ss.str(); } diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index d840cc009e2d..ce3ca50ffae6 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -20,29 +20,34 @@ /*! * \file auto_scheduler/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. The implementation of each step consists of 2 parts: - * - transform_step.cc: How each step interacts with TE and TE's schedule primitives - * - loop_state.cc: How each step updates LoopState + * step. * * \note To add a new transform step: * Take fuse step for example: - * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction - * function `FuseStep::FuseStep(...)` in `transform_steps.cc` - * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first + * construction function `FuseStep::FuseStep()` in `transform_steps.cc`. + * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::PrintAsPythonAPI()`. * - In these two functions you need to lower this step with tvm's te schedule API - * 3. Implement `State::fuse` and `State::DoFuseStep`. + * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`. * - In these two functions you need to incrementally update all data structures in State with - * CopyOnWrite style - * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works. - * 5. Add log record serialization support in `struct Handler>` - * in `record.cc`. - * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test. + * CopyOnWrite style. + * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and + * `StepPrintAsPythonAPI`, make sure it works. + * 5. Log record serialization support: + * - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and + * output the record to it. + * - Add another construction function that takes a mutable JSONReader as input, this will get a + * step record from the reader and create the step. + * - Add the step implementation to `StepReadFromRecord`. + * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test, the test should + * at lease consists of two parts: the functional test and the record serialization test. */ #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ #define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ #include +#include #include #include @@ -53,6 +58,92 @@ namespace auto_scheduler { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; +/*! \brief The type of an iterator. */ +enum class IteratorKind : int { + /*! \brief Spatial iterator. */ + kSpatial = 0, + /*! \brief Reduction iterator. */ + kReduction = 1, + /*! \brief Fused spatial and reduction iterator. */ + kMixed = 2, + /*! \brief Special iterator. (e.g. virtual root iterator) */ + kSpecial = 3 +}; + +/*! \brief The type of an iterator's annotation. */ +enum class IteratorAnnotation : int { + /*! \brief This iterator has no annotation. */ + kNone = 0, + /*! \brief This iterator has been unrolled. */ + kUnroll = 1, + /*! \brief This iterator has been vectorized. */ + kVectorize = 2, + /*! \brief This iterator has been paralleld. */ + kParallel = 3, + /*! \brief This iterator has been bind to vthread. */ + kVThread = 4, + /*! \brief This iterator has been bind to blockIdx.x. */ + kBlockX = 5, + /*! \brief This iterator has been bind to threadIdx.x. */ + kThreadX = 6, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockY = 7, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadY = 8, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockZ = 9, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadZ = 10, + /*! \brief This iterator has been mapped with a tensorize intrinsic. */ + kTensorize = 11 +}; + +extern const char* IteratorAnnotationString[]; + +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + /*! \brief The name of this iterator. */ + String name; + /*! \brief The range of this iterator. */ + Range range; + /*! \brief The iterator type of this iterator. */ + IteratorKind iter_kind; + /*! \brief The annotation type of this iterator. */ + IteratorAnnotation annotation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + v->Visit("iter_kind", &iter_kind); + v->Visit("annotation", &annotation); + } + + static constexpr const char* _type_key = "auto_scheduler.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; + +/*! + * \brief Managed reference to IteratorNode. + * \sa IteratorNode + */ +class Iterator : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of this iterator. + * \param range The range of this iterator. + * \param iter_kind The iterator type of this iterator. + * \param annotation The annotation type of this iterator. + */ + Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation); + + TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); +}; + /*! * \brief The base class of transformation steps. Each step has its corresponding tvm.te * schedule primitives. @@ -62,6 +153,12 @@ class StepNode : public Object { /*! \brief The index of the stage. */ int stage_id; + /*! + * \brief Serialize the current step record to JSONWriter. + * \param writer The output JSONWriter. + */ + virtual void WriteToRecord(dmlc::JSONWriter* writer) const = 0; + static constexpr const char* _type_key = "auto_scheduler.Step"; TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); }; @@ -75,6 +172,172 @@ class Step : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode); }; +// Forward declaration +class State; +class ComputeDAG; + +/*! + * \brief Read a step record from JSONReader and create the corresponding step. + * \param reader The input JSONReader. + */ +Step StepReadFromRecord(dmlc::JSONReader* reader); + +/*! + * \brief Apply the step to State. + * \param step The step to be applied to State. + * \param state A mutable pointer to State. + * \param dag The original ComputeDAG of this state. + * \return The iterator result after annotate. + */ +void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); + +/*! + * \brief Apply the step to tvm.schedule. + * \param step The step to be applied to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + */ +void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); + +/*! + * \brief Print the step as equivalent python schedule API. + * \param step The step to be applied to python API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ +String StepPrintAsPythonAPI(const Step& step, Array* stages, + StageToAxesMap* stage_to_axes); + +/********** Primitives working on single stage **********/ + +/*! + * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. + * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) + */ +class AnnotationStepNode : public StepNode { + public: + /*! \brief The index of the iterator to add annotation. */ + int iter_id; + /*! \brief The annotation type of this step. */ + IteratorAnnotation annotation; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \return The iterator result after annotate. + */ + Iterator ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "AN"; + + static constexpr const char* _type_key = "auto_scheduler.AnnotationStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); +}; + +/*! + * \brief Managed reference to AnnotationStepNode. + * \sa AnnotationStepNode + */ +class AnnotationStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to add annotation. + * \param iter_id The index of the iterator to add annotation. + * \param ann The annotation type of this step. + */ + AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit AnnotationStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); +}; + +/*! \brief Fuse step that corresponds to te::Stage::fuse */ +class FuseStepNode : public StepNode { + public: + /*! \brief The ids of iterators to fuse. */ + Array fused_ids; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \return The iterator result after fuse. + * \note If the iterators to be fused have stages attached at them(by compute_at), the fused + * result will become the new attach point. + */ + Iterator ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator result after fuse. + */ + tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "FU"; + + static constexpr const char* _type_key = "auto_scheduler.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); +}; + +/*! + * \brief Managed reference to FuseStepNode. + * \sa FuseStepNode + */ +class FuseStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be fused. + * \param fused_ids The index of the iterators to be fused. + */ + FuseStep(int stage_id, const Array& fused_ids); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit FuseStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); +}; + /*! \brief Reorder step that corresponds to te::Stage::reorder */ class ReorderStepNode : public StepNode { public: @@ -84,21 +347,31 @@ class ReorderStepNode : public StepNode { */ Array after_ids; + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ + void ApplyToState(State* state) const; + /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "RE"; + static constexpr const char* _type_key = "auto_scheduler.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; @@ -116,6 +389,13 @@ class ReorderStep : public Step { */ ReorderStep(int stage_id, const Array& after_ids); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ReorderStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; @@ -137,8 +417,19 @@ class SplitStepNode : public StepNode { */ bool inner_to_outer; + void WriteToRecord(dmlc::JSONWriter* writer) const final; + /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \return The iterator results after split. + * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner + * most iterator of split results will become the new attach point. + */ + Array ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator results after split. @@ -147,13 +438,15 @@ class SplitStepNode : public StepNode { StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + static constexpr const char* record_prefix_str = "SP"; + static constexpr const char* _type_key = "auto_scheduler.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); }; @@ -175,49 +468,195 @@ class SplitStep : public Step { SplitStep(int stage_id, int iter_id, Optional extent, const Array>& lengths, bool inner_to_outer); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit SplitStep(dmlc::JSONReader* reader); + TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::fuse */ -class FuseStepNode : public StepNode { +/********** Primitives working on multiple stages **********/ + +/*! \brief Compute at step that corresponds to te::Stage::compute_at */ +class ComputeAtStepNode : public StepNode { public: - /*! \brief The ids of iterators to fuse. */ - Array fused_ids; + /*! \brief The index of stage that this step will compute at to. */ + int target_stage_id; + /*! \brief The index of iterator in target stage that this step will compute at to. */ + int target_iter_id; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "CA"; + + static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); +}; + +/*! + * \brief Managed reference to ComputeAtStepNode. + * \sa ComputeAtStepNode + */ +class ComputeAtStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be compute at. + * \param target_stage_id The index of stage that this step will compute at to. + * \param target_iter_id The index of iterator in target stage that this step will compute at to. + */ + ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ComputeAtStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); +}; + +/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ +class ComputeInlineStepNode : public StepNode { + public: + void WriteToRecord(dmlc::JSONWriter* writer) const final; /*! - * \brief Apply the current state to tvm.schedule + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return The iterator result after fuse. */ - tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! - * \brief Print step as equivalent python schedule API. + * \brief Print the current step as equivalent python schedule API. * \param stages A pointer to a `te::Stage` Array. * \param stage_to_axes A pointer to a StageToAxesMap. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* _type_key = "auto_scheduler.FuseStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); + static constexpr const char* record_prefix_str = "CI"; + + static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); }; /*! - * \brief Managed reference to FuseStepNode. - * \sa FuseStepNode + * \brief Managed reference to ComputeInlineStepNode. + * \sa ComputeInlineStepNode */ -class FuseStep : public Step { +class ComputeInlineStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be fused. - * \param fused_ids The index of the iterators to be fused. + * \param stage_id The index of the stage to be compute inline. */ - FuseStep(int stage_id, const Array& fused_ids); + explicit ComputeInlineStep(int stage_id); - TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ComputeInlineStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); +}; + +/*! \brief Compute root step that corresponds to te::Stage::compute_root */ +class ComputeRootStepNode : public StepNode { + public: + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * information. However, it is relatively expensive and complicated, so we just fill "None" as + * bound for the newly created iterators. + * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return The iterator result after fuse. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "CR"; + + static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); +}; + +/*! + * \brief Managed reference to ComputeRootStepNode. + * \sa ComputeRootStepNode + */ +class ComputeRootStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be compute root + */ + explicit ComputeRootStep(int stage_id); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit ComputeRootStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; } // namespace auto_scheduler diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index 5637780e3991..de800da13b64 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -63,7 +63,7 @@ struct hash> { namespace tvm { namespace auto_scheduler { -/********** Utilities for Array, std::string **********/ +/********** Utilities for Array, std::vector, std::string **********/ /*! \brief Get the first appearance index of elements in an Array */ template inline void GetIndices(const Array& array, const Array& to_locate, Array* indices) { @@ -89,6 +89,15 @@ inline int GetIndex(const Array& array, const T& to_locate) { return -1; } +/*! \brief Delete the item in a std::vector if it exists. */ +template +inline void FindAndDeleteItem(std::vector* array, const T& to_delete) { + auto iter = std::find(array->begin(), array->end(), to_delete); + if (iter != array->end()) { + array->erase(iter); + } +} + /*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); @@ -98,6 +107,27 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st } } +/*! \brief Convert a Array to std::vector. */ +inline std::vector IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) { + std::vector out; + for (const auto& x : data) { + CHECK(x.defined()); + out.push_back(x); + } + return out; +} + +/*! \brief Convert a Array> to std::vector. */ +inline std::vector IntArrayToVector( + const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) { + std::vector out; + for (const auto& x : data) { + CHECK(x); + out.push_back(x.value()); + } + return out; +} + /********** Utilities for TVM Containers / ByteArray **********/ /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 078e1ae8e854..fa22fdc5597c 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -40,7 +40,7 @@ def matmul_auto_scheduler_test_rename_0(N, M, K): C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] - +@auto_scheduler.register_workload def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, CI, H, W), name='Data') kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 0801d9200275..32ea8faa84d0 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -26,8 +26,8 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, conv2d_nchw_bn_relu -def test_split_fuse_reorder(): - A, B, C = matmul_auto_scheduler_test(512, 512, 512) +def test_split_fuse_reorder_annotation(): + A, B, C = matmul_auto_scheduler_test(N=512, M=512, K=512) dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() i, j, k = s0[C].iters @@ -61,5 +61,88 @@ def test_split_fuse_reorder(): assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2 + res = s1.bind(C, i1, "blockIdx.x") + assert res == s1[C].iters[0] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["blockIdx.x"] + + res = s1.bind(C, i2, "vthread") + assert res == s1[C].iters[1] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vthread"] + + res = s1.bind(C, i3, "threadIdx.y") + assert res == s1[C].iters[2] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["threadIdx.y"] + + res = s1.parallel(C, j1) + assert res == s1[C].iters[3] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["parallel"] + + res = s1.unroll(C, j2) + assert res == s1[C].iters[4] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["unroll"] + + res = s1.vectorize(C, j3) + assert res == s1[C].iters[5] + assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"] + + +def test_compute_at_root_inline(): + dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64, + kernel_size=7, strides=2, padding=3)) + s0 = dag.get_init_state() + + # data, padding, kernel = 0, 1, 2 + conv = s0.stage_ops[3] + # bias = 4 + bias_add = s0.stage_ops[5] + # bn_scale = 6 + bn_mul = s0.stage_ops[7] + # bn_offset = 8 + bn_add = s0.stage_ops[9] + relu = s0.stage_ops[10] + + s0.compute_inline(bn_add) + assert s0[bn_add].compute_at == 1 + + s0.compute_inline(bn_mul) + assert s0[bn_mul].compute_at == 1 + + s0.compute_inline(bias_add) + assert s0[bias_add].compute_at == 1 + + assert s0[conv].iters[0].range.extent == 1 + assert s0[conv].iters[1].range.extent == 64 + assert s0[conv].iters[2].range.extent == 112 + assert s0[conv].iters[3].range.extent == 112 + assert s0[conv].iters[4].range.extent == 3 + assert s0[conv].iters[5].range.extent == 7 + assert s0[conv].iters[6].range.extent == 7 + s0.compute_at(conv, relu, s0[relu].iters[2]) + assert s0[conv].compute_at == 2 + s0 = dag.infer_bound_from_state(s0) + assert s0[conv].iters[0].range.extent == 1 + assert s0[conv].iters[1].range.extent == 1 + assert s0[conv].iters[2].range.extent == 1 + assert s0[conv].iters[3].range.extent == 112 + assert s0[conv].iters[4].range.extent == 3 + assert s0[conv].iters[5].range.extent == 7 + assert s0[conv].iters[6].range.extent == 7 + + s0.compute_root(bn_mul) + assert s0[bn_mul].compute_at == 0 + + s0.compute_root(conv) + assert s0[conv].compute_at == 0 + s0 = dag.infer_bound_from_state(s0) + assert s0[conv].iters[0].range.extent == 1 + assert s0[conv].iters[1].range.extent == 64 + assert s0[conv].iters[2].range.extent == 112 + assert s0[conv].iters[3].range.extent == 112 + assert s0[conv].iters[4].range.extent == 3 + assert s0[conv].iters[5].range.extent == 7 + assert s0[conv].iters[6].range.extent == 7 + + if __name__ == "__main__": - test_split_fuse_reorder() + test_split_fuse_reorder_annotation() + test_compute_at_root_inline() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index d6e6c51a28ba..333d20e4ce9a 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -18,7 +18,8 @@ """ Test measurement and log serialization. """ import tvm -from tvm import auto_scheduler +import topi +from tvm import te, auto_scheduler import tempfile from test_auto_scheduler_common import get_tiled_matmul @@ -28,7 +29,44 @@ def test_record(): if not tvm.runtime.enabled("llvm"): return - dag, s = get_tiled_matmul() + A = te.placeholder((512, 512), name='A') + B = te.placeholder((512, 512), name='B') + k = te.reduce_axis((0, 512), name='k') + C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') + D = topi.nn.relu(C) + k = te.reduce_axis((0, 512), name='k') + E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C') + F = topi.nn.relu(E) + + dag = auto_scheduler.ComputeDAG([A, B, F]) + s = dag.get_init_state() + + # Split + its0 = s.split(C, s[C].iters[0], [4, 8, 8]) + its1 = s.split(C, s[C].iters[4], [8, 4, 4]) + # Reorder + s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], + its1[3]]) + # Fuse + s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]]) + # Compute at + s.split(F, s[F].iters[0], [2]) + s.compute_at(E, F, s[F].iters[0]) + # Compute inline + s.compute_inline(D) + # Compute root + s.compute_root(D) + # Parallel + s.parallel(C, s[C].iters[0]) + # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here) + s.bind(C, s[C].iters[1], "blockIdx.x") + s.bind(C, s[C].iters[2], "threadIdx.z") + s.bind(C, s[C].iters[3], "vthread") + # Unroll + s.unroll(C, s[C].iters[4]) + # Vectorize + s.vectorize(C, s[C].iters[6]) + target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target)