Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule] Fix Task Hanging in EvolutionarySearch (apache#13246)
Browse files Browse the repository at this point in the history
This PR introduces a new argument for EvolutionarySearch that limits the failures (defined as rounds of no new generated candidate) in the `SampleInitPopulation` stage. In this way we can avoid the task to be hanging forever in special cases, e.g., some postproc always fails. This should fix apache#12330.
  • Loading branch information
zxybazh authored and xinetzone committed Nov 25, 2022
1 parent 17bd1ae commit 58c19bc
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class SearchStrategy : public runtime::ObjectRef {
* \param population_size The initial sample population.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
* \param max_fail_count The max number of failure during initial sampling.
* \param genetic_num_iters The iterations to run the genetic algorithm.
* \param genetic_mutate_prob The probability of mutation.
* \param genetic_max_fail_count The maximum number to try evolving the given trace.
Expand All @@ -208,6 +209,7 @@ class SearchStrategy : public runtime::ObjectRef {
TVM_DLL static SearchStrategy EvolutionarySearch(int population_size, //
double init_measured_ratio, //
int init_min_unmeasured, //
int max_fail_count, //
int genetic_num_iters, //
double genetic_mutate_prob, //
int genetic_max_fail_count, //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class EvolutionarySearch(SearchStrategy):
The ratio of measured samples in the initial population.
init_min_unmeasured : int
The minimal size of unmeasured population in the initial sampling.
max_fail_count : int
The maximum number of failure during initial sampling.
genetic_num_iters : int
The number of iterations for genetic algorithm.
genetic_mutate_prob : float
Expand All @@ -59,6 +61,7 @@ def __init__(
population_size: int = 2048,
init_measured_ratio: float = 0.2,
init_min_unmeasured: int = 50,
max_fail_count: int = 5,
genetic_num_iters: int = 4,
genetic_mutate_prob: float = 0.85,
genetic_max_fail_count: int = 10,
Expand All @@ -70,6 +73,7 @@ def __init__(
population_size,
init_measured_ratio,
init_min_unmeasured,
max_fail_count,
genetic_num_iters,
genetic_mutate_prob,
genetic_max_fail_count,
Expand Down
18 changes: 17 additions & 1 deletion src/meta_schedule/search_strategy/evolutionary_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ class EvolutionarySearchNode : public SearchStrategyNode {
double init_measured_ratio;
/*! \brief The minimal size of unmeasured population in the initial sampling.*/
int init_min_unmeasured;
/*! \brief The maximum number of failure during initial sampling. */
int max_fail_count;
/*** Configuration: evolution ***/
/*! \brief The number of iterations performed by generic algorithm. */
int genetic_num_iters;
Expand All @@ -387,6 +389,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
/*** Configuration: the initial population ***/
v->Visit("init_measured_ratio", &init_measured_ratio);
v->Visit("init_min_unmeasured", &init_min_unmeasured);
v->Visit("max_fail_count", &max_fail_count);
/*** Configuration: evolution ***/
v->Visit("genetic_num_iters", &genetic_num_iters);
v->Visit("genetic_mutate_prob", &genetic_mutate_prob);
Expand Down Expand Up @@ -456,6 +459,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop;
n->init_measured_ratio = this->init_measured_ratio;
n->init_min_unmeasured = this->init_min_unmeasured;
n->max_fail_count = this->max_fail_count;
n->genetic_num_iters = this->genetic_num_iters;
n->genetic_mutate_prob = this->genetic_mutate_prob;
n->genetic_max_fail_count = this->genetic_max_fail_count;
Expand Down Expand Up @@ -501,7 +505,9 @@ std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int nu
auto _ = Profiler::TimedScope("EvoSearch/SampleInitPopulation");
ThreadedTraceApply pp(self->postprocs_);
std::vector<Schedule> out_schs;
while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured) {
int fail_count = 0;
while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured &&
fail_count < self->max_fail_count) {
std::vector<Schedule> results(num, Schedule{nullptr});
auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void {
PerThreadData& data = this->per_thread_data_.at(thread_id);
Expand All @@ -516,11 +522,14 @@ std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int nu
}
};
support::parallel_for_dynamic(0, num, self->ctx_->num_threads, f_proc_unmeasured);
bool found_new = false;
for (int i = 0; i < num; i++) {
if (results[i].defined()) {
found_new = true;
out_schs.push_back(results[i]);
}
}
fail_count += !found_new;
TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n"
<< pp.SummarizeFailures();
}
Expand Down Expand Up @@ -706,6 +715,11 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
TVM_PY_LOG(INFO, self->ctx_->logger)
<< "Picked top " << measured.size() << " candidate(s) from database";
std::vector<Schedule> unmeasured = SampleInitPopulation(pop - measured.size());
if (static_cast<int>(unmeasured.size()) < self->init_min_unmeasured) {
TVM_PY_LOG(WARNING, self->ctx_->logger)
<< "Cannot sample enough initial population, evolutionary search failed.";
return NullOpt;
}
TVM_PY_LOG(INFO, self->ctx_->logger) << "Sampled " << unmeasured.size() << " candidate(s)";
inits.insert(inits.end(), measured.begin(), measured.end());
inits.insert(inits.end(), unmeasured.begin(), unmeasured.end());
Expand Down Expand Up @@ -737,6 +751,7 @@ size_t EvolutionarySearchNode::State::ModuleHash(const IRModule& mod) const {
SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, //
double init_measured_ratio, //
int init_min_unmeasured, //
int max_fail_count, //
int genetic_num_iters, //
double genetic_mutate_prob, //
int genetic_max_fail_count, //
Expand All @@ -749,6 +764,7 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, /
n->num_empty_iters_before_early_stop = 5;
n->init_measured_ratio = init_measured_ratio;
n->init_min_unmeasured = init_min_unmeasured;
n->max_fail_count = max_fail_count;
n->genetic_num_iters = genetic_num_iters;
n->genetic_max_fail_count = genetic_max_fail_count;
n->genetic_mutate_prob = genetic_mutate_prob;
Expand Down
56 changes: 56 additions & 0 deletions tests/python/unittest/test_meta_schedule_search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tvm
import tvm.testing
from tvm import meta_schedule as ms
from tvm.meta_schedule.utils import derived_object
from tvm.meta_schedule.testing.dummy_object import DummyMutator
from tvm.script import tir as T
from tvm.tir.schedule import Schedule, Trace
Expand Down Expand Up @@ -251,8 +252,63 @@ def _schedule_matmul_empty(sch: Schedule):
assert num_trials_each_iter == [1, 0, 0, 0, 0]


def test_meta_schedule_evolutionary_search_fail_init_population(): # pylint: disable = invalid-name
@derived_object
class AlwaysFailPostproc(ms.postproc.PyPostproc):
"""A postproc that always fails."""

def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
pass

def apply(self, sch: Schedule) -> bool:
return False

def clone(self) -> "AlwaysFailPostproc":
return AlwaysFailPostproc()

def __str__(self) -> str:
return "AlwaysFailPostproc"

num_trials_per_iter = 10
max_trials_per_task = 2000

context = ms.TuneContext(
mod=Matmul,
space_generator=ms.space_generator.ScheduleFn(
sch_fn=_schedule_matmul,
sch_rules=[],
postprocs=[AlwaysFailPostproc()],
mutator_probs={
DummyMutator(): 1.0,
},
),
search_strategy=ms.search_strategy.EvolutionarySearch(
population_size=5,
init_measured_ratio=0.1,
init_min_unmeasured=50,
genetic_num_iters=3,
genetic_mutate_prob=0.5,
genetic_max_fail_count=10,
eps_greedy=0.9,
),
target=tvm.target.Target("llvm"),
num_threads=1, # because we are using a mutator from the python side
)
strategy = context.search_strategy
strategy.pre_tuning(
max_trials=max_trials_per_task,
num_trials_per_iter=num_trials_per_iter,
design_spaces=context.space_generator.generate_design_space(context.mod),
database=ms.database.MemoryDatabase(),
cost_model=ms.cost_model.RandomModel(),
)
candidates = strategy.generate_measure_candidates()
assert candidates is None


if __name__ == "__main__":
test_meta_schedule_replay_func(ms.search_strategy.ReplayFunc)
test_meta_schedule_replay_func(ms.search_strategy.ReplayTrace)
test_meta_schedule_evolutionary_search()
test_meta_schedule_evolutionary_search_early_stop()
test_meta_schedule_evolutionary_search_fail_init_population()

0 comments on commit 58c19bc

Please sign in to comment.