Skip to content

Commit

Permalink
[MetaSchedule] Fix anchor-block flow with empty design space generator (
Browse files Browse the repository at this point in the history
  • Loading branch information
Icemist authored and yongwww committed Feb 27, 2023
1 parent 8a4edb6 commit 5c73039
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 25 deletions.
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class TuningRecordNode : public runtime::Object {
* argument information.
*/
ObjectRef AsJSON() const;
/*!
* \brief Check if this tuning record has valid trace instructions and successful run results.
* \return The check result.
*/
bool IsValid() const;
};

/*!
Expand Down Expand Up @@ -210,7 +215,7 @@ class DatabaseNode : public runtime::Object {
*/
virtual void CommitTuningRecord(const TuningRecord& record) = 0;
/*!
* \brief Get the top K tuning records of given workload from the database.
* \brief Get the top K valid tuning records of given workload from the database.
* \param workload The workload to be searched for.
* \param top_k The number of top records to be returned.
* \return An array of top K tuning records for the given workload.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def commit_tuning_record(self, record: TuningRecord) -> None:
_ffi_api.DatabaseCommitTuningRecord(self, record) # type: ignore # pylint: disable=no-member

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""Get the top K tuning records of given workload from the database.
"""Get the top K valid tuning records of given workload from the database.
Parameters
----------
Expand Down
15 changes: 15 additions & 0 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ ObjectRef TuningRecordNode::AsJSON() const {
json_args_info};
}

bool TuningRecordNode::IsValid() const {
if (!GetNumValidInstructions(trace->insts, /*remove_postproc*/ true)) {
return false;
}
if (run_secs.defined()) {
for (const auto& run_sec : run_secs.value()) {
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
if (run_sec.defined() && run_sec->value != SortTuningRecordByMeanRunSecs::kMaxMeanTime) {
return true;
}
}
}
return false;
}

TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) {
tir::Trace trace{nullptr};
Optional<Array<FloatImm>> run_secs{nullptr};
Expand Down
12 changes: 3 additions & 9 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,7 @@ class JSONDatabaseNode : public DatabaseNode {
results.reserve(top_k);
for (const TuningRecord& record : this->tuning_records_) {
auto run_secs = record->run_secs;
if (!run_secs.defined() || run_secs.value().empty() ||
std::all_of(run_secs.value().begin(), run_secs.value().end(),
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
[](tvm::FloatImm v) {
return v.defined() &&
v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
})) {
if (!record->IsValid()) {
continue;
}
if (record->workload.same_as(workload) ||
Expand All @@ -146,8 +140,8 @@ class JSONDatabaseNode : public DatabaseNode {
}
}
if (results.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
"enough valid records in the database for this workload.";
LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of "
<< top_k << " asked).";
}
return results;
}
Expand Down
13 changes: 3 additions & 10 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,7 @@ class MemoryDatabaseNode : public DatabaseNode {
std::vector<TuningRecord> results;
results.reserve(records.size());
for (const TuningRecord& record : records) {
auto run_secs = record->run_secs;
if (!run_secs.defined() || run_secs.value().empty() ||
std::all_of(run_secs.value().begin(), run_secs.value().end(),
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
[](tvm::FloatImm v) {
return v.defined() &&
v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
})) {
if (!record->IsValid()) {
continue;
}
if (record->workload.same_as(workload) ||
Expand All @@ -88,8 +81,8 @@ class MemoryDatabaseNode : public DatabaseNode {
return {results.begin(), results.begin() + top_k};
} else {
if (results.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
"enough valid records in the database for this workload.";
LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of "
<< top_k << " asked).";
}
return results;
}
Expand Down
4 changes: 3 additions & 1 deletion src/meta_schedule/trace_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) {
// Spatial blocks which are not referenced in the anchor trace will be inlined here.
auto block_sref = sch->GetSRef(block);
if (IsSpatial(block_sref) && !get_block_names.count(name)) {
if (IsOutputBlock(sch->state(), block_sref, GetScopeRoot(sch->state(), block_sref, false))) {
StmtSRef scopeRoot =
(name != "root") ? GetScopeRoot(sch->state(), block_sref, false) : block_sref;
if (IsOutputBlock(sch->state(), block_sref, scopeRoot)) {
last_block_idx = inline_todos.size();
}
inline_todos.push_back(name);
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const Array<ObjectRef>& new_outputs,
std::unordered_map<const Object*, const Object*>* rv_map);

/*!
* \brief Counts the number of trace instructions.
* \param insts The instructions representing a trace.
* \param remove_postproc If postprocessing instructions are removed.
* \return Number of instructions.
*/
int GetNumValidInstructions(const Array<Instruction>& insts, bool remove_postproc);

} // namespace tir
} // namespace tvm

Expand Down
14 changes: 11 additions & 3 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def test_module_equality_ignore_ndarray():
np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)


def _test_anchor_tuning(target):
def _test_anchor_tuning(target, space):
data_shape = (128, 128)
weight_shape1 = (128, 128)
weight_shape2 = (128, 128)
Expand Down Expand Up @@ -756,6 +756,7 @@ def _test_anchor_tuning(target):
target=target,
params=params,
work_dir=work_dir,
space=space,
max_trials_global=4,
strategy="replay-trace",
module_equality=module_equality,
Expand All @@ -779,8 +780,15 @@ def _test_anchor_tuning(target):
np.testing.assert_allclose(ref, out, atol=1e-3)


def test_anchor_tuning_cpu():
_test_anchor_tuning("llvm --num-cores=4")
@pytest.mark.parametrize(
"space",
[
ms.space_generator.PostOrderApply(),
ms.space_generator.PostOrderApply(sch_rules=[], postprocs=[], mutator_probs={}),
],
)
def test_anchor_tuning_cpu(space):
_test_anchor_tuning("llvm --num-cores=4", space)


def test_anchor_tuning_cpu_link_params():
Expand Down

0 comments on commit 5c73039

Please sign in to comment.