Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Start porting pass to the pass manager #3191

Merged
merged 2 commits into from
May 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 87 additions & 52 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <string>
#include <vector>
Expand Down Expand Up @@ -84,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod,
*/
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);

/*! \brief Compare two expressions for structural equivalence.
/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
Expand All @@ -101,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
*/
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);

/*! \brief Compare two types for structural equivalence.
/*!
* \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
Expand All @@ -119,7 +122,8 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
*/
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);

/*! \brief Add abstraction over a function
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
Expand All @@ -135,7 +139,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
*/
TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);

/*! \brief Check that each Var is only bound once.
/*!
* \brief Check that each Var is only bound once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
Expand All @@ -148,7 +153,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
*/
TVM_DLL bool WellFormed(const Expr& expr);

/*! \brief Get all bound variables from expression expr.
/*!
* \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
Expand All @@ -159,7 +165,8 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);

/*! \brief Get all bound variables from pattern pat.
/*!
* \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
Expand All @@ -170,7 +177,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);

/*! \brief Get free type parameters from expression expr.
/*!
* \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
* let or a function parameter in the context.
Expand All @@ -181,15 +189,17 @@ TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
*/
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);

/*! \brief Get all variables from expression expr.
/*!
* \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);

/*! \brief Get free TypeVars from expression expr.
/*!
* \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
Expand All @@ -201,7 +211,8 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);

/*! \brief Get free TypeVars from type t.
/*!
* \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
Expand All @@ -213,7 +224,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);

/*! \brief Get all bound type variables from expression expr.
/*!
* \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
Expand All @@ -225,7 +237,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);

/*! \brief Get all bound type variables from type t.
/*!
* \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
Expand All @@ -237,7 +250,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);

/*! \brief Get all type variables in expression expr.
/*!
* \brief Get all type variables in expression expr.
*
* \param expr the expression.
* \param mod the module.
Expand All @@ -246,7 +260,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);

/*! \brief Get all type variables in type t.
/*!
* \brief Get all type variables in type t.
*
* \param t the type.
* \param mod the module.
Expand All @@ -273,22 +288,27 @@ TVM_DLL Expr DeadCodeElimination(const Expr& e);

/*!
* \brief Fold constant expressions.
*
* \param expr the expression to be optimized.
*
* \return The optimized expression.
*/
TVM_DLL Expr FoldConstant(const Expr& expr);

/*!
* \brief Fuse operations into expr into seperate functions.
*
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \param mod the module.
*
* \return The optimized expression.
*/
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
Expand All @@ -298,84 +318,68 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Rewrite the annotated program.
*
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The updated program.
*/
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;

/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.

*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param e the expression to observably share.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form
* \return expression in A-Normal Form.
*/
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);

/*! \brief Remove let binding and directly share via pointer instead.
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
Expand All @@ -386,18 +390,49 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);

/*! \brief Aggressive constant propagation/constant folding/inlining.
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
*
* \return the optimized expression.
*/
Expr PartialEval(const Expr& e);
TVM_DLL Expr PartialEval(const Expr& e);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;

/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};

namespace vm {

/*! \brief Compile a module, and construct the virtual machine.
/*!
* \brief Compile a module, and construct the virtual machine.
*
* \param mod The module to compile.
*
* \return The constructed virtual machine.
*/
runtime::vm::VirtualMachine CompileModule(const Module& mod);
Expand Down
Loading