Skip to content

Commit

Permalink
[Meta Schedule] Add Winograd Test for Customizable Search Space (#24)
Browse files Browse the repository at this point in the history
* Finish relay tuning with ApplyHistoryBest.

* Fix tune relay.

* Add winograd conv test cpu.

Add cuda test.

Nit.

Nits.
  • Loading branch information
zxybazh authored Jan 26, 2022
1 parent 467abf5 commit 494101d
Show file tree
Hide file tree
Showing 2 changed files with 654 additions and 26 deletions.
65 changes: 49 additions & 16 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,7 @@ class BlockCollector : public tir::StmtVisitor {
blocks_to_collect_.clear();
VisitStmt(func->body);
for (const String& block_name : blocks_to_collect_) {
tir::BlockRV block_rv = sch_->GetBlock(block_name, func_name_);
// pick out the blocks with annotation for customized search space
if (Optional<ObjectRef> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch_->GetSRef(block_rv), "schedule_rule")) {
String custom_sch_rule_name = Downcast<String>(custom_sch_rule_name_opt.value());
if (custom_sch_rule_name != "None") {
const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name);
(*custom_sch_rule_func)(sch_, block_rv);
}
} else {
results_.push_back(block_rv);
}
results_.push_back(sch_->GetBlock(block_name, func_name_));
}
}
}
Expand Down Expand Up @@ -109,17 +98,61 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/mod_, //
/*rand_state=*/ForkSeed(&this->rand_state_), //
/*debug_mode=*/0, //
/*debug_mode=*/0, // tir::kVerifySRefTree | tir::kVerifyCachedFlags
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);

std::vector<ScheduleAndUnvisitedBlocks> stack;
Array<tir::Schedule> result{sch};
Array<tir::Schedule> result;
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch), func_blocks, non_func_blocks;
for (const tir::BlockRV& block_rv : all_blocks) {
if (Optional<String> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch->GetSRef(block_rv), "schedule_rule")) {
if (custom_sch_rule_name_opt.value() != "None") {
func_blocks.push_back(block_rv);
}
} else {
non_func_blocks.push_back(block_rv);
}
}

// only do this once for schedule rules on block annotations
stack.emplace_back(sch, func_blocks);
while (!stack.empty()) {
// get the stack.top()
tir::Schedule sch;
Array<tir::BlockRV> blocks;
std::tie(sch, blocks) = stack.back();
stack.pop_back();
// if all blocks are visited
if (blocks.empty()) {
result.push_back(sch);
continue;
}
// otherwise, get the last block that is not visited
tir::BlockRV block_rv = blocks.back();
blocks.pop_back();
if (sch->HasBlock(block_rv)) {
// pick out the blocks with annotation for customized search space
Optional<String> custom_sch_rule_name_opt =
tir::GetAnn<String>(sch->GetSRef(block_rv), "schedule_rule");
ICHECK(custom_sch_rule_name_opt.defined() && custom_sch_rule_name_opt.value() != "None");
String custom_sch_rule_name = custom_sch_rule_name_opt.value();
const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name);
CHECK(custom_sch_rule_func) << "The given custom schedule function is not defined!";
Array<tir::Schedule> applied = (*custom_sch_rule_func)(sch, block_rv);
for (const tir::Schedule& sch : applied) {
stack.emplace_back(sch, blocks);
}
} else {
stack.emplace_back(sch, blocks);
}
}

// Enumerate the schedule rules first because you can
// always concat multiple schedule rules as one
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch);
for (ScheduleRule sch_rule : sch_rules_) {
for (const tir::Schedule& sch : result) {
stack.emplace_back(sch, all_blocks);
stack.emplace_back(sch, non_func_blocks);
}
result.clear();

Expand Down
Loading

0 comments on commit 494101d

Please sign in to comment.