Skip to content

Commit

Permalink
[MetaSchedule][Minor] Allow Zero Run Time In Benchmarking Result (#13354
Browse files Browse the repository at this point in the history
)

This PR introduces a check to prevent records with run time of zero into the training data of cost model. This is because when working on microTVM there're cases where the run time of certain successful runs is very tiny, such that it got recorded as zero. In such cases, the runtime of 0 would break XGBoost model because it introduces infinite running speed in GFLOPs. A regression test was also added.
  • Loading branch information
zxybazh authored Nov 11, 2022
1 parent 93fdf83 commit f950b11
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/meta_schedule/measure_callback/update_cost_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class UpdateCostModelNode : public MeasureCallbackNode {
pruned_candidate.reserve(n);
pruned_runner_result.reserve(n);
for (int i = 0; i < n; i++) {
if (!builder_results[i]->error_msg.defined()) {
if (!builder_results[i]->error_msg.defined() &&
Sum(runner_results[i]->run_secs.value()) > 0) {
pruned_candidate.push_back(measure_candidates[i]);
pruned_runner_result.push_back(runner_results[i]);
}
Expand Down
13 changes: 13 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,19 @@ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) {
throw;
}

/*!
* \brief Summarize the run time of the given FloatImm array.
* \param arr The array of FloatImm.
* \return The summary of the values in the given array.
*/
inline double Sum(const Array<FloatImm>& arr) {
double sum = 0;
for (const FloatImm& f : arr) {
sum += f->value;
}
return sum;
}

} // namespace meta_schedule
} // namespace tvm

Expand Down
28 changes: 27 additions & 1 deletion tests/python/unittest/test_meta_schedule_measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import re
import tempfile
from typing import List

import pytest
import tvm
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.dummy_object import DummyBuilder, DummyRunner
from tvm.script import tir as T
from tvm.tir.schedule import Schedule

Expand Down Expand Up @@ -123,7 +123,33 @@ def apply(
assert pattern.match(str(measure_callback))


def test_meta_schedule_measure_callback_update_cost_model_with_zero():
@ms.derived_object
class AllZeroRunnerFuture(ms.runner.PyRunnerFuture):
def done(self) -> bool:
return True

def result(self) -> ms.runner.RunnerResult:
return ms.runner.RunnerResult([0.0, 0.0], None)

@ms.derived_object
class AllZeroRunner(ms.runner.PyRunner):
def run(self, runner_inputs: List[ms.runner.RunnerInput]) -> List[ms.runner.RunnerResult]:
return [AllZeroRunnerFuture() for _ in runner_inputs]

with tempfile.TemporaryDirectory() as work_dir:
ms.tune_tir(
mod=Matmul,
target="llvm -num-cores=1",
work_dir=work_dir,
max_trials_global=10,
runner=AllZeroRunner(),
measure_callbacks=[ms.measure_callback.UpdateCostModel()],
)


if __name__ == "__main__":
test_meta_schedule_measure_callback()
test_meta_schedule_measure_callback_fail()
test_meta_schedule_measure_callback_as_string()
test_meta_schedule_measure_callback_update_cost_model_with_zero()

0 comments on commit f950b11

Please sign in to comment.