diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index a1dd4a412eec..bea9cc3e371b 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -144,6 +144,11 @@ class TuningRecordNode : public runtime::Object { * argument information. */ ObjectRef AsJSON() const; + /*! + * \brief Check if this tuning record has valid trace instructions and successful run results. + * \return The check result. + */ + bool IsValid() const; }; /*! @@ -210,7 +215,7 @@ class DatabaseNode : public runtime::Object { */ virtual void CommitTuningRecord(const TuningRecord& record) = 0; /*! - * \brief Get the top K tuning records of given workload from the database. + * \brief Get the top K valid tuning records of given workload from the database. * \param workload The workload to be searched for. * \param top_k The number of top records to be returned. * \return An array of top K tuning records for the given workload. diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index b95cb1ddd7db..6621db9633d9 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -205,7 +205,7 @@ def commit_tuning_record(self, record: TuningRecord) -> None: _ffi_api.DatabaseCommitTuningRecord(self, record) # type: ignore # pylint: disable=no-member def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - """Get the top K tuning records of given workload from the database. + """Get the top K valid tuning records of given workload from the database. Parameters ---------- diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index da1d1db8f1cc..649429f9bc13 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -113,6 +113,21 @@ ObjectRef TuningRecordNode::AsJSON() const { json_args_info}; } +bool TuningRecordNode::IsValid() const { + if (!GetNumValidInstructions(trace->insts, /*remove_postproc*/ true)) { + return false; + } + if (run_secs.defined()) { + for (const auto& run_sec : run_secs.value()) { + // kMaxMeanTime(1e10) is used as a stub for undefined measurement times. + if (run_sec.defined() && run_sec->value != SortTuningRecordByMeanRunSecs::kMaxMeanTime) { + return true; + } + } + } + return false; +} + TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { tir::Trace trace{nullptr}; Optional> run_secs{nullptr}; diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 0e51e262df0f..10ff89a7ce18 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -128,13 +128,7 @@ class JSONDatabaseNode : public DatabaseNode { results.reserve(top_k); for (const TuningRecord& record : this->tuning_records_) { auto run_secs = record->run_secs; - if (!run_secs.defined() || run_secs.value().empty() || - std::all_of(run_secs.value().begin(), run_secs.value().end(), - // kMaxMeanTime(1e10) is used as a stub for undefined measurement times. - [](tvm::FloatImm v) { - return v.defined() && - v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime; - })) { + if (!record->IsValid()) { continue; } if (record->workload.same_as(workload) || @@ -146,8 +140,8 @@ class JSONDatabaseNode : public DatabaseNode { } } if (results.size() < static_cast(top_k)) { - LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " - "enough valid records in the database for this workload."; + LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of " + << top_k << " asked)."; } return results; } diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index 533a86acacfd..b003606c9cc0 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -68,14 +68,7 @@ class MemoryDatabaseNode : public DatabaseNode { std::vector results; results.reserve(records.size()); for (const TuningRecord& record : records) { - auto run_secs = record->run_secs; - if (!run_secs.defined() || run_secs.value().empty() || - std::all_of(run_secs.value().begin(), run_secs.value().end(), - // kMaxMeanTime(1e10) is used as a stub for undefined measurement times. - [](tvm::FloatImm v) { - return v.defined() && - v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime; - })) { + if (!record->IsValid()) { continue; } if (record->workload.same_as(workload) || @@ -88,8 +81,8 @@ class MemoryDatabaseNode : public DatabaseNode { return {results.begin(), results.begin() + top_k}; } else { if (results.size() < static_cast(top_k)) { - LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " - "enough valid records in the database for this workload."; + LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of " + << top_k << " asked)."; } return results; } diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 9213d414e1b5..e60fdf5b9d2b 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -75,7 +75,9 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { // Spatial blocks which are not referenced in the anchor trace will be inlined here. auto block_sref = sch->GetSRef(block); if (IsSpatial(block_sref) && !get_block_names.count(name)) { - if (IsOutputBlock(sch->state(), block_sref, GetScopeRoot(sch->state(), block_sref, false))) { + StmtSRef scopeRoot = + (name != "root") ? GetScopeRoot(sch->state(), block_sref, false) : block_sref; + if (IsOutputBlock(sch->state(), block_sref, scopeRoot)) { last_block_idx = inline_todos.size(); } inline_todos.push_back(name); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index a6aced4632fb..df92c7f80738 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -490,6 +490,14 @@ Array TranslateInputRVs(const Array& inputs, void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, std::unordered_map* rv_map); +/*! + * \brief Counts the number of trace instructions. + * \param insts The instructions representing a trace. + * \param remove_postproc If postprocessing instructions are removed. + * \return Number of instructions. + */ +int GetNumValidInstructions(const Array& insts, bool remove_postproc); + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 90be1ec0a1e9..f1d74348db17 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -724,7 +724,7 @@ def test_module_equality_ignore_ndarray(): np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) -def _test_anchor_tuning(target): +def _test_anchor_tuning(target, space): data_shape = (128, 128) weight_shape1 = (128, 128) weight_shape2 = (128, 128) @@ -756,6 +756,7 @@ def _test_anchor_tuning(target): target=target, params=params, work_dir=work_dir, + space=space, max_trials_global=4, strategy="replay-trace", module_equality=module_equality, @@ -779,8 +780,15 @@ def _test_anchor_tuning(target): np.testing.assert_allclose(ref, out, atol=1e-3) -def test_anchor_tuning_cpu(): - _test_anchor_tuning("llvm --num-cores=4") +@pytest.mark.parametrize( + "space", + [ + ms.space_generator.PostOrderApply(), + ms.space_generator.PostOrderApply(sch_rules=[], postprocs=[], mutator_probs={}), + ], +) +def test_anchor_tuning_cpu(space): + _test_anchor_tuning("llvm --num-cores=4", space) def test_anchor_tuning_cpu_link_params():