From b540851b71dde96bbf2b3a8c3474e4287a0be627 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 22 May 2019 20:57:20 +0000 Subject: [PATCH 1/7] merge passcontext and buildconfig --- include/tvm/relay/transform.h | 117 +++++++++-- python/tvm/relay/__init__.py | 3 +- python/tvm/relay/build_module.py | 98 ++------- python/tvm/relay/quantize/quantize.py | 14 +- python/tvm/relay/transform.py | 122 ++++++++--- src/relay/pass/pass_manager.cc | 284 ++++++++++++++++++-------- 6 files changed, 422 insertions(+), 216 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ba25483dfbb2..8cd51cc965d3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -67,6 +67,40 @@ namespace tvm { namespace relay { namespace transform { +/*! + * \brief A data structure to map the names of specific optimizations to + * numeric optimization levels + */ +struct OptPassLevel { + static const std::unordered_map CreateMap() { + const std::unordered_map m = { + {"SimplifyInference", 0}, + {"OpFusion", 1}, + {"FoldConstant", 2}, + {"CombineParallelConv2D", 3}, + {"FoldScaleAxis", 3}, + {"AlterOpLayout", 3}, + {"CanonicalizeOps", 3}, + {"EliminateCommonSubexpr", 3} + }; + return m; + } + /*! + * \brief Get level for an optimization pass + * + * \param key pass name + * \return int level + */ + int operator[](const std::string& key) const { + const auto data = CreateMap(); + auto it = data.find(key); + if (it == data.end()) { + return -1; + } + return it->second; + } +}; + /* * \brief The context of pass. */ @@ -83,18 +117,81 @@ class PassContextNode : public RelayNode { */ ErrorReporter err_reporter; + /*! \brief The default optimization level. */ + int opt_level{2}; + + /*! \brief CPU is the default fallback device for heterogeneous execution. */ + int fallback_device{static_cast(kDLCPU)}; + + /*! \brief The list of required passes. */ + tvm::Array required_pass; + /*! \brief The list of disabled passes. */ + tvm::Array disabled_pass; + + /*! + * \brief A helper struct to get the optimization pass name to opt level + * mapping. + */ + OptPassLevel OPT_PASS_LEVEL; + + /*! + * \brief Convert a list of tvm StringImm to a `std::string` set. + * + * \param input. The input StringImm array. + * + * \return The coverted `std::strin`g set. + */ + std::unordered_set ToStringSet( + const tvm::Array& input) const; + + /*! + * \brief Check if a pass is enabled. + * + * \param pass_name The name of an optimization/analysis pass. + * + * \return true if the pass is enabled. Otherwise, false. + */ + bool pass_enabled(const std::string& pass_name) const; + PassContextNode() = default; void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("opt_level", &opt_level); + v->Visit("fallback_device", &fallback_device); + v->Visit("required_pass", &required_pass); + v->Visit("disabled_pass", &disabled_pass); } - TVM_DLL static PassContext make(); - static constexpr const char* _type_key = "relay.PassContext"; TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); }; -TVM_DEFINE_NODE_REF(PassContext, PassContextNode) +class PassContext : public NodeRef { + public: + PassContext() {} + explicit PassContext(tvm::NodePtr n) : NodeRef(n) {} + + TVM_DLL PassContext(int opt_level, int fallback_device, + tvm::Array required_pass, + tvm::Array disabled_pass); + + // The entry of a pass context scope. + TVM_DLL static void EnterWithScope(const PassContext& pass_ctx); + // The exit of a pass context scope. + TVM_DLL static void ExitWithScope(); + // Get the currently used pass context. + TVM_DLL static PassContext Current(); + + const PassContextNode* operator->() const; + + using ContainerType = PassContextNode; + class Internal; + + private: + // Classes to get the Python `with` like syntax. Enabled after #3231 is merged + // friend class Internal; + // friend class With; +}; /* * \brief The meta data of a pass. @@ -149,13 +246,6 @@ class PassNode : public RelayNode { * \brief Get the pass information/meta data. */ virtual PassInfo Info() const = 0; - /*! - * \brief Set the context information for a pass. - * - * \param pass_ctx The context information for a certain pass. - */ - virtual void SetContext(const PassContext& pass_ctx) = 0; - /*! * \brief Execute the optimization pass using a functor. * @@ -165,6 +255,9 @@ class PassNode : public RelayNode { */ virtual Module operator()(const Module& mod) const = 0; + virtual Module Apply(const Module& mod, + const PassContext& pass_ctx) const = 0; + void VisitAttrs(tvm::AttrVisitor* v) override {} static constexpr const char* _type_key = "relay.Pass"; @@ -191,11 +284,9 @@ class Sequential : public Pass { * \brief The constructor of `Sequential`. * \param passes The passes to apply. * \param pass_info The pass metadata. - * \param disabled The passes that will not be applied. */ TVM_DLL Sequential(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled); + PassInfo pass_info); Sequential() = default; explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index d832c8988795..1c8f5d6ceed3 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -26,7 +26,8 @@ from . import adt from . import ir_pass from . import transform -from .build_module import build, build_config, create_executor +from .build_module import build, create_executor +from .transform import build_config from . import prelude from . import parser from . import debug diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d0ad78fee67f..4b57f176a012 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -28,81 +28,10 @@ from . import ir_pass from . import ty as _ty from . import expr as _expr +from . import transform as _transform from .backend import interpreter as _interpreter from .backend.vm import VMExecutor -class BuildConfig(object): - """Configuration scope to set a build config option. - - Parameters - ---------- - kwargs - Keyword arguments of configurations to set. - """ - current = None - defaults = { - "opt_level": 2, - "add_pass": None, - "disable_pass": None, - "fallback_device": None, - } - - def __init__(self, **kwargs): - self._old_scope = None - for k, _ in kwargs.items(): - if k not in BuildConfig.defaults: - raise ValueError("invalid argument %s, candidates are %s" % - (k, BuildConfig.defaults.keys())) - self._attr = kwargs - - def __getattr__(self, name): - if name not in self._attr: - return BuildConfig.defaults[name] - return self._attr[name] - - def __enter__(self): - # pylint: disable=protected-access - self._old_scope = BuildConfig.current - attr = BuildConfig.current._attr.copy() - attr.update(self._attr) - self._attr = attr - BuildConfig.current = self - return self - - def __exit__(self, ptype, value, trace): - assert self._old_scope - BuildConfig.current = self._old_scope - - -BuildConfig.current = BuildConfig() - - -def build_config(**kwargs): - """Configure the build behavior by setting config variables. - - Parameters - ---------- - opt_level: int, default=2 - Optimization level. See OPT_PASS_LEVEL for level of each pass. - - add_pass: set of str - Optimization pass to be added regardless of optimization level. - - disable_pass: set of str - Optimization pass to be disabled during optimization. - - fallback_device : str or tvm.TVMContext - The fallback device. It is also used as the default device for - operators without specified device during heterogeneous execution. - - Returns - ------- - config: BuildConfig - The build configuration - """ - return BuildConfig(**kwargs) - - def _update_target(target): target = target if target else _target.current_target() if target is None: @@ -189,7 +118,7 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params def _setup_build_config(self, params): - cfg = BuildConfig.current + cfg = _transform.current_pass_context() # Set opt_level. self.set_opt_level(cfg.opt_level) @@ -199,24 +128,24 @@ def _setup_build_config(self, params): self.set_fallback_device(cfg.fallback_device) # Add required passes. - if cfg.add_pass: + if cfg.required_pass: passes = set() - if isinstance(cfg.add_pass, (list, tuple, set)): - passes = set(cfg.add_pass) + if isinstance(cfg.required_pass, (list, tuple, set)): + passes = set(cfg.required_pass) else: raise TypeError("add_pass must be list, tuple, or set, but " + - "got {}".format(type(cfg.add_pass))) + "got {}".format(type(cfg.required_pass))) for pass_name in passes: self.add_pass(pass_name) # Add disabled passes. - if cfg.disable_pass: + if cfg.disabled_pass: passes = set() - if isinstance(cfg.disable_pass, (list, tuple, set)): - passes = set(cfg.disable_pass) + if isinstance(cfg.disabled_pass, (list, tuple, set)): + passes = set(cfg.disabled_pass) else: raise TypeError("disable_pass must be list, tuple, or set, " + - "but got {}".format(type(cfg.disable_pass))) + "but got {}".format(type(cfg.disabled_pass))) for pass_name in passes: self.disable_pass(pass_name) @@ -287,12 +216,11 @@ def set_fallback_device(self, fallback_device): fallback_device : str or tvm.TVMContext The fallback device used for heterogeneous execution. """ - if isinstance(fallback_device, str): + if isinstance(fallback_device, (int, str)): fallback_device = _nd.context(fallback_device) if not isinstance(fallback_device, TVMContext): - raise TypeError("fallback_device is expected to be str " + - "TVMContext, or dict of device name to target, " + - "but received: {}".format(type(fallback_device))) + raise TypeError("fallback_device is expected to be str, int, or " + + "TVMContext but received: {}".format(type(fallback_device))) self._set_fallback_device(fallback_device.device_type) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 7fd0099e64a2..2423e76d308a 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from . import _quantize from .. import expr as _expr from .. import ir_pass as _ir_pass -from .. import build_module as _build +from .. import transform as _transform from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node @@ -301,7 +301,7 @@ def optimize(func, params=None): "FoldConstant", "CanonicalizeOps"] - cfg = _build.build_config(add_pass=opt_passes) + cfg = _transform.build_config(required_pass=opt_passes) if params: name_dict = {} @@ -321,25 +321,25 @@ def optimize(func, params=None): bind_dict[arg] = _expr.const(v) func = _expr.bind(func, bind_dict) - if "SimplifyInference" in cfg.add_pass: + if "SimplifyInference" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.simplify_inference(func) - if "FoldConstant" in cfg.add_pass: + if "FoldConstant" in cfg.required_pass: func = _ir_pass.fold_constant(func) - if "FoldScaleAxis" in cfg.add_pass: + if "FoldScaleAxis" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.backward_fold_scale_axis(func) func = _ir_pass.infer_type(func) func = _ir_pass.forward_fold_scale_axis(func) func = _ir_pass.fold_constant(func) - if "CanonicalizeOps" in cfg.add_pass: + if "CanonicalizeOps" in cfg.required_pass: func = _ir_pass.infer_type(func) func = _ir_pass.canonicalize_ops(func) - if "FoldConstant" in cfg.add_pass: + if "FoldConstant" in cfg.required_pass: func = _ir_pass.fold_constant(func) return func diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 877538afea34..c41bc503b3a9 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -23,8 +23,10 @@ """ import types +from tvm._ffi.runtime_ctypes import TVMContext from . import _transform from .base import RelayNode, register_relay_node +from .. import nd as _nd @register_relay_node @@ -57,10 +59,99 @@ class PassContext(RelayNode): Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter to record the errors of during the optimization, etc. + + opt_level : Optional[int] + The optimization level of this pass. + + fallback_device : Optional[int] + The fallback device type. It is also used as the default device for + operators that are not annotated during heterogeneous execution. + + required_pass : Optional[List[str]] + The list of passes that are required by a certain pass. + + disabled_pass : Optional[List[str]] + The list of passes that are disabled. """ + defaults = { + "opt_level": 2, + "required_pass": None, + "disabled_pass": None, + "fallback_device": _nd.cpu(), + } + + def __init__(self, **kwargs): + for k, _ in kwargs.items(): + if k not in PassContext.defaults: + raise ValueError("invalid argument %s, candidates are %s" % + (k, PassContext.defaults.keys())) + + fallback_device = kwargs["fallback_device"] if "fallback_device" in \ + kwargs else PassContext.defaults["fallback_device"] + if isinstance(fallback_device, str): + fallback_device = _nd.context(fallback_device).device_type + elif isinstance(fallback_device, TVMContext): + fallback_device = fallback_device.device_type + if not isinstance(fallback_device, int): + raise TypeError("required_pass is expected to be the type of " + + "int/str/TVMContext.") + + required = kwargs["required_pass"] if "required_pass" in kwargs \ + else PassContext.defaults["required_pass"] + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("required_pass is expected to be the type of " + + "list/tuple/set.") + + disabled = kwargs["disabled_pass"] if "disabled_pass" in kwargs \ + else PassContext.defaults["disabled_pass"] + disabled = disabled if disabled else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled_pass is expected to be the type of " + + "list/tuple/set.") + + opt_level = kwargs["opt_level"] if "opt_level" in kwargs \ + else PassContext.defaults["opt_level"] + + self.__init_handle_by_constructor__(_transform.PassContext, opt_level, + fallback_device, required, + disabled) + + def __enter__(self): + _transform.EnterPassContext(self) + return self + + def __exit__(self, ptype, value, trace): + _transform.ExitPassContext(self) + + +def current_pass_context(): + """Return the current pass context.""" + return _transform.GetCurrentPassContext() - def __init__(self): - self.__init_handle_by_constructor__(_transform.PassContext) + +def build_config(**kwargs): + """Configure the build behavior by setting config variables. + Parameters + ---------- + opt_level: int, default=2 + Optimization level. See OPT_PASS_LEVEL for level of each pass. + + required_pass: set of str + Optimization passes that are required regardless of optimization level. + + disabled_pass: set of str + Optimization passes to be disabled during optimization. + + fallback_device : int or tvm.TVMContext + The fallback device. It is also used as the default device for + operators without specified device during heterogeneous execution. + Returns + ------- + config: PassContext + The pass context for optimizations. + """ + return PassContext(**kwargs) @register_relay_node @@ -70,20 +161,6 @@ class Pass(RelayNode): conveniently interact with the base class. """ - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain pass or a series - of passes. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _transform.SetContext(self, pass_ctx) - @property def info(self): """Get the pass meta.""" @@ -150,32 +227,23 @@ class Sequential(Pass): required : Optional[List[str]] The list of passes that the sequential pass is dependent on. - - disabled : Optional[List[str]] - A list of disabled passes. """ def __init__(self, passes=None, opt_level=2, name="sequential", - required=None, - disabled=None): + required=None): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - required = required if required else [] if not isinstance(required, (list, tuple)): raise TypeError("Required is expected to be the type of list/tuple.") self.__init_handle_by_constructor__(_transform.Sequential, - passes, opt_level, name, required, - disabled) + passes, opt_level, name, required) def module_pass(pass_func=None, opt_level=None, name=None, required=None): diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a105b692aa9d..827bf900a915 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -22,15 +22,98 @@ * \file src/relay/pass/pass_manager.cc * \brief Relay pass manager implementation. */ +#include #include #include +#include +#include + namespace tvm { namespace relay { namespace transform { using tvm::IRPrinter; +PassContext::PassContext(int opt_level, int fallback_device, + tvm::Array required_pass, + tvm::Array disabled_pass) { + auto ctx = make_node(); + ctx->opt_level = opt_level; + ctx->fallback_device = fallback_device; + ctx->required_pass = std::move(required_pass); + ctx->disabled_pass = std::move(disabled_pass); + node_ = std::move(ctx); +} + +const PassContextNode* PassContext::operator->() const { + return static_cast(node_.get()); +} + +struct RelayPassContextThreadLocalEntry { + /*! \brief The default pass context. */ + PassContext default_context; + + /*! \brief The current pass context. */ + std::stack context_stack; + + RelayPassContextThreadLocalEntry() { + default_context = PassContext(make_node()); + } +}; + +/*! \brief Thread local store to hold the pass context. */ +typedef dmlc::ThreadLocalStore + RelayPassContextThreadLocalStore; + +void PassContext::EnterWithScope(const PassContext& pass_ctx) { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + entry->context_stack.push(pass_ctx); +} + +void PassContext::ExitWithScope() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + // CHECK(entry->context_stack.top().same_as(*this)); + entry->context_stack.pop(); +} + +PassContext PassContext::Current() { + RelayPassContextThreadLocalEntry* entry = + RelayPassContextThreadLocalStore::Get(); + if (!entry->context_stack.empty()) { + return entry->context_stack.top(); + } else { + return entry->default_context; + } +} + +std::unordered_set PassContextNode::ToStringSet( + const tvm::Array& input) const { + std::unordered_set ret; + for (const auto& it : input) { + const auto* strimm = it.as(); + CHECK(strimm); + ret.emplace(strimm->value); + } + return ret; +} + +bool PassContextNode::pass_enabled(const std::string& pass_name) const { + const auto required = ToStringSet(this->required_pass); + const auto disabled = ToStringSet(this->disabled_pass); + if (disabled.count(pass_name)) { + return false; + } + if (required.count(pass_name)) { + return true; + } + return opt_level >= OPT_PASS_LEVEL[pass_name]; +} + + class ModulePass; /*! @@ -67,16 +150,19 @@ class ModulePassNode : public PassNode { Module operator()(const Module& mod) const final; /*! - * \brief Get the pass information/meta data. + * \brief Apply a module pass on given pass context. + * + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. + * + * \return Return the updated module. */ - PassInfo Info() const { return pass_info; } + Module Apply(const Module& mod, const PassContext& pass_ctx) const final; /*! - * \brief Set the context information for a module pass. - * - * \param pass_ctx The context information for a module pass. + * \brief Get the pass information/meta data. */ - void SetContext(const PassContext& pass_ctx) final; + PassInfo Info() const { return pass_info; } TVM_DLL static ModulePass make( runtime::TypedPackedFunc pass_func, @@ -84,12 +170,6 @@ class ModulePassNode : public PassNode { static constexpr const char* _type_key = "relay.ModulePass"; TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode); - - private: - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass); @@ -133,16 +213,20 @@ class FunctionPassNode : public PassNode { Module operator()(const Module& mod) const final; /*! - * \brief Get the pass information/meta data. + * \brief Apply a function pass on given pass context. + * + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. + * + * \return Return the updated module. */ - PassInfo Info() const { return pass_info; } + Module Apply(const Module& mod, const PassContext& pass_ctx) const final; + /*! - * \brief Set the context information for a function-level pass. - * - * \param pass_ctx The context information for a function-level pass. + * \brief Get the pass information/meta data. */ - void SetContext(const PassContext& pass_ctx) final; + PassInfo Info() const { return pass_info; } TVM_DLL static FunctionPass make( runtime::TypedPackedFunc pass_func, @@ -160,11 +244,6 @@ class FunctionPassNode : public PassNode { * \return Return true if the function will be skipped, otherwise false. */ bool SkipFunction(const Function& func) const; - - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); @@ -184,16 +263,9 @@ class SequentialNode : public PassNode { /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; - /*! - * \brief A list of disabled passes that should be excluded when executing the - * sequential pass. - */ - tvm::Array disabled; - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("pass_info", &pass_info); v->Visit("passes", &passes); - v->Visit("disabled", &disabled); } /*! @@ -224,7 +296,8 @@ class SequentialNode : public PassNode { */ void ResolveDependency(const Module& mod); - TVM_DLL std::vector DisabledPasses() const; + TVM_DLL std::unordered_set DisabledPasses( + const Array& disabled) const; /*! * \brief Perform optimizations on a series of passes. The aforementioned @@ -239,20 +312,17 @@ class SequentialNode : public PassNode { Module operator()(const Module& mod) const final; /*! - * \brief Set the context information for a sequential pass. + * \brief Apply a series of passes on given pass context. + * + * \param mod The module that these passes are applied on. + * \param mod The context that these passes execute on. * - * \param pass_ctx The context information for a sequential pass. + * \return Return the updated module. */ - void SetContext(const PassContext& pass_ctx) final; + Module Apply(const Module& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); - - private: - /*! - * \brief The context information that is used to help perform a module pass. - */ - PassContext pass_ctx_; }; PassInfo PassInfoNode::make(int opt_level, std::string name, @@ -264,11 +334,6 @@ PassInfo PassInfoNode::make(int opt_level, std::string name, return PassInfo(pass_info); } -PassContext PassContextNode::make() { - auto ctx = make_node(); - return PassContext(ctx); -} - ModulePass ModulePassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { @@ -287,13 +352,22 @@ Module ModulePassNode::operator()(const Module& mod) const { LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; CHECK(mod.defined()); - auto updated_mod = pass_func(mod, pass_ctx_); + PassContext ctx = PassContext::Current(); + auto updated_mod = pass_func(mod, ctx); CHECK(updated_mod.defined()); return updated_mod; } -void ModulePassNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; +Module ModulePassNode::Apply(const Module& mod, + const PassContext& pass_ctx) const { + PassInfo pass_info = Info(); + LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name + << " with opt level: " << pass_info.operator->()->opt_level << "\n"; + + CHECK(mod.defined()); + auto updated_mod = pass_func(mod, pass_ctx); + CHECK(updated_mod.defined()); + return updated_mod; } FunctionPass FunctionPassNode::make( @@ -312,26 +386,35 @@ Module FunctionPassNode::operator()(const Module& mod) const { LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; CHECK(mod.defined()); - std::vector> updated_funcs; - ModuleNode* mod_node = mod.operator->(); - for (const auto& it : mod_node->functions) { - if (!SkipFunction(it.second)) { - auto updated_func = pass_func(it.second, pass_ctx_); - CHECK(updated_func.defined()); - updated_funcs.push_back({std::move(it.first), std::move(updated_func)}); - } - } - - // Update the optimized functions. - for (const auto& it : updated_funcs) { - mod_node->Update(it.first, it.second); + Module new_mod = ModuleNode::make({}, mod->type_definitions); + PassContext ctx = PassContext::Current(); + + // Execute the pass function and return a new module. + for (const auto& it : mod->functions) { + auto updated_func = + SkipFunction(it.second) ? it.second : pass_func(it.second, ctx); + new_mod->Add(it.first, updated_func); } - return GetRef(mod_node); + return new_mod; } -void FunctionPassNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; +Module FunctionPassNode::Apply(const Module& mod, + const PassContext& pass_ctx) const { + PassInfo pass_info = Info(); + LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name + << " with opt level: " << pass_info.operator->()->opt_level << "\n"; + CHECK(mod.defined()); + Module new_mod = ModuleNode::make({}, mod->type_definitions); + + // Execute the pass function and return a new module. + for (const auto& it : mod->functions) { + auto updated_func = + SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx); + new_mod->Add(it.first, updated_func); + } + + return new_mod; } // TODO(zhiics) Create an enum attribute for FunctionNode @@ -342,13 +425,10 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { return pval && pval->value != 0; } -Sequential::Sequential(tvm::Array passes, - PassInfo pass_info, - tvm::Array disabled) { +Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { auto n = make_node(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); - n->disabled = std::move(disabled); node_ = std::move(n); } @@ -378,18 +458,36 @@ void SequentialNode::ResolveDependency(const Module& mod) { << "\n"; } -std::vector SequentialNode::DisabledPasses() const { - std::vector ret; +std::unordered_set SequentialNode::DisabledPasses( + const Array& disabled) const { + std::unordered_set ret; for (const auto& it : disabled) { const auto* str = it.as(); CHECK(str) << "disabled passes must be string."; - ret.push_back(str->value); + ret.emplace(str->value); } return ret; } -void SequentialNode::SetContext(const PassContext& pass_ctx) { - pass_ctx_ = pass_ctx; +Module SequentialNode::Apply(const Module& module, + const PassContext& pass_ctx) const { + const auto* ctx_node = pass_ctx.operator->(); + int opt_level = ctx_node->opt_level; + auto disabled = DisabledPasses(ctx_node->disabled_pass); + Module mod = module; + for (const Pass& pass : passes) { + CHECK(pass.defined()) << "Found undefined pass for optimization."; + PassInfo info = pass->Info(); + const auto& pass_name = info.operator->()->name; + const auto& pass_opt_level = info.operator->()->opt_level; + // Skip the pass if its optimization level is higher that the one of in the + // pass context or if this pass is disabled. + if (pass_opt_level > opt_level || disabled.count(pass_name)) { + continue; + } + mod = pass->Apply(mod, pass_ctx); + } + return mod; } Pass CreateModulePass( @@ -481,9 +579,8 @@ TVM_REGISTER_API("relay._transform.Sequential") int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - tvm::Array disabled = args[4]; PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - *ret = Sequential(passes, pass_info, disabled); + *ret = Sequential(passes, pass_info); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -501,17 +598,17 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "]"; }); -TVM_REGISTER_API("relay._transform.SetContext") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Pass pass = args[0]; - PassContext pass_ctx = args[1]; - pass->SetContext(pass_ctx); -}); - TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_API("relay._transform.PassContext") -.set_body_typed(PassContextNode::make); +.set_body([](TVMArgs args, TVMRetValue* ret) { + int opt_level = args[0]; + int fallback_device = args[1]; + tvm::Array required = args[2]; + tvm::Array disabled = args[3]; + *ret = PassContext(opt_level, fallback_device, required, disabled); +}); +; TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PassContextNode* node, @@ -521,6 +618,27 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << "\n"; }); +// Enable after #3231 is merged. +// class PassContext::Internal { +// public: +// static void EnterScope(PassContext pass_ctx) { +// pass_ctx.EnterWithScope(); +// } +// +// static void ExitScope(PassContext pass_ctx) { +// pass_ctx.ExitWithScope(); +// } +// }; + +TVM_REGISTER_API("relay._transform.GetCurrentPassContext") +.set_body_typed(PassContext::Current); + +TVM_REGISTER_API("relay._transform.EnterPassContext") +.set_body_typed(PassContext::EnterWithScope); + +TVM_REGISTER_API("relay._transform.ExitPassContext") +.set_body_typed(PassContext::ExitWithScope); + } // namespace transform } // namespace relay } // namespace tvm From dda19211a64a67e3b4cd7d44e5d279ff87e8f3e1 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 24 May 2019 00:51:41 +0000 Subject: [PATCH 2/7] more methods to sequential --- include/tvm/relay/transform.h | 63 ++++++----------- src/relay/pass/pass_manager.cc | 125 +++++++++++++++++++++------------ 2 files changed, 101 insertions(+), 87 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 8cd51cc965d3..0475ff3b960d 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -71,20 +71,8 @@ namespace transform { * \brief A data structure to map the names of specific optimizations to * numeric optimization levels */ -struct OptPassLevel { - static const std::unordered_map CreateMap() { - const std::unordered_map m = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 3}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} - }; - return m; - } +class OptPassLevel { + public: /*! * \brief Get level for an optimization pass * @@ -99,6 +87,21 @@ struct OptPassLevel { } return it->second; } + + private: + static const std::unordered_map CreateMap() { + const std::unordered_map m = { + {"SimplifyInference", 0}, + {"OpFusion", 1}, + {"FoldConstant", 2}, + {"CombineParallelConv2D", 3}, + {"FoldScaleAxis", 3}, + {"AlterOpLayout", 3}, + {"CanonicalizeOps", 3}, + {"EliminateCommonSubexpr", 3} + }; + return m; + } }; /* @@ -128,31 +131,6 @@ class PassContextNode : public RelayNode { /*! \brief The list of disabled passes. */ tvm::Array disabled_pass; - /*! - * \brief A helper struct to get the optimization pass name to opt level - * mapping. - */ - OptPassLevel OPT_PASS_LEVEL; - - /*! - * \brief Convert a list of tvm StringImm to a `std::string` set. - * - * \param input. The input StringImm array. - * - * \return The coverted `std::strin`g set. - */ - std::unordered_set ToStringSet( - const tvm::Array& input) const; - - /*! - * \brief Check if a pass is enabled. - * - * \param pass_name The name of an optimization/analysis pass. - * - * \return true if the pass is enabled. Otherwise, false. - */ - bool pass_enabled(const std::string& pass_name) const; - PassContextNode() = default; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -175,6 +153,7 @@ class PassContext : public NodeRef { tvm::Array required_pass, tvm::Array disabled_pass); + // Move exter/exit to private once #3231 is merged. // The entry of a pass context scope. TVM_DLL static void EnterWithScope(const PassContext& pass_ctx); // The exit of a pass context scope. @@ -185,7 +164,7 @@ class PassContext : public NodeRef { const PassContextNode* operator->() const; using ContainerType = PassContextNode; - class Internal; + // class Internal; private: // Classes to get the Python `with` like syntax. Enabled after #3231 is merged @@ -255,8 +234,8 @@ class PassNode : public RelayNode { */ virtual Module operator()(const Module& mod) const = 0; - virtual Module Apply(const Module& mod, - const PassContext& pass_ctx) const = 0; + virtual Module operator()(const Module& mod, + const PassContext& pass_ctx) const = 0; void VisitAttrs(tvm::AttrVisitor* v) override {} diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 827bf900a915..04732f3599b7 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -25,7 +25,9 @@ #include #include #include +#include +#include #include #include @@ -90,30 +92,6 @@ PassContext PassContext::Current() { } } -std::unordered_set PassContextNode::ToStringSet( - const tvm::Array& input) const { - std::unordered_set ret; - for (const auto& it : input) { - const auto* strimm = it.as(); - CHECK(strimm); - ret.emplace(strimm->value); - } - return ret; -} - -bool PassContextNode::pass_enabled(const std::string& pass_name) const { - const auto required = ToStringSet(this->required_pass); - const auto disabled = ToStringSet(this->disabled_pass); - if (disabled.count(pass_name)) { - return false; - } - if (required.count(pass_name)) { - return true; - } - return opt_level >= OPT_PASS_LEVEL[pass_name]; -} - - class ModulePass; /*! @@ -150,14 +128,14 @@ class ModulePassNode : public PassNode { Module operator()(const Module& mod) const final; /*! - * \brief Apply a module pass on given pass context. + * \brief Run a module pass on given pass context. * * \param mod The module that an optimization pass is applied on. * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ - Module Apply(const Module& mod, const PassContext& pass_ctx) const final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. @@ -213,15 +191,14 @@ class FunctionPassNode : public PassNode { Module operator()(const Module& mod) const final; /*! - * \brief Apply a function pass on given pass context. + * \brief Run a function pass on given pass context. * * \param mod The module that an optimization pass is applied on. * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ - Module Apply(const Module& mod, const PassContext& pass_ctx) const final; - + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. @@ -261,6 +238,12 @@ class SequentialNode : public PassNode { /* \brief The pass meta data.*/ PassInfo pass_info; + /*! + * \brief A helper struct to get the optimization pass name to opt level + * mapping. + */ + OptPassLevel opt_pass_level; + /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -282,6 +265,15 @@ class SequentialNode : public PassNode { passes.push_back(pass); } + /*! + * \brief Check if a pass is enabled. + * + * \param pass_name The name of an optimization/analysis pass. + * + * \return true if the pass is enabled. Otherwise, false. + */ + bool pass_enabled(const std::string& pass_name) const; + /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. @@ -296,9 +288,11 @@ class SequentialNode : public PassNode { */ void ResolveDependency(const Module& mod); - TVM_DLL std::unordered_set DisabledPasses( + std::unordered_set DisabledPasses( const Array& disabled) const; + std::unordered_set RequiredPasses( + const Array& disabled) const; /*! * \brief Perform optimizations on a series of passes. The aforementioned * typical pass manager jobs could be done by it. This function could @@ -312,14 +306,14 @@ class SequentialNode : public PassNode { Module operator()(const Module& mod) const final; /*! - * \brief Apply a series of passes on given pass context. + * \brief Run a series of passes on given pass context. * * \param mod The module that these passes are applied on. * \param mod The context that these passes execute on. * * \return Return the updated module. */ - Module Apply(const Module& mod, const PassContext& pass_ctx) const final; + Module operator()(const Module& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); @@ -358,8 +352,8 @@ Module ModulePassNode::operator()(const Module& mod) const { return updated_mod; } -Module ModulePassNode::Apply(const Module& mod, - const PassContext& pass_ctx) const { +Module ModulePassNode::operator()(const Module& mod, + const PassContext& pass_ctx) const { PassInfo pass_info = Info(); LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; @@ -399,8 +393,8 @@ Module FunctionPassNode::operator()(const Module& mod) const { return new_mod; } -Module FunctionPassNode::Apply(const Module& mod, - const PassContext& pass_ctx) const { +Module FunctionPassNode::operator()(const Module& mod, + const PassContext& pass_ctx) const { PassInfo pass_info = Info(); LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name << " with opt level: " << pass_info.operator->()->opt_level << "\n"; @@ -469,8 +463,36 @@ std::unordered_set SequentialNode::DisabledPasses( return ret; } -Module SequentialNode::Apply(const Module& module, - const PassContext& pass_ctx) const { +std::unordered_set SequentialNode::RequiredPasses( + const Array& required) const { + std::unordered_set ret; + for (const auto& it : required) { + const auto* str = it.as(); + CHECK(str) << "disabled passes must be string."; + ret.emplace(str->value); + } + return ret; +} + +bool SequentialNode::pass_enabled(const std::string& pass_name) const { + PassContext ctx = PassContext::Current(); + + const PassContextNode* ctx_node = ctx.operator->(); + auto required = RequiredPasses(ctx_node->required_pass); + auto disabled = DisabledPasses(ctx_node->required_pass); + + if (disabled.count(pass_name)) { + return false; + } + + if (required.count(pass_name)) { + return true; + } + return ctx_node->opt_level >= opt_pass_level[pass_name]; +} + +Module SequentialNode::operator()(const Module& module, + const PassContext& pass_ctx) const { const auto* ctx_node = pass_ctx.operator->(); int opt_level = ctx_node->opt_level; auto disabled = DisabledPasses(ctx_node->disabled_pass); @@ -485,7 +507,8 @@ Module SequentialNode::Apply(const Module& module, if (pass_opt_level > opt_level || disabled.count(pass_name)) { continue; } - mod = pass->Apply(mod, pass_ctx); + const auto* pn = pass.operator->(); + mod = (*pn)(mod, pass_ctx); } return mod; } @@ -608,14 +631,26 @@ TVM_REGISTER_API("relay._transform.PassContext") tvm::Array disabled = args[3]; *ret = PassContext(opt_level, fallback_device, required, disabled); }); -; TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PassContextNode* node, - tvm::IRPrinter* p) { - p->stream << "TODO(zhiics): printing context"; - LOG(FATAL) << "PassContext printer has not been implemented yet." - << "\n"; + tvm::IRPrinter* p) { + p->stream << "Pass context information: " << "\n"; + p->stream << "\topt_level: " << node->opt_level << "\n"; + p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level) + << "\n"; + + p->stream << "\trequired passes: [" << node->opt_level; + for (const auto& it : node->required_pass) { + p->stream << it << " "; + } + p->stream << "]\n"; + + p->stream << "\tdisabled passes: [" << node->opt_level; + for (const auto& it : node->disabled_pass) { + p->stream << it << " "; + } + p->stream << "]"; }); // Enable after #3231 is merged. @@ -624,7 +659,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) // static void EnterScope(PassContext pass_ctx) { // pass_ctx.EnterWithScope(); // } -// +// // static void ExitScope(PassContext pass_ctx) { // pass_ctx.ExitWithScope(); // } From 6a48444d38e42236e95cfbc0391882c7f6da8f4b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 24 May 2019 05:16:59 +0000 Subject: [PATCH 3/7] remove kwargs --- include/tvm/relay/transform.h | 1 + python/tvm/relay/transform.py | 49 ++++++++++++++--------------------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 0475ff3b960d..586e723444e7 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -61,6 +61,7 @@ #include #include #include +#include #include namespace tvm { diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index c41bc503b3a9..64242242ad1d 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -73,21 +73,11 @@ class PassContext(RelayNode): disabled_pass : Optional[List[str]] The list of passes that are disabled. """ - defaults = { - "opt_level": 2, - "required_pass": None, - "disabled_pass": None, - "fallback_device": _nd.cpu(), - } - - def __init__(self, **kwargs): - for k, _ in kwargs.items(): - if k not in PassContext.defaults: - raise ValueError("invalid argument %s, candidates are %s" % - (k, PassContext.defaults.keys())) - - fallback_device = kwargs["fallback_device"] if "fallback_device" in \ - kwargs else PassContext.defaults["fallback_device"] + def __init__(self, + opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type elif isinstance(fallback_device, TVMContext): @@ -96,23 +86,16 @@ def __init__(self, **kwargs): raise TypeError("required_pass is expected to be the type of " + "int/str/TVMContext.") - required = kwargs["required_pass"] if "required_pass" in kwargs \ - else PassContext.defaults["required_pass"] - required = required if required else [] + required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): raise TypeError("required_pass is expected to be the type of " + "list/tuple/set.") - disabled = kwargs["disabled_pass"] if "disabled_pass" in kwargs \ - else PassContext.defaults["disabled_pass"] - disabled = disabled if disabled else [] + disabled = list(disabled_pass) if disabled_pass else [] if not isinstance(disabled, (list, tuple)): raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.") - opt_level = kwargs["opt_level"] if "opt_level" in kwargs \ - else PassContext.defaults["opt_level"] - self.__init_handle_by_constructor__(_transform.PassContext, opt_level, fallback_device, required, disabled) @@ -130,12 +113,20 @@ def current_pass_context(): return _transform.GetCurrentPassContext() -def build_config(**kwargs): +def build_config(opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): """Configure the build behavior by setting config variables. Parameters ---------- opt_level: int, default=2 - Optimization level. See OPT_PASS_LEVEL for level of each pass. + Optimization level. See include/tvm/relay/transform.h for level of each + pass. + + fallback_device : int or tvm.TVMContext + The fallback device. It is also used as the default device for + operators without specified device during heterogeneous execution. required_pass: set of str Optimization passes that are required regardless of optimization level. @@ -143,15 +134,13 @@ def build_config(**kwargs): disabled_pass: set of str Optimization passes to be disabled during optimization. - fallback_device : int or tvm.TVMContext - The fallback device. It is also used as the default device for - operators without specified device during heterogeneous execution. Returns ------- config: PassContext The pass context for optimizations. """ - return PassContext(**kwargs) + return PassContext(opt_level, fallback_device, required_pass, + disabled_pass) @register_relay_node From 9a406afc3219e9ef501e55aec4d9d5100fc3d504 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 24 May 2019 05:53:49 +0000 Subject: [PATCH 4/7] transform.build_config --- tutorials/frontend/from_tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index f8686e9d20ab..01669818bb99 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -144,7 +144,7 @@ def extract(path): # target x86 CPU target = "llvm" -with relay.build_module.build_config(opt_level=3): +with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) ###################################################################### From 48332a5e96bc3b50e5c06ce91a2f809c2d6cc14b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 24 May 2019 06:38:16 +0000 Subject: [PATCH 5/7] more fix --- docs/api/python/relay/build_module.rst | 8 -- docs/api/python/relay/transform.rst | 47 ++++++++ include/tvm/relay/transform.h | 41 +------ python/tvm/relay/transform.py | 30 +++-- src/relay/pass/pass_manager.cc | 116 +++++++------------ tests/python/frontend/coreml/test_forward.py | 4 +- tests/python/frontend/keras/test_forward.py | 2 +- 7 files changed, 118 insertions(+), 130 deletions(-) create mode 100644 docs/api/python/relay/transform.rst diff --git a/docs/api/python/relay/build_module.rst b/docs/api/python/relay/build_module.rst index 28dadea21e78..26164bf1ade9 100644 --- a/docs/api/python/relay/build_module.rst +++ b/docs/api/python/relay/build_module.rst @@ -22,17 +22,9 @@ tvm.relay.build_module .. autofunction:: tvm.relay.build_module.build -.. autofunction:: tvm.relay.build_module.build_config - .. autofunction:: tvm.relay.build_module.optimize .. autofunction:: tvm.relay.build_module.create_executor -.. autoclass:: tvm.relay.build_module.BuildConfig - :members: - -.. autofunction:: tvm.relay.build_module.build_config - :members: - .. autoclass:: tvm.relay.build_module.GraphExecutor :members: diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst new file mode 100644 index 000000000000..c618628d121e --- /dev/null +++ b/docs/api/python/relay/transform.rst @@ -0,0 +1,47 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relay.transform +---------------------- + +.. automodule:: tvm.relay.transform + +.. autofunction:: tvm.relay.transform.build_config + +.. autofunction:: tvm.relay.transform.module_pass + +.. autofunction:: tvm.relay.transform.function_pass + +.. autofunction:: tvm.relay.transform.current_pass_context + +.. autoclass:: tvm.relay.transform.Pass + :members: + +.. autoclass:: tvm.relay.transform.PassInfo + :members: + +.. autoclass:: tvm.relay.transform.PassContext + :members: + +.. autoclass:: tvm.relay.transform.ModulePass + :members: + +.. autoclass:: tvm.relay.transform.FunctionPass + :members: + +.. autoclass:: tvm.relay.transform.Sequential + :members: diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 586e723444e7..61005a0835a1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -68,43 +68,6 @@ namespace tvm { namespace relay { namespace transform { -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - */ -class OptPassLevel { - public: - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - const auto data = CreateMap(); - auto it = data.find(key); - if (it == data.end()) { - return -1; - } - return it->second; - } - - private: - static const std::unordered_map CreateMap() { - const std::unordered_map m = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 3}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} - }; - return m; - } -}; - /* * \brief The context of pass. */ @@ -233,7 +196,9 @@ class PassNode : public RelayNode { * * \return The updated module. */ - virtual Module operator()(const Module& mod) const = 0; + Module operator()(const Module& mod) const { + return this->operator()(mod, PassContext::Current()); + } virtual Module operator()(const Module& mod, const PassContext& pass_ctx) const = 0; diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 64242242ad1d..5ddcd9a21e7d 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -118,25 +118,39 @@ def build_config(opt_level=2, required_pass=None, disabled_pass=None): """Configure the build behavior by setting config variables. + Parameters ---------- - opt_level: int, default=2 - Optimization level. See include/tvm/relay/transform.h for level of each - pass. - - fallback_device : int or tvm.TVMContext + opt_level: int, optional + Optimization level. The optimization pass name and level are as the + following: + + .. code-block:: python + + OPT_PASS_LEVEL = { + "SimplifyInference": 0, + "OpFusion": 1, + "FoldConstant": 2, + "CombineParallelConv2D": 3, + "FoldScaleAxis": 3, + "AlterOpLayout": 3, + "CanonicalizeOps": 3, + "EliminateCommonSubexpr": 3, + } + + fallback_device : int, str, or tvm.TVMContext, optional The fallback device. It is also used as the default device for operators without specified device during heterogeneous execution. - required_pass: set of str + required_pass: set of str, optional Optimization passes that are required regardless of optimization level. - disabled_pass: set of str + disabled_pass: set of str, optional Optimization passes to be disabled during optimization. Returns ------- - config: PassContext + pass_context: PassContext The pass context for optimizations. """ return PassContext(opt_level, fallback_device, required_pass, diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 04732f3599b7..41bf119d6e6f 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -37,6 +37,43 @@ namespace transform { using tvm::IRPrinter; +/*! + * \brief A data structure to map the names of specific optimizations to + * numeric optimization levels + */ +class OptPassLevel { + public: + /*! + * \brief Get level for an optimization pass + * + * \param key pass name + * \return int level + */ + int operator[](const std::string& key) const { + const auto data = CreateMap(); + auto it = data.find(key); + if (it == data.end()) { + return -1; + } + return it->second; + } + + private: + static const std::unordered_map CreateMap() { + const std::unordered_map m = { + {"SimplifyInference", 0}, + {"OpFusion", 1}, + {"FoldConstant", 2}, + {"CombineParallelConv2D", 3}, + {"FoldScaleAxis", 3}, + {"AlterOpLayout", 3}, + {"CanonicalizeOps", 3}, + {"EliminateCommonSubexpr", 3} + }; + return m; + } +}; + PassContext::PassContext(int opt_level, int fallback_device, tvm::Array required_pass, tvm::Array disabled_pass) { @@ -118,15 +155,6 @@ class ModulePassNode : public PassNode { v->Visit("pass_info", &pass_info); } - /*! - * \brief Run a module pass on a certain module. - * - * \param mod The module that an optimization pass runs on. - * - * \return Return the updated module. - */ - Module operator()(const Module& mod) const final; - /*! * \brief Run a module pass on given pass context. * @@ -181,15 +209,6 @@ class FunctionPassNode : public PassNode { v->Visit("pass_info", &pass_info); } - /*! - * \brief Run a function pass on a certain module. - * - * \param mod The module that an optimization pass runs on. - * - * \return Return the updated module. - */ - Module operator()(const Module& mod) const final; - /*! * \brief Run a function pass on given pass context. * @@ -293,23 +312,15 @@ class SequentialNode : public PassNode { std::unordered_set RequiredPasses( const Array& disabled) const; + /*! * \brief Perform optimizations on a series of passes. The aforementioned * typical pass manager jobs could be done by it. This function could * be overloaded to focus on different metrics, i.e. performance, * memory footprint, etc. * - * \param mod The module that an optimization pass runs on. - * - * \return Return the updated module. - */ - Module operator()(const Module& mod) const final; - - /*! - * \brief Run a series of passes on given pass context. - * * \param mod The module that these passes are applied on. - * \param mod The context that these passes execute on. + * \param pass_ctx The context that these passes execute on. * * \return Return the updated module. */ @@ -338,20 +349,7 @@ ModulePass ModulePassNode::make( } // Module -> Module optimizations. -// TODO(zhiics) 1. Check and handle the required passes. -// 2. Probably use CoW for all places that use module instead of -// returning the updated one. -Module ModulePassNode::operator()(const Module& mod) const { - PassInfo pass_info = Info(); - LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name - << " with opt level: " << pass_info.operator->()->opt_level << "\n"; - CHECK(mod.defined()); - PassContext ctx = PassContext::Current(); - auto updated_mod = pass_func(mod, ctx); - CHECK(updated_mod.defined()); - return updated_mod; -} - +// TODO(zhiics) Check and handle the required passes. Module ModulePassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); @@ -375,24 +373,6 @@ FunctionPass FunctionPassNode::make( // Perform Module -> Module optimizations at the Function level. // TODO(zhiics) Check and handle the required passes. -Module FunctionPassNode::operator()(const Module& mod) const { - PassInfo pass_info = Info(); - LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name - << " with opt level: " << pass_info.operator->()->opt_level << "\n"; - CHECK(mod.defined()); - Module new_mod = ModuleNode::make({}, mod->type_definitions); - PassContext ctx = PassContext::Current(); - - // Execute the pass function and return a new module. - for (const auto& it : mod->functions) { - auto updated_func = - SkipFunction(it.second) ? it.second : pass_func(it.second, ctx); - new_mod->Add(it.first, updated_func); - } - - return new_mod; -} - Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); @@ -430,19 +410,6 @@ const SequentialNode* Sequential::operator->() const { return static_cast(this->node_.get()); } -// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in -// a Sequential without the consideration of their orders. The phase -// ordering problem needed to be handled in the future. -Module SequentialNode::operator()(const Module& module) const { - Module mod = module; - for (const Pass& pass : passes) { - CHECK(pass.defined()) << "Found undefined pass for optimization."; - const auto* pn = pass.operator->(); - mod = (*pn)(mod); - } - return mod; -} - void SequentialNode::ResolveDependency(const Module& mod) { // TODO(zhiics) Implement it. // 1. Consider the required passes for each pass. @@ -491,6 +458,9 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const { return ctx_node->opt_level >= opt_pass_level[pass_name]; } +// TODO(zhiics): we currenlty only sequentially execute each pass in +// a Sequential without the consideration of their orders. The phase +// ordering problem needed to be handled in the future. Module SequentialNode::operator()(const Module& module, const PassContext& pass_ctx) const { const auto* ctx_node = pass_ctx.operator->(); diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 0fed49079fd2..da78e960091d 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -31,7 +31,7 @@ def get_tvm_output(func, x, params, target, ctx, out_shape=(1, 1000), input_name='image', dtype='float32'): - with relay.build_module.build_config(opt_level=3): + with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap dtype_dict = {input_name: input_data.dtype} func, params = relay.frontend.from_coreml(coreml_model, shape_dict) - with relay.build_module.build_config(opt_level=3): + with relay.transform.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) from tvm.contrib import graph_runtime diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 35a9229443cb..8817d4faaeaa 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -43,7 +43,7 @@ def get_keras_output(xs, dtype='float32'): def get_tvm_output(xs, target, ctx, dtype='float32'): shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} func, params = relay.frontend.from_keras(keras_model, shape_dict) - with relay.build_module.build_config(opt_level=2): + with relay.transform.build_config(opt_level=2): graph, lib, params = relay.build(func, target, params=params) m = graph_runtime.create(graph, lib, ctx) for name, x in zip(keras_model.input_names, xs): From f148e30891ab0bf5c6d204c161663f0ef1dbf8c5 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 24 May 2019 16:46:04 +0000 Subject: [PATCH 6/7] with --- include/tvm/relay/transform.h | 19 ++++++++++--------- src/relay/pass/pass_manager.cc | 31 +++++++++++++++---------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 61005a0835a1..fbc3cebe7e21 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -56,6 +56,7 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ +#include #include #include #include @@ -117,23 +118,23 @@ class PassContext : public NodeRef { tvm::Array required_pass, tvm::Array disabled_pass); - // Move exter/exit to private once #3231 is merged. - // The entry of a pass context scope. - TVM_DLL static void EnterWithScope(const PassContext& pass_ctx); - // The exit of a pass context scope. - TVM_DLL static void ExitWithScope(); // Get the currently used pass context. TVM_DLL static PassContext Current(); const PassContextNode* operator->() const; using ContainerType = PassContextNode; - // class Internal; + class Internal; private: - // Classes to get the Python `with` like syntax. Enabled after #3231 is merged - // friend class Internal; - // friend class With; + // The entry of a pass context scope. + TVM_DLL void EnterWithScope(); + // The exit of a pass context scope. + TVM_DLL void ExitWithScope(); + + // Classes to get the Python `with` like syntax. + friend class Internal; + friend class tvm::With; }; /* diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 41bf119d6e6f..2f0eb9717853 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -105,17 +105,17 @@ struct RelayPassContextThreadLocalEntry { typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; -void PassContext::EnterWithScope(const PassContext& pass_ctx) { +void PassContext::EnterWithScope() { RelayPassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); - entry->context_stack.push(pass_ctx); + entry->context_stack.push(*this); } void PassContext::ExitWithScope() { RelayPassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); - // CHECK(entry->context_stack.top().same_as(*this)); + CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } @@ -623,26 +623,25 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "]"; }); -// Enable after #3231 is merged. -// class PassContext::Internal { -// public: -// static void EnterScope(PassContext pass_ctx) { -// pass_ctx.EnterWithScope(); -// } -// -// static void ExitScope(PassContext pass_ctx) { -// pass_ctx.ExitWithScope(); -// } -// }; +class PassContext::Internal { + public: + static void EnterScope(PassContext pass_ctx) { + pass_ctx.EnterWithScope(); + } + + static void ExitScope(PassContext pass_ctx) { + pass_ctx.ExitWithScope(); + } +}; TVM_REGISTER_API("relay._transform.GetCurrentPassContext") .set_body_typed(PassContext::Current); TVM_REGISTER_API("relay._transform.EnterPassContext") -.set_body_typed(PassContext::EnterWithScope); +.set_body_typed(PassContext::Internal::EnterScope); TVM_REGISTER_API("relay._transform.ExitPassContext") -.set_body_typed(PassContext::ExitWithScope); +.set_body_typed(PassContext::Internal::ExitScope); } // namespace transform } // namespace relay From ef9f5c6f61baf4d87d895dba86207d73f141e4ae Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 24 May 2019 17:49:33 +0000 Subject: [PATCH 7/7] fix more docs --- docs/api/python/relay/transform.rst | 2 -- include/tvm/relay/transform.h | 37 +++++++++++++++++++++++++++-- python/tvm/relay/build_module.py | 2 +- python/tvm/relay/transform.py | 14 +++++------ src/relay/pass/pass_manager.cc | 8 +++++++ 5 files changed, 51 insertions(+), 12 deletions(-) diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst index c618628d121e..4eb7f9d8fea7 100644 --- a/docs/api/python/relay/transform.rst +++ b/docs/api/python/relay/transform.rst @@ -26,8 +26,6 @@ tvm.relay.transform .. autofunction:: tvm.relay.transform.function_pass -.. autofunction:: tvm.relay.transform.current_pass_context - .. autoclass:: tvm.relay.transform.Pass :members: diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index fbc3cebe7e21..5123f3a3dcf3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -114,7 +114,19 @@ class PassContext : public NodeRef { PassContext() {} explicit PassContext(tvm::NodePtr n) : NodeRef(n) {} - TVM_DLL PassContext(int opt_level, int fallback_device, + /* + * \brief Constructor of a `PassContext` object. + * + * \param opt_level The optimization level that will be applied. + * \param fallback_device The fallback device used for heterogeneous + * execution. + * \param required_pass The passes that are required for a context to execute + * other passes. + * \param required_pass The passes that will be disabled during the + * optimization under a context. + */ + TVM_DLL PassContext(int opt_level, + int fallback_device, tvm::Array required_pass, tvm::Array disabled_pass); @@ -191,7 +203,8 @@ class PassNode : public RelayNode { virtual PassInfo Info() const = 0; /*! - * \brief Execute the optimization pass using a functor. + * \brief Execute the optimization pass using a functor. This functor + * internally uses a current pass context. * * \param mod The module that an optimization pass runs on. * @@ -201,6 +214,15 @@ class PassNode : public RelayNode { return this->operator()(mod, PassContext::Current()); } + /*! + * \brief Execute the optimization pass using a functor under a given pass context. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that will be used to help the execution of + * optimizations. + * + * \return The updated module. + */ virtual Module operator()(const Module& mod, const PassContext& pass_ctx) const = 0; @@ -228,11 +250,22 @@ class Sequential : public Pass { public: /*! * \brief The constructor of `Sequential`. + * * \param passes The passes to apply. * \param pass_info The pass metadata. */ TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); +/*! + * \brief The constructor of `Sequential`. + * + * \param passes The passes to apply. + * \param name The name of a sequential pass. It's defaulted to "sequential". + * This allows users to only provide a list of passes and execute them + * under a given context. + */ + TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); + Sequential() = default; explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 4b57f176a012..6cee393d5f91 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -118,7 +118,7 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params def _setup_build_config(self, params): - cfg = _transform.current_pass_context() + cfg = _transform.PassContext.current() # Set opt_level. self.set_opt_level(cfg.opt_level) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 5ddcd9a21e7d..a7887c630c76 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -63,14 +63,14 @@ class PassContext(RelayNode): opt_level : Optional[int] The optimization level of this pass. - fallback_device : Optional[int] + fallback_device : Optional[Union[int, str, TVMContext]] The fallback device type. It is also used as the default device for operators that are not annotated during heterogeneous execution. - required_pass : Optional[List[str]] + required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are required by a certain pass. - disabled_pass : Optional[List[str]] + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. """ def __init__(self, @@ -107,10 +107,10 @@ def __enter__(self): def __exit__(self, ptype, value, trace): _transform.ExitPassContext(self) - -def current_pass_context(): - """Return the current pass context.""" - return _transform.GetCurrentPassContext() + @staticmethod + def current(): + """Return the current pass context.""" + return _transform.GetCurrentPassContext() def build_config(opt_level=2, diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 2f0eb9717853..4bcc0bb39cc4 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -406,6 +406,14 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { node_ = std::move(n); } +Sequential::Sequential(tvm::Array passes, std::string name) { + auto n = make_node(); + n->passes = std::move(passes); + PassInfo pass_info = PassInfoNode::make(2, std::move(name), {}); + n->pass_info = std::move(pass_info); + node_ = std::move(n); +} + const SequentialNode* Sequential::operator->() const { return static_cast(this->node_.get()); }