diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 31067925fa63..c84e3f952de4 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -20,46 +20,12 @@ /*! * \file tvm/relay/pass.h * \brief The set of Relay passes written in C++. - * - * This file also implements a pass manager. The pass manager manages a sequence - * of Relay-to-Relay transformation passes over a particlar unit of AST. The - * design is largely inspired from LLVM's pass manager and modern deep learning - * frameworks that perform tensor->tensor transformations. - * - * The responsibilities of a traditional compiler pass manager usually involves: - * - Organizing the execution order of optimization passes though not - * necessarily in the optimal sequence. - * - Collecting required analysis information and keep them up-to-date. - * - Reducing the effort required to implement new passes for compiler - * developers, etc. - * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work - * different granularity, i.e. module level, function level, and even sequential - * passe that contains a host of passes. - * - * However, we also extend the functionality of the traditional pass manager - * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass - * manager performs the Relay.Module -> Relay.Module transformation. All - * different types of passes, including the sequential-level pass object, are - * essentially pass objects. This design, therefore, effectively provides users - * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with - * the pass manager, external users will be able to have custom passes correctly - * scheduled without having to modify a single handcrafted pass order. - * - * In the future we need to describe constraints between passes. For example, - * we may want to preserve dependencies between different passes and validate - * them on the completion of a certain pass. - * - * We also need to store side information and import the error reporting system. - */ + */ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ #include #include -#include #include #include #include @@ -72,174 +38,6 @@ namespace tvm { namespace relay { -namespace pass { - -/* - * \brief The context of pass. - */ -class PassContext; - -/*! - * \brief PassContextNode contains the information that a pass can rely on, such as - * analysis results. - */ -class PassContextNode : public RelayNode { - public: - /*! - * \brief The error reporter used to notify users why an optimization fails. - */ - ErrorReporter err_reporter; - - PassContextNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) final { - } - - TVM_DLL static PassContext make(); - - static constexpr const char* _type_key = "relay.PassContext"; - TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); -}; - -TVM_DEFINE_NODE_REF(PassContext, PassContextNode) - -/* - * \brief The meta data of a pass. - * - * PassInfo can be extended conveniently in the future if more meta information - * is needed. - */ -class PassInfo; - -/*! - * \brief PassInfoNode contains meta data that will be used to help optimization - * and analysis. - */ -class PassInfoNode : public RelayNode { - public: - /*! \brief The minimal optimization level that this pass will be enabled. */ - int opt_level; - - /*! \brief The name of an optimization/analysis pass. */ - std::string name; - - /*! \brief The passes that are required to perform the current pass. */ - tvm::Array required; - - PassInfoNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("opt_level", &opt_level); - v->Visit("name", &name); - v->Visit("required", &required); - } - - TVM_DLL static PassInfo make(int opt_level, std::string name, - tvm::Array required); - - static constexpr const char* _type_key = "relay.PassInfo"; - TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); -}; - -TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) - -class Pass; - -/*! - * \brief PassNode is the base type of differnt types of optimization passes. - * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. - */ -class PassNode : public RelayNode { - public: - /* - * \brief Get the pass information/meta data. */ - virtual PassInfo Info() const = 0; - - /*! - * \brief Set the context information for a pass. - * - * \param pass_ctx The context information for a certain pass. - */ - virtual void SetContext(const PassContext& pass_ctx) = 0; - - /*! - * \brief Execute the optimization pass using a functor. - * - * \param mod The module that an optimization pass runs on. - * - * \return The updated module. - */ - virtual Module operator()(const Module& mod) const = 0; - - void VisitAttrs(tvm::AttrVisitor* v) override {} - - static constexpr const char* _type_key = "relay.Pass"; - TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); -}; - -class Pass : public NodeRef { - public: - Pass() = default; - explicit Pass(NodePtr p) : NodeRef(p) {} - - PassNode* operator->() const { - return static_cast(this->node_.get()); - } - - using ContainerType = PassNode; -}; - -/* - * \brief Create a module pass. - * - * \param pass_func The packed function that contains the optimization. - * \param opt_level The optimization level of the module pass. - * \param name The name of the module pass. - * \param required The list of the passes that the module pass is dependent on. - * - * \return The created module pass. - */ -Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); - -/* - * \brief Create a function pass. - * - * \param pass_func The packed function that contains the optimization. - * \param opt_level The optimization level of the function pass. - * \param name The name of the function pass. - * \param required The list of the passes that the function pass is dependent on. - * - * \return The created function pass. - */ -Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); -/* - * \brief Create a sequential pass. - * - * \param passes The optimization passes will be performed. - * \param opt_level The optimization level of the sequential pass. - * \param name The name of the sequential pass. - * \param required The list of the passes that the sequential pass is dependent on. - * \param disabled The disabled passes. - * - * \return The created sequential pass. - */ -Pass CreateSequentialPass(const tvm::Array& passes, - int opt_level, - const std::string& name, - const tvm::Array& required, - const tvm::Array& disabled); - -} // namespace pass - /*! * \brief Infer the type of an expression. * diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h new file mode 100644 index 000000000000..ba25483dfbb2 --- /dev/null +++ b/include/tvm/relay/transform.h @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/transform.h + * + * This file implements a pass manager. The pass manager manages a sequence + * of Relay-to-Relay transformation passes over a particlar unit of AST. The + * design is largely inspired from LLVM's pass manager and modern deep learning + * frameworks that perform tensor->tensor transformations. + * + * The responsibilities of a traditional compiler pass manager usually involves: + * - Organizing the execution order of optimization passes though not + * necessarily in the optimal sequence. + * - Collecting required analysis information and keep them up-to-date. + * - Reducing the effort required to implement new passes for compiler + * developers, etc. + * + * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * different granularity, i.e. module level, function level, and even sequential + * passe that contains a host of passes. + * + * However, we also extend the functionality of the traditional pass manager + * with the consideration of requirements/convention from deep learning + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * manager performs the Relay.Module -> Relay.Module transformation. All + * different types of passes, including the sequential-level pass object, are + * essentially pass objects. This design, therefore, effectively provides users + * a consistent and convenient interface, i.e. Pass, to play with. It offers a + * means to ease the development and testing of Relay passes. For example, with + * the pass manager, external users will be able to have custom passes correctly + * scheduled without having to modify a single handcrafted pass order. + * + * In the future we need to describe constraints between passes. For example, + * we may want to preserve dependencies between different passes and validate + * them on the completion of a certain pass. + * + * We also need to store side information and import the error reporting system. + */ +#ifndef TVM_RELAY_TRANSFORM_H_ +#define TVM_RELAY_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +/* + * \brief The context of pass. + */ +class PassContext; + +/*! + * \brief PassContextNode contains the information that a pass can rely on, such as + * analysis results. + */ +class PassContextNode : public RelayNode { + public: + /*! + * \brief The error reporter used to notify users why an optimization fails. + */ + ErrorReporter err_reporter; + + PassContextNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) final { + } + + TVM_DLL static PassContext make(); + + static constexpr const char* _type_key = "relay.PassContext"; + TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); +}; + +TVM_DEFINE_NODE_REF(PassContext, PassContextNode) + +/* + * \brief The meta data of a pass. + * + * PassInfo can be extended conveniently in the future if more meta information + * is needed. + */ +class PassInfo; + +/*! + * \brief PassInfoNode contains meta data that will be used to help optimization + * and analysis. + */ +class PassInfoNode : public RelayNode { + public: + /*! \brief The minimal optimization level that this pass will be enabled. */ + int opt_level; + + /*! \brief The name of an optimization/analysis pass. */ + std::string name; + + /*! \brief The passes that are required to perform the current pass. */ + tvm::Array required; + + PassInfoNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("opt_level", &opt_level); + v->Visit("name", &name); + v->Visit("required", &required); + } + + TVM_DLL static PassInfo make(int opt_level, std::string name, + tvm::Array required); + + static constexpr const char* _type_key = "relay.PassInfo"; + TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); +}; + +TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) + +class Pass; + +/*! + * \brief PassNode is the base type of differnt types of optimization passes. + * It is designed as a pure class and implemented by different pass subclasses + * at different granularity of Relay nodes. + */ +class PassNode : public RelayNode { + public: + /* + * \brief Get the pass information/meta data. */ + virtual PassInfo Info() const = 0; + + /*! + * \brief Set the context information for a pass. + * + * \param pass_ctx The context information for a certain pass. + */ + virtual void SetContext(const PassContext& pass_ctx) = 0; + + /*! + * \brief Execute the optimization pass using a functor. + * + * \param mod The module that an optimization pass runs on. + * + * \return The updated module. + */ + virtual Module operator()(const Module& mod) const = 0; + + void VisitAttrs(tvm::AttrVisitor* v) override {} + + static constexpr const char* _type_key = "relay.Pass"; + TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); +}; + +class Pass : public NodeRef { + public: + Pass() = default; + explicit Pass(NodePtr p) : NodeRef(p) {} + + PassNode* operator->() const { + return static_cast(this->node_.get()); + } + + using ContainerType = PassNode; +}; + +class SequentialNode; + +class Sequential : public Pass { + public: + /*! + * \brief The constructor of `Sequential`. + * \param passes The passes to apply. + * \param pass_info The pass metadata. + * \param disabled The passes that will not be applied. + */ + TVM_DLL Sequential(tvm::Array passes, + PassInfo pass_info, + tvm::Array disabled); + Sequential() = default; + explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} + + const SequentialNode* operator->() const; + using ContainerType = Sequential; +}; + + +/* + * \brief Create a module pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the module pass. + * \param name The name of the module pass. + * \param required The list of the passes that the module pass is dependent on. + * + * \return The created module pass. + */ +Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +/* + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORM_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 1f1e4a683ead..d832c8988795 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -25,6 +25,7 @@ from . import module from . import adt from . import ir_pass +from . import transform from .build_module import build, build_config, create_executor from . import prelude from . import parser @@ -97,9 +98,8 @@ var = expr.var const = expr.const bind = expr.bind -module_pass = ir_pass.module_pass -function_pass = ir_pass.function_pass -sequential_pass = ir_pass.sequential_pass +module_pass = transform.module_pass +function_pass = transform.function_pass # ExprFunctor ExprFunctor = expr_functor.ExprFunctor @@ -114,9 +114,9 @@ load_param_dict = param_dict.load_param_dict # Pass manager -PassInfo = ir_pass.PassInfo -PassContext = ir_pass.PassContext -Pass = ir_pass.Pass -ModulePass = ir_pass.ModulePass -FunctionPass = ir_pass.FunctionPass -SequentialPass = ir_pass.SequentialPass +PassInfo = transform.PassInfo +PassContext = transform.PassContext +Pass = transform.Pass +ModulePass = transform.ModulePass +FunctionPass = transform.FunctionPass +Sequential = transform.Sequential diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi index 6aedb5248657..13035bb36f71 100644 --- a/python/tvm/relay/_ir_pass.pyi +++ b/python/tvm/relay/_ir_pass.pyi @@ -17,62 +17,8 @@ import tvm from . import ir -from .base import NodeBase from .env import Module - -class PassContext(NodeBase): - def __init__(self): - ... - -class PassInfo(NodeBase): - name = ... # type: str - opt_level = ... # type: int - required = ... # type: list - - def __init__(self, name, opt_level, required) - # type: (str, int, list) -> None - - -class Pass(NodeBase): - def __init__(self): - ... - - -class ModulePass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class FunctionPass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class SequentialPass(Pass): - name = ... # type: str - opt_level = ... # type: int - passes = ... # type: list - required = ... # type: list - disabled = ... # type: list - - def __init__(self, name, opt_level, passes, required, disabled): - # type: (str, int, list, list, list) -> None - ... - - def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ... diff --git a/python/tvm/relay/_transform.py b/python/tvm/relay/_transform.py new file mode 100644 index 000000000000..273d97e0962a --- /dev/null +++ b/python/tvm/relay/_transform.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI exposing the Relay type inference and checking.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._transform", __name__) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 5f23e14d5559..ea34c6b1958b 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -17,324 +17,16 @@ # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck """ -This file contains: -1. The set of passes for Relay, which exposes an interface for configuring the - passes and scripting them in Python. - -2. The pass manager for Relay which exposes different granularity of interfaces - for users to implement and use passes more conveniently. +This file contains the set of passes for Relay, which exposes an interface for +configuring the passes and scripting them in Python. """ -import types - from . import _ir_pass from . import _make from .expr import Expr from .ty import Type -from .base import RelayNode, register_relay_node from .module import Module -@register_relay_node -class PassInfo(RelayNode): - """The class that contains the meta data required by a pass. It is the - container of information needed by running an optimization or analysis. - This class can be extended by adding new members when more meta data is - needed. - - Parameters - ---------- - name : str - The pass name. - - opt_level : int - The optimization level of this pass. - - required : List[str] - The list of passes that are required by a certain pass. - """ - - def __init__(self, name, opt_level, required=None): - self.__init_handle_by_constructor__(_ir_pass.PassInfo, name, opt_level, - required) - - -@register_relay_node -class PassContext(RelayNode): - """The basis where a Relay optimization/analysis runs on. - Each pass context contains a number of auxiliary information that is used - to help an optimization pass. Such information includes the error reporter - to record the errors of during the optimization, etc. - """ - - def __init__(self): - self.__init_handle_by_constructor__(_ir_pass.PassContext) - - -@register_relay_node -class Pass(RelayNode): - """The base class of all passes. All methods here are just simple wrappers - that are implemented in the backend. They are defined for users to - conveniently interact with the base class. - """ - - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain pass or a series - of passes. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _ir_pass.SetContext(self, pass_ctx) - - @property - def info(self): - """Get the pass meta.""" - return _ir_pass.Info(self) - - def __call__(self, mod): - """Execute the pass. Note that for sequential pass, the dependency among - different passes will be resolved in the backend. - - Parameters - ---------- - mod : tvm.relay.Module - The module that a certain optimization is performed on. - - Returns - ------- - mod : tvm.relay.Module - The updated module after applying this pass. - """ - return _ir_pass.RunPass(self, mod) - - -@register_relay_node -class ModulePass(Pass): - """A pass that works on tvm.relay.Module. Users don't need to interact with - this class directly. Instead, a module pass should be created through - `module_pass`, because the design of the `module_pass` API is flexible - enough to handle the creation of a module pass in different manners. In - addition, all members of a module pass can be accessed from the base class. - The same rule applies to FunctionPass and SequentialPass as well. - """ - - -@register_relay_node -class FunctionPass(Pass): - """A pass that works on each tvm.relay.Function in a module. A function - pass class should be created through `function_pass`. - """ - - -@register_relay_node -class SequentialPass(Pass): - """A pass that works on a sequence of pass objects. A sequential pass class - should be created through `sequential_pass`. - """ - - -def module_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a module pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created module level pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the module pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_module_pass : Union[Callable, ModulePass] - The callable that will create a module pass is returned when - pass_func is not passed in. Otherwise, a ModulePass object will be - directly created. - - Examples - -------- - The following code creates a module level pass and adds an abs function to - the module. - - .. code-block:: python - - @relay.ir_pass.module_pass(opt_level=2) - def transform(mod, ctx): - tp = relay.TensorType((10,), "float32") - x = relay.var("x", tp) - gv = relay.GlobalVar("var") - func = relay.Function([x], relay.abs(x)) - new_mod = relay.Module({gv: func}) - new_mod.update(mod) - return new_mod - - module_pass = transform - assert isinstance(module_pass, ir_pass.ModulePass) - assert module_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = module_pass(m) - # Now a function abs should be added to the module m. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the module pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_module_pass(pass_func): - """Internal function that creates a module pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _ir_pass.CreateModulePass(pass_func, opt_level, - name if name else pass_func.__name__, - required) - - if pass_func: - return create_module_pass(pass_func) - return create_module_pass - - -def function_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a function pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created function pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the function pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_function_pass : Union[Callable, FunctionPass] - The callable that will create a function pass is returned when - pass_func is not passed in. Otherwise, a FunctionPass object will be - created. - - Examples - -------- - The following code creates a function level pass that performs constant - folding. - - .. code-block:: python - - @relay.ir_pass.function_pass(opt_level=2) - def transform(func, ctx): - return ir_pass.fold_constant(func) - - function_pass = transform - assert isinstance(function_pass, ir_pass.FunctionPass) - assert function_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = function_pass(m) - # Now constant folding should have been applied to every function in - # the provided module m. And the updated module will be returned. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the funtion pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_function_pass(pass_func): - """Internal function that creates a function pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _ir_pass.CreateFunctionPass(pass_func, opt_level, - name if name else pass_func.__name__, - required) - - if pass_func: - return create_function_pass(pass_func) - return create_function_pass - - -def sequential_pass(passes=None, opt_level=2, name="sequential_pass", - required=None, disabled=None): - """Create a sequential pass using a defined optimization function from - Python. Some typical usage of the sequential pass are: - 1. Users provide a list of passes for optimization. - 2. Only an optimization level is provided so that the backend system has - to glob all passes at this level and below to perform the optimizations. - Note that users can also provide a series of passes that they don't want to - apply when running a sequential pass. Pass dependency will be resolved in - the backend as well. - - Parameters - ---------- - passes : Optional[List[Pass]] - A sequence of passes candidate for optimization. - - opt_level : Optional[int] - The optimization level of this sequential pass. - - name : Optional[str] - The name of the sequential pass. - - required : Optional[List[str]] - The list of passes that the sequential pass is dependent on. - - disabled : Optional[List[str]] - A list of disabled passes. - - Returns - ------- - ret : Pass - A sequential pass built through pass_func. - """ - - passes = passes if passes else [] - if not isinstance(passes, (list, tuple)): - raise TypeError("passes must be a list of Pass objects.") - - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of list/tuple.") - - return _ir_pass.CreateSequentialPass(passes, opt_level, name, required, - disabled) - - def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py new file mode 100644 index 000000000000..877538afea34 --- /dev/null +++ b/python/tvm/relay/transform.py @@ -0,0 +1,325 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the pass manager for Relay which exposes different +granularity of interfaces for users to implement and use passes more +conveniently. +""" +import types + +from . import _transform +from .base import RelayNode, register_relay_node + + +@register_relay_node +class PassInfo(RelayNode): + """The class that contains the meta data required by a pass. It is the + container of information needed by running an optimization or analysis. + This class can be extended by adding new members when more meta data is + needed. + + Parameters + ---------- + name : str + The pass name. + + opt_level : int + The optimization level of this pass. + + required : List[str] + The list of passes that are required by a certain pass. + """ + + def __init__(self, name, opt_level, required=None): + self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level, + required) + + +@register_relay_node +class PassContext(RelayNode): + """The basis where a Relay optimization/analysis runs on. + Each pass context contains a number of auxiliary information that is used + to help an optimization pass. Such information includes the error reporter + to record the errors of during the optimization, etc. + """ + + def __init__(self): + self.__init_handle_by_constructor__(_transform.PassContext) + + +@register_relay_node +class Pass(RelayNode): + """The base class of all passes. All methods here are just simple wrappers + that are implemented in the backend. They are defined for users to + conveniently interact with the base class. + """ + + def set_pass_context(self, pass_ctx): + """Setup the pass context for analysis and optimizations. This context + could be shared by different passes for sequential passes. + + Parameters + ---------- + pass_ctx : PassContext + The context that is used to help perform a certain pass or a series + of passes. + """ + if not isinstance(pass_ctx, PassContext): + raise TypeError("pass_ctx is expected to be the PassContext type") + _transform.SetContext(self, pass_ctx) + + @property + def info(self): + """Get the pass meta.""" + return _transform.Info(self) + + def __call__(self, mod): + """Execute the pass. Note that for sequential pass, the dependency among + different passes will be resolved in the backend. + + Parameters + ---------- + mod : tvm.relay.Module + The module that a certain optimization is performed on. + + Returns + ------- + mod : tvm.relay.Module + The updated module after applying this pass. + """ + return _transform.RunPass(self, mod) + + +@register_relay_node +class ModulePass(Pass): + """A pass that works on tvm.relay.Module. Users don't need to interact with + this class directly. Instead, a module pass should be created through + `module_pass`, because the design of the `module_pass` API is flexible + enough to handle the creation of a module pass in different manners. In + addition, all members of a module pass can be accessed from the base class. + The same rule applies to FunctionPass and Sequential as well. + """ + + +@register_relay_node +class FunctionPass(Pass): + """A pass that works on each tvm.relay.Function in a module. A function + pass class should be created through `function_pass`. + """ + + +@register_relay_node +class Sequential(Pass): + """A pass that works on a sequence of pass objects. Multiple passes can be + executed sequentially using this class. + + Some typical usage of the sequential pass are: + 1. Users provide a list of passes for optimization. + 2. Only an optimization level is provided so that the backend system has + to glob all passes at this level and below to perform the optimizations. + Note that users can also provide a series of passes that they don't want to + apply when running a sequential pass. Pass dependency will be resolved in + the backend as well. + + Parameters + ---------- + passes : Optional[List[Pass]] + A sequence of passes candidate for optimization. + + opt_level : Optional[int] + The optimization level of this sequential pass. + + name : Optional[str] + The name of the sequential pass. + + required : Optional[List[str]] + The list of passes that the sequential pass is dependent on. + + disabled : Optional[List[str]] + A list of disabled passes. + """ + + def __init__(self, + passes=None, + opt_level=2, + name="sequential", + required=None, + disabled=None): + passes = passes if passes else [] + if not isinstance(passes, (list, tuple)): + raise TypeError("passes must be a list of Pass objects.") + + disabled = disabled if disabled else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled must be a list or tuple of pass names") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of list/tuple.") + + self.__init_handle_by_constructor__(_transform.Sequential, + passes, opt_level, name, required, + disabled) + + +def module_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a module pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created module level pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the module pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_module_pass : Union[Callable, ModulePass] + The callable that will create a module pass is returned when + pass_func is not passed in. Otherwise, a ModulePass object will be + directly created. + + Examples + -------- + The following code creates a module level pass and adds an abs function to + the module. + + .. code-block:: python + + @relay.transform.module_pass(opt_level=2) + def transform(mod, ctx): + tp = relay.TensorType((10,), "float32") + x = relay.var("x", tp) + gv = relay.GlobalVar("var") + func = relay.Function([x], relay.abs(x)) + new_mod = relay.Module({gv: func}) + new_mod.update(mod) + return new_mod + + module_pass = transform + assert isinstance(module_pass, transform.ModulePass) + assert module_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = module_pass(m) + # Now a function abs should be added to the module m. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the module pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_module_pass(pass_func): + """Internal function that creates a module pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + return _transform.CreateModulePass( + pass_func, opt_level, name if name else pass_func.__name__, + required) + + if pass_func: + return create_module_pass(pass_func) + return create_module_pass + + +def function_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a function pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + The callable that will create a function pass is returned when + pass_func is not passed in. Otherwise, a FunctionPass object will be + created. + + Examples + -------- + The following code creates a function level pass that performs constant + folding. + + .. code-block:: python + + @relay.transform.function_pass(opt_level=2) + def transform(func, ctx): + return ir_pass.fold_constant(func) + + function_pass = transform + assert isinstance(function_pass, transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now constant folding should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the funtion pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_function_pass(pass_func): + """Internal function that creates a function pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + return _transform.CreateFunctionPass( + pass_func, opt_level, name if name else pass_func.__name__, + required) + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass diff --git a/python/tvm/relay/transform.pyi b/python/tvm/relay/transform.pyi new file mode 100644 index 000000000000..343e89976b09 --- /dev/null +++ b/python/tvm/relay/transform.pyi @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from .base import NodeBase + + +class PassContext(NodeBase): + def __init__(self): + ... + +class PassInfo(NodeBase): + name = ... # type: str + opt_level = ... # type: int + required = ... # type: list + + def __init__(self, name, opt_level, required) + # type: (str, int, list) -> None + + +class Pass(NodeBase): + def __init__(self): + ... + + +class ModulePass(Pass): + name = ... # type: str + opt_level = ... # type: int + pass_func = ... # type: Callable + required = ... # type: list + + def __init__(self, name, opt_level, pass_func, required): + # type: (str, int, Callable, list) -> None + ... + + +class FunctionPass(Pass): + name = ... # type: str + opt_level = ... # type: int + pass_func = ... # type: Callable + required = ... # type: list + + def __init__(self, name, opt_level, pass_func, required): + # type: (str, int, Callable, list) -> None + ... + + +class Sequential(Pass): + name = ... # type: str + opt_level = ... # type: int + passes = ... # type: list + required = ... # type: list + disabled = ... # type: list + + def __init__(self, name, opt_level, passes, required, disabled): + # type: (str, int, list, list, list) -> None + ... diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index d607247b3bc8..a105b692aa9d 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -23,11 +23,11 @@ * \brief Relay pass manager implementation. */ #include -#include +#include namespace tvm { namespace relay { -namespace pass { +namespace transform { using tvm::IRPrinter; @@ -169,17 +169,15 @@ class FunctionPassNode : public PassNode { RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); -class SequentialPass; - /*! - * \brief The SequentialPassNode contains a set of passes that transform Relay + * \brief The SequentialNode contains a set of passes that transform Relay * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly * perform a host of optimizations with a given optimization level and disabled * passes. */ -class SequentialPassNode : public PassNode { +class SequentialNode : public PassNode { public: /* \brief The pass meta data.*/ PassInfo pass_info; @@ -212,10 +210,6 @@ class SequentialPassNode : public PassNode { passes.push_back(pass); } - TVM_DLL static SequentialPass make(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled); - /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. @@ -251,8 +245,8 @@ class SequentialPassNode : public PassNode { */ void SetContext(const PassContext& pass_ctx) final; - static constexpr const char* _type_key = "relay.SequentialPass"; - TVM_DECLARE_NODE_TYPE_INFO(SequentialPassNode, PassNode); + static constexpr const char* _type_key = "relay.Sequential"; + TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); private: /*! @@ -261,8 +255,6 @@ class SequentialPassNode : public PassNode { PassContext pass_ctx_; }; -RELAY_DEFINE_NODE_REF(SequentialPass, SequentialPassNode, Pass); - PassInfo PassInfoNode::make(int opt_level, std::string name, tvm::Array required) { auto pass_info = make_node(); @@ -350,20 +342,24 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { return pval && pval->value != 0; } -SequentialPass SequentialPassNode::make(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled) { - auto n = make_node(); +Sequential::Sequential(tvm::Array passes, + PassInfo pass_info, + tvm::Array disabled) { + auto n = make_node(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); n->disabled = std::move(disabled); - return SequentialPass(n); + node_ = std::move(n); +} + +const SequentialNode* Sequential::operator->() const { + return static_cast(this->node_.get()); } // TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in -// a SequentialPass without the consideration of their orders. The phase +// a Sequential without the consideration of their orders. The phase // ordering problem needed to be handled in the future. -Module SequentialPassNode::operator()(const Module& module) const { +Module SequentialNode::operator()(const Module& module) const { Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; @@ -373,7 +369,7 @@ Module SequentialPassNode::operator()(const Module& module) const { return mod; } -void SequentialPassNode::ResolveDependency(const Module& mod) { +void SequentialNode::ResolveDependency(const Module& mod) { // TODO(zhiics) Implement it. // 1. Consider the required passes for each pass. // 2. Only resolve the enabled passes. @@ -382,7 +378,7 @@ void SequentialPassNode::ResolveDependency(const Module& mod) { << "\n"; } -std::vector SequentialPassNode::DisabledPasses() const { +std::vector SequentialNode::DisabledPasses() const { std::vector ret; for (const auto& it : disabled) { const auto* str = it.as(); @@ -392,7 +388,7 @@ std::vector SequentialPassNode::DisabledPasses() const { return ret; } -void SequentialPassNode::SetContext(const PassContext& pass_ctx) { +void SequentialNode::SetContext(const PassContext& pass_ctx) { pass_ctx_ = pass_ctx; } @@ -414,21 +410,12 @@ Pass CreateFunctionPass( return FunctionPassNode::make(pass_func, pass_info); } -Pass CreateSequentialPass(const tvm::Array& passes, - int opt_level, - const std::string& name, - const tvm::Array& required, - const tvm::Array& disabled) { - PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - return SequentialPassNode::make(passes, pass_info, disabled); -} - TVM_REGISTER_NODE_TYPE(PassInfoNode); -TVM_REGISTER_API("relay._ir_pass.PassInfo") +TVM_REGISTER_API("relay._transform.PassInfo") .set_body_typed(PassInfoNode::make); -TVM_REGISTER_API("relay._ir_pass.Info") +TVM_REGISTER_API("relay._transform.Info") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); @@ -450,10 +437,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_API("relay._ir_pass.CreateModulePass") +TVM_REGISTER_API("relay._transform.CreateModulePass") .set_body_typed(CreateModulePass); -TVM_REGISTER_API("relay._ir_pass.RunPass") +TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; Module mod = args[1]; @@ -475,7 +462,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass") +TVM_REGISTER_API("relay._transform.CreateFunctionPass") .set_body_typed(CreateFunctionPass); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -486,9 +473,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << " at the optimization level " << pn->opt_level; }); -TVM_REGISTER_NODE_TYPE(SequentialPassNode); +TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") +TVM_REGISTER_API("relay._transform.Sequential") .set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; @@ -496,14 +483,14 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") tvm::Array required = args[3]; tvm::Array disabled = args[4]; PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - *ret = SequentialPassNode::make(passes, pass_info, disabled); + *ret = Sequential(passes, pass_info, disabled); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const SequentialPassNode* node, - tvm::IRPrinter* p) { +.set_dispatch([](const SequentialNode* node, + tvm::IRPrinter* p) { const PassInfoNode* seq_pn = node->Info().operator->(); - p->stream << "Run SequentialPass pass: " << seq_pn->name + p->stream << "Run Sequential pass: " << seq_pn->name << " at the optimization level. " << seq_pn->opt_level; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { @@ -514,7 +501,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "]"; }); -TVM_REGISTER_API("relay._ir_pass.SetContext") +TVM_REGISTER_API("relay._transform.SetContext") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; PassContext pass_ctx = args[1]; @@ -523,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext") TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_API("relay._ir_pass.PassContext") +TVM_REGISTER_API("relay._transform.PassContext") .set_body_typed(PassContextNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -534,6 +521,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << "\n"; }); -} // namespace pass +} // namespace transform } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index b8216775ee1c..db346e7f712f 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -22,6 +22,7 @@ from tvm.relay import ExprFunctor from tvm.relay import Function, Call from tvm.relay import ir_pass +from tvm.relay import transform as _transform from tvm.relay.testing import ctx_list @@ -126,13 +127,13 @@ def test_module_pass(): opt_tester = OptTester(mod) pass_ctx = None - @ir_pass.module_pass(opt_level=opt_level, name=pass_name) + @_transform.module_pass(opt_level=opt_level, name=pass_name) def transform(expr, ctx): return opt_tester.transform(expr, ctx) def test_pass_registration(): mod_pass = transform - assert isinstance(mod_pass, ir_pass.ModulePass) + assert isinstance(mod_pass, _transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level @@ -140,8 +141,8 @@ def test_pass_registration(): def test_pass_registration_no_decorator(): def direct_transform(expr, ctx): return opt_tester.transform(expr, ctx) - mod_pass = ir_pass.module_pass(direct_transform, opt_level=3) - assert isinstance(mod_pass, ir_pass.ModulePass) + mod_pass = _transform.module_pass(direct_transform, opt_level=3) + assert isinstance(mod_pass, _transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == "direct_transform" assert pass_info.opt_level == 3 @@ -202,7 +203,7 @@ def test_function_pass(): opt_tester = OptTester(mod) pass_ctx = None - @ir_pass.function_pass(opt_level=opt_level, name=pass_name) + @_transform.function_pass(opt_level=opt_level, name=pass_name) def transform(expr, ctx): return opt_tester.transform(expr, ctx) @@ -212,7 +213,7 @@ def get_ref_log(): def test_pass_registration(): function_pass = transform - assert isinstance(function_pass, ir_pass.FunctionPass) + assert isinstance(function_pass, _transform.FunctionPass) pass_info = function_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level @@ -220,8 +221,8 @@ def test_pass_registration(): def test_pass_registration_no_decorator(): def direct_transform(expr, ctx): return opt_tester.transform(expr, ctx) - mod_pass = ir_pass.function_pass(direct_transform, opt_level=0) - assert isinstance(mod_pass, ir_pass.FunctionPass) + mod_pass = _transform.function_pass(direct_transform, opt_level=0) + assert isinstance(mod_pass, _transform.FunctionPass) pass_info = mod_pass.info assert pass_info.name == "direct_transform" assert pass_info.opt_level == 0 @@ -294,14 +295,14 @@ def get_ref_abs(): opt_tester = OptTester(mod) pass_ctx = None - @ir_pass.module_pass(opt_level=1) + @_transform.module_pass(opt_level=1) def mod_transform(expr, ctx): return opt_tester.transform(expr, ctx) module_pass = mod_transform # Register a function pass. - @ir_pass.function_pass(opt_level=1) + @_transform.function_pass(opt_level=1) def func_transform(expr, ctx): return opt_tester.transform(expr, ctx) @@ -310,25 +311,23 @@ def func_transform(expr, ctx): def test_pass_registration(): passes = [module_pass, function_pass] opt_level = 2 - pass_name = "sequential_pass" - sequential_pass = ir_pass.sequential_pass(passes=passes, - opt_level=opt_level) - assert isinstance(sequential_pass, ir_pass.SequentialPass) - pass_info = sequential_pass.info + pass_name = "sequential" + sequential = _transform.Sequential(passes=passes, opt_level=opt_level) + pass_info = sequential.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level def test_no_pass(): passes = [] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) mod_func = ret_mod[v_sub] check_func(sub, mod_func) def test_only_module_pass(): passes = [module_pass] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, sub) @@ -341,8 +340,8 @@ def test_only_module_pass(): def test_only_function_pass(): # Check the subtract function. passes = [function_pass] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) @@ -355,8 +354,8 @@ def test_multiple_passes(): # function pass. mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] - sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) - ret_mod = sequential_pass(mod) + sequential = _transform.Sequential(opt_level=1, passes=passes) + ret_mod = sequential(mod) # Check the abs function is added. abs_var, abs_func = get_var_func()