diff --git a/include/tvm/auto_scheduler/feature.h b/include/tvm/auto_scheduler/feature.h new file mode 100644 index 000000000000..504c2b8ca5a0 --- /dev/null +++ b/include/tvm/auto_scheduler/feature.h @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/feature.h + * \brief Feature extraction for the cost model. + * We extract one feature vector per BufferStoreNode statement in a TIR Stmt, + * so we call this feature as "Per Store" feature. + * The cost model also does prediction for each BufferStoreNode statement and aggregates + * the predictions as the whole score for a TVM IR (Stmt). + * + * The feature specification is defined by `src/auto_scheduler/feature.cc:: FeatureSet` + */ + +#ifndef TVM_AUTO_SCHEDULER_FEATURE_H_ +#define TVM_AUTO_SCHEDULER_FEATURE_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace auto_scheduler { + +/*! + * \brief Get per-store feature from a TIR Stmt + * \param stmt The input lowered TIR statement + * \param cache_line_size The size of cache line in bytes + * \param max_n_bufs The maximum number of extracted buffers for one statement + * \param ret The returned feature vector + */ +void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs, + std::vector* ret); + +/* + * \brief Get the names of elements in the feature vector. Use this for debug and inspection. + * \param max_n_bufs The maximum number of extracted buffers for one statement + * \param ret The returned names. + */ +void GetPerStoreFeatureName(int max_n_bufs, std::vector* ret); + +/*! + * \brief Get per-store feature from states of the same task + * \param states The input states + * \param task The same search task for all states + * \param skip_first_n_feature_extraction Skip feature extraction for the first n states + * \param max_n_bufs The maximum number of extracted buffers for one statement + * \param features The returned feature vector. The innermost vector contains the + * feature vectors for all BufferStoreNode statements + */ +void GetPerStoreFeaturesFromStates(const Array& states, const SearchTask& task, + int skip_first_n_feature_extraction, int max_n_bufs, + std::vector >* features); + +/*! + * \brief Get per-store feature from states of different tasks + * \param states The input states + * \param tasks The search tasks corresponding to the input states + * \param skip_first_n_feature_extraction Skip feature extraction for the first n states + * \param max_n_bufs The maximum number of extracted buffers for one statement + * \param features The returned feature vector. The innermost vector contains the + * feature vectors for all BufferStoreNode statements + */ +void GetPerStoreFeaturesFromStates(const Array& states, const std::vector& tasks, + int skip_first_n_feature_extraction, int max_n_bufs, + std::vector >* features); + +/*! + * \brief Get per-store features from a log file + * \param filename The name of log file + * \param max_lines Only read the first n lines of the file + * \param max_n_bufs The maximum number of extracted buffers for one statement + * \param features The returned feature vector. The innermost vector contains the + * feature vectors for all BufferStoreNode statements + * \param normalized_throughputs The normalized throughputs for all states + * \param task_ids The task ids for all states + */ +void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids); + +/*! + * \brief Get per-store features from measurement input/result pairs + * \param inputs The meaurement inputs + * \param results The measurement results + * \param skip_first_n_feature_extraction Skip feature extraction for the first n meaurement pairs + * \param max_n_bufs The maximum number of extracted buffers for one statement + * \param features The returned feature vector. The innermost vector contains the + * feature vectors for all BufferStoreNode statements + * \param normalized_throughputs The normalized throughputs for all states + * \param task_ids The task ids for all states + */ +void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, + const Array& results, + int skip_first_n_feature_extraction, int max_n_bufs, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids); + +} // namespace auto_scheduler +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULER_FEATURE_H_ diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 32ac4f5a3e3a..bf32cec675f4 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -23,6 +23,7 @@ from . import loop_state from . import utils from . import workload_registry +from . import feature # Shortcut from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ diff --git a/python/tvm/auto_scheduler/feature.py b/python/tvm/auto_scheduler/feature.py new file mode 100644 index 000000000000..3ed87f3c2b9a --- /dev/null +++ b/python/tvm/auto_scheduler/feature.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""" +Python API for Feature extraction. The extracted features vector are used by cost models. + +We extract one feature vector per BufferStoreNode statement in a TIR Stmt, +so we call this feature as "Per Store" feature. +The cost model also does prediction for each BufferStoreNode statement and aggregates +the predicted score of each BufferStoreNode as the score of a TIR Stmt. + +The feature specification is defined by `src/auto_scheduler/feature.cc::FeatureSet` +""" + +from typing import List, Tuple, Union, Optional +import struct + +import numpy as np + +from .loop_state import State, StateObject +from .measure import MeasureInput, MeasureResult +from . import _ffi_api + +# The maximum number of extracted buffers for one statement +DEFAULT_MAX_N_BUFS = 5 + +# The length of the feature vector +DEFAULT_FEATURE_VEC_LEN = 164 + +# The size of int and float in bytes +SIZE_OF_INT32 = 4 +SIZE_OF_FLOAT32 = 4 + +def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Unpack the flatten feature (in byte array format) from c++ + + Parameters + ---------- + byte_arr: bytearray + The two-dimensional feature vector in serialized byte array format + + Returns + ------- + features: np.ndarray + Feature vectors + normalized_throughputs: np.ndarray + Normalized throughputs + task_ids: np.ndarray + Task ids + """ + + # The format for n records is: + # { + # int n; + # int[n+2] sizes + + # float[sizes[0]] feature for record 1 + # float[sizes[1]] feature for record 2 + # ... feature for record i... + # float[sizes[n-1]] feature for record n + + # float[sizes[n]] normalized throughput for n records + # int[sizes[n+1]] task id for n records + # } + + vec_len = DEFAULT_FEATURE_VEC_LEN + + # unpack sizes + offset = 0 + n = struct.unpack_from("1i", byte_arr, offset=offset)[0] + offset += SIZE_OF_INT32 + + sizes = struct.unpack_from("%di" % (n+2), byte_arr, offset=offset) + offset += SIZE_OF_INT32 * (n+2) + + # unpack features + features = [] + for size in sizes[:-2]: + row = [] + + # Now, we need to unpack the feature for multiple statements. + # The format is: + # { + # int n_stmts + # float[n_stmt][vec_len] feature_vecs + # } + # where vec_len can be calculated by `(size - 1) / n_stmts` + + if size == 0: + # failed during lowering + features.append(np.zeros((1, vec_len))) + else: + n_stmts = struct.unpack_from("f", byte_arr, offset=offset) + offset += SIZE_OF_FLOAT32 + + n_stmts = int(n_stmts[0] + 0.5) + tmp_vec_len = (size - 1) // n_stmts + assert tmp_vec_len == vec_len, "The lenght of feature vector is wrong. " \ + "Expected %d but got %d." % (vec_len, tmp_vec_len) + assert tmp_vec_len * n_stmts == size - 1 + for _ in range(n_stmts): + x = struct.unpack_from("%df" % vec_len, byte_arr, offset=offset) + offset += vec_len * SIZE_OF_FLOAT32 + row.append(x) + + features.append(np.array(row)) + + # unpack normalized_throughputs + m = sizes[-2] + normalized_throughputs = struct.unpack_from("%df" % m, byte_arr, offset=offset) + offset += m * SIZE_OF_INT32 + + # unpack task_ids + m = sizes[-1] + task_ids = struct.unpack_from("%di" % m, byte_arr, offset=offset) + offset += m * SIZE_OF_INT32 + + assert offset == len(byte_arr), "%d vs %d" % (offset, len(byte_arr)) + return np.array(features, dtype=object), np.array(normalized_throughputs), np.array(task_ids) + + +def get_per_store_features_from_file(filename: str, + max_lines: int, + max_n_bufs: Optional[int] = None) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get per_store features from a log file + + Parameters + ---------- + filename: str + The input filename + max_lines: int + Only extract the first n lines of the file + max_n_bufs: Optional[int] + The maximum number of extracted buffers for one statement + + Returns + ------- + features: np.ndarray + Feature vectors + normalized_throughputs: np.ndarray + Normalized throughputs + task_ids: np.ndarray + Task ids + """ + byte_arr = _ffi_api.GetPerStoreFeaturesFromFile( + filename, max_lines, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr) + + +def get_per_store_features_from_measure_pairs(inputs: List[MeasureInput], + results: List[MeasureResult], + skip_first_n_feature_extraction: int = 0, + max_n_bufs: Optional[int] = None) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get per_store features from measurement input/result pairs + + Parameters + ---------- + inputs: List[MeasureInput] + The measure inputs + results: List[MeasureResult] + The measure results + skip_first_n_feature_extraction: int + Skip feature extraction for the first n states + max_n_bufs: int + The maximum number of extracted buffers for one statement + + Returns + ------- + features: np.ndarray + Feature vectors + normalized_throughputs: np.ndarray + Normalized throughputs + task_ids: np.ndarray + Task ids + """ + byte_arr = _ffi_api.GetPerStoreFeaturesFromMeasurePairs( + inputs, results, skip_first_n_feature_extraction, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr) + + +def get_per_store_features_from_states(states: List[Union[State, StateObject]], + task: "SearchTask", + max_n_bufs: Optional[int] = None) -> List[np.ndarray]: + """Get per_store features from measurement input/result pairs + + Parameters + ---------- + states: List[Union[State, StateObject]] + The input states + task: SearchTask + The search task of the input states + max_n_bufs: Optional[int] + The maximum number of extracted buffers for one statement + + Returns + ------- + features: np.ndarray + Feature vectors + normalized_throughputs: np.ndarray + Normalized throughputs + task_ids: np.ndarray + Task ids + """ + if isinstance(states[0], State): + state_objects = [s.state_object for s in states] + elif isinstance(states[0], StateObject): + state_objects = states + byte_arr = _ffi_api.GetPerStoreFeaturesFromStates( + state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr)[0] + + +def get_per_store_feature_names(max_n_bufs: Optional[int] = None) -> List[str]: + """Get the name of every element in the feature vector. Use this for debug and inspection. + + Parameters + ---------- + max_n_bufs: int + The maximum number of extracted buffers for one statement + + Returns + ------- + names: List[str] + The names of elements in the flatten feature vector + """ + return _ffi_api.GetPerStoreFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 87b162aca6d5..f5f08840de86 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -547,17 +547,25 @@ class FlopEstimator : public ExprFunctor { double ret = 0; for (const auto& op : ops) { if (auto pop = op.as()) { - double num_element = AxisLengthProd(pop->axis); - if (num_element == -1) { - fail_ = true; - break; - } - cur_type_code_ = pop->output_dtype(0).code(); - double op_per_element = 0; - for (const auto& x : pop->body) { - op_per_element += VisitExpr(x); + if (pop->attrs.count("FLOP")) { + // Use user-provided FLOP + auto pint = pop->attrs["FLOP"].as(); + CHECK(pint != nullptr); + ret += pint->value; + } else { + // Estimate by parsing the compute body + double num_element = AxisLengthProd(pop->axis); + if (num_element == -1) { + fail_ = true; + break; + } + cur_type_code_ = pop->output_dtype(0).code(); + double op_per_element = 0; + for (const auto& x : pop->body) { + op_per_element += VisitExpr(x); + } + ret += num_element * op_per_element; } - ret += num_element * op_per_element; } else if (op->IsInstance()) { {} // do nothing } else { diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc new file mode 100644 index 000000000000..2f89750919ba --- /dev/null +++ b/src/auto_scheduler/feature.cc @@ -0,0 +1,1639 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/feature.cc + * \brief Feature extraction for the cost model + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +// import the function from driver_api.cc +void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list); +} // namespace tvm + +namespace tvm { +namespace auto_scheduler { + +using namespace tvm::tir; +using arith::Analyzer; +using arith::ConstIntBound; + +template +using BufferMap = std::unordered_map; + +// The number of samples to extract for arithmetic intensity curves +static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; + +// Annotation position encoding +enum class AnnotationPosType : int { + kPosNone = 0, // Does not have this kind of annotation + kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator + kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator + kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator + kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator + kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator + kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator + kPosMixed = 7 // The annotated iterator is a mixed space and reduce iterator +}; + +// Buffer access type +enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 }; + +// Accesses to a buffer +struct BufferAccess { + // data reuse type + BufferAccessType acc_type{BufferAccessType::kUnknownRW}; + // Use a two-dimentional array to store multiple multi-dimentional accesses. + // The innermost vector stores the multi-dimentional indices of one access. + std::vector> indices; +}; + +// Data reuse type +enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 }; + +// Feature for an access of a buffer +struct BufferAccessFeature { + std::string buffer_name; // The name of the buffer + BufferAccessType acc_type; // The type of the access + float bytes; // The touched memory in bytes + float unique_bytes; // The touched unique memory in bytes + float lines; // The number of touched cache lines + float unique_lines; // The number touched unique cache lines + ReuseType reuse_type; // Tye type of data reuse + float reuse_dis_iter; // The reuse distance in iterator number + float reuse_dis_bytes; // The reuse distance in total touched bytes + float reuse_ct; // The reuse ratio + float bytes_d_reuse_ct; // bytes / reuse_ct + float unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct + float lines_d_reuse_ct; // lines / reuse_ct + float unique_lines_d_reuse_ct; // unique_lines / reuse_ct + float stride; // The stride in access +}; + +// Feature set of a BufferStore statement +struct FeatureSet { + // Group 1: Computation related features + float float_mad; // The number of float MAD (Multiply–add) ops + float float_addsub; // The number of float add and sub ops + float float_mul; // The number of float multiply ops + float float_divmod; // The number of float div and mod ops + float float_cmp; // The number of float comparison ops + float float_math_func; // The number of float math func calls + float float_other_func; // The number of other float func calls + float int_mad; // The number of integer MAD (Multiply–add) ops + float int_addsub; // The number of integer add and sub ops + float int_mul; // The number of float multiply ops + float int_divmod; // The number of float div and mod ops + float int_cmp; // The number of float comparison ops + float int_math_func; // The number of float math func calls + float int_other_func; // The number of other float func calls + float bool_op; // The number of bool ops + float select_op; // The number of select ops + float vec_num; // The number of vectorized iterators + float vec_prod; // The product of the lengths of vectorized iterators + float vec_len; // The length of the innermost vectorized iterator + AnnotationPosType vec_type; // The type of vectorizatoin position + float unroll_num; // The number of unrolled iterators + float unroll_prod; // The product of the lengths of vectorized iterators + float unroll_len; // The length of the innermost unrolled iterator + AnnotationPosType unroll_type; // The type of unroll position + float parallel_num; // The number of paralleled iterators + float parallel_prod; // The product of the lengths of paralleled iterators + float parallel_len; // The length of the innermost paralleled iterators + AnnotationPosType parallel_type; // The type of parallel position + float is_gpu; // Whether it is a GPU task + float blockIdx_x_len; // The length of blockIdx.x + float blockIdx_y_len; // The length of blockIdx.y + float blockIdx_z_len; // The length of blockIdx.z + float threadIdx_x_len; // The length of threadIdx.x + float threadIdx_y_len; // The length of threadIdx.y + float threadIdx_z_len; // The length of threadIdx.z + float vthread_len; // The length of virtual thread + + // Group 2: Buffer access related features (per buffer) + std::vector access_feas; + + // Group 3: Arithmetic intensity related features + float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; // points sampled from the + // arithmetic intensity curve + + // Group 4: Allocation related features + float alloc_size; // The size of allocated buffer in bytes + float alloc_outer_prod; // The product of lenghts of loops outside the scope of the allocation + float alloc_inner_prod; // The product of lenghts of loops inside the score of the allocation + float alloc_prod; // alloc_outer_prod * alloc_inner_prod + + // Group 5: Outer scope related features + float outer_prod; // The product of lenghts of outer loops + float num_loops; // The number of outer loops + float auto_unroll_max_step; // The value of pragma "auto_unroll_max_step" +}; + +// Return whether a var is in an expr +bool VarInExpr(const Var& var, const PrimExpr& expr) { + bool find = false; + + PostOrderVisit(expr, [&find, &var](const ObjectRef& node) { + if (find) { + return; + } + + if (const VarNode* op = node.as()) { + if (op == var.get()) { + find = true; + } + } + }); + + return find; +} + +// Get position encoding for annotation +AnnotationPosType GetAnnotationPosEncoding(const Var& var, const Array& spatial_args, + const Array& axis, + const Array& reduce_axis) { + // Try to match spatial args first + size_t find_i = 0; + size_t find_ct = 0; + for (size_t i = 0; i < spatial_args.size(); ++i) { + if (VarInExpr(var, spatial_args[i])) { + find_i = i; + find_ct += 1; + } + } + + if (find_ct == 0) { + // If it is not found in spacial args, then it is a reduce iterator. + // Use name to match + const std::string& var_name = var->name_hint; + for (size_t i = 0; i < reduce_axis.size(); ++i) { + if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) { + find_i = i; + find_ct++; + } + } + if (find_ct >= 1) { + if (find_i == 0) { + return AnnotationPosType::kPosInnerReduce; + } else if (find_i == reduce_axis.size() - 1) { + return AnnotationPosType::kPosOuterReduce; + } else { + return AnnotationPosType::kPosMiddleReduce; + } + } else { + // If the axis is not found in both spatial args and reduce axis, + // then this stage must compute_at somewhere under this aixs and this axis is simplified out + // We assume it is an outer spatial + return AnnotationPosType::kPosOuterSpatial; + } + } else if (find_ct == 1) { + if (find_i == spatial_args.size() - 1) { + return AnnotationPosType::kPosInnerSpatial; + } else if (find_i == 0) { + return AnnotationPosType::kPosOuterSpatial; + } else { + return AnnotationPosType::kPosMiddleSpatial; + } + } else { + return AnnotationPosType::kPosMixed; + } +} + +// Return the extent of a for loop +int64_t GetLoopExtent(const ForNode* node) { + auto pint = node->extent.as(); + if (pint != nullptr) { + return pint->value; + } else { + return 1; + } +} + +// Count math ops in an expr +class MathOpCounter : public StmtExprVisitor { + public: +#define VisitBinary(Type, float_ct, int_ct) \ + void VisitExpr_(const Type* op) final { \ + if (op->a.dtype().is_float()) { \ + float_ct++; \ + } else { \ + int_ct++; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ + } + + VisitBinary(AddNode, float_addsub, int_addsub); + VisitBinary(SubNode, float_addsub, int_addsub); + VisitBinary(MulNode, float_mul, int_mul); + VisitBinary(DivNode, float_divmod, int_divmod); + VisitBinary(ModNode, float_divmod, int_divmod); + VisitBinary(FloorDivNode, float_divmod, int_divmod); + VisitBinary(FloorModNode, float_divmod, int_divmod); + VisitBinary(MaxNode, float_cmp, int_cmp); + VisitBinary(MinNode, float_cmp, int_cmp); + VisitBinary(EQNode, float_cmp, int_cmp); + VisitBinary(NENode, float_cmp, int_cmp); + VisitBinary(LTNode, float_cmp, int_cmp); + VisitBinary(LENode, float_cmp, int_cmp); + VisitBinary(GTNode, float_cmp, int_cmp); + VisitBinary(GENode, float_cmp, int_cmp); + +#undef VisitBinary + + void VisitExpr_(const AndNode* op) final { + bool_op++; + StmtExprVisitor::VisitExpr_(op); + } + void VisitExpr_(const OrNode* op) final { + bool_op++; + StmtExprVisitor::VisitExpr_(op); + } + void VisitExpr_(const NotNode* op) final { + bool_op++; + StmtExprVisitor::VisitExpr_(op); + } + void VisitExpr_(const SelectNode* op) final { + select_op++; + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const CallNode* op) final { + auto* pop = op->op.as(); + CHECK(pop != nullptr); + auto effect_kind = op_call_effect_[GetRef(pop)]; + bool is_pure = + effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; + + if (is_pure) { + if (op->dtype.is_float()) { + float_math_func++; + } else { + int_math_func++; + } + } else { + if (op->dtype.is_float()) { + float_other_func++; + } else { + int_other_func++; + } + } + StmtExprVisitor::VisitExpr_(op); + } + + // todo(merrymercy): Detect MAD (Multiply–add) + size_t float_mad{0}; // The number of float MAD (Multiply–add) ops + size_t float_addsub{0}; // The number of float add and sub ops + size_t float_mul{0}; // The number of float multiply ops + size_t float_divmod{0}; // The number of float div and mod ops + size_t float_cmp{0}; // The number of float comparison ops + size_t float_math_func{0}; // The number of float math func calls + size_t float_other_func{0}; // The number of other float func calls + size_t int_mad{0}; // The number of integer MAD (Multiply–add) ops + size_t int_addsub{0}; // The number of integer add and sub ops + size_t int_mul{0}; // The number of float multiply ops + size_t int_divmod{0}; // The number of float div and mod ops + size_t int_cmp{0}; // The number of float comparison ops + size_t int_math_func{0}; // The number of float math func calls + size_t int_other_func{0}; // The number of other float func calls + size_t bool_op{0}; // The number of bool ops + size_t select_op{0}; // The number of select ops + + OpAttrMap op_call_effect_ = Op::GetAttrMap("TCallEffectKind"); +}; + +// Extract all buffer accesses in an expr +class BufferAccessExtractor : public StmtExprVisitor { + public: + void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); } + + void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array& indices) { + BufferAccess& acc = buf_accesses[buf]; + acc.acc_type = acc_type; + acc.indices.push_back(std::vector(indices.begin(), indices.end())); + } + + void VisitExpr_(const BufferLoadNode* op) final { + BufferAccess& acc = buf_accesses[op->buffer]; + switch (acc.acc_type) { + case BufferAccessType::kRead: + break; + case BufferAccessType::kWrite: + acc.acc_type = BufferAccessType::kReadWrite; + break; + case BufferAccessType::kReadWrite: + break; + case BufferAccessType::kUnknownRW: + default: + acc.acc_type = BufferAccessType::kRead; + break; + } + + if (acc.acc_type != BufferAccessType::kReadWrite) { + // If a buffer is both read and written, in the tvm DSL, it must be a update, + // so the indices should be the same. Then we can skip appending indices for it. + // Otherwise we do the following. + buf_accesses[op->buffer].indices.push_back( + std::vector(op->indices.begin(), op->indices.end())); + } + StmtExprVisitor::VisitExpr_(op); + } + + BufferMap buf_accesses; +}; + +// Compute the coefficient for an loop iterator in an expression +// Note: we use an approximation strategy to find coefficient. +// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) +class CoefficientExtractor : public StmtExprVisitor { + public: + void VisitExpr_(const MulNode* node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_add) { + if (auto a = node->a.as()) { + visited_mul = true; + stride = a->value; + } else if (auto b = node->b.as()) { + visited_mul = true; + stride = b->value; + } + } + } + } + + void VisitExpr_(const AddNode* node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_mul) { + visited_add = true; + stride = 1; + } + } + } + + void VisitExpr_(const VarNode* node) final { + if (node == var_) { + visited_var = true; + // This is a magic default stride in case our approximation strategy fails + stride = 2; + } + } + + int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) { + visited_var = visited_mul = visited_add = false; + var_ = var; + + this->VisitExpr(expr); + + if (visited_var && !visited_mul && !visited_add) { + return 1; + } else { + return stride; + } + } + + bool visited_var{false}; + bool visited_mul{false}; + bool visited_add{false}; + int stride{0}; + + private: + const VarNode* var_{nullptr}; +}; + +// Compute stride for the accesses to a buffer +int64_t ComputeStride(const std::vector>& indices, + const std::vector& shape, const VarNode* stride_var) { + int64_t min_stride = std::numeric_limits::max(); + bool find = false; + CoefficientExtractor extractor; + + for (const auto& index : indices) { + int64_t shape_stride = 1; + for (int i = static_cast(index.size()) - 1; i >= 0; i--) { + int coefficient = extractor.ExtractCoefficient(index[i], stride_var); + if (extractor.visited_var) { + find = true; + min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride); + break; + } + shape_stride *= shape[i]; + } + } + + return find ? min_stride : 0; +} + +// Compute touched bytes and cache lines for accesses to a buffer +void ComputeRegion(const std::vector>& indices, arith::Analyzer* ana, + std::vector* region) { + region->clear(); + + if (indices.empty()) { + return; + } + + region->reserve(indices[0].size()); + + if (indices.size() == 1) { + for (const auto& index : indices[0]) { + ConstIntBound bound = ana->const_int_bound(index); + region->push_back(bound->max_value - bound->min_value + 1); + } + } else { + // future(lmzheng): implement a more accurate IntSet? + for (size_t i = 0; i < indices[0].size(); ++i) { + int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf; + for (size_t j = 0; j < indices.size(); ++j) { + ConstIntBound bound = ana->const_int_bound(indices[j][i]); + + minimum = std::min(minimum, bound->min_value); + maximum = std::max(maximum, bound->max_value); + } + region->push_back(maximum - minimum + 1); + } + } +} + +// Compute reuse distance and reuse ratio for accesses to a buffer +// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct +std::tuple ComputeReuse( + const Buffer& buf, const std::vector>& indices, + const std::vector& for_loop_stack, + const std::unordered_map>>>& + for_touch_regions) { + float reuse_dis_iter = 1.0f; + float reuse_dis_bytes = -1.0f; + + for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; --i) { + const ForNode* cur_for = for_loop_stack[i]; + bool find = false; + + for (size_t j = 0; j < indices.size(); j++) { + for (size_t k = 0; k < indices[j].size(); k++) { + if (VarInExpr(cur_for->loop_var, indices[j][k])) { + find = true; + break; + } + } + if (find) { + break; + } + } + + int64_t extent = GetLoopExtent(for_loop_stack[i]); + if (find) { + // accumulate/update reuse distance + reuse_dis_iter *= extent; + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); + } + } + } else { + // Have LoopMultipleRead reuse + if (reuse_dis_bytes < 0) { + // For the reuse in the innermost axis, the above code won't be executed. + // So we compute bytes here + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += 1 * std::get<2>(access); + } + } + } + return std::make_tuple(ReuseType::kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); + } + + const BufferMap>>& buffer_map = + for_touch_regions.at(cur_for); + + int serial_reuse = static_cast(buffer_map.at(buf).size()) - 1; + if (serial_reuse > 0) { + int64_t extent = GetLoopExtent(cur_for); + + // Have SerialMultipleReadWrite reuse + reuse_dis_iter = std::numeric_limits::max(); + for (const auto& acc_info : buffer_map.at(buf)) { + reuse_dis_iter = std::min(reuse_dis_iter, static_cast(std::get<1>(acc_info))); + } + + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); + } + } + + return std::make_tuple(ReuseType::kSerialMultipleReadWrite, reuse_dis_iter / extent, + reuse_dis_bytes / extent, serial_reuse); + } + } + + return std::make_tuple(ReuseType::kNoReuse, 0, 0, 0); +} + +// Extract features for every BufferStore statement +class PerStoreFeatureExtractor : public StmtExprVisitor { + public: + explicit PerStoreFeatureExtractor(int cache_line_size) : cache_line_size_(cache_line_size) {} + + void VisitStmt_(const AttrStmtNode* node) final { + if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) { + const Var& var = node->node.as()->var; + int extent = GetIntImm(node->value); + + int* plen = nullptr; + + const std::string& name = var.get()->name_hint; + if (node->attr_key == tir::attr::thread_extent) { + if (name == "blockIdx.x") { + plen = &blockIdx_x_len_; + } else if (name == "blockIdx.y") { + plen = &block_idx_y_len_; + } else if (name == "blockIdx.z") { + plen = &block_idx_z_len_; + } else if (name == "threadIdx.x") { + plen = &threadIdx_x_len_; + } else if (name == "threadIdx.y") { + plen = &thread_idx_y_len_; + } else if (name == "threadIdx.z") { + plen = &thread_idx_z_len_; + } else { + LOG(FATAL) << "invalid thread itervar " + name; + } + } else { + plen = &vthread_len_; + } + + int extent_before = *plen; + if (node->attr_key == tir::attr::thread_extent) { + *plen = extent; + } else { + *plen *= extent; + } + + is_gpu_ = true; + + // make a fake for node for blockIdx.x or threadIdx.x + Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, DeviceAPI::None, node->body); + + outer_loop_prod_ *= extent; + for_loop_stack_.push_back(fake_for_node.as()); + StmtExprVisitor::VisitStmt_(node); + for_loop_stack_.pop_back(); + outer_loop_prod_ /= extent; + + *plen = extent_before; + } else if (node->attr_key == "pragma_auto_unroll_max_step") { + int value = GetIntImm(node->value); + + int16_t old_value = cur_auto_unroll_max_step_; + cur_auto_unroll_max_step_ = value; + StmtExprVisitor::VisitStmt_(node); + cur_auto_unroll_max_step_ = old_value; + } else { + StmtExprVisitor::VisitStmt_(node); + } + } + + void VisitStmt_(const ForNode* node) final { + int64_t loop_extent = GetLoopExtent(node); + + if (node->for_type == ForType::Vectorized) { + vec_for_stack_.push_back(node); + } else if (node->for_type == ForType::Unrolled) { + unroll_for_stack_.push_back(node); + } else if (node->for_type == ForType::Parallel) { + parallel_for_stack_.push_back(node); + } + + outer_loop_prod_ *= loop_extent; + for_loop_stack_.push_back(node); + StmtExprVisitor::VisitStmt_(node); + for_loop_stack_.pop_back(); + outer_loop_prod_ /= loop_extent; + + if (node->for_type == ForType::Vectorized) { + vec_for_stack_.pop_back(); + } else if (node->for_type == ForType::Unrolled) { + unroll_for_stack_.pop_back(); + } else if (node->for_type == ForType::Parallel) { + parallel_for_stack_.pop_back(); + } + } + + void VisitStmt_(const BufferStoreNode* node) final { + MathOpCounter math_op_counter; + math_op_counter(node->value); + std::vector mem_bytes_list; + std::vector compute_ops_list; + int cur_compute_ops; + + // Group 1: Computation related features + ExtractComputationFeature(node, math_op_counter); + + // Group 2: Buffer access related features (per buffer) + ExtractBufferAccessFeature(node, math_op_counter, &cur_compute_ops, &compute_ops_list, + &mem_bytes_list); + + // Group 3: Arithmetic intensity related features + ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list, mem_bytes_list); + + // Group 4: Allocation related features + ExtractOuterScopeFeature(node); + } + + void VisitStmt_(const BufferRealizeNode* node) final { + StmtExprVisitor::VisitStmt_(node); + + // Group 5: Outer scope related features + ExtractAllocationFeature(node); + } + + // Extract computation related features (group 1) + void ExtractComputationFeature(const BufferStoreNode* node, + const MathOpCounter& math_op_counter) { + FeatureSet& fea = buffer_features[node->buffer]; + + // Computation related features + fea.float_mad = outer_loop_prod_ * math_op_counter.float_mad; + fea.float_addsub = outer_loop_prod_ * math_op_counter.float_addsub; + fea.float_mul = outer_loop_prod_ * math_op_counter.float_mul; + fea.float_divmod = outer_loop_prod_ * math_op_counter.float_divmod; + fea.float_cmp = outer_loop_prod_ * math_op_counter.float_cmp; + fea.float_math_func = outer_loop_prod_ * math_op_counter.float_math_func; + fea.float_other_func = outer_loop_prod_ * math_op_counter.float_other_func; + fea.int_mad = outer_loop_prod_ * math_op_counter.int_mad; + fea.int_addsub = outer_loop_prod_ * math_op_counter.int_addsub; + fea.int_mul = outer_loop_prod_ * math_op_counter.int_mul; + fea.int_divmod = outer_loop_prod_ * math_op_counter.int_divmod; + fea.int_math_func = outer_loop_prod_ * math_op_counter.int_math_func; + fea.int_cmp = outer_loop_prod_ * math_op_counter.int_cmp; + fea.int_other_func = outer_loop_prod_ * math_op_counter.int_other_func; + fea.bool_op = outer_loop_prod_ * math_op_counter.bool_op; + fea.select_op = outer_loop_prod_ * math_op_counter.select_op; + + fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f; + fea.vec_type = fea.unroll_type = fea.parallel_type = AnnotationPosType::kPosNone; + + fea.vec_num = vec_for_stack_.size(); + if (!vec_for_stack_.empty()) { + fea.vec_len = GetLoopExtent(vec_for_stack_.back()); + fea.vec_prod = 1.0; + for (const ForNode* pfor : vec_for_stack_) { + fea.vec_prod *= GetLoopExtent(pfor); + } + fea.vec_type = AnnotationPosType::kPosMixed; + // todo(merrymercy): this feature requires operation (tvm.compute) information + // GetAnnotationPosEncoding(vec_for_stack_.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); + } + + fea.unroll_num = unroll_for_stack_.size(); + if (!unroll_for_stack_.empty()) { + fea.unroll_len = GetLoopExtent(unroll_for_stack_.back()); + fea.unroll_prod = 1.0; + for (const ForNode* pfor : unroll_for_stack_) { + fea.unroll_prod *= GetLoopExtent(pfor); + } + fea.unroll_type = AnnotationPosType::kPosMixed; + // GetAnnotationPosEncoding(unroll_for_stack_.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); + } + + fea.parallel_num = parallel_for_stack_.size(); + if (!parallel_for_stack_.empty()) { + fea.parallel_len = GetLoopExtent(parallel_for_stack_.back()); + fea.parallel_prod = 1.0; + for (const ForNode* pfor : parallel_for_stack_) { + fea.parallel_prod *= GetLoopExtent(pfor); + } + fea.parallel_type = AnnotationPosType::kPosMixed; + // GetAnnotationPosEncoding(parallel_for_stack_.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); + } + + // GPU threads + fea.is_gpu = is_gpu_; + fea.blockIdx_x_len = blockIdx_x_len_; + fea.blockIdx_y_len = block_idx_y_len_; + fea.blockIdx_z_len = block_idx_z_len_; + fea.threadIdx_x_len = threadIdx_x_len_; + fea.threadIdx_y_len = thread_idx_y_len_; + fea.threadIdx_z_len = thread_idx_z_len_; + fea.vthread_len = vthread_len_; + } + + // Extract buffer access related features (group 2) + void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& math_op_counter, + int* cur_compute_ops, std::vector* compute_ops_list, + std::vector* mem_bytes_list) { + FeatureSet& fea = buffer_features[node->buffer]; + + // Extract all buffer accesses + std::vector acc_feas; + BufferAccessExtractor buf_extractor; + buf_extractor.InsertAccess(node->buffer, BufferAccessType::kWrite, node->indices); + buf_extractor.ExtractReads(node->value); + + // Compute touched region for all outer loops + for (auto x : for_loop_stack_) { + ana_.Bind(x->loop_var, Range::FromMinExtent(x->min, 1), true); + } + + mem_bytes_list->reserve(for_loop_stack_.size()); + compute_ops_list->reserve(for_loop_stack_.size()); + + *cur_compute_ops = math_op_counter.float_mad + math_op_counter.float_addsub + + math_op_counter.float_mul + math_op_counter.float_divmod + + math_op_counter.float_cmp + math_op_counter.float_math_func + + math_op_counter.float_other_func; + + std::vector tmp_region; + for (int i = static_cast(for_loop_stack_.size()) - 1; i >= 0; i--) { + const ForNode* p_for = for_loop_stack_[i]; + + ana_.Bind(p_for->loop_var, + Range::FromMinExtent(for_loop_stack_[i]->min, for_loop_stack_[i]->extent), true); + + // Note, here we do overwrite. + // So if there are multiple BufferStoreNode, the last one will overwrite the first few. + // e.g. The update part in gemm will overwrite the init part. + BufferMap>>& buffer_regions_map = + for_touch_regions_[p_for]; + + int64_t mem_bytes = 0; + for (const auto& x : buf_extractor.buf_accesses) { + const Buffer& t = x.first; + const BufferAccess& acc = x.second; + + ComputeRegion(acc.indices, &ana_, &tmp_region); + int64_t touched_size = ElementProduct(tmp_region); + buffer_regions_map[t].push_back( + std::make_tuple(acc.acc_type, touched_size, t->dtype.bytes())); + mem_bytes += touched_size * t->dtype.bytes(); + } + + mem_bytes_list->push_back(std::log2(mem_bytes)); + *cur_compute_ops *= GetLoopExtent(for_loop_stack_[i]); + compute_ops_list->push_back(std::log2(*cur_compute_ops)); + } + + // Buffer access related features (per buffer) + for (const auto& x : buf_extractor.buf_accesses) { + const Buffer& t = x.first; + const BufferAccess& acc = x.second; + + std::vector int_shape; + for (const auto& dim : t->shape) { + int_shape.push_back(GetIntImm(dim)); + } + + size_t ele_bytes = t->dtype.bytes(); + + // calculate bytes + float bytes = outer_loop_prod_ * ele_bytes; + float unique_bytes; + + // calculate cache lines + int64_t stride; + float lines; + float unique_lines; + + if (for_loop_stack_.empty()) { + unique_bytes = ele_bytes; + stride = 0; + lines = 1.0f; + unique_lines = 1.0f; + } else { + unique_bytes = + std::get<1>(for_touch_regions_[for_loop_stack_.front()][t].front()) * ele_bytes; + + stride = 0; + int64_t reduce_ratio = 1; + + int i; + for (i = static_cast(for_loop_stack_.size()) - 1; i >= 0; i--) { + stride = ComputeStride(acc.indices, int_shape, for_loop_stack_[i]->loop_var.get()); + if (stride != 0) { + break; + } + reduce_ratio *= GetLoopExtent(for_loop_stack_.back()); + } + + lines = outer_loop_prod_ / reduce_ratio * + std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_); + lines = std::max(lines, 1.0f); + + // convert `stride` back to the stride of the innermost iterator + stride = (i == static_cast(for_loop_stack_.size()) - 1 ? stride : 0); + + float n_continuous = ele_bytes; + for (int i = static_cast(tmp_region.size()) - 1; i >= 0; i--) { + if (tmp_region[i] == int_shape[i]) { + n_continuous *= tmp_region[i]; + break; + } + } + unique_lines = unique_bytes / std::min(n_continuous, static_cast(cache_line_size_)); + unique_lines = std::max(unique_lines, 1.0f); + } + + ReuseType reuse_type; + float reuse_dis_iter, reuse_dis_bytes, reuse_ct; + std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) = + ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_); + + acc_feas.emplace_back(); + BufferAccessFeature& acc_fea = acc_feas.back(); + + acc_fea.buffer_name = t->name; + acc_fea.acc_type = acc.acc_type; + acc_fea.stride = stride; + acc_fea.bytes = bytes; + acc_fea.unique_bytes = unique_bytes; + acc_fea.lines = lines; + acc_fea.unique_lines = unique_lines; + acc_fea.reuse_type = reuse_type; + acc_fea.reuse_dis_iter = reuse_dis_iter; + acc_fea.reuse_dis_bytes = reuse_dis_bytes; + acc_fea.reuse_ct = reuse_ct; + if (acc_fea.reuse_ct > 0.5) { + acc_fea.bytes_d_reuse_ct = bytes / reuse_ct; + acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct; + acc_fea.lines_d_reuse_ct = lines / reuse_ct; + acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct; + } else { + // no reuse, multiply by a magic number '2' + acc_fea.bytes_d_reuse_ct = bytes * 2; + acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2; + acc_fea.lines_d_reuse_ct = lines * 2; + acc_fea.unique_lines_d_reuse_ct = unique_lines * 2; + } + } + + fea.access_feas = acc_feas; + } + + // Extract arithmetic intensity related feature (group 3) + void ExtractArithmeticIntensityFeature(const BufferStoreNode* node, int cur_compute_ops, + const std::vector& compute_ops_list, + const std::vector& mem_bytes_list) { + FeatureSet& fea = buffer_features[node->buffer]; + + // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). + // We use piecewise linear interpolation to fit this curve. + int pt = 0; + if (cur_compute_ops <= 0 || compute_ops_list.empty()) { + std::fill(fea.arith_intensity_curve, + fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0); + } else { + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + float cur_compute_ops = compute_ops_list.back() * (i + 1) / ARITH_INTENSITY_CURVE_SAMPLE_N; + while (compute_ops_list[pt] < cur_compute_ops - 1e-4) { + pt++; + } + CHECK_LT(pt, compute_ops_list.size()); + + float value; + if (pt == 0) { + value = compute_ops_list[pt] / mem_bytes_list[pt]; + } else { + float base = compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]; + float slope = (compute_ops_list[pt] / mem_bytes_list[pt] - + compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]) / + (compute_ops_list[pt] - compute_ops_list[pt - 1]); + value = base + slope * (cur_compute_ops - compute_ops_list[pt - 1]); + } + fea.arith_intensity_curve[i] = value; + } + } + } + + // Extract allocation related features (group 4) + void ExtractAllocationFeature(const BufferRealizeNode* node) { + FeatureSet& fea = buffer_features[node->buffer]; + + float allocation_size = 1.0f; + for (const auto& x : node->bounds) { + allocation_size *= GetIntImm(x->extent); + } + // allocation feature + fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); + fea.alloc_prod = allocation_size * outer_loop_prod_; + fea.alloc_outer_prod = outer_loop_prod_; + fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_; + } + + // Extract outer scope related features (group 5) + void ExtractOuterScopeFeature(const BufferStoreNode* node) { + FeatureSet& fea = buffer_features[node->buffer]; + + fea.outer_prod = outer_loop_prod_; + fea.num_loops = for_loop_stack_.size(); + fea.auto_unroll_max_step = cur_auto_unroll_max_step_; + } + + // Stores FeatureSet for every buffer + BufferMap buffer_features; + + private: + // The shared arithmetic analyzer + Analyzer ana_; + + // The product of outer loop + float outer_loop_prod_ = 1.0f; + + // The stacks to store parent loops during DFS + std::vector for_loop_stack_; + std::vector parallel_for_stack_; + std::vector vec_for_stack_; + std::vector unroll_for_stack_; + + // GPU-related features + bool is_gpu_{false}; + int blockIdx_x_len_{1}; + int block_idx_y_len_{1}; + int block_idx_z_len_{1}; + int threadIdx_x_len_{1}; + int thread_idx_y_len_{1}; + int thread_idx_z_len_{1}; + int vthread_len_{1}; + int16_t cur_auto_unroll_max_step_{0}; + + // Store touch region information for all for loops. The format of this nested map: + // For a loop, for all its touched buffers, for all different accesses to the buffers, + // its (access type, number of touched elements, number of bytes of single element) + std::unordered_map>>> + for_touch_regions_; + + // The default cache line size in bytes + const int cache_line_size_ = 64; +}; + +// shifted log to incorporate the property that slog(0) = 0 +inline float slog(float x) { return x < 0 ? -std::log2(-x + 1) : std::log2(x + 1); } + +void GetPerStoreFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs, + std::vector* ret) { + PerStoreFeatureExtractor extractor(cache_line_size); + extractor(stmt); + + ret->push_back(extractor.buffer_features.size()); + + for (const auto& x : extractor.buffer_features) { + const FeatureSet& fea_set = x.second; + + /***** Group 1: Computation related features *****/ + ret->push_back(slog(fea_set.float_mad)); + ret->push_back(slog(fea_set.float_addsub)); + ret->push_back(slog(fea_set.float_mul)); + ret->push_back(slog(fea_set.float_divmod)); + ret->push_back(slog(fea_set.float_cmp)); + ret->push_back(slog(fea_set.float_math_func)); + ret->push_back(slog(fea_set.float_other_func)); + ret->push_back(slog(fea_set.int_mad)); + ret->push_back(slog(fea_set.int_addsub)); + ret->push_back(slog(fea_set.int_mul)); + ret->push_back(slog(fea_set.int_divmod)); + ret->push_back(slog(fea_set.int_cmp)); + ret->push_back(slog(fea_set.int_math_func)); + ret->push_back(slog(fea_set.int_other_func)); + ret->push_back(slog(fea_set.bool_op)); + ret->push_back(slog(fea_set.select_op)); + + ret->push_back(slog(fea_set.vec_num)); + ret->push_back(slog(fea_set.vec_prod)); + ret->push_back(slog(fea_set.vec_len)); + for (int i = 0; i <= static_cast(AnnotationPosType::kPosMixed); i++) { + ret->push_back(i == static_cast(fea_set.vec_type)); + } + + ret->push_back(slog(fea_set.unroll_num)); + ret->push_back(slog(fea_set.unroll_prod)); + ret->push_back(slog(fea_set.unroll_len)); + for (int i = 0; i <= static_cast(AnnotationPosType::kPosMixed); i++) { + ret->push_back(i == static_cast(fea_set.unroll_type)); + } + + ret->push_back(slog(fea_set.parallel_num)); + ret->push_back(slog(fea_set.parallel_prod)); + ret->push_back(slog(fea_set.parallel_len)); + for (int i = 0; i <= static_cast(AnnotationPosType::kPosMixed); i++) { + ret->push_back(i == static_cast(fea_set.parallel_type)); + } + + ret->push_back(fea_set.is_gpu); + ret->push_back(slog(fea_set.blockIdx_x_len)); + ret->push_back(slog(fea_set.blockIdx_y_len)); + ret->push_back(slog(fea_set.blockIdx_z_len)); + ret->push_back(slog(fea_set.threadIdx_x_len)); + ret->push_back(slog(fea_set.threadIdx_y_len)); + ret->push_back(slog(fea_set.threadIdx_z_len)); + ret->push_back(slog(fea_set.vthread_len)); + + /***** Group 2: Buffer access related features *****/ + // sort according to pair (lines, bytes) + std::vector> buf_order_key; + for (const auto& acc_fea : fea_set.access_feas) { + buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes); + } + std::vector buf_order(buf_order_key.size()); + std::iota(buf_order.begin(), buf_order.end(), 0); + + auto cmp = [&buf_order_key](int l, int r) { + return buf_order_key[l].first > buf_order_key[r].first || + (buf_order_key[l].first == buf_order_key[r].first && + buf_order_key[l].second > buf_order_key[r].second); + }; + std::sort(buf_order.begin(), buf_order.end(), cmp); + int n_bufs = std::min(max_n_bufs, static_cast(buf_order.size())); + buf_order.resize(n_bufs); + + for (int idx : buf_order) { + const auto& acc_fea = fea_set.access_feas[idx]; + for (int j = 0; j <= static_cast(BufferAccessType::kReadWrite); ++j) { + ret->push_back(j == static_cast(acc_fea.acc_type)); + } + ret->push_back(slog(acc_fea.bytes)); + ret->push_back(slog(acc_fea.unique_bytes)); + ret->push_back(slog(acc_fea.lines)); + ret->push_back(slog(acc_fea.unique_lines)); + for (int j = 0; j <= static_cast(ReuseType::kNoReuse); ++j) { + ret->push_back(j == static_cast(acc_fea.reuse_type)); + } + ret->push_back(slog(acc_fea.reuse_dis_iter)); + ret->push_back(slog(acc_fea.reuse_dis_bytes)); + ret->push_back(slog(acc_fea.reuse_ct)); + ret->push_back(slog(acc_fea.bytes_d_reuse_ct)); + ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct)); + ret->push_back(slog(acc_fea.lines_d_reuse_ct)); + ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct)); + ret->push_back(slog(acc_fea.stride)); + } + // - fill padding + for (int i = 0; i < max_n_bufs - n_bufs; ++i) { + for (int j = 0; j <= static_cast(BufferAccessType::kReadWrite); ++j) { // 3 + ret->push_back(0.0f); + } + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + for (int j = 0; j <= static_cast(ReuseType::kNoReuse); ++j) { // 3 + ret->push_back(0.0f); + } + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + ret->push_back(0.0f); + } + + /***** Group 3: Arithmetic intensity related features *****/ + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + ret->push_back(fea_set.arith_intensity_curve[i]); + } + + /***** Group 4: Allocation related features *****/ + ret->push_back(slog(fea_set.alloc_size)); + ret->push_back(slog(fea_set.alloc_prod)); + ret->push_back(slog(fea_set.alloc_outer_prod)); + ret->push_back(slog(fea_set.alloc_inner_prod)); + + /***** Group 5: Outer scope related features *****/ + ret->push_back(slog(fea_set.outer_prod)); + ret->push_back(slog(fea_set.num_loops)); + ret->push_back(slog(fea_set.auto_unroll_max_step)); + } +} + +void GetPerStoreFeatureName(int max_n_bufs, std::vector* ret) { + /***** Group 1: Computation related features *****/ + ret->push_back(("float_mad")); + ret->push_back(("float_addsub")); + ret->push_back(("float_mul")); + ret->push_back(("float_divmod")); + ret->push_back(("float_cmp")); + ret->push_back(("float_mathfunc")); + ret->push_back(("float_otherfunc")); + ret->push_back(("int_mad")); + ret->push_back(("int_addsub")); + ret->push_back(("int_mul")); + ret->push_back(("int_divmod")); + ret->push_back(("int_cmp")); + ret->push_back(("int_mathfunc")); + ret->push_back(("int_otherfunc")); + ret->push_back(("bool_op")); + ret->push_back(("select_op")); + ret->push_back(("vec_num")); + ret->push_back(("vec_prod")); + ret->push_back(("vec_len")); + ret->push_back(("vec_type.kPosNone")); + ret->push_back(("vec_type.kPosInnerSpatial")); + ret->push_back(("vec_type.kPosMiddleSpatial")); + ret->push_back(("vec_type.kPosOuterSpatial")); + ret->push_back(("vec_type.kPosInnerReduce")); + ret->push_back(("vec_type.kPosMiddleReduce")); + ret->push_back(("vec_type.kPosOuterReduce")); + ret->push_back(("vec_type.kPosMixed")); + ret->push_back(("unroll_num")); + ret->push_back(("unroll_prod")); + ret->push_back(("unroll_len")); + ret->push_back(("unroll_type.kPosNone")); + ret->push_back(("unroll_type.kPosInnerSpatial")); + ret->push_back(("unroll_type.kPosMiddleSpatial")); + ret->push_back(("unroll_type.kPosOuterSpatial")); + ret->push_back(("unroll_type.kPosInnerReduce")); + ret->push_back(("unroll_type.kPosMiddleReduce")); + ret->push_back(("unroll_type.kPosOuterReduce")); + ret->push_back(("unroll_type.kPosMixed")); + ret->push_back(("parallel_num")); + ret->push_back(("parallel_prod")); + ret->push_back(("parallel_len")); + ret->push_back(("parallel_type.kPosNone")); + ret->push_back(("parallel_type.kPosInnerSpatial")); + ret->push_back(("parallel_type.kPosMiddleSpatial")); + ret->push_back(("parallel_type.kPosOuterSpatial")); + ret->push_back(("parallel_type.kPosInnerReduce")); + ret->push_back(("parallel_type.kPosMiddleReduce")); + ret->push_back(("parallel_type.kPosOuterReduce")); + ret->push_back(("parallel_type.kPosMixed")); + ret->push_back(("is_gpu")); + ret->push_back(("blockIdx_x_len")); + ret->push_back(("blockIdx_y_len")); + ret->push_back(("blockIdx_z_len")); + ret->push_back(("threadIdx_x_len")); + ret->push_back(("threadIdx_y_len")); + ret->push_back(("threadIdx_z_len")); + ret->push_back(("vthread_len")); + // section total: 57 + + /***** Group 2: Buffer access related features *****/ + for (size_t i = 0; i < static_cast(max_n_bufs); ++i) { + std::string prefix = "B" + std::to_string(i) + "."; + ret->push_back((prefix + "acc_type.kRead")); + ret->push_back((prefix + "acc_type.kWrite")); + ret->push_back((prefix + "acc_type.kReadWrite")); + ret->push_back((prefix + "bytes")); + ret->push_back((prefix + "unique_bytes")); + ret->push_back((prefix + "lines")); + ret->push_back((prefix + "unique_lines")); + ret->push_back((prefix + "reuse_type.kLoopMultipleRead")); + ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite")); + ret->push_back((prefix + "reuse_type.kNoReuse")); + ret->push_back((prefix + "reuse_dis_iter")); + ret->push_back((prefix + "reuse_dis_bytes")); + ret->push_back((prefix + "reuse_ct")); + ret->push_back((prefix + "bytes_d_reuse_ct")); + ret->push_back((prefix + "unique_bytes_d_reuse_ct")); + ret->push_back((prefix + "lines_d_reuse_ct")); + ret->push_back((prefix + "unique_lines_d_reuse_ct")); + ret->push_back((prefix + "stride")); + } + // section total : max_n_bufs * 18 + + /***** Group 3: Arithmetic intensity related features *****/ + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + ret->push_back(("arith_intensity_curve_" + std::to_string(i))); + } + // section total: ARITH_INTENSITY_CURVE_SAMPLE_N = 10 + + /***** Group 4: Allocation related features *****/ + ret->push_back(("alloc_size")); + ret->push_back(("alloc_prod")); + ret->push_back(("alloc_outer_prod")); + ret->push_back(("alloc_inner_prod")); + // section total : 4 + + /***** Group 5: Outer scope related features *****/ + ret->push_back(("outer_prod")); + ret->push_back(("num_loops")); + ret->push_back(("auto_unroll_max_step")); + // section total : 3 +} + +void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs, + std::vector* feature, std::atomic* error_ct) { + te::Schedule sch; + Array tensors; + + std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + sch = sch.normalize(); + auto bounds = te::InferBound(sch); + + try { + auto stmt = te::ScheduleOps(sch, bounds, false); + Map out_binds; + Array out_arg_list; + bool compact = te::VerifyCompactBuffer(stmt); + const std::string& name = "main"; + GlobalVar global_var(name); + + // Copied from driver_api.cc::lower + auto pass_ctx = tvm::transform::PassContext::Current(); + GetBinds(tensors, compact, std::unordered_map(), &out_binds, + &out_arg_list); + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + bool disable_vectorize = + pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); + bool instrument_bound_checkers = + pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + auto mod = IRModule(Map({{global_var, f}})); + + if (task->target->kind->device_type == kDLGPU) { + auto pass_list = Array(); + // Phase 0 + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + // Phase 1 + pass_list.push_back(tir::transform::NarrowDataType(32)); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tir::transform::Simplify()); + tvm::Map gpu_params{ + {"max_shared_memory_per_block", task->hardware_params->max_shared_memory_per_block}, + {"max_local_memory_per_block", task->hardware_params->max_registers_per_block}, + {"max_threads_per_block", task->hardware_params->max_threads_per_block}, + {"max_vector_bytes", task->hardware_params->vector_unit_bytes}, + {"max_vthread", task->hardware_params->max_vthread_extent}, + }; + pass_list.push_back(tir::transform::VerifyGPUCode(gpu_params)); + const auto& optimize = tir::transform::Sequential(pass_list); + optimize(mod); + } + const auto& optimize = + tir::transform::Sequential(Array{tir::transform::Simplify()}); + mod = optimize(std::move(mod)); + const auto& it = mod->functions.find(global_var); + CHECK(it != mod->functions.end()); + const auto& prim_func = (*it).second.as(); + GetPerStoreFeature(prim_func->body, task->hardware_params->cache_line_bytes, max_n_bufs, + feature); + } catch (dmlc::Error& e) { + (*error_ct)++; + } +} + +void GetPerStoreFeaturesFromStates(const Array& states, const SearchTask& task, + int skip_first_n_feature_extraction, int max_n_bufs, + std::vector>* features) { + // extract features + features->assign(states.size(), std::vector()); + + std::atomic error_ct(0); + + for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { + GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &(*features)[i], &error_ct); + } + + if (error_ct > 0) { + std::cerr << "Encountered " << error_ct + << " errors during feature extraction, which are safely ignored." << std::endl; + } +} + +void GetPerStoreFeaturesFromStates(const Array& states, const std::vector& tasks, + int skip_first_n_feature_extraction, int max_n_bufs, + std::vector>* features) { + // extract features + features->assign(states.size(), std::vector()); + + std::atomic error_ct(0); + + for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { + GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, &(*features)[i], &error_ct); + } + + if (error_ct > 0) { + std::cerr << "Encountered " << error_ct + << " errors during feature extraction. which are safely ignored." << std::endl; + } +} + +void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs, + std::vector>* features, + std::vector* normalized_throughputs, + std::vector* task_ids) { + Array states; + std::vector tasks; + + normalized_throughputs->clear(); + task_ids->clear(); + + // (workload_key, target) -> (search_task, task_id) + std::unordered_map, std::pair> task_cache; + // task_id -> min_cost + std::vector min_costs; + + const auto* workload_key_to_tensors = + tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors"); + CHECK(workload_key_to_tensors != nullptr); + + // read from file + RecordReader reader(filename); + auto cur_inp = make_object(); + auto cur_res = make_object(); + while (reader->ReadNext(cur_inp.get(), cur_res.get())) { + float cost = static_cast(FloatArrayMean(cur_res->costs)); + const std::string& workload_key = cur_inp->task->workload_key; + + SearchTask task; + size_t task_id; + std::pair key(workload_key, cur_inp->task->target->str()); + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + // rebuild task + Array tensors = (*workload_key_to_tensors)(workload_key); + task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, + cur_inp->task->target_host, cur_inp->task->hardware_params); + task_id = task_cache.size(); + + // compute min cost for each task + task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); + min_costs.push_back(cost); + } else { + std::tie(task, task_id) = find_res->second; + min_costs[task_id] = std::min(min_costs[task_id], cost); + } + + tasks.push_back(std::move(task)); + task_ids->push_back(task_id); + states.push_back(cur_inp->state); + normalized_throughputs->push_back(cost); + + if (max_lines > 0 && static_cast(states.size()) >= max_lines) { + break; + } + } + + for (size_t i = 0; i < normalized_throughputs->size(); ++i) { + (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; + } + + GetPerStoreFeaturesFromStates(states, tasks, 0, max_n_bufs, features); +} + +void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, + const Array& results, + int skip_first_n_feature_extraction, int max_n_bufs, + std::vector>* features, + std::vector* normalized_throughputs, + std::vector* task_ids) { + Array states; + std::vector tasks; + + normalized_throughputs->clear(); + task_ids->clear(); + + // (workload_key, target) -> (search_task, task_id) + std::unordered_map, std::pair> task_cache; + // task_id -> min_cost + std::vector min_costs; + + const auto* workload_key_to_tensors = + tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors"); + CHECK(workload_key_to_tensors != nullptr); + + tasks.reserve(inputs.size()); + normalized_throughputs->reserve(inputs.size()); + task_ids->reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + float cost = static_cast(FloatArrayMean(results[i]->costs)); + const std::string& workload_key = inputs[i]->task->workload_key; + SearchTask task; + + size_t task_id; + std::pair key(workload_key, inputs[i]->task->target->str()); + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete + task = inputs[i]->task; + } else { // the measure input is incomplete + // rebuild task for incomplete measure pairs read from file + Array tensors = (*workload_key_to_tensors)(workload_key); + task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, + inputs[i]->task->target_host, inputs[i]->task->hardware_params); + } + task_id = task_cache.size(); + + // compute min cost for each task + task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); + min_costs.push_back(cost); + } else { + std::tie(task, task_id) = find_res->second; + min_costs[task_id] = std::min(min_costs[task_id], cost); + } + + tasks.push_back(std::move(task)); + task_ids->push_back(task_id); + states.push_back(inputs[i]->state); + normalized_throughputs->push_back(cost); + } + + for (size_t i = 0; i < normalized_throughputs->size(); ++i) { + (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; + } + + GetPerStoreFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, max_n_bufs, + features); +} + +/* + * \brief Serialize a two-dimensional variable-size feature vector with normalized throughputs + * and task ids to a one-dimensional flatten byte array. + * We have to serialize it for faster transmission speed when copying it to python. + * This flatten array will be deserialized in python. + * + * serialization format for n records: + * + * int n; + * int[n+2] sizes + * + * float[sizes[0]] feature for record 1 + * float[sizes[1]] feature for record 2 + * ... feature for record i... + * float[sizes[n-1]] feature for record n + * + * float[sizes[n]] normalized throughput for n records + * int[sizes[n+1]] task id for n records + */ +TVMByteArray SerializeFeatures(std::vector>&& features, + std::vector&& normalized_throughputs, + std::vector&& task_ids, std::vector* out_data) { + size_t total_bytes = 0; + std::vector size_vector; + + int n = features.size(); + + // serialize sizes + size_t size_vector_size = 1 + n + 2; + total_bytes += size_vector_size * sizeof(int); + + size_vector.reserve(size_vector_size); + size_vector.push_back(features.size()); + for (const auto& x : features) { + size_vector.push_back(static_cast(x.size())); + total_bytes += sizeof(float) * x.size(); + } + size_vector.push_back(static_cast(normalized_throughputs.size())); + total_bytes += sizeof(float) * normalized_throughputs.size(); + size_vector.push_back(static_cast(task_ids.size())); + total_bytes += sizeof(int) * task_ids.size(); + + CHECK_EQ(size_vector.size(), size_vector_size); + + // allocate memory + out_data->reserve(total_bytes); + char* ptr = out_data->data(); + + // serialize size_vector + memmove(ptr, reinterpret_cast(size_vector.data()), size_vector.size() * sizeof(int)); + ptr += size_vector.size() * sizeof(int); + + // serialize features + for (auto& x : features) { + memmove(ptr, x.data(), sizeof(float) * x.size()); + ptr += sizeof(float) * x.size(); + x.clear(); + } + + // serialize normalized_throughputs + memmove(ptr, reinterpret_cast(normalized_throughputs.data()), + normalized_throughputs.size() * sizeof(int)); + ptr += normalized_throughputs.size() * sizeof(int); + + // serialize task_ids + memmove(ptr, reinterpret_cast(task_ids.data()), task_ids.size() * sizeof(int)); + ptr += task_ids.size() * sizeof(int); + + CHECK_EQ(ptr - out_data->data(), total_bytes); + + return TVMByteArray{out_data->data(), total_bytes}; +} + +TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromFile") + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::string filename = args[0]; + int max_lines = args[1]; + int max_n_bufs = args[2]; + + std::vector> features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features, + &normalized_throughputs, &task_ids); + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromMeasurePairs") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Array inputs = args[0]; + Array results = args[1]; + int skip_first_n_feature_extraction = args[2]; + int max_n_bufs = args[3]; + + std::vector> features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStoreFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction, + max_n_bufs, &features, &normalized_throughputs, + &task_ids); + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromStates") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Array states = args[0]; + SearchTask task = args[1]; + int max_n_bufs = args[2]; + + std::vector> features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStoreFeaturesFromStates(states, task, 0, max_n_bufs, &features); + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeatureNames") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int max_n_bufs = args[0]; + std::vector names; + + GetPerStoreFeatureName(max_n_bufs, &names); + + Array arr; + for (const auto& x : names) { + arr.push_back(x); + } + *ret = arr; + }); + +} // namespace auto_scheduler +} // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index 2530d554e8ee..6b76dc607da9 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -44,13 +44,17 @@ def test_estimate_flop(): D = topi.nn.relu(C) dag = auto_scheduler.ComputeDAG([A, B, D]) - assert abs(dag.flop_ct - 2 * N ** 3 - N * N) < 0.5 + assert abs(dag.flop_ct - (2 * N ** 3 + N * N)) < 0.5 # should not count the comparison operations in padding - D = topi.nn.pad(C, [1, 1]) - dag = auto_scheduler.ComputeDAG([A, B, D]) + E = topi.nn.pad(C, [1, 1]) + dag = auto_scheduler.ComputeDAG([A, B, E]) assert abs(dag.flop_ct - 2 * N ** 3) < 0.5 + F = te.compute((N, N), lambda i, j: E[i,j], name='F', attrs={"FLOP": 1234}) + dag = auto_scheduler.ComputeDAG([A, B, F]) + assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5 + if __name__ == "__main__": test_apply_steps() diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py new file mode 100644 index 000000000000..05f1cbb8641c --- /dev/null +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test feature extraction""" + +import math +import tempfile + +import tvm +from tvm import te, auto_scheduler + +from test_auto_scheduler_common import matmul_auto_scheduler_test + + +def fequal(a, b): + return math.fabs(a - b) < 1e-6 + + +def test_cpu_matmul(): + dag = auto_scheduler.ComputeDAG(matmul_auto_scheduler_test(512, 512, 512)) + s = dag.get_init_state() + C = s.stage_ops[2] + + i, j, k = s[C].iters + io, ii = s.split(C, i, [16]) + jo, ji = s.split(C, j, [8]) + s.reorder(C, [io, jo, k, ji, ii]) + s.vectorize(C, ji) + s.parallel(C, io) + s.parallel(C, jo) + s.unroll(C, k) + + target = tvm.target.create('llvm') + task = auto_scheduler.SearchTask(dag, "test", target) + names = auto_scheduler.feature.get_per_store_feature_names() + fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0] + + stage_0 = fea[0] + assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) + fea_dict = {} + for name, value in zip(names, stage_0): + fea_dict[name] = value + + for name in ["B0", "B1", "B2"]: + if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0): + c_name = name + if fequal(fea_dict[name + ".acc_type.kRead"], 1.0): + if fequal(fea_dict[name + ".stride"], 0.0): + b_name = name + else: + a_name = name + + """ + lowered IR: + + Placeholder: A, B + parallel i.0 (0,32) + parallel j.0 (0,64) + unroll k (0,512) + vectorize j.1 (0,8) + for i.1 (0,16) + C...] = A[...] * B[...] + """ + + # check touched memory in bytes, touched unique memory in bytes, reuse distance, etc. + assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) + assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) + assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) + assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) + assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) + + # check annotations + assert fequal(fea_dict["unroll_num"], math.log2(1 + 1)) + # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0) + assert fequal(fea_dict["vec_num"], math.log2(1 + 1)) + assert fequal(fea_dict["parallel_num"], math.log2(2 + 1)) + assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1)) + + +def test_cpu_fusion(): + def fusion_test(N, M): + A = te.placeholder((N, M), name='A') + B = te.compute((N, M), lambda i, j: A[i][j], name='B') + C = te.compute((N, M), lambda i, j: B[i][j], name='C') + return [A, B, C] + + dag = auto_scheduler.ComputeDAG(fusion_test(64, 32)) + s = dag.get_init_state() + s.compute_at(1, 2, s.stages[2].iters[1]) + + target = tvm.target.create('llvm') + task = auto_scheduler.SearchTask(dag, "test", target) + names = auto_scheduler.feature.get_per_store_feature_names() + fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0] + + """ + lowered IR: + + Placeholder: A + for i (0,64) + for j (0,32) + for ii (1) + for jj (1) + B[...] = A[...] + C[...] = B[...] + """ + + # check reuse distance and reuse type after fusion + found = False + for stage_fea in fea: + for i, (name, value) in enumerate(zip(names, stage_fea)): + if 'reuse_type.kSerialMultipleReadWrite' in name and value > 0.5: + assert fequal(stage_fea[i + 2], 1.0) # reuse distance in #iter + assert fequal(stage_fea[i + 3], math.log2(16 + 1)) # reuse distance in bytes + found = True + assert found + + +def test_gpu_feature(): + # Use records to build a complicated GPU program + json_records = "\n".join(( + """{"i": [["[\\"matmul_auto_scheduler_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""", + )) + + # load states + with tempfile.NamedTemporaryFile(mode='w') as f: + f.write(json_records) + f.flush() + inputs, results = auto_scheduler.RecordReader(f.name).read_lines() + + inp = inputs[0] + dag = auto_scheduler.ComputeDAG(inp.task.workload_key) + task = auto_scheduler.SearchTask(dag, inp.task.workload_key, inp.task.target, None, auto_scheduler.HardwareParams(100000, 16, 64)) + + state = dag.infer_bound_from_state(inputs[0].state) + fea = auto_scheduler.feature.get_per_store_features_from_states([state], task)[0] + names = auto_scheduler.feature.get_per_store_feature_names() + + # build feature dict + fea_dicts = [] + for i in range(len(fea)): + tmp_dict = {} + for j in range(len(names)): + tmp_dict[names[j]] = fea[i][j] + fea_dicts.append(tmp_dict) + + """ + lowered IR: + + Placeholder: A, B + blockIdx.x i.0@j.0@ (0,8) + vthread i.1@j.1@ (0,4) + threadIdx.x i.2@j.2@ (0,16) + C.local auto_unroll: 1024 + for k.0 (0,256) + for ax0@ax1@.0 (0,8) + threadIdx.x ax0@ax1@.1 (0,16) + B.shared = ... + for ax0@ax1@.0 (0,64) + threadIdx.x ax0@ax1@.1 (0,16) + A.shared = ... + for i_c.3 (0,32) + for k.2 (0,2) + for j_c.4 (0,16) + C.local = ... + for i.3 (0,32) + for j.3 (0,16) + C = ... + """ + + # check gpu-related features + assert fequal(fea_dicts[0]['blockIdx_x_len'], math.log2(8 + 1)) + assert fequal(fea_dicts[0]['vthread_len'], math.log2(4 + 1)) + assert fequal(fea_dicts[1]['threadIdx_x_len'], math.log2(16 + 1)) + assert fequal(fea_dicts[0]['threadIdx_y_len'], math.log2(1 + 1)) + assert fequal(fea_dicts[2]['blockIdx_z_len'], math.log2(1 + 1)) + assert fequal(fea_dicts[0]['is_gpu'], 1.0) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_fusion() + test_gpu_feature()