diff --git a/docs/api/python/auto_scheduler.rst b/docs/api/python/auto_scheduler.rst index 85ff22f58b37..a7c190ab0ffa 100644 --- a/docs/api/python/auto_scheduler.rst +++ b/docs/api/python/auto_scheduler.rst @@ -31,5 +31,20 @@ tvm.auto_scheduler.auto_schedule .. autofunction:: tvm.auto_scheduler.auto_schedule.auto_schedule +tvm.auto_scheduler.workload_registry +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: tvm.auto_scheduler.workload_registry.register_workload + +tvm.auto_scheduler.measure +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: tvm.auto_scheduler.measure + +.. autoclass:: tvm.auto_scheduler.measure.LocalRPCMeasureContext + +.. autoclass:: tvm.auto_scheduler.measure.LocalRunner + +.. autoclass:: tvm.auto_scheduler.measure.LocalBuilder + +.. autoclass:: tvm.auto_scheduler.measure.RPCRunner diff --git a/include/tvm/auto_scheduler/search_policy.h b/include/tvm/auto_scheduler/search_policy.h index 176b10c1d7ea..ddb0dd284875 100755 --- a/include/tvm/auto_scheduler/search_policy.h +++ b/include/tvm/auto_scheduler/search_policy.h @@ -65,6 +65,7 @@ #include #include +#include #include #include @@ -191,7 +192,7 @@ class SearchPolicyNode : public Object { * We store the string format of a state for redundancy check. This is used to make sure a * measured state will never be measured again. */ - std::unordered_set measured_states_set_; + std::unordered_set measured_states_set_; /*! \brief The array of already measured states. * The good states can be used as the initial population in evolutionary search. */ std::vector measured_states_vector_; diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index c57b39bb1d10..eebccf4b9d6e 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" +r""" Distributed measurement infrastructure to measure the runtime costs of tensor programs. These functions are responsible for building the tvm module, uploading it to @@ -25,8 +25,8 @@ A builder builds the executable binary files and a runner runs the binary files to get the measurement results. The flow of data structures is - `ProgramBuilder` `ProgramRunner` -`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` + . `ProgramBuilder` `ProgramRunner` + `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` We implement these in python to utilize python's multiprocessing and error handling. """ @@ -222,7 +222,7 @@ class LocalRunner(ProgramRunner): where the first "1" is warm up and will be discarded. The returned result contains `repeat` costs, each of which is an average of `number` costs. - min_repeat_ms : int = 0 + min_repeat_ms : int = 100 The minimum duration of one `repeat` in milliseconds. By default, one `repeat` contains `number` runs. If this parameter is set, the parameters `number` will be dynamically adjusted to meet the @@ -244,7 +244,7 @@ def __init__( timeout=10, number=3, repeat=1, - min_repeat_ms=0, + min_repeat_ms=100, cooldown_interval=0.0, enable_cpu_cache_flush=False, ): @@ -289,7 +289,7 @@ class RPCRunner(ProgramRunner): where the first "1" is warm up and will be discarded. The returned result contains `repeat` costs, each of which is an average of `number` costs. - min_repeat_ms : int = 0 + min_repeat_ms : int = 100 The minimum duration of one `repeat` in milliseconds. By default, one `repeat` contains `number` runs. If this parameter is set, the parameters `number` will be dynamically adjusted to meet the @@ -316,7 +316,7 @@ def __init__( timeout=10, number=3, repeat=1, - min_repeat_ms=0, + min_repeat_ms=100, cooldown_interval=0.0, enable_cpu_cache_flush=False, ): diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index a9d323622277..6e278aed021a 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -91,7 +91,7 @@ class SketchPolicy(SearchPolicy): ---------- task : SearchTask The SearchTask for the computation declaration. - schedule_cost_model : CostModel = RandomModel() + program_cost_model : CostModel = RandomModel() The cost model to estimate the complete schedules. params : Optional[Dict[str, Any]] Parameters of the search policy. @@ -129,7 +129,7 @@ class SketchPolicy(SearchPolicy): def __init__( self, task, - schedule_cost_model=RandomModel(), + program_cost_model=RandomModel(), params=None, seed=None, verbose=1, @@ -145,7 +145,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.SketchPolicy, task, - schedule_cost_model, + program_cost_model, params, seed or random.randint(1, 1 << 30), verbose, diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index f0c839800d70..1d9ee6da4f7a 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -55,13 +55,15 @@ def register_workload(func_name, f=None, override=False): Examples -------- - @auto_scheduler.register_workload - def matmul(N, M, K): - A = te.placeholder((N, K), name='A') - B = te.placeholder((K, M), name='B') - k = te.reduce_axis((0, K), name='k') - C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') - return [A, B, C] + .. code-block:: python + + @auto_scheduler.register_workload + def matmul(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] """ global WORKLOAD_FUNC_REGISTRY diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 084f46716e15..3565040e1d76 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -22,10 +22,14 @@ from .._ffi import get_global_func from ..contrib import graph_runtime -from .base import _rpc_connect from ..rpc import RPCSession from .transport import TransportLogger +try: + from .base import _rpc_connect +except ImportError: + raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake") + class Session: """MicroTVM Device Session diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index ffc00941143c..6b4b6ae120bd 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -45,7 +46,6 @@ namespace tvm { namespace auto_scheduler { /********** Sketch generation rules **********/ - static RuleSkipStage rule_skip_stage; static RuleAlwaysInline rule_always_inline; static RuleMultiLevelTiling rule_multi_level_tiling; @@ -58,7 +58,6 @@ static RuleSimplifyComputeWithConstTensor rule_simplify_compute_with_const_tenso static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu; /********** Init population rules **********/ - static InitFillTileSize init_fill_tile_size; static InitChangeComputeLocation init_change_compute_location; static InitParallel init_parallel; @@ -66,23 +65,15 @@ static InitUnroll init_unroll; static InitVectorization init_vectorization; static InitThreadBind init_thread_bind; -/********** Mutation rules **********/ - -static MutateTileSize mutate_tile_size; -static MutateMaxUnrollFactor mutate_max_unroll_factor; -static MutateComputeLocation mutate_compute_location; -static MutateParallel mutate_parallel; - /********** Sketch policy **********/ - TVM_REGISTER_NODE_TYPE(SketchPolicyNode); -SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model, +SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model, Map params, int seed, int verbose, Optional> init_search_callbacks) { auto node = make_object(); node->search_task = std::move(task); - node->schedule_cost_model = std::move(schedule_cost_model); + node->program_cost_model = std::move(program_cost_model); node->rand_gen = std::mt19937(seed); node->params = std::move(params); node->verbose = verbose; @@ -97,18 +88,32 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model, node->RunCallbacks(init_search_callbacks.value()); } - // Notice: Some rules require us to skip all the rest rules after they are applied. - // So the rules below should be ordered carefully. + // NOTE: There are strong dependency among the rules below, + // so the order to push them into the vector should be considered carefully. if (IsCPUTask(node->search_task)) { - // The default sketch rules for CPU policy + // Sketch Generation Rules node->sketch_rules.push_back(&rule_always_inline); node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor); node->sketch_rules.push_back(&rule_add_rfactor); node->sketch_rules.push_back(&rule_add_cache_write_stage); node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); node->sketch_rules.push_back(&rule_multi_level_tiling); - } else if (IsCUDATask(node->search_task)) { - // The default sketch rules for CUDA policy + node->sketch_rules.push_back(&rule_skip_stage); + + // Initial Population Generation Rules + node->init_rules.push_back(&init_fill_tile_size); + node->init_rules.push_back(&init_change_compute_location); + node->init_rules.push_back(&init_parallel); + node->init_rules.push_back(&init_unroll); + node->init_rules.push_back(&init_vectorization); + + // Mutation Rules for Evolutionary Search + node->mutation_rules.push_back(std::make_shared(0.90)); + node->mutation_rules.push_back(std::make_shared(0.04)); + node->mutation_rules.push_back(std::make_shared(0.05)); + node->mutation_rules.push_back(std::make_shared(0.01)); + } else if (IsGPUTask(node->search_task)) { + // Sketch Generation Rules node->sketch_rules.push_back(&rule_add_cache_read_stage); node->sketch_rules.push_back(&rule_always_inline); node->sketch_rules.push_back(&rule_special_compute_location_gpu); @@ -117,32 +122,20 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model, node->sketch_rules.push_back(&rule_add_cache_write_stage); node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); node->sketch_rules.push_back(&rule_multi_level_tiling); - } else { - LOG(FATAL) << "No default sketch rules for target: " << task->target; - } - node->sketch_rules.push_back(&rule_skip_stage); // This should always be the last rule + node->sketch_rules.push_back(&rule_skip_stage); - node->init_rules.push_back(&init_fill_tile_size); // This should always be the first rule - if (IsCPUTask(node->search_task)) { - // The default init population rules for CPU policy - node->init_rules.push_back(&init_change_compute_location); - node->init_rules.push_back(&init_parallel); - node->init_rules.push_back(&init_unroll); - node->init_rules.push_back(&init_vectorization); - } else if (IsCUDATask(node->search_task)) { - // The default init population rules for CUDA policy + // Initial Population Generation Rules + node->init_rules.push_back(&init_fill_tile_size); node->init_rules.push_back(&init_thread_bind); node->init_rules.push_back(&init_unroll); + + // Mutation Rules for Evolutionary Search + node->mutation_rules.push_back(std::make_shared(0.90)); + node->mutation_rules.push_back(std::make_shared(0.10)); } else { - LOG(FATAL) << "No default init rules for target: " << task->target; + LOG(FATAL) << "No default sketch rules for target: " << task->target; } - // The default mutation rules. - node->mutation_rules.push_back(&mutate_tile_size); - node->mutation_rules.push_back(&mutate_max_unroll_factor); - node->mutation_rules.push_back(&mutate_compute_location); - node->mutation_rules.push_back(&mutate_parallel); - data_ = std::move(node); } @@ -169,7 +162,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure if (!inputs.empty()) { // Retrain cost models before the next search round PrintTitle("Train cost model", verbose); - schedule_cost_model->Update(inputs, results); + program_cost_model->Update(inputs, results); } // Search one round to get promising states @@ -179,9 +172,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure // Infer bound. This is necessary for computing the correct ToStr() for redundancy check best_states = search_task->compute_dag.InferBound(best_states); - PruneInvalidState(search_task, &best_states); random_states = search_task->compute_dag.InferBound(random_states); - PruneInvalidState(search_task, &random_states); // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state // Also pick some random states to do eps-greedy @@ -242,14 +233,16 @@ Array SketchPolicyNode::SearchOneRound(int num_random_states, Array( GetDoubleParam(params, SketchParamKey::EvolutionarySearch::use_measured_ratio) * population)); - bool is_cost_model_reasonable = !schedule_cost_model->IsInstance(); + bool is_cost_model_reasonable = !program_cost_model->IsInstance(); // 1. Generate sketches - const Array& sketches = GenerateSketches(); + if (sketch_cache_.empty()) { + sketch_cache_ = GenerateSketches(); + } // 2. Sample the init population Array init_population = SampleInitPopulation( - sketches, is_cost_model_reasonable ? population - num_use_measured : population); + sketch_cache_, is_cost_model_reasonable ? population - num_use_measured : population); // 3. If the cost model is useless (i.e. RandomCostModel), just random pick some generated // states, else perform evolutionary search @@ -260,7 +253,7 @@ Array SketchPolicyNode::SearchOneRound(int num_random_states, Array SketchPolicyNode::GenerateSketches() { // A map that maps state to its current working position (stage_id) std::unordered_map cur_stage_id_map; - cur_stage_id_map[init_state] = static_cast(init_state->stages.size() - 1); + cur_stage_id_map[init_state] = static_cast(init_state->stages.size()) - 1; // Derivation rule based enumeration Array out_states; @@ -379,7 +372,7 @@ Array SketchPolicyNode::EvolutionarySearch(const Array& init_popul Array best_states; auto tic_begin = std::chrono::high_resolution_clock::now(); - size_t population = init_population.size(); + size_t population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population); int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters); double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob); @@ -390,135 +383,102 @@ Array SketchPolicyNode::EvolutionarySearch(const Array& init_popul Array* pnow = &states_buf1; Array* pnext = &states_buf2; - // The set of explored states to avoid redundancy. - std::unordered_set explored_set; - - // The heap to maintain the so far best states. + // A heap to keep the best states during evolution using StateHeapItem = std::pair; auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) { return left.second > right.second; }; - using StateHeap = std::priority_queue, decltype(cmp)>; - StateHeap heap(cmp); - auto update_heap = [&heap, &explored_set](const Array& states, - const std::vector& scores, const int out_size) { - float max_score = 0.0; - for (size_t i = 0; i < states.size(); ++i) { - const State& state = states[i]; + std::vector heap; + std::unordered_set in_heap(measured_states_set_); + heap.reserve(out_size); + + // auxiliary global variables + std::vector pop_scores; + std::vector pop_selection_probs; + float max_score = 0.0; + pop_scores.reserve(population); + pop_selection_probs.reserve(population); + std::uniform_real_distribution<> dis(0.0, 1.0); + + // mutation rules + int mutation_success_ct, mutation_fail_ct; + mutation_success_ct = mutation_fail_ct = 0; + std::vector rule_weights; + std::vector rule_selection_probs; + for (const auto& rule : mutation_rules) { + rule_weights.push_back(rule->weight); + } + ComputePrefixSumProb(rule_weights, &rule_selection_probs); + + // Genetic Algorithm + for (int k = 0; k < num_iters + 1; ++k) { + // Maintain the heap + *pnow = search_task->compute_dag.InferBound(*pnow); + PruneInvalidState(search_task, pnow); + program_cost_model->Predict(search_task, *pnow, &pop_scores); + + for (size_t i = 0; i < pnow->size(); ++i) { + const State& state = (*pnow)[i]; std::string state_str = state.ToStr(); - // Skip redundant states. - if (explored_set.count(state_str) > 0) { - continue; - } - explored_set.insert(state_str); - - if (static_cast(heap.size()) < out_size) { - // Directly push item if the heap is not full yet. - heap.push({state, scores[i]}); - } else if (scores[i] > heap.top().second) { - // Replace the worst state in the heap with the new state. - heap.pop(); - heap.push({state, scores[i]}); + if (in_heap.count(state_str) == 0) { + if (static_cast(heap.size()) < out_size) { + heap.emplace_back((*pnow)[i], pop_scores[i]); + std::push_heap(heap.begin(), heap.end(), cmp); + in_heap.insert(state_str); + } else if (pop_scores[i] > heap.front().second) { + std::string old_state_str = heap.front().first.ToStr(); + in_heap.erase(old_state_str); + in_heap.insert(state_str); + + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.back() = StateHeapItem(state, pop_scores[i]); + std::push_heap(heap.begin(), heap.end(), cmp); + } + if (pop_scores[i] > max_score) { + max_score = pop_scores[i]; + } } - max_score = (scores[i] > max_score) ? scores[i] : max_score; } - return max_score; - }; - // Cost model predicted scores. - std::vector scores; - scores.reserve(population); - - // The function to generate prefix sum probabilities based on the given scores. - auto assign_prob = [](const std::vector& scores, std::vector* prefix_sum_probs) { - // Compute selection probabilities. - double sum = 0.0; - prefix_sum_probs->resize(scores.size()); - for (size_t i = 0; i < scores.size(); ++i) { - sum += std::max(scores[i], 0.0f); - (*prefix_sum_probs)[i] = sum; + // Print statistical information + if (k % 5 == 0 || k == num_iters) { + StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4) + << "\tMax score: " << max_score << "\tMin score: " << heap.front().second + << "\t#Pop: " << pnow->size() << "\t#M+: " << mutation_success_ct / (k + 1) + << "\t#M-: " << mutation_fail_ct / (k + 1) << std::endl; } - for (size_t i = 0; i < scores.size(); ++i) { - (*prefix_sum_probs)[i] /= sum; + if (k == num_iters) { + break; } - }; - // State selection probabilities. - std::uniform_real_distribution<> uniform_dist(0.0, 1.0); - std::vector state_select_probs; - state_select_probs.reserve(population); + // Compute selection probability + ComputePrefixSumProb(pop_scores, &pop_selection_probs); - // Mutation rule selection probabilities. - std::vector rule_select_probs; - rule_select_probs.reserve(mutation_rules.size()); - std::vector rule_levels; - for (const auto& rule : mutation_rules) { - rule_levels.push_back(rule->GetLevel(search_task)); - } - assign_prob(rule_levels, &rule_select_probs); - - // Evaluate the init populations. - *pnow = search_task->compute_dag.InferBound(*pnow); - PruneInvalidState(search_task, pnow); - CHECK_GT(pnow->size(), 0) << "All initial populations are invalid"; - schedule_cost_model->Predict(search_task, *pnow, &scores); - - // Maintain the best states in the heap. - float max_score = update_heap(*pnow, scores, out_size); - - // Genetic algorithm. - for (auto iter_idx = 1; iter_idx <= num_iters; ++iter_idx) { - // Assign the selection probability to each state based on the cost model scores. - assign_prob(scores, &state_select_probs); - - // TODO(@comaniac): Perform cross over. - - // Perform mutations. - size_t fail_ct = 0; - while (pnext->size() < population && fail_ct < population * 2) { - // Select a state to be mutated. - State tmp_s = (*pnow)[RandomChoose(state_select_probs, &rand_gen)]; - if (uniform_dist(rand_gen) < mutation_prob) { - // Select a rule and mutate the state. - const auto& rule = mutation_rules[RandomChoose(rule_select_probs, &rand_gen)]; + // Do mutation + while (pnext->size() < population) { + State tmp_s = (*pnow)[RandomChoose(pop_selection_probs, &rand_gen)]; + + if (dis(rand_gen) < mutation_prob) { + const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)]; if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) { pnext->push_back(std::move(tmp_s)); + mutation_success_ct++; } else { - fail_ct++; + mutation_fail_ct++; } } else { - // Do not mutate this state in this round. pnext->push_back(std::move(tmp_s)); } } - // Evaluate the new populations. - *pnext = search_task->compute_dag.InferBound(*pnext); - PruneInvalidState(search_task, pnext); - - // Throw away all states generated in this iterations if all new states are invalid. - if (pnext->size() > 0) { - std::swap(pnext, pnow); - schedule_cost_model->Predict(search_task, *pnow, &scores); - - // Maintain the best states in the heap. - float iter_max_score = update_heap(*pnow, scores, out_size); - max_score = (iter_max_score > max_score) ? iter_max_score : max_score; - } + std::swap(pnext, pnow); pnext->clear(); - - if (iter_idx % 5 == 0 || iter_idx == num_iters) { - StdCout(verbose) << "GA Iter: " << iter_idx << std::fixed << std::setprecision(4) - << "\tMax Score: " << max_score << "\tPop Size: " << pnow->size() - << std::endl; - } } - // Copy best states in the heap to the output. - while (!heap.empty()) { - auto item = heap.top(); - heap.pop(); + // Copy best states in the heap to out_states + std::sort(heap.begin(), heap.end(), cmp); + for (auto& item : heap) { best_states.push_back(std::move(item.first)); } @@ -580,10 +540,10 @@ Array SketchPolicyNode::PickStatesWithEpsGreedy(const Array } TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy") - .set_body_typed([](SearchTask task, CostModel schedule_cost_model, - Map params, int seed, int verbose, + .set_body_typed([](SearchTask task, CostModel program_cost_model, Map params, + int seed, int verbose, Optional> init_search_callbacks) { - return SketchPolicy(task, schedule_cost_model, params, seed, verbose, init_search_callbacks); + return SketchPolicy(task, program_cost_model, params, seed, verbose, init_search_callbacks); }); TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyGenerateSketches") diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 2d93d8775c86..21aaa6ef7b90 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -34,6 +34,7 @@ #include #include +#include #include #include #include @@ -88,15 +89,15 @@ struct SketchParamKey { class SketchPolicyNode : public SearchPolicyNode { public: /*! \brief The cost model to estimate the complete schedules. */ - CostModel schedule_cost_model; + CostModel program_cost_model; /*! \brief The parameters map for this search policy. */ Map params; /*! \brief The rules to generate sketches. */ std::vector sketch_rules; - /*! \brief The rules to generate initial states. */ + /*! \brief The rules to generate initial population. */ std::vector init_rules; - /*! \brief The rules to mutate states. */ - std::vector mutation_rules; + /*! \brief The rules to mutate states in the evolutionary search. */ + std::vector> mutation_rules; /*! \brief Random generator. */ std::mt19937 rand_gen; /*! \brief Memorize split space for Split. */ @@ -154,6 +155,9 @@ class SketchPolicyNode : public SearchPolicyNode { /*! \brief The number of states to measure per iteration. */ int num_measure_per_iter_; + + /*! \brief The cached sketches */ + Array sketch_cache_; }; /*! @@ -165,14 +169,14 @@ class SketchPolicy : public SearchPolicy { /*! * \brief The constructor. * \param task The SearchTask for the computation declaration. - * \param schedule_cost_model The cost model for complete programs. + * \param program_cost_model The cost model for complete programs. * \param params The parameters map for this search process. * \param seed The random seed of this search process. * \param verbose Verbose level. 0 for silent, 1 to output information during schedule * search. * \param init_search_callbacks SearchCallback to be called before schedule search. */ - SketchPolicy(SearchTask task, CostModel schedule_cost_model, Map params, + SketchPolicy(SearchTask task, CostModel program_cost_model, Map params, int seed, int verbose, Optional> init_search_callbacks); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode); diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index dab6e4d65f20..228dda461beb 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -34,6 +34,9 @@ namespace tvm { namespace auto_scheduler { +static std::vector auto_unroll_configs_cpu = {0, 16, 64, 512}; +static std::vector auto_unroll_configs_gpu = {0, 16, 64, 512, 1024}; + /********** Sketch Generation Rule **********/ /********** RuleSkipStage **********/ @@ -472,9 +475,8 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p return ResultKind::kValid; } -PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNode* policy, - State* state, - bool infer_bound = true) { +PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, + State* state) const { if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { return PopulationGenerationRule::ResultKind::kValid; } @@ -490,81 +492,8 @@ PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNod continue; } - int target_stage_id = GetSingleConsumerId(policy->search_task, *state, stage_id); - if (target_stage_id < 0) { - continue; - } - const Stage& target_stage = (*state)->stages[target_stage_id]; - - std::vector> candidates; - bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter; - bool target_is_tiled = IsTiled(target_stage); - - bool visited_reduce = false; - // enumerate compute_at location at target_stage - // TODO(merrymercy): More analysis here to make smarter choices - for (size_t i = 0; i < target_stage->iters.size(); ++i) { - const Iterator& target_iter = target_stage->iters[i]; - if (target_iter->iter_kind == IteratorKind::kReduction) { - visited_reduce = true; - if (!target_is_tiled) { // Do not go into reduce iter - break; - } - } else if (target_iter->iter_kind == IteratorKind::kSpatial) { - if (visited_reduce) { // Do not go into inner tile - break; - } - } - - if (target_iter->annotation == IteratorAnnotation::kUnroll) { - // Do not go into the unroll region of const tensor indices - break; - } - - if (GetExtent(target_iter) == 1) { - // Skip iterators with length of 1 - continue; - } - if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial && - StrEndsWith(target_iter->name, ".0")) { - // Skip the first level iterators if target stage compute_at another stage - // In this case, the lengths of first level iterators are always one - continue; - } - candidates.emplace_back(target_stage_id, i); - - if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) { - break; - } - } - - // if the target_stage is already compute_at another stage X, try also compute_at X - // We call stage X as `target_target_stage` - if (target_compute_at_other) { - int target_target_stage_id; - target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first; - const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; - - for (size_t i = 0; i < target_target_stage->iters.size(); ++i) { - const Iterator& target_target_iter = target_target_stage->iters[i]; - if (target_target_iter->iter_kind == IteratorKind::kReduction || - (*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_target_stage_id, i))) { - break; - } - - if (target_target_iter->annotation == IteratorAnnotation::kUnroll) { - // Do not go into the unroll region of const tensor indices - break; - } - - if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 - continue; - } - - candidates.emplace_back(target_target_stage_id, i); - } - } + std::vector> candidates = + GetComputeLocationCandidates(policy->search_task, *state, stage_id); int choice = (policy->rand_gen)() % (candidates.size() + 2); @@ -585,17 +514,10 @@ PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNod } } - if (infer_bound) { - *state = policy->search_task->compute_dag.InferBound(*state); - } + *state = policy->search_task->compute_dag.InferBound(*state); return PopulationGenerationRule::ResultKind::kValid; } -PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, - State* state) const { - return MutateComputeLocationCommon(policy, state, true); -} - PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state) const { std::function @@ -663,9 +585,8 @@ PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* polic PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const { - std::vector auto_unroll_configs = IsGPUTask(policy->search_task) - ? std::vector({0, 16, 64, 512, 1024}) - : std::vector({0, 16, 64, 512}); + std::vector& auto_unroll_configs = + IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; // Skip the inlined stage and placeholder stage @@ -801,6 +722,10 @@ PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* pol // Deal with the cross-thread reduction generated by RuleCrossThreadReduction if (HasCrossThreadReduction(*state, stage_id)) { + if (stage->compute_at != ComputeAtKind::kRoot) { + continue; + } + Iterator fused_it; *state = std::move(FuseAllOuterSpaceIterators(*state, stage_id, &fused_it)); state->bind(stage_id, fused_it, IteratorAnnotation::kBlockX); @@ -983,6 +908,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol continue; } + // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx] size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; const std::vector& factors = policy->split_memo.GetFactors(length); CHECK_GE(factors.size(), 1); @@ -1017,6 +943,8 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol } } + CHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor); + StateNode* pstate = state->CopyOnWrite(); pstate->transform_steps.Set( step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent, @@ -1027,39 +955,98 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol return ResultKind::kInvalid; } -PopulationGenerationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, + State* state) const { // Extract all auto_unroll_max_step pragma steps. - std::vector annotate_steps; + std::vector pragma_steps; for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { if (auto ps = (*state)->transform_steps[i].as()) { if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) { - annotate_steps.push_back(i); + pragma_steps.push_back(i); } } } - if (annotate_steps.empty()) { + if (pragma_steps.empty()) { return ResultKind::kInvalid; } - // Random pick up one unroll factor candidate. - auto cands = (IsGPUTask(policy->search_task)) ? &gpu_unroll_cands_ : &cpu_unroll_cands_; - auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % cands->size()]); + std::vector& auto_unroll_configs = + IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; - // Random pick up and mutate an unroll step. - auto step_id = annotate_steps[(policy->rand_gen)() % annotate_steps.size()]; + // Randomly pick up an auto unroll pragma step + auto step_id = pragma_steps[(policy->rand_gen)() % pragma_steps.size()]; auto ps = (*state)->transform_steps[step_id].as(); CHECK(ps); + + // Mutate its value to a random candidates + auto val = std::to_string(auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()]); StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps.Set(step_id, - PragmaStep(ps->stage_id, ps->iter_id, - std::string("auto_unroll_max_step") + "$" + new_factor)); + pstate->transform_steps.Set(step_id, PragmaStep(ps->stage_id, ps->iter_id, + std::string("auto_unroll_max_step") + "$" + val)); return ResultKind::kValid; } PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, State* state) const { - return MutateComputeLocationCommon(policy, state, false); + if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { + return PopulationGenerationRule::ResultKind::kInvalid; + } + + // Extract all compute_at steps. + std::vector compute_at_steps; + for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { + if (auto ps = (*state)->transform_steps[s].as()) { + int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id; + + if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) { + continue; + } + + if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + stage_inc)) { + continue; + } + compute_at_steps.push_back(s); + } + } + if (compute_at_steps.empty()) { + return PopulationGenerationRule::ResultKind::kInvalid; + } + + // Randomly pick one step + size_t step_id = compute_at_steps[(policy->rand_gen)() % compute_at_steps.size()]; + auto ps = (*state)->transform_steps[step_id].as(); + int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id; + CHECK(ps != nullptr); + + std::vector> candidates = + GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + stage_inc); + + if (candidates.empty()) { + return PopulationGenerationRule::ResultKind::kInvalid; + } + + int choice = (policy->rand_gen)() % (candidates.size()); + int new_compute_at_stage_id = candidates[choice].first; + int new_compute_at_iter_id = candidates[choice].second; + + // Replay a new state. + State tmp_s = policy->search_task->compute_dag->init_state; + for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { + if (s == step_id) { + tmp_s.CopyOnWrite()->transform_steps.push_back( + ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, new_compute_at_iter_id)); + } else { + tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]); + } + try { + StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag); + } catch (dmlc::Error& e) { + return PopulationGenerationRule::ResultKind::kInvalid; + } + } + + *state = tmp_s; + return PopulationGenerationRule::ResultKind::kValid; } PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 418fbda6a030..4098df23a604 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -124,7 +124,7 @@ DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); /********** Init Population **********/ -/*! \brief The base class for derivation rules used in the initial population. */ +/*! \brief The base class for rules used to annotate the sketches to get the initial population. */ class PopulationGenerationRule { public: /*! \brief Result enumeration of the apply function. */ @@ -138,8 +138,12 @@ class PopulationGenerationRule { * \return The result of this rule, indicate if there's any valid state generated. */ virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0; + + /*! \brief The deconstructor */ + virtual ~PopulationGenerationRule() = default; }; +// A helper to define population initialization rules #define DEFINE_INIT_POPULATION_RULE(rule_name) \ class rule_name : public PopulationGenerationRule { \ public: \ @@ -149,7 +153,7 @@ class PopulationGenerationRule { /*! \brief The rule that fills the incomplete SplitSteps. */ DEFINE_INIT_POPULATION_RULE(InitFillTileSize); -/*! \brief The rule that randomly changes the computation location for some stages, which do not +/*! \brief The rule that randomly changes the computation location for some stages that do not * need tiling and are not strictly inlineable(e.g. data padding). */ DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation); @@ -170,50 +174,37 @@ DEFINE_INIT_POPULATION_RULE(InitThreadBind); /*! \brief The base class for mutation rules used in the evolutionary search. */ class PopulationMutationRule : public PopulationGenerationRule { public: - /*! - * \brief Get the priority level of this mutation rule. - * \return The priority level of this mutation rule. Higher the better. + /* \brief The constructor + * \param selection_weight the probabiliy of applying this rule is + * proportional to this weight */ - virtual int GetLevel(const SearchTask& task) const = 0; + explicit PopulationMutationRule(double selection_weight) : weight(selection_weight) {} + + /* \brief The weight of this rule */ + double weight; }; -// A helper to define mutation rules with a constant rule level. -#define DEFINE_MUTATE_POPULATION_RULE(rule_name, rule_level) \ - class rule_name : public PopulationMutationRule { \ - public: \ - ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ - int GetLevel(const SearchTask& task) const final { return rule_level; } \ +// A helper to define mutation rules used in the evolutionary search +#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \ + class rule_name : public PopulationMutationRule { \ + public: \ + explicit rule_name(double weight) : PopulationMutationRule(weight) {} \ + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ }; /*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor and multipling it to another tile size. */ -DEFINE_MUTATE_POPULATION_RULE(MutateTileSize, 100); - -/*! \brief The rule that mutates the fusion iterators annotated by parallel. */ -DEFINE_MUTATE_POPULATION_RULE(MutateParallel, 50); +DEFINE_MUTATE_POPULATION_RULE(MutateTileSize); -/*! \brief The rule that mutates the factor of a randomly selected auto max unroll step. */ -class MutateMaxUnrollFactor : public PopulationMutationRule { - public: - ResultKind Apply(SketchPolicyNode* policy, State* state) const final; - int GetLevel(const SearchTask& task) const final { return 10; } - - const std::vector cpu_unroll_cands_ = {0, 16, 64, 512, 1024}; - const std::vector gpu_unroll_cands_ = {0, 16, 64, 512}; -}; +/*! \brief The rule that mutates the number of fused outer iterators annotated by parallel. */ +DEFINE_MUTATE_POPULATION_RULE(MutateParallel); -/*! \brief The rule that randomly changes the computation location for some stages, which do not +/*! \brief The rule that randomly changes the computation location for some stages that do not * need tiling and are not strictly inlineable(e.g. data padding). */ -class MutateComputeLocation : public PopulationMutationRule { - public: - ResultKind Apply(SketchPolicyNode* policy, State* state) const final; - int GetLevel(const SearchTask& task) const final { - if (IsGPUTask(task)) { - return 0; - } - return 5; - } -}; +DEFINE_MUTATE_POPULATION_RULE(MutateComputeLocation); + +/*! \brief The rule that mutates the value of a randomly selected auto unroll pragma step. */ +DEFINE_MUTATE_POPULATION_RULE(MutateAutoUnroll); } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index 62ffce4dc875..744573a44124 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -67,6 +67,87 @@ Array GetSpatialSplitStepIds(const State& s, int stage_id) { return spatial_split_step_ids; } +std::vector> GetComputeLocationCandidates(const SearchTask& task, + const State& state, int stage_id) { + int target_stage_id = GetSingleConsumerId(task, state, stage_id); + if (target_stage_id < 0) { + return {}; + } + const Stage& target_stage = state->stages[target_stage_id]; + + std::vector> candidates; + bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter; + bool target_is_tiled = IsTiled(target_stage); + + bool visited_reduce = false; + // Enumerate compute_at location at target_stage + // TODO(merrymercy): More analysis here to make smarter choices + for (size_t i = 0; i < target_stage->iters.size(); ++i) { + const Iterator& target_iter = target_stage->iters[i]; + if (target_iter->iter_kind == IteratorKind::kReduction) { + visited_reduce = true; + if (!target_is_tiled) { // Do not go into reduce iter + break; + } + } else if (target_iter->iter_kind == IteratorKind::kSpatial) { + if (visited_reduce) { // Do not go into inner tile + break; + } + } + + if (target_iter->annotation == IteratorAnnotation::kUnroll) { + // Do not go into the unroll region of const tensor indices + break; + } + + if (GetExtent(target_iter) == 1) { + // Skip iterators with length of 1 + continue; + } + if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial && + StrEndsWith(target_iter->name, ".0")) { + // Skip the first level iterators if target stage compute_at another stage + // In this case, the lengths of first level iterators are always one + continue; + } + candidates.emplace_back(target_stage_id, i); + + if (state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) { + break; + } + } + + // if the target_stage is already compute_at another stage X, try also compute_at X + // We call stage X as `target_target_stage` + if (target_compute_at_other) { + int target_target_stage_id; + target_target_stage_id = state->attach_map->stage_to_attach_iter.at(target_stage_id).first; + const Stage& target_target_stage = state->stages[target_target_stage_id]; + + for (size_t i = 0; i < target_target_stage->iters.size(); ++i) { + const Iterator& target_target_iter = target_target_stage->iters[i]; + if (target_target_iter->iter_kind == IteratorKind::kReduction || + state->attach_map->iter_to_attached_stages.count( + std::make_pair(target_target_stage_id, i))) { + break; + } + + if (target_target_iter->annotation == IteratorAnnotation::kUnroll) { + // Do not go into the unroll region of const tensor indices + break; + } + + if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 + continue; + } + + candidates.emplace_back(target_target_stage_id, i); + } + } + + return candidates; +} + State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector* spatial_split_step_ids) { // Temporal object to be used if the input pointer is nullptr @@ -327,7 +408,7 @@ void PruneInvalidState(const SearchTask& task, Array* states) { } if (pt == 0) { - LOG(INFO) << "All states are invalid."; + LOG(FATAL) << "Internal error: All states are invalid."; } else { states->resize(pt); } diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index d2ba1289a5b5..75bf0d048c11 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -372,7 +372,8 @@ inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const St *target_stage_id = *consumers.begin(); if (ElementwiseMatch(task, state, stage_id, *target_stage_id) && (!(HasReduceIter(state->stages[stage_id]) && - HasReduceIter(state->stages[*target_stage_id])))) { + HasReduceIter(state->stages[*target_stage_id]))) && + (!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared"))) { return true; } } @@ -535,6 +536,22 @@ inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { return stage->iters[0]; } +/*! \brief Get the target stage id of a history step in the new state. + * We need this because the stage_id in the history may be stale due to later steps */ +inline int GetTargetStageIDInState(const State& s, int step_id) { + int stage_inc = 0; + + for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) { + if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc) + stage_inc++; + } + } + return s->transform_steps[step_id]->stage_id + stage_inc; +} + /*! \brief Get all split steps for one stage. */ inline void GetSplitStepIds(const State& s, int stage_id, std::vector* split_step_ids) { for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { @@ -615,6 +632,32 @@ inline Array RandomSampleStates(const Array& in_states, std::mt199 return out_states; } +/*! \brief Compute prefix-sum probabiilty based on the given weights */ +inline void ComputePrefixSumProb(const std::vector& weights, + std::vector* prefix_sum_probs) { + // Compute selection probabilities. + float sum = 0.0; + prefix_sum_probs->resize(weights.size()); + for (size_t i = 0; i < weights.size(); ++i) { + sum += std::max(weights[i], 0.0f); + (*prefix_sum_probs)[i] = sum; + } + for (size_t i = 0; i < weights.size(); ++i) { + (*prefix_sum_probs)[i] /= sum; + } +} + +/*! \brief Random choose an index according to a prefix sum probability. */ +inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { + std::uniform_real_distribution<> dis(0.0, 1.0); + double x = dis(*random_gen); + + CHECK(!prefix_sum_probs.empty()); + + return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - + prefix_sum_probs.begin(); +} + /*! \brief Print a title */ inline void PrintTitle(const std::string& title, int verbose) { StdCout(verbose) << Chars('-', 60) << "\n" @@ -648,6 +691,10 @@ class SplitFactorizationMemo { /*! \brief Get the indexes of SplitStep that processes on spatial iterator. */ Array GetSpatialSplitStepIds(const State& s, int stage_id); +/*! \brief Get the possible compute locations for a stage. */ +std::vector> GetComputeLocationCandidates(const SearchTask& task, + const State& state, int stage_id); + // Apply multi-level tiling structure according to a string format, // where "S" stands a space level, "R" stands for a reduction level. // For example, if the format is "SSRSRS", then we will @@ -662,17 +709,6 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, int n_split); -// Random choose an index according to a prefix sum probability. -inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { - std::uniform_real_distribution<> dis(0.0, 1.0); - double x = dis(*random_gen); - - CHECK(!prefix_sum_probs.empty()); - - return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - - prefix_sum_probs.begin(); -} - // Prune invalid states and return the results in-place. void PruneInvalidState(const SearchTask& task, Array* states); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index cec83bb93515..2a9349739752 100755 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -780,7 +780,9 @@ Array ApplySplitToState(State* state, int stage_id, int iter_id, res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); tosplit_min = NullOpt; tosplit_extent = NullOpt; - concrete = false; + if (!l.defined()) { + concrete = false; + } } outs.push_back(std::move(res)); } diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 33e498ecdfd0..eaf328c6303a 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -40,6 +40,19 @@ def matmul_auto_scheduler_test(N, M, K): return [A, B, C] +@auto_scheduler.register_workload +def double_matmul_auto_scheduler_test(N): + A = te.placeholder((N, N), name="A", dtype="float32") + B = te.placeholder((N, N), name="B", dtype="float32") + C = te.placeholder((N, N), name="C", dtype="float32") + k = te.reduce_axis((0, N), name="k") + D = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="D") + k = te.reduce_axis((0, N), name="k") + E = te.compute((N, N), lambda i, j: te.sum(D[i][k] * C[k][j], axis=[k]), name="E") + + return [A, B, C, E] + + # Test for register_workload with different name @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1") def matmul_auto_scheduler_test_rename_0(N, M, K): diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py index eb706b7e6976..bf6efd0bf11d 100644 --- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py +++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py @@ -47,7 +47,7 @@ def test_evo_search(): workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4)) dag = auto_scheduler.ComputeDAG(workload_key) task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm")) - policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=MockCostModel(), verbose=0) + policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0) states = policy.sample_initial_population(50) pruned_states = [] for state in states: diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 6ec96a6f544a..04b54b2858cf 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -57,7 +57,7 @@ def search_common( search_policy = auto_scheduler.EmptyPolicy(task) elif search_policy == "sketch": search_policy = auto_scheduler.SketchPolicy( - task, schedule_cost_model=cost_model, init_search_callbacks=init_search_callbacks + task, program_cost_model=cost_model, init_search_callbacks=init_search_callbacks ) tuning_options = auto_scheduler.TuningOptions( diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index fa67756833bf..5a687daf686a 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -25,6 +25,7 @@ from test_auto_scheduler_common import ( matmul_auto_scheduler_test, + double_matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test, max_pool2d_auto_scheduler_test, min_nm_auto_scheduler_test, @@ -73,9 +74,9 @@ def assert_has_cross_thread_reduction(state, stage_id): def test_cpu_matmul_sketch(): sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "llvm") """ 3 multi-level tiling sketches - 0 - Multi-level tiling - 1 - Multi-level tiling with cache write on position 0 - 2 - Multi-level tiling with cache write on position 1 + No.0 : Multi-level tiling + No.1 : Multi-level tiling with cache write on position 0 + No.2 : Multi-level tiling with cache write on position 1 """ assert len(sketches) == 3 # Sketch 0 @@ -92,11 +93,11 @@ def test_cpu_matmul_sketch(): sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 512), "llvm") """ 2 rfactor sketches + 3 multi-level tiling sketches - 0 - Rfactor with factor position 0 - 1 - Rfactor with factor position 1 - 2 - Multi-level tiling - 3 - Multi-level tiling with cache write on position 0 - 4 - Multi-level tiling with cache write on position 1 + No.0 : Rfactor with factor position 0 + No.1 : Rfactor with factor position 1 + No.2 : Multi-level tiling + No.3 : Multi-level tiling with cache write on position 0 + No.4 : Multi-level tiling with cache write on position 1 """ assert len(sketches) == 5 # Sketch 0 @@ -116,15 +117,20 @@ def test_cpu_matmul_sketch(): assert_compute_at_condition(sketches[4].stages[2], "iter") assert sketches[3] != sketches[4] + sketches = generate_sketches(double_matmul_auto_scheduler_test, (512,), "llvm") + """ 3 multi-level tiling sketches for one matmul, so 3 * 3 = 9 sketches in total """ + assert len(sketches) == 9 + assert_is_tiled(sketches[8].stages[5]) + def test_cpu_conv2d_bn_relu_sketch(): sketches = generate_sketches( conv2d_nchw_bn_relu_auto_scheduler_test, (1, 56, 56, 512, 512, 3, 1, 1), "llvm" ) """ 3 multi-level tiling sketches - 0 - Conv2d multi-level tiling with fusion on position 0 - 1 - Conv2d multi-level tiling with fusion on position 1 - 2 - Conv2d multi-level tiling without fusion + No.0 : Conv2d multi-level tiling with fusion on position 0 + No.1 : Conv2d multi-level tiling with fusion on position 1 + No.2 : Conv2d multi-level tiling without fusion """ assert len(sketches) == 3 # Sketch 0 @@ -164,9 +170,9 @@ def test_cpu_max_pool2d_sketch(): def test_cpu_min_sketch(): sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), "llvm") """ 2 rfactor sketches + 1 default sketch - 0 - Rfactor with factor position 0 - 1 - Rfactor with factor position 1 - 2 - Default sketch + No.0 : Rfactor with factor position 0 + No.1 : Rfactor with factor position 1 + No.2 : Default sketch """ assert len(sketches) == 3 # Sketch 0 @@ -209,9 +215,9 @@ def test_cpu_conv2d_winograd_sketch(): conv2d_winograd_nhwc_auto_scheduler_test, (1, 28, 28, 128, 128, 3, 1, 1), "llvm" ) """ 3 multi-level tiling sketches - 0 - Bgemm multi-level tiling - 1 - Bgemm multi-level tiling with cache write on position 0 - 2 - Bgemm multi-level tiling with cache write on position 1 + No.0 : Bgemm multi-level tiling + No.1 : Bgemm multi-level tiling with cache write on position 0 + No.2 : Bgemm multi-level tiling with cache write on position 1 """ assert len(sketches) == 3 # Sketch 0 @@ -277,6 +283,12 @@ def test_cuda_matmul_sketch(): assert_compute_at_condition(sketches[1].stages[4], "iter") assert_is_tiled(sketches[1].stages[5]) + sketches = generate_sketches(double_matmul_auto_scheduler_test, (512,), "cuda") + """ 1 multi-level tiling sketch for one matmul, so 1 x 1 = 1 sketch in total """ + assert len(sketches) == 1 + assert_compute_at_condition(sketches[0].stages[5], "root") + assert_compute_at_condition(sketches[0].stages[6], "iter") + @tvm.testing.requires_cuda def test_cuda_conv2d_bn_relu_sketch(): diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py new file mode 100644 index 000000000000..98e66bbb7baf --- /dev/null +++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _auto-scheduler-conv-gpu: + +Auto-scheduling a convolution layer for GPU +=========================================== +**Author**: `Lianmin Zheng `_, \ + `Chengfan Jia `_ + + +Different from the existing :ref:`autotvm ` which relies on +manual templates to define the search space, the auto-scheduler does not require any templates. +The auto-scheduler is template-free, so users only need to write the computation declaration without +any schedule commands or templates. +The auto-scheduler can automatically generate a large +search space and find a good schedule in the space. + +We use a convolution layer as an example in this tutorial. +""" + +import numpy as np +import tvm +from tvm import te, testing, auto_scheduler, topi +from tvm.topi.testing import conv2d_nchw_python + +###################################################################### +# Define the computation +# ^^^^^^^^^^^^^^^^^^^^^^ +# To begin with, let us define the computation of a convolution layer. +# The function should return the list of input/output tensors. +# From these tensors, the auto-scheduler can get the whole computational graph. + + +@auto_scheduler.register_workload +def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding): + data = te.placeholder((N, CI, H, W), name="data") + kernel = te.placeholder((CO, CI, KH, KW), name="kernel") + bias = te.placeholder((1, CO, 1, 1), name="bias") + conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32") + out = topi.nn.relu(conv + bias) + return [data, kernel, bias, out] + + +###################################################################### +# Create the search task +# ^^^^^^^^^^^^^^^^^^^^^^ +# We then create a search task for the last convolution layer in the resnet. + +target = tvm.target.Target("cuda") + +# the last layer in resnet +N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) +task = auto_scheduler.create_task(conv2d_layer, (N, H, W, CO, CI, KH, KW, strides, padding), target) + +# Inspect the computational graph +print(task.compute_dag) + +###################################################################### +# Next, we set parameters for the auto-scheduler. These parameters +# mainly specify how we do the measurement during the search and auto-tuning. +# +# * `measure_ctx` launches a different process for measurement. This +# provides an isolation. It can protect the master process from GPU crashes +# happended during measurement and avoid other runtime conflicts. +# * `min_repeat_ms` defines the minimum duration of one "repeat" in every measurement. +# This can warmup the GPU, which is necessary to get accurate measurement results. +# Typically, we recommend a value > 300 ms. +# * `num_measure_trials` is the number of measurement trials we can use during the search. +# We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a +# good value for the search to converge. You can do more trials according to your time budget. +# * In addition, we use `RecordToFile` to dump measurement records into a file `conv2d.json`. +# The measurement records can be used to query the history best, resume the search, +# and do more analyses later. +# * see :any:`auto_scheduler.auto_schedule.TuningOptions`:, +# :any:`auto_scheduler.measure.LocalRPCMeasureContext` for more parameters. + +measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) +tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, + runner=measure_ctx.runner, + measure_callbacks=[auto_scheduler.RecordToFile("conv2d.json")], +) + +###################################################################### +# Run the search +# ^^^^^^^^^^^^^^ +# Now we get all inputs ready. Pretty simple, isn't it? +# We can kick off the search and let the auto-scheduler do its magic. +# After some measurement trials, it will return the best schedule it found. + +sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option) + +###################################################################### +# We can lower the schedule to see the IR after auto-scheduling. +# The auto-scheduler correctly performs optimizations including multi-level tiling, +# cooperative fetching, unrolling and operator fusion. + +print(tvm.lower(sch, args, simple_mode=True)) + +###################################################################### +# Check correctness and evaluate performance +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We build the binary and check its correctness and performance. + +func = tvm.build(sch, args, target) + +# check correctness +data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32) +weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32) +bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32) +conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding) +out_np = np.maximum(conv_np + bias_np, 0.0) + +ctx = tvm.gpu() +data_tvm = tvm.nd.array(data_np, ctx=ctx) +weight_tvm = tvm.nd.array(weight_np, ctx=ctx) +bias_tvm = tvm.nd.array(bias_np, ctx=ctx) +out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx) +func(data_tvm, weight_tvm, bias_tvm, out_tvm) + +# Check results +tvm.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3) + +# Evaluate execution time +evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) +print( + "Execution time of this operator: %.3f ms" + % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000) +) + +###################################################################### +# Using the record file +# ^^^^^^^^^^^^^^^^^^^^^ +# During the search, all measuremnt records are dumpped into the record +# file "conv2d.json". The measurement records can be used to re-apply search results, +# resume the search, and perform other analyses. + +###################################################################### +# Here is an example where we load the best schedule from a file, +# print the equivalent python schedule API, and build the binary again. + +# Load the measuremnt record for the best schedule +inp, res = auto_scheduler.load_best("conv2d.json", task.workload_key) + +# Print equivalent python schedule API. This can be used for debugging and +# learning the behavior of the auto-scheduler. +print("Equivalent python schedule:") +print(task.compute_dag.print_python_code_from_state(inp.state)) + +# Rebuild the binary. This shows how you can apply the best schedule from a +# log file without reruning the search again. +sch, args = task.compute_dag.apply_steps_from_state(inp.state) +func = tvm.build(sch, args, target) + +###################################################################### +# A more complicated example is to resume the search. +# In this case, we need to create the search policy and cost model by ourselves +# and resume the status of search policy and cost model with the log file. +# In the example below we resume the status and do more 5 trials. + + +log_file = "conv2d.json" +cost_model = auto_scheduler.XGBModel() +cost_model.update_from_file(log_file) +search_policy = auto_scheduler.SketchPolicy( + task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)] +) +tune_option = auto_scheduler.TuningOptions( + num_measure_trials=5, + runner=measure_ctx.runner, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], +) +sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option) + +# kill the measurement process +del measure_ctx diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py b/tutorials/auto_scheduler/tune_matmul_x86.py index 1a9af42510eb..918030d21e54 100644 --- a/tutorials/auto_scheduler/tune_matmul_x86.py +++ b/tutorials/auto_scheduler/tune_matmul_x86.py @@ -37,7 +37,7 @@ ###################################################################### # Define the computation # ^^^^^^^^^^^^^^^^^^^^^^ -# To begin with, we define the computation of a matmul with bias add. +# To begin with, let us define the computation of a matmul with bias add. # The function should return the list of input/output tensors. # From these tensors, the auto-scheduler can get the whole computational graph. @@ -59,6 +59,9 @@ def matmul_add(N, L, M, dtype): # Create the search task # ^^^^^^^^^^^^^^^^^^^^^^ # We then create a search task with N=L=M=128 and dtype="float32" +# If your machine supports avx instructions, you can +# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 +# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 target = tvm.target.Target("llvm") task = auto_scheduler.create_task(matmul_add, (128, 128, 128, "float32"), target) @@ -75,7 +78,7 @@ def matmul_add(N, L, M, dtype): # * In addition, we use `RecordToFile` to dump measurement records into a file `matmul.json`. # The measurement records can be used to query the history best, resume the search, # and do more analyses later. -# * see :any:`auto_schedule.TuningOptions`: for more parameters +# * see :any:`auto_scheduler.auto_schedule.TuningOptions`: for more parameters tune_option = auto_scheduler.TuningOptions( num_measure_trials=10, measure_callbacks=[auto_scheduler.RecordToFile("matmul.json")] @@ -93,25 +96,38 @@ def matmul_add(N, L, M, dtype): ###################################################################### # We can lower the schedule to see the IR after auto-scheduling. # The auto-scheduler correctly performs optimizations including multi-level tiling, -# parallelization, vectorization, unrolling and fusion. +# parallelization, vectorization, unrolling and operator fusion. print(tvm.lower(sch, args, simple_mode=True)) ###################################################################### -# Check correctness -# ^^^^^^^^^^^^^^^^^ -# We build the binary and check its correctness +# Check correctness and evaluate performance +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We build the binary and check its correctness and performance. func = tvm.build(sch, args) a_np = np.random.uniform(size=(128, 128)).astype(np.float32) b_np = np.random.uniform(size=(128, 128)).astype(np.float32) c_np = np.random.uniform(size=(128, 128)).astype(np.float32) -d_np = a_np.dot(b_np) + c_np - -d_tvm = tvm.nd.empty(d_np.shape) -func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm) +out_np = a_np.dot(b_np) + c_np + +ctx = tvm.cpu() +a_tvm = tvm.nd.array(a_np, ctx=ctx) +b_tvm = tvm.nd.array(b_np, ctx=ctx) +c_tvm = tvm.nd.array(c_np, ctx=ctx) +out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx) +func(a_tvm, b_tvm, c_tvm, out_tvm) + +# Check results +tvm.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3) + +# Evaluate execution time. +evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) +print( + "Execution time of this operator: %.3f ms" + % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000) +) -tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-3) ###################################################################### # Using the record file @@ -129,6 +145,7 @@ def matmul_add(N, L, M, dtype): # Print equivalent python schedule API. This can be used for debugging and # learning the behavior of the auto-scheduler. +print("Equivalent python schedule:") print(task.compute_dag.print_python_code_from_state(inp.state)) # Rebuild the binary. This shows how you can apply the best schedule from a @@ -161,13 +178,16 @@ def resume_search(task, log_file): # .. note:: # We cannot run the line above because of the conflict between # python's multiprocessing and tvm's thread pool. -# After running a tvm generated binary (L112), the python's multiprocessing -# library will hang forever. -# You have to make sure that you don't run any tvm generated binaries before -# calling ansor's search. To run the L156 above, you should comment out L112-114. +# After running a tvm generated binary the python's multiprocessing library +# will hang forever. You have to make sure that you don't run any tvm +# generated binaries before calling auot-scheduler's search. +# To run the function above, you should comment out all code in +# "Check correctness and evaluate performance" section. # # You should be careful about this problem in your applications. # There are other workarounds for this problem. # For example, you can start a new thread/process (with the builtin python library # threading or multiprocessing) and run the tvm binaries in the new thread/process. # This provides an isolation and avoids the conflict in the main thread/process. +# You can also use :any:`auto_scheduler.measure.LocalRPCMeasureContext` for auto-scheduler, +# as shown in the GPU tutorial (:ref:`auto-scheduler-conv-gpu`).