Skip to content

Commit

Permalink
remove pass_manager.py create separate classes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 2, 2019
1 parent 498de9b commit 061a490
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 324 deletions.
55 changes: 26 additions & 29 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,30 @@
*
* This file also implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager.
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a pass manager usually involves:
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* TODO(jroesch, zhiics): We are currently using a very simple design for the
* pass manager, i.e. it executes a specific pass or sequence of passes.
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements from deep learning frameworks. Each
* pass in the Relay pass manager performs the Relay.Module -> Relay.Module
* transformation. All different type of passes including the sequential-level
* pass object are essentially a pass object. This design, therefore,
* effectively provides users a consistent and convenient interface, i.e. pass,
* to play with. It offers a means to eases the development and testing Relay
* passes. For example, with the pass manager, external users will be able to
* have custom passes correctly scheduled without having to modify a single
* handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
Expand All @@ -26,7 +39,6 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

#include <tvm/attrs.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
Expand All @@ -43,11 +55,6 @@ namespace relay {

namespace pass {

// Forward declaration
class ModulePass;
class FunctionPass;
class SequentialPass;

// Define pass context.
class PassContext;

Expand All @@ -60,7 +67,7 @@ class PassContextNode : public RelayNode {
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter_;
ErrorReporter err_reporter;

PassContextNode() = default;

Expand All @@ -73,17 +80,7 @@ class PassContextNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};

class PassContext : public NodeRef {
public:
PassContext() = default;
explicit PassContext(NodePtr<tvm::Node> p) : NodeRef(p) {}

const PassContextNode* operator->() const {
return static_cast<PassContextNode*>(this->node_.get());
}

using ContainerType = PassContextNode;
};
TVM_DEFINE_NODE_REF(PassContext, PassContextNode)

// We use currying here. It runs on a Relay node type NodeT and yields a new
// node with the same type. The Relay module is captured for optimizations as
Expand Down Expand Up @@ -167,8 +164,8 @@ class Pass : public NodeRef {
*
* \return The created module pass.
*/
ModulePass CreateModulePass(const std::string& name, int opt_level,
const PassFunc<Module>& pass_func);
Pass CreateModulePass(const std::string& name, int opt_level,
const PassFunc<Module>& pass_func);

/*
* \brief Create a function pass.
Expand All @@ -179,8 +176,8 @@ ModulePass CreateModulePass(const std::string& name, int opt_level,
*
* \return The created function pass.
*/
FunctionPass CreateFunctionPass(const std::string& name, int opt_level,
const PassFunc<Function>& pass_func);
Pass CreateFunctionPass(const std::string& name, int opt_level,
const PassFunc<Function>& pass_func);
/*
* \brief Create a sequential pass.
*
Expand All @@ -192,9 +189,9 @@ FunctionPass CreateFunctionPass(const std::string& name, int opt_level,
*
* \return The created sequential pass.
*/
SequentialPass CreateSequentialPass(const std::string& name, int opt_level,
const tvm::Array<Pass>& passes,
const tvm::Array<tvm::Expr>& disabled);
Pass CreateSequentialPass(const std::string& name, int opt_level,
const tvm::Array<Pass>& passes,
const tvm::Array<tvm::Expr>& disabled);

} // namespace pass

Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from . import ir_pass
from .build_module import build, build_config, create_executor, optimize
from . import prelude
from . import pass_manager
from . import parser
from . import debug
from . import param_dict
Expand Down Expand Up @@ -80,7 +79,9 @@
var = expr.var
const = expr.const
bind = expr.bind
create_pass = pass_manager.create_pass
create_module_pass = ir_pass.create_module_pass
create_function_pass = ir_pass.create_function_pass
create_sequential_pass = ir_pass.create_sequential_pass

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
Expand All @@ -94,8 +95,8 @@
load_param_dict = param_dict.load_param_dict

# Pass manager
PassContext = pass_manager.PassContext
Pass = pass_manager.Pass
ModulePass = pass_manager.ModulePass
FunctionPass = pass_manager.FunctionPass
SequentialPass = pass_manager.SequentialPass
PassContext = ir_pass.PassContext
Pass = ir_pass.Pass
ModulePass = ir_pass.ModulePass
FunctionPass = ir_pass.FunctionPass
SequentialPass = ir_pass.SequentialPass
Loading

0 comments on commit 061a490

Please sign in to comment.