From 061a49036f85b3179db5859f8ff40d1a1de2b921 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 2 Mar 2019 23:06:38 +0000 Subject: [PATCH] remove pass_manager.py create separate classes --- include/tvm/relay/pass.h | 55 +++-- python/tvm/relay/__init__.py | 15 +- python/tvm/relay/ir_pass.py | 277 +++++++++++++++++++++++- python/tvm/relay/pass_manager.py | 248 --------------------- src/relay/pass/pass_manager.cc | 23 +- tests/python/relay/test_pass_manager.py | 51 ++--- 6 files changed, 345 insertions(+), 324 deletions(-) delete mode 100644 python/tvm/relay/pass_manager.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8464ecc87ffab..b6f8e5166277c 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -5,17 +5,30 @@ * * This file also implements a pass manager. The pass manager manages a sequence * of Relay-to-Relay transformation passes over a particlar unit of AST. The - * design is largely inspired from LLVM's pass manager. + * design is largely inspired from LLVM's pass manager and modern deep learning + * frameworks that perform tensor->tensor transformations. * - * The responsibilities of a pass manager usually involves: + * The responsibilities of a traditional compiler pass manager usually involves: * - Organizing the execution order of optimization passes though not * necessarily in the optimal sequence. * - Collecting required analysis information and keep them up-to-date. * - Reducing the effort required to implement new passes for compiler * developers, etc. * - * TODO(jroesch, zhiics): We are currently using a very simple design for the - * pass manager, i.e. it executes a specific pass or sequence of passes. + * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * different granularity, i.e. module level, function level, and even sequential + * passe that contains a host of passes. + * + * However, we also extend the functionality of the traditional pass manager + * with the consideration of requirements from deep learning frameworks. Each + * pass in the Relay pass manager performs the Relay.Module -> Relay.Module + * transformation. All different type of passes including the sequential-level + * pass object are essentially a pass object. This design, therefore, + * effectively provides users a consistent and convenient interface, i.e. pass, + * to play with. It offers a means to eases the development and testing Relay + * passes. For example, with the pass manager, external users will be able to + * have custom passes correctly scheduled without having to modify a single + * handcrafted pass order. * * In the future we need to describe constraints between passes. For example, * we may want to preserve dependencies between different passes and validate @@ -26,7 +39,6 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include #include #include #include @@ -43,11 +55,6 @@ namespace relay { namespace pass { -// Forward declaration -class ModulePass; -class FunctionPass; -class SequentialPass; - // Define pass context. class PassContext; @@ -60,7 +67,7 @@ class PassContextNode : public RelayNode { /*! * \brief The error reporter used to notify users why an optimization fails. */ - ErrorReporter err_reporter_; + ErrorReporter err_reporter; PassContextNode() = default; @@ -73,17 +80,7 @@ class PassContextNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); }; -class PassContext : public NodeRef { - public: - PassContext() = default; - explicit PassContext(NodePtr p) : NodeRef(p) {} - - const PassContextNode* operator->() const { - return static_cast(this->node_.get()); - } - - using ContainerType = PassContextNode; -}; +TVM_DEFINE_NODE_REF(PassContext, PassContextNode) // We use currying here. It runs on a Relay node type NodeT and yields a new // node with the same type. The Relay module is captured for optimizations as @@ -167,8 +164,8 @@ class Pass : public NodeRef { * * \return The created module pass. */ -ModulePass CreateModulePass(const std::string& name, int opt_level, - const PassFunc& pass_func); +Pass CreateModulePass(const std::string& name, int opt_level, + const PassFunc& pass_func); /* * \brief Create a function pass. @@ -179,8 +176,8 @@ ModulePass CreateModulePass(const std::string& name, int opt_level, * * \return The created function pass. */ -FunctionPass CreateFunctionPass(const std::string& name, int opt_level, - const PassFunc& pass_func); +Pass CreateFunctionPass(const std::string& name, int opt_level, + const PassFunc& pass_func); /* * \brief Create a sequential pass. * @@ -192,9 +189,9 @@ FunctionPass CreateFunctionPass(const std::string& name, int opt_level, * * \return The created sequential pass. */ -SequentialPass CreateSequentialPass(const std::string& name, int opt_level, - const tvm::Array& passes, - const tvm::Array& disabled); +Pass CreateSequentialPass(const std::string& name, int opt_level, + const tvm::Array& passes, + const tvm::Array& disabled); } // namespace pass diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index f5cbc407b1dc8..8f805e2f117b2 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -11,7 +11,6 @@ from . import ir_pass from .build_module import build, build_config, create_executor, optimize from . import prelude -from . import pass_manager from . import parser from . import debug from . import param_dict @@ -80,7 +79,9 @@ var = expr.var const = expr.const bind = expr.bind -create_pass = pass_manager.create_pass +create_module_pass = ir_pass.create_module_pass +create_function_pass = ir_pass.create_function_pass +create_sequential_pass = ir_pass.create_sequential_pass # ExprFunctor ExprFunctor = expr_functor.ExprFunctor @@ -94,8 +95,8 @@ load_param_dict = param_dict.load_param_dict # Pass manager -PassContext = pass_manager.PassContext -Pass = pass_manager.Pass -ModulePass = pass_manager.ModulePass -FunctionPass = pass_manager.FunctionPass -SequentialPass = pass_manager.SequentialPass +PassContext = ir_pass.PassContext +Pass = ir_pass.Pass +ModulePass = ir_pass.ModulePass +FunctionPass = ir_pass.FunctionPass +SequentialPass = ir_pass.SequentialPass diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 02a6e8b5906e1..336099601038a 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,16 +1,287 @@ # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck -"""The set of passes for Relay. +""" +This file contains: +1. The set of passes for Relay, which exposes an interface for configuring the + passes and scripting them in Python. -Exposes an interface for configuring the passes and -scripting them in Python. +2. The pass manager for Relay which exposes different granularity of interfaces + for users to implement and use passes more conveniently. """ +import types + from . import _ir_pass from . import _make from .expr import Expr from .ty import Type +from .base import RelayNode, register_relay_node from .module import Module + +@register_relay_node +class PassContext(RelayNode): + """The basis where a Relay optimization/analysis runs on. + Each pass context contains a number of auxiliary information that is used + to help an optimization pass. Such information includes the error reporter + to record the errors of during the performing the optimization, etc. + """ + + def __init__(self): + self.__init_handle_by_constructor__(_ir_pass.PassContext) + + +@register_relay_node +class Pass(RelayNode): + """The base class of all passes. This class is designed as a pure virtual + class that will be implemented by the subclasses. + + Parameters + ---------- + name : str + The pass name. + + opt_level : int + The optimization level of this pass. + """ + + def set_pass_context(self, pass_ctx): + """Setup the pass context for analysis and optimizations. This context + could be shared by different passes for sequential passes. + + Parameters + ---------- + pass_ctx : PassContext + The context that is used to help perform a certain pass or a series + of passes. + + Returns + ------- + pass : Pass + The updated pass. + """ + if not isinstance(pass_ctx, PassContext): + raise TypeError("pass_ctx is expected to be the PassContext type") + return _ir_pass.SetContext(self, pass_ctx) + + def __call__(self, mod): + """Execute the pass. It is an abstract function that will be + implemented by subclasses. + + Parameters + ---------- + mod : tvm.relay.Module + The module that a certain optimization is performed on. + + Returns + ------- + mod : tvm.relay.Module + The updated module after applying this pass. + """ + raise NotImplementedError("Pure virtual function is not implemented.") + + +@register_relay_node +class ModulePass(Pass): + """A pass that works on tvm.relay.Module. + + Parameters + ---------- + name : str + The pass name. + + opt_level : int + The optimization level of this pass. + + pass_func : Callable[PassContext: tvm.relay.Module -> tvm.relay.Module] + The curried callback that sketches a certain optimization. + """ + + def __init__(self, name, opt_level, pass_func): + self.__init_handle_by_constructor__(_ir_pass.CreateModulePass, name, + opt_level, pass_func) + + def __call__(self, mod): + """Execute a module pass. + + Parameters + ---------- + mod : tvm.relay.Module + The module that the module pass is executed on. + + Returns + ------- + ret : tvm.relay.Module + The updated module. + """ + return _ir_pass.RunModulePass(self, mod) + + +@register_relay_node +class FunctionPass(Pass): + """A pass that works on each tvm.relay.Function in a module. + + Parameters + ---------- + name : str + The pass name. + + opt_level : int + The optimization level of this pass. + + pass_func : Callable[PassContext: tvm.relay.Function -> tvm.relay.Function] + The curried callback that sketches a certain optimization. + """ + + def __init__(self, name, opt_level, pass_func): + self.__init_handle_by_constructor__(_ir_pass.CreateFunctionPass, name, + opt_level, pass_func) + + def __call__(self, mod): + """Execute a function pass. + + Parameters + ---------- + mod : tvm.relay.Module + The module that the function pass is executed on. + + Returns + ------- + ret : tvm.relay.Module + The updated module. + """ + return _ir_pass.RunFunctionPass(self, mod) + + +@register_relay_node +class SequentialPass(Pass): + """A pass that works on each tvm.relay.Function in a module. + + Parameters + ---------- + name : str + The pass name. + + opt_level : int + The optimization level of this pass. + + passes : List[Pass] + The pass candidates to be executed. + + disabled : Optional[List[str]] + The list of passes that are disabled. + """ + + def __init__(self, name, opt_level, passes, disabled=None): + disabled = disabled if disabled else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled must be a list or tuple of pass names") + self.__init_handle_by_constructor__(_ir_pass.CreateSequentialPass, + name, opt_level, passes, disabled) + + def __call__(self, mod): + """Execute a sequence of passes. + + Parameters + ---------- + mod : tvm.relay.Module + The module that the function pass is executed on. + + Returns + ------- + ret : tvm.relay.Module + The updated module. + """ + return _ir_pass.RunSequentialPass(self, mod) + + +def create_module_pass(pass_name, opt_level, pass_func): + """Create a module pass using a defined optimization function from Python. + + Parameters + ---------- + pass_name : str + The name of the pass. + + opt_level : int + The optimization level of this pass. + + pass_func : Optional[Callable[PassContext: Module/Function -> + Module/Function]] + The implemented optimization pass. + + Returns + ------- + ret : Pass + A module level pass built through pass_func. + """ + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + return _ir_pass.CreateModulePass(pass_name, opt_level, pass_func) + + +def create_function_pass(pass_name, opt_level, pass_func): + """Create a function pass using a defined optimization function from + Python. + + Parameters + ---------- + pass_name : str + The name of the pass. + + opt_level : int + The optimization level of this pass. + + pass_func : Optional[Callable[PassContext: Module/Function -> + Module/Function]] + The implemented optimization pass. + + Returns + ------- + ret : Pass + A function level pass built through pass_func. + """ + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + return _ir_pass.CreateFunctionPass(pass_name, opt_level, pass_func) + + +def create_sequential_pass(pass_name, opt_level, sequential_passes, + disabled=None): + """Create a sequential pass using a defined optimization function from + Python. + + Parameters + ---------- + pass_name : str + The name of the pass. + + opt_level : int + The optimization level of this pass. + + sequential_passes : Optional[List[Pass]] + A sequence of passes candidate for optimization. + + disabled : Optional[List[str]] + A list of disabled passes. + + Returns + ------- + ret : Pass + A sequential pass built through pass_func. + """ + if not isinstance(sequential_passes, (list, tuple)): + raise TypeError("sequential_passes must be a list of Pass objects.") + + disabled = disabled if disabled else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled must be a list or tuple of pass names") + + return _ir_pass.CreateSequentialPass(pass_name, opt_level, + sequential_passes, disabled) + + def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited diff --git a/python/tvm/relay/pass_manager.py b/python/tvm/relay/pass_manager.py deleted file mode 100644 index d0fa3e3937ff7..0000000000000 --- a/python/tvm/relay/pass_manager.py +++ /dev/null @@ -1,248 +0,0 @@ -# pylint: disable=no-else-return -# pylint: disable=unidiomatic-typecheck -"""The pass manager for Relay. - -This file exposes differen granularity of interfaces for users to implement and -use passes more conveniently. -""" -from enum import IntEnum - -from . import _ir_pass -from .base import RelayNode, register_relay_node - - -class PassKind(IntEnum): - """The different granularity of passes for optimization/analysis.""" - ModuleKind = 1 - FunctionKind = 2 - SequentialKind = 3 - - -@register_relay_node -class PassContext(RelayNode): - """The basis where a Relay optimization/analysis runs on. - Each pass context contains a number of auxiliary information that is used - to help an optimization pass. Such information includes the error reporter - to record the errors of during the performing the optimization, etc. - """ - - def __init__(self): - self.__init_handle_by_constructor__(_ir_pass.PassContext) - - -@register_relay_node -class Pass(RelayNode): - """The base class of all passes. This class is designed as a pure virtual - class that will be implemented by the subclasses. - - Parameters - ---------- - name : str - The pass name. - - opt_level : int - The optimization level of this pass. - """ - - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain pass or a series - of passes. - - Returns - ------- - pass : Pass - The updated pass. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - return _ir_pass.SetContext(self, pass_ctx) - - def __call__(self, mod): - """Execute the pass. It is an abstract function that will be - implemented by subclasses. - - Parameters - ---------- - mod : tvm.relay.Module - The module that a certain optimization is performed on. - - Returns - ------- - mod : tvm.relay.Module - The updated module after applying this pass. - """ - raise NotImplementedError("Pure virtual function is not implemented.") - - -@register_relay_node -class ModulePass(Pass): - """A pass that works on tvm.relay.Module. - - Parameters - ---------- - name : str - The pass name. - - opt_level : int - The optimization level of this pass. - - pass_func : Callable[PassContext: tvm.relay.Module -> tvm.relay.Module] - The curried callback that sketches a certain optimization. - """ - - def __init__(self, name, opt_level, pass_func): - self.__init_handle_by_constructor__(_ir_pass.CreateModulePass, name, - opt_level, pass_func) - - def __call__(self, mod): - """Execute a module pass. - - Parameters - ---------- - mod : tvm.relay.Module - The module that the module pass is executed on. - - Returns - ------- - ret : tvm.relay.Module - The updated module. - """ - return _ir_pass.RunModulePass(self, mod) - - -@register_relay_node -class FunctionPass(Pass): - """A pass that works on each tvm.relay.Function in a module. - - Parameters - ---------- - name : str - The pass name. - - opt_level : int - The optimization level of this pass. - - pass_func : Callable[PassContext: tvm.relay.Function -> tvm.relay.Function] - The curried callback that sketches a certain optimization. - """ - - def __init__(self, name, opt_level, pass_func): - self.__init_handle_by_constructor__(_ir_pass.CreateFunctionPass, name, - opt_level, pass_func) - - def __call__(self, mod): - """Execute a function pass. - - Parameters - ---------- - mod : tvm.relay.Module - The module that the function pass is executed on. - - Returns - ------- - ret : tvm.relay.Module - The updated module. - """ - return _ir_pass.RunFunctionPass(self, mod) - - -@register_relay_node -class SequentialPass(Pass): - """A pass that works on each tvm.relay.Function in a module. - - Parameters - ---------- - name : str - The pass name. - - opt_level : int - The optimization level of this pass. - - passes : List[Pass] - The pass candidates to be executed. - - disabled : Optional[List[str]] - The list of passes that are disabled. - """ - - def __init__(self, name, opt_level, passes, disabled=None): - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - self.__init_handle_by_constructor__(_ir_pass.CreateSequentialPass, - name, opt_level, passes, disabled) - - def __call__(self, mod): - """Execute a sequence of passes. - - Parameters - ---------- - mod : tvm.relay.Module - The module that the function pass is executed on. - - Returns - ------- - ret : tvm.relay.Module - The updated module. - """ - return _ir_pass.RunSequentialPass(self, mod) - - -def create_pass(pass_name, opt_level, - pass_kind=PassKind.FunctionKind, - pass_func=None, sequential_passes=None, disabled=None): - """Create a pass using a defined optimization function from Python. - - Parameters - ---------- - pass_name : str - The name of the pass. - - opt_level : int - The optimization level of this pass. - - pass_kind : Optional[PassKind] - The type of pass for optimization/analysis. - - pass_func : Optional[Callable[PassContext: Module/Function/Expr -> - Module/Function/Expr]] - The implemented optimization pass. - - sequential_passes : Optional[List[Pass]] - A sequence of passes candidate for optimization. - - disabled : Optional[List[str]] - A list of disabled passes. - - Returns - ------- - ret : Pass - The pass built through pass_func. - """ - if not isinstance(pass_kind, PassKind): - raise TypeError("pass_kind is expected to be the type of PassKind.") - - if pass_kind == PassKind.ModuleKind: - if not pass_func: - raise TypeError("pass_func must be defined for Module pass") - return _ir_pass.CreateModulePass(pass_name, opt_level, pass_func) - elif pass_kind == PassKind.FunctionKind: - if not pass_func: - raise TypeError("pass_func must be defined for Function pass") - return _ir_pass.CreateFunctionPass(pass_name, opt_level, pass_func) - else: - if not isinstance(sequential_passes, (list, tuple)): - raise TypeError( - "sequential_passes must be a list of Pass objects.") - - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - return _ir_pass.CreateSequentialPass(pass_name, opt_level, - sequential_passes, disabled) diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 4a49c5b01bd02..bc0e26ff67cf3 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -12,6 +12,8 @@ namespace pass { using tvm::IRPrinter; +class ModulePass; + /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes @@ -68,6 +70,8 @@ class ModulePassNode : public PassNode { RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass); +class FunctionPass; + /*! * \brief Function-level passes are used to implement various global * optimizations for a given Relay module. It fetches one function at a time @@ -136,6 +140,8 @@ class FunctionPassNode : public PassNode { RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); +class SequentialPass; + /*! * \brief The SequentialPassNode contains a set of passes that transform Relay * programs from one AST to another semantically equivalent one. @@ -324,6 +330,9 @@ SequentialPass SequentialPassNode::make(std::string name, int opt_level, return SequentialPass(n); } +// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in +// a SequentialPass without the consideration of their orders. The phase +// ordering problem needed to be handled in the future. Module SequentialPassNode::operator()(const Module& module) const { Module mod = module; for (const Pass& pass : passes) { @@ -361,19 +370,19 @@ void SequentialPassNode::SetContext(const PassContext& pass_ctx) { pass_ctx_ = pass_ctx; } -ModulePass CreateModulePass(const std::string& name, int opt_level, - const PassFunc& pass_func) { +Pass CreateModulePass(const std::string& name, int opt_level, + const PassFunc& pass_func) { return ModulePassNode::make(name, opt_level, pass_func); } -FunctionPass CreateFunctionPass(const std::string& name, int opt_level, - const PassFunc& pass_func) { +Pass CreateFunctionPass(const std::string& name, int opt_level, + const PassFunc& pass_func) { return FunctionPassNode::make(name, opt_level, pass_func); } -SequentialPass CreateSequentialPass(const std::string& name, int opt_level, - const tvm::Array& passes, - const tvm::Array& disabled) { +Pass CreateSequentialPass(const std::string& name, int opt_level, + const tvm::Array& passes, + const tvm::Array& disabled) { return SequentialPassNode::make(name, opt_level, passes, disabled); } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 706a19e1002c5..eade9901cbdda 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -5,8 +5,7 @@ from tvm import relay from tvm.relay import ExprFunctor from tvm.relay import Function, Call -from tvm.relay.ir_pass import infer_type, graph_equal -from tvm.relay import pass_manager +from tvm.relay import ir_pass from tvm.relay.testing import ctx_list @@ -103,9 +102,9 @@ def _transform(m): def check_func(func, ref_func): - func = infer_type(func) - ref_func = infer_type(ref_func) - assert graph_equal(func, ref_func) + func = ir_pass.infer_type(func) + ref_func = ir_pass.infer_type(ref_func) + assert ir_pass.graph_equal(func, ref_func) def test_module_pass(): @@ -120,18 +119,16 @@ def test_module_pass(): pass_name = "module_pass_test" opt_level = 0 - pass_kind = pass_manager.PassKind.ModuleKind pass_func = pass_function def test_pass_registration(): - mod_pass = pass_manager.create_pass(pass_name, opt_level, pass_kind, - pass_func=pass_func) - assert isinstance(mod_pass, pass_manager.ModulePass) + mod_pass = ir_pass.create_module_pass(pass_name, opt_level, pass_func) + assert isinstance(mod_pass, ir_pass.ModulePass) assert mod_pass.name == pass_name assert mod_pass.opt_level == opt_level def test_pass_run(): - module_pass = pass_manager.ModulePass(pass_name, opt_level, pass_func) + module_pass = ir_pass.ModulePass(pass_name, opt_level, pass_func) assert pass_name in module_pass.astext() updated_mod = module_pass(mod) @@ -182,7 +179,6 @@ def test_function_pass(): pass_name = "function_pass_test" opt_level = 1 - pass_kind = pass_manager.PassKind.FunctionKind pass_func = pass_function def get_ref_log(): @@ -190,14 +186,14 @@ def get_ref_log(): return ref_log def test_pass_registration(): - function_pass = pass_manager.create_pass(pass_name, opt_level, - pass_func=pass_func) - assert isinstance(function_pass, pass_manager.FunctionPass) + function_pass = ir_pass.create_function_pass(pass_name, opt_level, + pass_func) + assert isinstance(function_pass, ir_pass.FunctionPass) assert function_pass.name == pass_name assert function_pass.opt_level == opt_level def test_pass_run(): - function_pass = pass_manager.FunctionPass(pass_name, opt_level, pass_func) + function_pass = ir_pass.FunctionPass(pass_name, opt_level, pass_func) assert pass_name in function_pass.astext() updated_mod = function_pass(mod) @@ -261,37 +257,33 @@ def get_ref_abs(): # Register a module pass. module_pass_func = pass_function - module_pass = pass_manager.ModulePass("module_pass", 1, module_pass_func) + module_pass = ir_pass.ModulePass("module_pass", 1, module_pass_func) # Register a function pass. function_pass_func = pass_function - function_pass = pass_manager.FunctionPass("function_pass", 2, - function_pass_func) + function_pass = ir_pass.FunctionPass("function_pass", 2, + function_pass_func) def test_pass_registration(): - pass_kind = pass_manager.PassKind.SequentialKind passes = [module_pass, function_pass] pass_name = "sequential_pass" opt_level = 2 - sequential_pass = pass_manager.create_pass(pass_name, opt_level, - pass_kind, - sequential_passes=passes) - assert isinstance(sequential_pass, pass_manager.SequentialPass) + sequential_pass = ir_pass.create_sequential_pass(pass_name, opt_level, + passes) + assert isinstance(sequential_pass, ir_pass.SequentialPass) assert sequential_pass.name == pass_name assert sequential_pass.opt_level == opt_level def test_no_pass(): passes = [] - sequential_pass = pass_manager.SequentialPass("sequential_pass", 1, - passes) + sequential_pass = ir_pass.SequentialPass("sequential_pass", 1, passes) ret_mod = sequential_pass(mod) mod_func = ret_mod[v_sub] check_func(sub, mod_func) def test_only_module_pass(): passes = [module_pass] - sequential_pass = pass_manager.SequentialPass("sequential_pass", 1, - passes) + sequential_pass = ir_pass.SequentialPass("sequential_pass", 1, passes) ret_mod = sequential_pass(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) @@ -305,8 +297,7 @@ def test_only_module_pass(): def test_only_function_pass(): # Check the subtract function. passes = [function_pass] - sequential_pass = pass_manager.SequentialPass("sequential_pass", 2, - passes) + sequential_pass = ir_pass.SequentialPass("sequential_pass", 2, passes) ret_mod = sequential_pass(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) @@ -320,7 +311,7 @@ def test_multiple_passes(): # function pass. mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] - sequential_pass = pass_manager.SequentialPass("sequential_pass", 2, passes) + sequential_pass = ir_pass.SequentialPass("sequential_pass", 2, passes) ret_mod = sequential_pass(mod) # Check the abs function is added.