diff --git a/docs/api/python/relay/build_module.rst b/docs/api/python/relay/build_module.rst index 28dadea21e78..26164bf1ade9 100644 --- a/docs/api/python/relay/build_module.rst +++ b/docs/api/python/relay/build_module.rst @@ -22,17 +22,9 @@ tvm.relay.build_module .. autofunction:: tvm.relay.build_module.build -.. autofunction:: tvm.relay.build_module.build_config - .. autofunction:: tvm.relay.build_module.optimize .. autofunction:: tvm.relay.build_module.create_executor -.. autoclass:: tvm.relay.build_module.BuildConfig - :members: - -.. autofunction:: tvm.relay.build_module.build_config - :members: - .. autoclass:: tvm.relay.build_module.GraphExecutor :members: diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst new file mode 100644 index 000000000000..4eb7f9d8fea7 --- /dev/null +++ b/docs/api/python/relay/transform.rst @@ -0,0 +1,45 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relay.transform +---------------------- + +.. automodule:: tvm.relay.transform + +.. autofunction:: tvm.relay.transform.build_config + +.. autofunction:: tvm.relay.transform.module_pass + +.. autofunction:: tvm.relay.transform.function_pass + +.. autoclass:: tvm.relay.transform.Pass + :members: + +.. autoclass:: tvm.relay.transform.PassInfo + :members: + +.. autoclass:: tvm.relay.transform.PassContext + :members: + +.. autoclass:: tvm.relay.transform.ModulePass + :members: + +.. autoclass:: tvm.relay.transform.FunctionPass + :members: + +.. autoclass:: tvm.relay.transform.Sequential + :members: diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ba25483dfbb2..5123f3a3dcf3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -56,11 +56,13 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ +#include #include #include #include #include #include +#include #include namespace tvm { @@ -83,18 +85,69 @@ class PassContextNode : public RelayNode { */ ErrorReporter err_reporter; + /*! \brief The default optimization level. */ + int opt_level{2}; + + /*! \brief CPU is the default fallback device for heterogeneous execution. */ + int fallback_device{static_cast(kDLCPU)}; + + /*! \brief The list of required passes. */ + tvm::Array required_pass; + /*! \brief The list of disabled passes. */ + tvm::Array disabled_pass; + PassContextNode() = default; void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("opt_level", &opt_level); + v->Visit("fallback_device", &fallback_device); + v->Visit("required_pass", &required_pass); + v->Visit("disabled_pass", &disabled_pass); } - TVM_DLL static PassContext make(); - static constexpr const char* _type_key = "relay.PassContext"; TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); }; -TVM_DEFINE_NODE_REF(PassContext, PassContextNode) +class PassContext : public NodeRef { + public: + PassContext() {} + explicit PassContext(tvm::NodePtr n) : NodeRef(n) {} + + /* + * \brief Constructor of a `PassContext` object. + * + * \param opt_level The optimization level that will be applied. + * \param fallback_device The fallback device used for heterogeneous + * execution. + * \param required_pass The passes that are required for a context to execute + * other passes. + * \param required_pass The passes that will be disabled during the + * optimization under a context. + */ + TVM_DLL PassContext(int opt_level, + int fallback_device, + tvm::Array required_pass, + tvm::Array disabled_pass); + + // Get the currently used pass context. + TVM_DLL static PassContext Current(); + + const PassContextNode* operator->() const; + + using ContainerType = PassContextNode; + class Internal; + + private: + // The entry of a pass context scope. + TVM_DLL void EnterWithScope(); + // The exit of a pass context scope. + TVM_DLL void ExitWithScope(); + + // Classes to get the Python `with` like syntax. + friend class Internal; + friend class tvm::With; +}; /* * \brief The meta data of a pass. @@ -150,20 +203,28 @@ class PassNode : public RelayNode { virtual PassInfo Info() const = 0; /*! - * \brief Set the context information for a pass. + * \brief Execute the optimization pass using a functor. This functor + * internally uses a current pass context. + * + * \param mod The module that an optimization pass runs on. * - * \param pass_ctx The context information for a certain pass. + * \return The updated module. */ - virtual void SetContext(const PassContext& pass_ctx) = 0; + Module operator()(const Module& mod) const { + return this->operator()(mod, PassContext::Current()); + } /*! - * \brief Execute the optimization pass using a functor. + * \brief Execute the optimization pass using a functor under a given pass context. * * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that will be used to help the execution of + * optimizations. * * \return The updated module. */ - virtual Module operator()(const Module& mod) const = 0; + virtual Module operator()(const Module& mod, + const PassContext& pass_ctx) const = 0; void VisitAttrs(tvm::AttrVisitor* v) override {} @@ -189,13 +250,22 @@ class Sequential : public Pass { public: /*! * \brief The constructor of `Sequential`. + * * \param passes The passes to apply. * \param pass_info The pass metadata. - * \param disabled The passes that will not be applied. */ TVM_DLL Sequential(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled); + PassInfo pass_info); +/*! + * \brief The constructor of `Sequential`. + * + * \param passes The passes to apply. + * \param name The name of a sequential pass. It's defaulted to "sequential". + * This allows users to only provide a list of passes and execute them + * under a given context. + */ + TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); + Sequential() = default; explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index d832c8988795..1c8f5d6ceed3 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -26,7 +26,8 @@ from . import adt from . import ir_pass from . import transform -from .build_module import build, build_config, create_executor +from .build_module import build, create_executor +from .transform import build_config from . import prelude from . import parser from . import debug diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d0ad78fee67f..6cee393d5f91 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -28,81 +28,10 @@ from . import ir_pass from . import ty as _ty from . import expr as _expr +from . import transform as _transform from .backend import interpreter as _interpreter from .backend.vm import VMExecutor -class BuildConfig(object): - """Configuration scope to set a build config option. - - Parameters - ---------- - kwargs - Keyword arguments of configurations to set. - """ - current = None - defaults = { - "opt_level": 2, - "add_pass": None, - "disable_pass": None, - "fallback_device": None, - } - - def __init__(self, **kwargs): - self._old_scope = None - for k, _ in kwargs.items(): - if k not in BuildConfig.defaults: - raise ValueError("invalid argument %s, candidates are %s" % - (k, BuildConfig.defaults.keys())) - self._attr = kwargs - - def __getattr__(self, name): - if name not in self._attr: - return BuildConfig.defaults[name] - return self._attr[name] - - def __enter__(self): - # pylint: disable=protected-access - self._old_scope = BuildConfig.current - attr = BuildConfig.current._attr.copy() - attr.update(self._attr) - self._attr = attr - BuildConfig.current = self - return self - - def __exit__(self, ptype, value, trace): - assert self._old_scope - BuildConfig.current = self._old_scope - - -BuildConfig.current = BuildConfig() - - -def build_config(**kwargs): - """Configure the build behavior by setting config variables. - - Parameters - ---------- - opt_level: int, default=2 - Optimization level. See OPT_PASS_LEVEL for level of each pass. - - add_pass: set of str - Optimization pass to be added regardless of optimization level. - - disable_pass: set of str - Optimization pass to be disabled during optimization. - - fallback_device : str or tvm.TVMContext - The fallback device. It is also used as the default device for - operators without specified device during heterogeneous execution. - - Returns - ------- - config: BuildConfig - The build configuration - """ - return BuildConfig(**kwargs) - - def _update_target(target): target = target if target else _target.current_target() if target is None: @@ -189,7 +118,7 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params def _setup_build_config(self, params): - cfg = BuildConfig.current + cfg = _transform.PassContext.current() # Set opt_level. self.set_opt_level(cfg.opt_level) @@ -199,24 +128,24 @@ def _setup_build_config(self, params): self.set_fallback_device(cfg.fallback_device) # Add required passes. - if cfg.add_pass: + if cfg.required_pass: passes = set() - if isinstance(cfg.add_pass, (list, tuple, set)): - passes = set(cfg.add_pass) + if isinstance(cfg.required_pass, (list, tuple, set)): + passes = set(cfg.required_pass) else: raise TypeError("add_pass must be list, tuple, or set, but " + - "got {}".format(type(cfg.add_pass))) + "got {}".format(type(cfg.required_pass))) for pass_name in passes: self.add_pass(pass_name) # Add disabled passes. - if cfg.disable_pass: + if cfg.disabled_pass: passes = set() - if isinstance(cfg.disable_pass, (list, tuple, set)): - passes = set(cfg.disable_pass) + if isinstance(cfg.disabled_pass, (list, tuple, set)): + passes = set(cfg.disabled_pass) else: raise TypeError("disable_pass must be list, tuple, or set, " + - "but got {}".format(type(cfg.disable_pass))) + "but got {}".format(type(cfg.disabled_pass))) for pass_name in passes: self.disable_pass(pass_name) @@ -287,12 +216,11 @@ def set_fallback_device(self, fallback_device): fallback_device : str or tvm.TVMContext The fallback device used for heterogeneous execution. """ - if isinstance(fallback_device, str): + if isinstance(fallback_device, (int, str)): fallback_device = _nd.context(fallback_device) if not isinstance(fallback_device, TVMContext): - raise TypeError("fallback_device is expected to be str " + - "TVMContext, or dict of device name to target, " + - "but received: {}".format(type(fallback_device))) + raise TypeError("fallback_device is expected to be str, int, or " + + "TVMContext but received: {}".format(type(fallback_device))) self._set_fallback_device(fallback_device.device_type) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 7fd0099e64a2..2423e76d308a 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from . import _quantize from .. import expr as _expr from .. import ir_pass as _ir_pass -from .. import build_module as _build +from .. import transform as _transform from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node @@ -301,7 +301,7 @@ def optimize(func, params=None): "FoldConstant", "CanonicalizeOps"] - cfg = _build.build_config(add_pass=opt_passes) + cfg = _transform.build_config(required_pass=opt_passes) if params: name_dict = {} @@ -321,25 +321,25 @@ def optimize(func, params=None): bind_dict[arg] = _expr.const(v) func = _expr.bind(func, bind_dict) - if "SimplifyInference" in cfg.add_pass: + if "SimplifyInference" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.simplify_inference(func) - if "FoldConstant" in cfg.add_pass: + if "FoldConstant" in cfg.required_pass: func = _ir_pass.fold_constant(func) - if "FoldScaleAxis" in cfg.add_pass: + if "FoldScaleAxis" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.backward_fold_scale_axis(func) func = _ir_pass.infer_type(func) func = _ir_pass.forward_fold_scale_axis(func) func = _ir_pass.fold_constant(func) - if "CanonicalizeOps" in cfg.add_pass: + if "CanonicalizeOps" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.canonicalize_ops(func) - if "FoldConstant" in cfg.add_pass: + if "FoldConstant" in cfg.required_pass: func = _ir_pass.fold_constant(func) return func diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 877538afea34..a7887c630c76 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -23,8 +23,10 @@ """ import types +from tvm._ffi.runtime_ctypes import TVMContext from . import _transform from .base import RelayNode, register_relay_node +from .. import nd as _nd @register_relay_node @@ -57,10 +59,102 @@ class PassContext(RelayNode): Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter to record the errors of during the optimization, etc. + + opt_level : Optional[int] + The optimization level of this pass. + + fallback_device : Optional[Union[int, str, TVMContext]] + The fallback device type. It is also used as the default device for + operators that are not annotated during heterogeneous execution. + + required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are required by a certain pass. + + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are disabled. """ + def __init__(self, + opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): + if isinstance(fallback_device, str): + fallback_device = _nd.context(fallback_device).device_type + elif isinstance(fallback_device, TVMContext): + fallback_device = fallback_device.device_type + if not isinstance(fallback_device, int): + raise TypeError("required_pass is expected to be the type of " + + "int/str/TVMContext.") + + required = list(required_pass) if required_pass else [] + if not isinstance(required, (list, tuple)): + raise TypeError("required_pass is expected to be the type of " + + "list/tuple/set.") - def __init__(self): - self.__init_handle_by_constructor__(_transform.PassContext) + disabled = list(disabled_pass) if disabled_pass else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled_pass is expected to be the type of " + + "list/tuple/set.") + + self.__init_handle_by_constructor__(_transform.PassContext, opt_level, + fallback_device, required, + disabled) + + def __enter__(self): + _transform.EnterPassContext(self) + return self + + def __exit__(self, ptype, value, trace): + _transform.ExitPassContext(self) + + @staticmethod + def current(): + """Return the current pass context.""" + return _transform.GetCurrentPassContext() + + +def build_config(opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): + """Configure the build behavior by setting config variables. + + Parameters + ---------- + opt_level: int, optional + Optimization level. The optimization pass name and level are as the + following: + + .. code-block:: python + + OPT_PASS_LEVEL = { + "SimplifyInference": 0, + "OpFusion": 1, + "FoldConstant": 2, + "CombineParallelConv2D": 3, + "FoldScaleAxis": 3, + "AlterOpLayout": 3, + "CanonicalizeOps": 3, + "EliminateCommonSubexpr": 3, + } + + fallback_device : int, str, or tvm.TVMContext, optional + The fallback device. It is also used as the default device for + operators without specified device during heterogeneous execution. + + required_pass: set of str, optional + Optimization passes that are required regardless of optimization level. + + disabled_pass: set of str, optional + Optimization passes to be disabled during optimization. + + Returns + ------- + pass_context: PassContext + The pass context for optimizations. + """ + return PassContext(opt_level, fallback_device, required_pass, + disabled_pass) @register_relay_node @@ -70,20 +164,6 @@ class Pass(RelayNode): conveniently interact with the base class. """ - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain pass or a series - of passes. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _transform.SetContext(self, pass_ctx) - @property def info(self): """Get the pass meta.""" @@ -150,32 +230,23 @@ class Sequential(Pass): required : Optional[List[str]] The list of passes that the sequential pass is dependent on. - - disabled : Optional[List[str]] - A list of disabled passes. """ def __init__(self, passes=None, opt_level=2, name="sequential", - required=None, - disabled=None): + required=None): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - required = required if required else [] if not isinstance(required, (list, tuple)): raise TypeError("Required is expected to be the type of list/tuple.") self.__init_handle_by_constructor__(_transform.Sequential, - passes, opt_level, name, required, - disabled) + passes, opt_level, name, required) def module_pass(pass_func=None, opt_level=None, name=None, required=None): diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a105b692aa9d..4bcc0bb39cc4 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -22,8 +22,14 @@ * \file src/relay/pass/pass_manager.cc * \brief Relay pass manager implementation. */ +#include #include #include +#include + +#include +#include +#include namespace tvm { namespace relay { @@ -31,6 +37,98 @@ namespace transform { using tvm::IRPrinter; +/*! + * \brief A data structure to map the names of specific optimizations to + * numeric optimization levels + */ +class OptPassLevel { + public: + /*! + * \brief Get level for an optimization pass + * + * \param key pass name + * \return int level + */ + int operator[](const std::string& key) const { + const auto data = CreateMap(); + auto it = data.find(key); + if (it == data.end()) { + return -1; + } + return it->second; + } + + private: + static const std::unordered_map CreateMap() { + const std::unordered_map m = { + {"SimplifyInference", 0}, + {"OpFusion", 1}, + {"FoldConstant", 2}, + {"CombineParallelConv2D", 3}, + {"FoldScaleAxis", 3}, + {"AlterOpLayout", 3}, + {"CanonicalizeOps", 3}, + {"EliminateCommonSubexpr", 3} + }; + return m; + } +}; + +PassContext::PassContext(int opt_level, int fallback_device, + tvm::Array required_pass, + tvm::Array disabled_pass) { + auto ctx = make_node(); + ctx->opt_level = opt_level; + ctx->fallback_device = fallback_device; + ctx->required_pass = std::move(required_pass); + ctx->disabled_pass = std::move(disabled_pass); + node_ = std::move(ctx); +} + +const PassContextNode* PassContext::operator->() const { + return static_cast(node_.get()); +} + +struct RelayPassContextThreadLocalEntry { + /*! \brief The default pass context. */ + PassContext default_context; + + /*! \brief The current pass context. */ + std::stack context_stack; + + RelayPassContextThreadLocalEntry() { + default_context = PassContext(make_node()); + } +}; + +/*! \brief Thread local store to hold the pass context. */ +typedef dmlc::ThreadLocalStore + RelayPassContextThreadLocalStore; + +void PassContext::EnterWithScope() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + entry->context_stack.push(*this); +} + +void PassContext::ExitWithScope() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); + entry->context_stack.pop(); +} + +PassContext PassContext::Current() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + if (!entry->context_stack.empty()) { + return entry->context_stack.top(); + } else { + return entry->default_context; + } +} + class ModulePass; /*! @@ -58,38 +156,26 @@ class ModulePassNode : public PassNode { } /*! - * \brief Run a module pass on a certain module. + * \brief Run a module pass on given pass context. * - * \param mod The module that an optimization pass runs on. + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ - Module operator()(const Module& mod) const final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. */ PassInfo Info() const { return pass_info; } - /*! - * \brief Set the context information for a module pass. - * - * \param pass_ctx The context information for a module pass. - */ - void SetContext(const PassContext& pass_ctx) final; - TVM_DLL static ModulePass make( runtime::TypedPackedFunc pass_func, PassInfo pass_info); static constexpr const char* _type_key = "relay.ModulePass"; TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode); - - private: - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass); @@ -124,26 +210,20 @@ class FunctionPassNode : public PassNode { } /*! - * \brief Run a function pass on a certain module. + * \brief Run a function pass on given pass context. * - * \param mod The module that an optimization pass runs on. + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ - Module operator()(const Module& mod) const final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. */ PassInfo Info() const { return pass_info; } - /*! - * \brief Set the context information for a function-level pass. - * - * \param pass_ctx The context information for a function-level pass. - */ - void SetContext(const PassContext& pass_ctx) final; - TVM_DLL static FunctionPass make( runtime::TypedPackedFunc pass_func, PassInfo pass_info); @@ -160,11 +240,6 @@ class FunctionPassNode : public PassNode { * \return Return true if the function will be skipped, otherwise false. */ bool SkipFunction(const Function& func) const; - - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); @@ -182,18 +257,17 @@ class SequentialNode : public PassNode { /* \brief The pass meta data.*/ PassInfo pass_info; - /*! \brief A list of passes that used to compose a sequential pass. */ - tvm::Array passes; /*! - * \brief A list of disabled passes that should be excluded when executing the - * sequential pass. + * \brief A helper struct to get the optimization pass name to opt level + * mapping. */ - tvm::Array disabled; + OptPassLevel opt_pass_level; + /*! \brief A list of passes that used to compose a sequential pass. */ + tvm::Array passes; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("pass_info", &pass_info); v->Visit("passes", &passes); - v->Visit("disabled", &disabled); } /*! @@ -210,6 +284,15 @@ class SequentialNode : public PassNode { passes.push_back(pass); } + /*! + * \brief Check if a pass is enabled. + * + * \param pass_name The name of an optimization/analysis pass. + * + * \return true if the pass is enabled. Otherwise, false. + */ + bool pass_enabled(const std::string& pass_name) const; + /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. @@ -224,7 +307,11 @@ class SequentialNode : public PassNode { */ void ResolveDependency(const Module& mod); - TVM_DLL std::vector DisabledPasses() const; + std::unordered_set DisabledPasses( + const Array& disabled) const; + + std::unordered_set RequiredPasses( + const Array& disabled) const; /*! * \brief Perform optimizations on a series of passes. The aforementioned @@ -232,27 +319,15 @@ class SequentialNode : public PassNode { * be overloaded to focus on different metrics, i.e. performance, * memory footprint, etc. * - * \param mod The module that an optimization pass runs on. + * \param mod The module that these passes are applied on. + * \param pass_ctx The context that these passes execute on. * * \return Return the updated module. */ - Module operator()(const Module& mod) const final; - - /*! - * \brief Set the context information for a sequential pass. - * - * \param pass_ctx The context information for a sequential pass. - */ - void SetContext(const PassContext& pass_ctx) final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); - - private: - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; PassInfo PassInfoNode::make(int opt_level, std::string name, @@ -264,11 +339,6 @@ PassInfo PassInfoNode::make(int opt_level, std::string name, return PassInfo(pass_info); } -PassContext PassContextNode::make() { - auto ctx = make_node(); - return PassContext(ctx); -} - ModulePass ModulePassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { @@ -279,23 +349,19 @@ ModulePass ModulePassNode::make( } // Module -> Module optimizations. -// TODO(zhiics) 1. Check and handle the required passes. -// 2. Probably use CoW for all places that use module instead of -// returning the updated one. -Module ModulePassNode::operator()(const Module& mod) const { +// TODO(zhiics) Check and handle the required passes. +Module ModulePassNode::operator()(const Module& mod, + const PassContext& pass_ctx) const { PassInfo pass_info = Info(); LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; + CHECK(mod.defined()); - auto updated_mod = pass_func(mod, pass_ctx_); + auto updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } -void ModulePassNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; -} - FunctionPass FunctionPassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { @@ -307,31 +373,22 @@ FunctionPass FunctionPassNode::make( // Perform Module -> Module optimizations at the Function level. // TODO(zhiics) Check and handle the required passes. -Module FunctionPassNode::operator()(const Module& mod) const { +Module FunctionPassNode::operator()(const Module& mod, + const PassContext& pass_ctx) const { PassInfo pass_info = Info(); LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; CHECK(mod.defined()); - std::vector> updated_funcs; - ModuleNode* mod_node = mod.operator->(); - for (const auto& it : mod_node->functions) { - if (!SkipFunction(it.second)) { - auto updated_func = pass_func(it.second, pass_ctx_); - CHECK(updated_func.defined()); - updated_funcs.push_back({std::move(it.first), std::move(updated_func)}); - } - } + Module new_mod = ModuleNode::make({}, mod->type_definitions); - // Update the optimized functions. - for (const auto& it : updated_funcs) { - mod_node->Update(it.first, it.second); + // Execute the pass function and return a new module. + for (const auto& it : mod->functions) { + auto updated_func = + SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx); + new_mod->Add(it.first, updated_func); } - return GetRef(mod_node); -} - -void FunctionPassNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; + return new_mod; } // TODO(zhiics) Create an enum attribute for FunctionNode @@ -342,31 +399,23 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { return pval && pval->value != 0; } -Sequential::Sequential(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled) { +Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { auto n = make_node(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); - n->disabled = std::move(disabled); node_ = std::move(n); } -const SequentialNode* Sequential::operator->() const { - return static_cast(this->node_.get()); +Sequential::Sequential(tvm::Array passes, std::string name) { + auto n = make_node(); + n->passes = std::move(passes); + PassInfo pass_info = PassInfoNode::make(2, std::move(name), {}); + n->pass_info = std::move(pass_info); + node_ = std::move(n); } -// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in -// a Sequential without the consideration of their orders. The phase -// ordering problem needed to be handled in the future. -Module SequentialNode::operator()(const Module& module) const { - Module mod = module; - for (const Pass& pass : passes) { - CHECK(pass.defined()) << "Found undefined pass for optimization."; - const auto* pn = pass.operator->(); - mod = (*pn)(mod); - } - return mod; +const SequentialNode* Sequential::operator->() const { + return static_cast(this->node_.get()); } void SequentialNode::ResolveDependency(const Module& mod) { @@ -378,18 +427,68 @@ void SequentialNode::ResolveDependency(const Module& mod) { << "\n"; } -std::vector SequentialNode::DisabledPasses() const { - std::vector ret; +std::unordered_set SequentialNode::DisabledPasses( + const Array& disabled) const { + std::unordered_set ret; for (const auto& it : disabled) { const auto* str = it.as(); CHECK(str) << "disabled passes must be string."; - ret.push_back(str->value); + ret.emplace(str->value); } return ret; } -void SequentialNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; +std::unordered_set SequentialNode::RequiredPasses( + const Array& required) const { + std::unordered_set ret; + for (const auto& it : required) { + const auto* str = it.as(); + CHECK(str) << "disabled passes must be string."; + ret.emplace(str->value); + } + return ret; +} + +bool SequentialNode::pass_enabled(const std::string& pass_name) const { + PassContext ctx = PassContext::Current(); + + const PassContextNode* ctx_node = ctx.operator->(); + auto required = RequiredPasses(ctx_node->required_pass); + auto disabled = DisabledPasses(ctx_node->required_pass); + + if (disabled.count(pass_name)) { + return false; + } + + if (required.count(pass_name)) { + return true; + } + return ctx_node->opt_level >= opt_pass_level[pass_name]; +} + +// TODO(zhiics): we currenlty only sequentially execute each pass in +// a Sequential without the consideration of their orders. The phase +// ordering problem needed to be handled in the future. +Module SequentialNode::operator()(const Module& module, + const PassContext& pass_ctx) const { + const auto* ctx_node = pass_ctx.operator->(); + int opt_level = ctx_node->opt_level; + auto disabled = DisabledPasses(ctx_node->disabled_pass); + Module mod = module; + for (const Pass& pass : passes) { + CHECK(pass.defined()) << "Found undefined pass for optimization."; + PassInfo info = pass->Info(); + const auto& pass_name = info.operator->()->name; + const auto& pass_opt_level = info.operator->()->opt_level; + // Skip the pass if its optimization level is higher that the one of in the + // pass context or if this pass is disabled. + if (pass_opt_level > opt_level || disabled.count(pass_name)) { + continue; + } + const auto* pn = pass.operator->(); + mod = (*pn)(mod, pass_ctx); + } + return mod; } Pass CreateModulePass( @@ -481,9 +580,8 @@ TVM_REGISTER_API("relay._transform.Sequential") int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - tvm::Array disabled = args[4]; PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - *ret = Sequential(passes, pass_info, disabled); + *ret = Sequential(passes, pass_info); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -501,26 +599,58 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "]"; }); -TVM_REGISTER_API("relay._transform.SetContext") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Pass pass = args[0]; - PassContext pass_ctx = args[1]; - pass->SetContext(pass_ctx); -}); - TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_API("relay._transform.PassContext") -.set_body_typed(PassContextNode::make); +.set_body([](TVMArgs args, TVMRetValue* ret) { + int opt_level = args[0]; + int fallback_device = args[1]; + tvm::Array required = args[2]; + tvm::Array disabled = args[3]; + *ret = PassContext(opt_level, fallback_device, required, disabled); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PassContextNode* node, - tvm::IRPrinter* p) { - p->stream << "TODO(zhiics): printing context"; - LOG(FATAL) << "PassContext printer has not been implemented yet." - << "\n"; + tvm::IRPrinter* p) { + p->stream << "Pass context information: " << "\n"; + p->stream << "\topt_level: " << node->opt_level << "\n"; + p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level) + << "\n"; + + p->stream << "\trequired passes: [" << node->opt_level; + for (const auto& it : node->required_pass) { + p->stream << it << " "; + } + p->stream << "]\n"; + + p->stream << "\tdisabled passes: [" << node->opt_level; + for (const auto& it : node->disabled_pass) { + p->stream << it << " "; + } + p->stream << "]"; }); +class PassContext::Internal { + public: + static void EnterScope(PassContext pass_ctx) { + pass_ctx.EnterWithScope(); + } + + static void ExitScope(PassContext pass_ctx) { + pass_ctx.ExitWithScope(); + } +}; + +TVM_REGISTER_API("relay._transform.GetCurrentPassContext") +.set_body_typed(PassContext::Current); + +TVM_REGISTER_API("relay._transform.EnterPassContext") +.set_body_typed(PassContext::Internal::EnterScope); + +TVM_REGISTER_API("relay._transform.ExitPassContext") +.set_body_typed(PassContext::Internal::ExitScope); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 0fed49079fd2..da78e960091d 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -31,7 +31,7 @@ def get_tvm_output(func, x, params, target, ctx, out_shape=(1, 1000), input_name='image', dtype='float32'): - with relay.build_module.build_config(opt_level=3): + with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap dtype_dict = {input_name: input_data.dtype} func, params = relay.frontend.from_coreml(coreml_model, shape_dict) - with relay.build_module.build_config(opt_level=3): + with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) from tvm.contrib import graph_runtime diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 35a9229443cb..8817d4faaeaa 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -43,7 +43,7 @@ def get_keras_output(xs, dtype='float32'): def get_tvm_output(xs, target, ctx, dtype='float32'): shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} func, params = relay.frontend.from_keras(keras_model, shape_dict) - with relay.build_module.build_config(opt_level=2): + with relay.transform.build_config(opt_level=2): graph, lib, params = relay.build(func, target, params=params) m = graph_runtime.create(graph, lib, ctx) for name, x in zip(keras_model.input_names, xs): diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index f8686e9d20ab..01669818bb99 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -144,7 +144,7 @@ def extract(path): # target x86 CPU target = "llvm" -with relay.build_module.build_config(opt_level=3): +with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) ######################################################################