Skip to content

Commit

Permalink
[Unity][Pass][TuningAPI] Introduce TuningAPI and MetaSchedule pass (#…
Browse files Browse the repository at this point in the history
…14014)

Add TuningAPI and MetaSchedule tuning pass
  • Loading branch information
sunggg authored Feb 17, 2023
1 parent ba47501 commit 9e36bb1
Show file tree
Hide file tree
Showing 29 changed files with 3,987 additions and 47 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/analysis/*.cc
src/relax/transform/*.cc
src/relax/backend/vm/*.cc
src/relax/backend/task_extraction.cc
src/relax/utils.cc
)

Expand Down
54 changes: 44 additions & 10 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* Similar to LLVM's pass manager, we designed the Relay/Relax pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay/Relax pass
* manager performs the IRModule -> IRModule transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* means to ease the development and testing of Relay/Relax passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
Expand Down Expand Up @@ -90,7 +90,16 @@ class PassContextNode : public Object {

/*! \brief A list of pass instrument implementations. */
Array<instrument::PassInstrument> instruments;

// TODO(@sunggg): Fix dependency issue in the header file and correct the types
// e.g., relax::trace, relax::database in tvm/relax/tuning_api.h
/*! \brief Trace stack for relax pass infra. */
mutable Array<ObjectRef> trace_stack;
/*! \brief List of passes to be traced. If not defined, make every pass traceable. */
Optional<Map<String, Bool>> make_traceable;
/*! \brief Number of evaluations conducted in the pass pipeline. */
mutable int num_evals{0};
/*! \brief Database for tuning API. */
Optional<ObjectRef> tuning_api_database;
PassContextNode() = default;

/*!
Expand Down Expand Up @@ -130,7 +139,27 @@ class PassContextNode : public Object {
v->Visit("instruments", &instruments);
v->Visit("config", &config);
v->Visit("diag_ctx", &diag_ctx);
v->Visit("trace_stack", &trace_stack);
v->Visit("make_traceable", &make_traceable);
v->Visit("num_evals", &num_evals);
v->Visit("tuning_api_daatabase", &tuning_api_database);
}

Array<ObjectRef> GetTraceStack() { return trace_stack; }
void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); }
void PopTrace() {
ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check.";
trace_stack.pop_back();
}
int GetTraceStackSize() { return trace_stack.size(); }
ObjectRef GetCurrentTrace() {
ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check.";
return trace_stack.back();
}
void SetNumEvals(int _num_evals) { num_evals = _num_evals; }
void IncNumEvals(int _num_evals) { num_evals += _num_evals; }

Optional<ObjectRef> GetTuningAPIDatabase() { return tuning_api_database; }

static constexpr const char* _type_key = "transform.PassContext";
static constexpr bool _type_has_method_sequal_reduce = false;
Expand Down Expand Up @@ -287,6 +316,9 @@ class PassInfoNode : public Object {
/*! \brief The name of an optimization/analysis pass. */
String name;

/*! \brief Boolean that tells whether this pass will be traced or not. */
bool traceable;

/*! \brief The passes that are required to perform the current pass. */
Array<String> required;

Expand All @@ -296,6 +328,7 @@ class PassInfoNode : public Object {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
v->Visit("traceable", &traceable);
}

static constexpr const char* _type_key = "transform.PassInfo";
Expand All @@ -314,16 +347,17 @@ class PassInfo : public ObjectRef {
* \param opt_level The optimization level
* \param name Name of the pass.
* \param required The passes that are required to perform the current pass.
* \param traceable Boolean that tells whether the pass is traceable.
*/
TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);
TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required, bool traceable);

TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};

/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
* at different granularity of Relay/Relax nodes.
*/
class PassNode : public Object {
public:
Expand Down Expand Up @@ -396,7 +430,7 @@ class Pass : public ObjectRef {
};

/*!
* \brief The SequentialNode contains a set of passes that transform Relay
* \brief The SequentialNode contains a set of passes that transform Relay/Relax
* programs from one AST to another semantically equivalent one.
*
* One example of this level of pass is that the pass manager needs to correctly
Expand Down Expand Up @@ -489,9 +523,9 @@ class Sequential : public Pass {
*
* \return The created module pass.
*/
TVM_DLL Pass
CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, String name, Array<runtime::String> required);
TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, int opt_level,
String name, Array<runtime::String> required, bool traceable = false);

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
Expand Down
22 changes: 20 additions & 2 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ using DataflowBlock = tvm::relax::DataflowBlock;
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
* \param traceable Boolean variable whether the dataflowblock pass is traceable.
*
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);
int opt_level, String name, tvm::Array<String> required, bool traceable = false);

/*!
* \brief Create a dataflowblock pass.
Expand All @@ -58,12 +59,13 @@ TVM_DLL Pass CreateFunctionPass(
* \param opt_level The optimization level of the dataflowblock pass.
* \param name The name of the dataflowblock pass.
* \param required The list of the passes that the dataflowblock pass is dependent on.
* \param traceable Boolean variable whether the dataflowblock pass is traceable.
*
* \return The created dataflowblock pass.
*/
TVM_DLL Pass CreateDataflowBlockPass(
const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);
int opt_level, String name, tvm::Array<String> required, bool traceable = false);

/*!
* \brief Transform all dataflow structure to non-dataflow version.
Expand Down Expand Up @@ -93,6 +95,22 @@ TVM_DLL Pass CallTIRRewrite();
*/
TVM_DLL Pass RewriteDataflowReshape();

/*!
* \brief Bind params of function of the module to constant tensors.
*
* \param func_name The name of the function to bind parameters.
* \param params The parameters to bind.
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params);

/*!
* \brief Fold constant expressions.
*
* \return The Pass.
*/
TVM_DLL Pass FoldConstant();
/*!
* \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
*
Expand Down
Loading

0 comments on commit 9e36bb1

Please sign in to comment.