Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Pass manager #2546

Merged
merged 32 commits into from
Mar 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ebf4c29
initial commit
Jan 22, 2019
4630afc
add python frontend and module tests
zhiics Jan 30, 2019
c894f04
add unit tests for function pass and optimize interface
zhiics Jan 31, 2019
0880027
add ExprPass
zhiics Jan 31, 2019
4122079
remove PassState and pass context for run
zhiics Feb 14, 2019
49ae421
add required_passes
zhiics Feb 14, 2019
b1adacc
return module
zhiics Feb 15, 2019
42a3227
remove move
zhiics Feb 15, 2019
c1c6d07
fix minor reviews
zhiics Feb 19, 2019
fd22d34
remove optimizer, optimizer->pass_manager, make pass a the base class…
zhiics Feb 20, 2019
4cf5843
remove deleted files
zhiics Feb 20, 2019
04ef13e
move resolvedependency to sequential pass, use ir_pass namespace
zhiics Feb 21, 2019
4cd4bd1
add todo
zhiics Feb 21, 2019
d98af5a
add disabled passes in sequetialpass
zhiics Feb 21, 2019
ccb5197
fix minor
zhiics Feb 21, 2019
e5da540
fix currying doc
zhiics Feb 21, 2019
c021126
remove pass_kind from passnode
zhiics Feb 25, 2019
42c5619
remove pass kind from test
zhiics Feb 25, 2019
7cfd1b6
fix doc
zhiics Feb 27, 2019
8c4d548
fix per @tqchen's comments
zhiics Mar 2, 2019
15e1d6c
remove pass_manager.py create separate classes
zhiics Mar 2, 2019
02d7de7
simplify pass_func
zhiics Mar 3, 2019
550ad63
inline using passfunc
zhiics Mar 3, 2019
699e5b3
update doc
zhiics Mar 3, 2019
dc2b30c
disable test_quantize_pass for now
zhiics Mar 5, 2019
49df272
create PassInfo class to contain the meta data
zhiics Mar 8, 2019
810cbac
flatten passinfo for interface
zhiics Mar 8, 2019
ee359fb
retrigger ci
zhiics Mar 8, 2019
a0863b4
remove required method
zhiics Mar 9, 2019
c8fb6b5
make Pass python class lighter
zhiics Mar 10, 2019
6df5c7d
create pass -> decorator
zhiics Mar 11, 2019
e78f4e2
make the api consistent for all classes
zhiics Mar 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,225 @@
* Copyright (c) 2018 by Contributors
* \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++.
*
* This file also implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* manager performs the Relay.Module -> Relay.Module transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
*/
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>

#include <string>
#include <vector>

namespace tvm {
namespace relay {

namespace pass {

/*
* \brief The context of pass.
*/
class PassContext;

/*!
* \brief PassContextNode contains the information that a pass can rely on, such as
* analysis results.
*/
class PassContextNode : public RelayNode {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;

PassContextNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) final {
}

TVM_DLL static PassContext make();

static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassContext, PassContextNode)

/*
* \brief The meta data of a pass.
*
* PassInfo can be extended conveniently in the future if more meta information
* is needed.
*/
class PassInfo;

/*!
* \brief PassInfoNode contains meta data that will be used to help optimization
* and analysis.
*/
class PassInfoNode : public RelayNode {
public:
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;

/*! \brief The name of an optimization/analysis pass. */
std::string name;

/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::Expr> required;

PassInfoNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
}

TVM_DLL static PassInfo make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required);

static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)

class Pass;

/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public RelayNode {
public:
/*
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;

/*!
* \brief Set the context information for a pass.
*
* \param pass_ctx The context information for a certain pass.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;

/*!
* \brief Execute the optimization pass using a functor.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;

void VisitAttrs(tvm::AttrVisitor* v) override {}

static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};

class Pass : public NodeRef {
tqchen marked this conversation as resolved.
Show resolved Hide resolved
public:
Pass() = default;
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}

PassNode* operator->() const {
return static_cast<PassNode*>(this->node_.get());
}

using ContainerType = PassNode;
};

/*
* \brief Create a module pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the module pass.
* \param name The name of the module pass.
* \param required The list of the passes that the module pass is dependent on.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);

/*
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a sequential pass.
*
* \param passes The optimization passes will be performed.
* \param opt_level The optimization level of the sequential pass.
* \param name The name of the sequential pass.
* \param required The list of the passes that the sequential pass is dependent on.
* \param disabled The disabled passes.
*
* \return The created sequential pass.
*/
Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required,
const tvm::Array<tvm::Expr>& disabled);

} // namespace pass

/*!
* \brief Infer the type of an expression.
*
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@
var = expr.var
const = expr.const
bind = expr.bind
module_pass = ir_pass.module_pass
function_pass = ir_pass.function_pass
sequential_pass = ir_pass.sequential_pass

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
Expand All @@ -90,3 +93,11 @@
# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict

# Pass manager
PassInfo = ir_pass.PassInfo
PassContext = ir_pass.PassContext
Pass = ir_pass.Pass
ModulePass = ir_pass.ModulePass
FunctionPass = ir_pass.FunctionPass
SequentialPass = ir_pass.SequentialPass
57 changes: 56 additions & 1 deletion python/tvm/relay/_ir_pass.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,60 @@
from .env import Module
import tvm
from . import ir
from .base import NodeBase
from .env import Module


class PassContext(NodeBase):
def __init__(self):
...

class PassInfo(NodeBase):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list

def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None


class Pass(NodeBase):
def __init__(self):
...


class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list

def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...


class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list

def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...


class SequentialPass(Pass):
name = ... # type: str
opt_level = ... # type: int
passes = ... # type: list
required = ... # type: list
disabled = ... # type: list

def __init__(self, name, opt_level, passes, required, disabled):
# type: (str, int, list, list, list) -> None
...


def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
Expand Down
Loading