Skip to content

Commit

Permalink
[Relay][Transform] merge PassContext and BuildConfig (apache#3234)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and Wei Chen committed Jun 26, 2019
1 parent 50ba25a commit 5e25c76
Show file tree
Hide file tree
Showing 11 changed files with 501 additions and 264 deletions.
8 changes: 0 additions & 8 deletions docs/api/python/relay/build_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,9 @@ tvm.relay.build_module

.. autofunction:: tvm.relay.build_module.build

.. autofunction:: tvm.relay.build_module.build_config

.. autofunction:: tvm.relay.build_module.optimize

.. autofunction:: tvm.relay.build_module.create_executor

.. autoclass:: tvm.relay.build_module.BuildConfig
:members:

.. autofunction:: tvm.relay.build_module.build_config
:members:

.. autoclass:: tvm.relay.build_module.GraphExecutor
:members:
45 changes: 45 additions & 0 deletions docs/api/python/relay/transform.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.transform
----------------------

.. automodule:: tvm.relay.transform

.. autofunction:: tvm.relay.transform.build_config

.. autofunction:: tvm.relay.transform.module_pass

.. autofunction:: tvm.relay.transform.function_pass

.. autoclass:: tvm.relay.transform.Pass
:members:

.. autoclass:: tvm.relay.transform.PassInfo
:members:

.. autoclass:: tvm.relay.transform.PassContext
:members:

.. autoclass:: tvm.relay.transform.ModulePass
:members:

.. autoclass:: tvm.relay.transform.FunctionPass
:members:

.. autoclass:: tvm.relay.transform.Sequential
:members:
92 changes: 81 additions & 11 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_

#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
#include <unordered_map>
#include <vector>

namespace tvm {
Expand All @@ -83,18 +85,69 @@ class PassContextNode : public RelayNode {
*/
ErrorReporter err_reporter;

/*! \brief The default optimization level. */
int opt_level{2};

/*! \brief CPU is the default fallback device for heterogeneous execution. */
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
tvm::Array<tvm::Expr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::Expr> disabled_pass;

PassContextNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
}

TVM_DLL static PassContext make();

static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
class PassContext : public NodeRef {
public:
PassContext() {}
explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}

/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
*/
TVM_DLL PassContext(int opt_level,
int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass);

// Get the currently used pass context.
TVM_DLL static PassContext Current();

const PassContextNode* operator->() const;

using ContainerType = PassContextNode;
class Internal;

private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();

// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
};

/*
* \brief The meta data of a pass.
Expand Down Expand Up @@ -150,20 +203,28 @@ class PassNode : public RelayNode {
virtual PassInfo Info() const = 0;

/*!
* \brief Set the context information for a pass.
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
*
* \param mod The module that an optimization pass runs on.
*
* \param pass_ctx The context information for a certain pass.
* \return The updated module.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;
Module operator()(const Module& mod) const {
return this->operator()(mod, PassContext::Current());
}

/*!
* \brief Execute the optimization pass using a functor.
* \brief Execute the optimization pass using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;

void VisitAttrs(tvm::AttrVisitor* v) override {}

Expand All @@ -189,13 +250,22 @@ class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
PassInfo pass_info);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");

Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from . import adt
from . import ir_pass
from . import transform
from .build_module import build, build_config, create_executor
from .build_module import build, create_executor
from .transform import build_config
from . import prelude
from . import parser
from . import debug
Expand Down
98 changes: 13 additions & 85 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,81 +28,10 @@
from . import ir_pass
from . import ty as _ty
from . import expr as _expr
from . import transform as _transform
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor

class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
"add_pass": None,
"disable_pass": None,
"fallback_device": None,
}

def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError("invalid argument %s, candidates are %s" %
(k, BuildConfig.defaults.keys()))
self._attr = kwargs

def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]

def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope


BuildConfig.current = BuildConfig()


def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
disable_pass: set of str
Optimization pass to be disabled during optimization.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)


def _update_target(target):
target = target if target else _target.current_target()
if target is None:
Expand Down Expand Up @@ -189,7 +118,7 @@ def build(self, func, target=None, target_host=None, params=None):
return graph_json, mod, params

def _setup_build_config(self, params):
cfg = BuildConfig.current
cfg = _transform.PassContext.current()

# Set opt_level.
self.set_opt_level(cfg.opt_level)
Expand All @@ -199,24 +128,24 @@ def _setup_build_config(self, params):
self.set_fallback_device(cfg.fallback_device)

# Add required passes.
if cfg.add_pass:
if cfg.required_pass:
passes = set()
if isinstance(cfg.add_pass, (list, tuple, set)):
passes = set(cfg.add_pass)
if isinstance(cfg.required_pass, (list, tuple, set)):
passes = set(cfg.required_pass)
else:
raise TypeError("add_pass must be list, tuple, or set, but " +
"got {}".format(type(cfg.add_pass)))
"got {}".format(type(cfg.required_pass)))
for pass_name in passes:
self.add_pass(pass_name)

# Add disabled passes.
if cfg.disable_pass:
if cfg.disabled_pass:
passes = set()
if isinstance(cfg.disable_pass, (list, tuple, set)):
passes = set(cfg.disable_pass)
if isinstance(cfg.disabled_pass, (list, tuple, set)):
passes = set(cfg.disabled_pass)
else:
raise TypeError("disable_pass must be list, tuple, or set, " +
"but got {}".format(type(cfg.disable_pass)))
"but got {}".format(type(cfg.disabled_pass)))
for pass_name in passes:
self.disable_pass(pass_name)

Expand Down Expand Up @@ -287,12 +216,11 @@ def set_fallback_device(self, fallback_device):
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
if isinstance(fallback_device, str):
if isinstance(fallback_device, (int, str)):
fallback_device = _nd.context(fallback_device)
if not isinstance(fallback_device, TVMContext):
raise TypeError("fallback_device is expected to be str " +
"TVMContext, or dict of device name to target, " +
"but received: {}".format(type(fallback_device)))
raise TypeError("fallback_device is expected to be str, int, or " +
"TVMContext but received: {}".format(type(fallback_device)))

self._set_fallback_device(fallback_device.device_type)

Expand Down
Loading

0 comments on commit 5e25c76

Please sign in to comment.