Skip to content

Commit

Permalink
[Relay] Start porting pass to the pass manager (apache#3191)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and Wei Chen committed Jun 26, 2019
1 parent 51a2d64 commit f8c4cb9
Show file tree
Hide file tree
Showing 12 changed files with 328 additions and 69 deletions.
139 changes: 87 additions & 52 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <string>
#include <vector>
Expand Down Expand Up @@ -84,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod,
*/
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);

/*! \brief Compare two expressions for structural equivalence.
/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
Expand All @@ -101,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
*/
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);

/*! \brief Compare two types for structural equivalence.
/*!
* \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
Expand All @@ -119,7 +122,8 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
*/
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);

/*! \brief Add abstraction over a function
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
Expand All @@ -135,7 +139,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
*/
TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);

/*! \brief Check that each Var is only bound once.
/*!
* \brief Check that each Var is only bound once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
Expand All @@ -148,7 +153,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
*/
TVM_DLL bool WellFormed(const Expr& expr);

/*! \brief Get all bound variables from expression expr.
/*!
* \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
Expand All @@ -159,7 +165,8 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);

/*! \brief Get all bound variables from pattern pat.
/*!
* \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
Expand All @@ -170,7 +177,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);

/*! \brief Get free type parameters from expression expr.
/*!
* \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
* let or a function parameter in the context.
Expand All @@ -181,15 +189,17 @@ TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
*/
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);

/*! \brief Get all variables from expression expr.
/*!
* \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);

/*! \brief Get free TypeVars from expression expr.
/*!
* \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
Expand All @@ -201,7 +211,8 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);

/*! \brief Get free TypeVars from type t.
/*!
* \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
Expand All @@ -213,7 +224,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);

/*! \brief Get all bound type variables from expression expr.
/*!
* \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
Expand All @@ -225,7 +237,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);

/*! \brief Get all bound type variables from type t.
/*!
* \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
Expand All @@ -237,7 +250,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);

/*! \brief Get all type variables in expression expr.
/*!
* \brief Get all type variables in expression expr.
*
* \param expr the expression.
* \param mod the module.
Expand All @@ -246,7 +260,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);

/*! \brief Get all type variables in type t.
/*!
* \brief Get all type variables in type t.
*
* \param t the type.
* \param mod the module.
Expand All @@ -273,22 +288,27 @@ TVM_DLL Expr DeadCodeElimination(const Expr& e);

/*!
* \brief Fold constant expressions.
*
* \param expr the expression to be optimized.
*
* \return The optimized expression.
*/
TVM_DLL Expr FoldConstant(const Expr& expr);

/*!
* \brief Fuse operations into expr into seperate functions.
*
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \param mod the module.
*
* \return The optimized expression.
*/
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
Expand All @@ -298,84 +318,68 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Rewrite the annotated program.
*
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The updated program.
*/
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;

/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param e the expression to observably share.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form
* \return expression in A-Normal Form.
*/
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);

/*! \brief Remove let binding and directly share via pointer instead.
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
Expand All @@ -386,18 +390,49 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);

/*! \brief Aggressive constant propagation/constant folding/inlining.
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
*
* \return the optimized expression.
*/
Expr PartialEval(const Expr& e);
TVM_DLL Expr PartialEval(const Expr& e);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;

/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};

namespace vm {

/*! \brief Compile a module, and construct the virtual machine.
/*!
* \brief Compile a module, and construct the virtual machine.
*
* \param mod The module to compile.
*
* \return The constructed virtual machine.
*/
runtime::vm::VirtualMachine CompileModule(const Module& mod);
Expand Down
Loading

0 comments on commit f8c4cb9

Please sign in to comment.