Skip to content

Commit

Permalink
[Ansor] Parallel the InitPopulation (#6529)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 authored Sep 23, 2020
1 parent cc96117 commit 56b18ec
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 64 deletions.
53 changes: 35 additions & 18 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "sketch_policy.h"

#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>

#include <algorithm>
#include <iomanip>
Expand Down Expand Up @@ -334,28 +335,44 @@ Array<State> SketchPolicyNode::GenerateSketches() {
Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches, int out_size) {
int fail_ct = 0;
Array<State> out_states;
std::vector<std::mt19937> rand_gens;
rand_gens.reserve(out_size);
for (int i = 0; i < out_size; i++) {
rand_gens.push_back(std::mt19937(rand_gen()));
}
auto tic_begin = std::chrono::high_resolution_clock::now();

while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have
// different potential on generating state with better performance
State tmp_s = sketches[(rand_gen)() % sketches.size()];

// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
std::vector<State> temp_states(out_size);

support::parallel_for(0, out_size - out_states.size(),
[this, &temp_states, &sketches, &rand_gens](int index) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different
// possibility for they may have different potential on generating state
// with better performance
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
if (valid) {
temp_states[index] = std::move(tmp_s);
}
});

for (int i = 0; i < out_size; i++) {
if (temp_states[i].defined()) {
out_states.push_back(std::move(temp_states[i]));
} else {
fail_ct++;
}
}

if (valid) {
out_states.push_back(std::move(tmp_s));
} else {
fail_ct++;
}
}

double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
Expand Down Expand Up @@ -461,7 +478,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul

if (dis(rand_gen) < mutation_prob) {
const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)];
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) {
if (rule->Apply(this, &tmp_s, &rand_gen) == PopulationGenerationRule::ResultKind::kValid) {
pnext->push_back(std::move(tmp_s));
mutation_success_ct++;
} else {
Expand Down
66 changes: 34 additions & 32 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,8 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(

/********** Init Population **********/

PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
StateNode* pstate = state->CopyOnWrite();
// Scan the transformation history and randomly fill tiles size for all SplitStep
for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
Expand All @@ -461,7 +461,7 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
extent, ps->lengths.size(),
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor));
const auto& candidate_lengths = candidate_lens[(policy->rand_gen)() % candidate_lens.size()];
const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];

pstate->transform_steps.Set(
step_id,
Expand All @@ -475,8 +475,8 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(
SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return PopulationGenerationRule::ResultKind::kValid;
}
Expand All @@ -495,7 +495,7 @@ PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPoli
std::vector<std::pair<int, int>> candidates =
GetComputeLocationCandidates(policy->search_task, *state, stage_id);

int choice = (policy->rand_gen)() % (candidates.size() + 2);
int choice = (*rand_gen)() % (candidates.size() + 2);

if (choice == 0) {
if (!HasReduceIter(stage)) {
Expand All @@ -518,8 +518,8 @@ PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPoli
return PopulationGenerationRule::ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::function<void(const SketchPolicyNode&, State*, int stage_id, int iter_offset)>
annotate_parallel;
annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state,
Expand Down Expand Up @@ -583,8 +583,8 @@ PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* polic
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::vector<int>& auto_unroll_configs =
IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
Expand Down Expand Up @@ -625,7 +625,7 @@ PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy,

if (HasReduceIter(stage)) {
// Use auto unroll for multi level tiled stage
int value = auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()];
int value = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()];
state->pragma(stage_id, (*state)->stages[stage_id]->iters[0],
std::string("auto_unroll_max_step") + "$" + std::to_string(value));
}
Expand All @@ -635,7 +635,8 @@ PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy,
}

PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy,
State* state) const {
State* state,
std::mt19937* rand_gen) const {
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stage and placeholder stage
Expand Down Expand Up @@ -679,7 +680,7 @@ PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode*

if (num_fusible > 1) {
// Select a random range to fuse
num_fusible = 1 + (policy->rand_gen)() % (num_fusible - 1);
num_fusible = 1 + (*rand_gen)() % (num_fusible - 1);
}

if (num_fusible == 1) {
Expand All @@ -693,8 +694,8 @@ PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode*
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::set<int> multi_level_tiling_root_set;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
Expand Down Expand Up @@ -847,8 +848,8 @@ PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* pol
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
int max_innermost_split_factor =
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);

Expand Down Expand Up @@ -877,7 +878,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol
const SplitStepNode* ps;

do {
step_id = split_step_ids[(policy->rand_gen)() % split_step_ids.size()];
step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()];
ps = (*state)->transform_steps[step_id].as<SplitStepNode>();
CHECK(ps != nullptr);
extent = GetIntImm(ps->extent.value());
Expand All @@ -898,7 +899,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol

// Random permute the tile size order.
std::vector<int> random_perm;
RandomPermutation(lengths.size(), &random_perm, &(policy->rand_gen));
RandomPermutation(lengths.size(), &random_perm, rand_gen);

// Try to divide a factor from one tile size and multiple it to another.
for (size_t i = 0; i < random_perm.size(); ++i) {
Expand Down Expand Up @@ -926,9 +927,9 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol
// Failed on this dst_idx, try next one.
continue;
}
divide_factor = factors[1 + (policy->rand_gen)() % (max_factor_index)];
divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)];
} else {
divide_factor = factors[1 + (policy->rand_gen)() % (factors.size() - 1)];
divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)];
}

// Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx].
Expand All @@ -955,8 +956,8 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol
return ResultKind::kInvalid;
}

PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
// Extract all auto_unroll_max_step pragma steps.
std::vector<int> pragma_steps;
for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
Expand All @@ -974,20 +975,21 @@ PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* p
IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;

// Randomly pick up an auto unroll pragma step
auto step_id = pragma_steps[(policy->rand_gen)() % pragma_steps.size()];
auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
CHECK(ps);

// Mutate its value to a random candidates
auto val = std::to_string(auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()]);
auto val = std::to_string(auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]);
StateNode* pstate = state->CopyOnWrite();
pstate->transform_steps.Set(step_id, PragmaStep(ps->stage_id, ps->iter_id,
std::string("auto_unroll_max_step") + "$" + val));
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
State* state,
std::mt19937* rand_gen) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return PopulationGenerationRule::ResultKind::kInvalid;
}
Expand All @@ -1013,7 +1015,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo
}

// Randomly pick one step
size_t step_id = compute_at_steps[(policy->rand_gen)() % compute_at_steps.size()];
size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>();
int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id;
CHECK(ps != nullptr);
Expand All @@ -1025,7 +1027,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo
return PopulationGenerationRule::ResultKind::kInvalid;
}

int choice = (policy->rand_gen)() % (candidates.size());
int choice = (*rand_gen)() % (candidates.size());
int new_compute_at_stage_id = candidates[choice].first;
int new_compute_at_iter_id = candidates[choice].second;

Expand All @@ -1049,8 +1051,8 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo
return PopulationGenerationRule::ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
// This mutation rule only focuses on a case that parallel was added to
// the outermost loop and the loop is generated by fusing other loops.
// In short, we mutate the fusion step before the parallel step.
Expand All @@ -1074,7 +1076,7 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol
}

// Randomly pick one parallel step.
size_t step_id = parallel_steps[(policy->rand_gen)() % parallel_steps.size()];
size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<AnnotationStepNode>();
CHECK(ps);
size_t stage_id = ps->stage_id;
Expand Down Expand Up @@ -1113,7 +1115,7 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol

// Mutate the fusion iters and replay the mutated fused/annotation steps.
int iter_offset = 0;
if (RandomChoose(fuse_dir, &(policy->rand_gen)) == 0) {
if (RandomChoose(fuse_dir, rand_gen) == 0) {
fused_ids.pop_back();
iter_offset = 1;
} else {
Expand Down
21 changes: 11 additions & 10 deletions src/auto_scheduler/search_policy/sketch_policy_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,18 @@ class PopulationGenerationRule {
* \param state The state to apply this rule, update inplace.
* \return The result of this rule, indicate if there's any valid state generated.
*/
virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0;
virtual ResultKind Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const = 0;

/*! \brief The deconstructor */
virtual ~PopulationGenerationRule() = default;
};

// A helper to define population initialization rules
#define DEFINE_INIT_POPULATION_RULE(rule_name) \
class rule_name : public PopulationGenerationRule { \
public: \
ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \
#define DEFINE_INIT_POPULATION_RULE(rule_name) \
class rule_name : public PopulationGenerationRule { \
public: \
ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
};

/*! \brief The rule that fills the incomplete SplitSteps. */
Expand Down Expand Up @@ -185,11 +186,11 @@ class PopulationMutationRule : public PopulationGenerationRule {
};

// A helper to define mutation rules used in the evolutionary search
#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \
class rule_name : public PopulationMutationRule { \
public: \
explicit rule_name(double weight) : PopulationMutationRule(weight) {} \
ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \
#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \
class rule_name : public PopulationMutationRule { \
public: \
explicit rule_name(double weight) : PopulationMutationRule(weight) {} \
ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
};

/*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor
Expand Down
Loading

0 comments on commit 56b18ec

Please sign in to comment.