diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 0563699ba6b9..8a8a43658409 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -42,7 +42,8 @@ class UpdateCostModelNode : public MeasureCallbackNode { pruned_candidate.reserve(n); pruned_runner_result.reserve(n); for (int i = 0; i < n; i++) { - if (!builder_results[i]->error_msg.defined()) { + if (!builder_results[i]->error_msg.defined() && + Sum(runner_results[i]->run_secs.value()) > 0) { pruned_candidate.push_back(measure_candidates[i]); pruned_runner_result.push_back(runner_results[i]); } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 80264516c4ce..969aa630df39 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -540,6 +540,19 @@ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { throw; } +/*! + * \brief Summarize the run time of the given FloatImm array. + * \param arr The array of FloatImm. + * \return The summary of the values in the given array. + */ +inline double Sum(const Array& arr) { + double sum = 0; + for (const FloatImm& f : arr) { + sum += f->value; + } + return sum; +} + } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index 20596e8e8c4d..c3fbbbe97231 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -16,12 +16,12 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import re +import tempfile from typing import List import pytest import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.dummy_object import DummyBuilder, DummyRunner from tvm.script import tir as T from tvm.tir.schedule import Schedule @@ -123,7 +123,33 @@ def apply( assert pattern.match(str(measure_callback)) +def test_meta_schedule_measure_callback_update_cost_model_with_zero(): + @ms.derived_object + class AllZeroRunnerFuture(ms.runner.PyRunnerFuture): + def done(self) -> bool: + return True + + def result(self) -> ms.runner.RunnerResult: + return ms.runner.RunnerResult([0.0, 0.0], None) + + @ms.derived_object + class AllZeroRunner(ms.runner.PyRunner): + def run(self, runner_inputs: List[ms.runner.RunnerInput]) -> List[ms.runner.RunnerResult]: + return [AllZeroRunnerFuture() for _ in runner_inputs] + + with tempfile.TemporaryDirectory() as work_dir: + ms.tune_tir( + mod=Matmul, + target="llvm -num-cores=1", + work_dir=work_dir, + max_trials_global=10, + runner=AllZeroRunner(), + measure_callbacks=[ms.measure_callback.UpdateCostModel()], + ) + + if __name__ == "__main__": test_meta_schedule_measure_callback() test_meta_schedule_measure_callback_fail() test_meta_schedule_measure_callback_as_string() + test_meta_schedule_measure_callback_update_cost_model_with_zero()