Skip to content

Commit

Permalink
more fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 24, 2019
1 parent 1e24c9c commit 02e8dfb
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 130 deletions.
8 changes: 0 additions & 8 deletions docs/api/python/relay/build_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,9 @@ tvm.relay.build_module

.. autofunction:: tvm.relay.build_module.build

.. autofunction:: tvm.relay.build_module.build_config

.. autofunction:: tvm.relay.build_module.optimize

.. autofunction:: tvm.relay.build_module.create_executor

.. autoclass:: tvm.relay.build_module.BuildConfig
:members:

.. autofunction:: tvm.relay.build_module.build_config
:members:

.. autoclass:: tvm.relay.build_module.GraphExecutor
:members:
47 changes: 47 additions & 0 deletions docs/api/python/relay/transform.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.transform
----------------------

.. automodule:: tvm.relay.transform

.. autofunction:: tvm.relay.transform.build_config

.. autofunction:: tvm.relay.transform.module_pass

.. autofunction:: tvm.relay.transform.function_pass

.. autofunction:: tvm.relay.transform.current_pass_context

.. autoclass:: tvm.relay.transform.Pass
:members:

.. autoclass:: tvm.relay.transform.PassInfo
:members:

.. autoclass:: tvm.relay.transform.PassContext
:members:

.. autoclass:: tvm.relay.transform.ModulePass
:members:

.. autoclass:: tvm.relay.transform.FunctionPass
:members:

.. autoclass:: tvm.relay.transform.Sequential
:members:
41 changes: 3 additions & 38 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,43 +68,6 @@ namespace tvm {
namespace relay {
namespace transform {

/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*/
class OptPassLevel {
public:
/*!
* \brief Get level for an optimization pass
*
* \param key pass name
* \return int level
*/
int operator[](const std::string& key) const {
const auto data = CreateMap();
auto it = data.find(key);
if (it == data.end()) {
return -1;
}
return it->second;
}

private:
static const std::unordered_map<std::string, int> CreateMap() {
const std::unordered_map<std::string, int> m = {
{"SimplifyInference", 0},
{"OpFusion", 1},
{"FoldConstant", 2},
{"CombineParallelConv2D", 3},
{"FoldScaleAxis", 3},
{"AlterOpLayout", 3},
{"CanonicalizeOps", 3},
{"EliminateCommonSubexpr", 3}
};
return m;
}
};

/*
* \brief The context of pass.
*/
Expand Down Expand Up @@ -233,7 +196,9 @@ class PassNode : public RelayNode {
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
Module operator()(const Module& mod) const {
return this->operator()(mod, PassContext::Current());
}

virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;
Expand Down
30 changes: 22 additions & 8 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,39 @@ def build_config(opt_level=2,
required_pass=None,
disabled_pass=None):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See include/tvm/relay/transform.h for level of each
pass.
fallback_device : int or tvm.TVMContext
opt_level: int, optional
Optimization level. The optimization pass name and level are as the
following:
.. code-block:: python
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
fallback_device : int, str, or tvm.TVMContext, optional
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
required_pass: set of str
required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.
disabled_pass: set of str
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.
Returns
-------
config: PassContext
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
Expand Down
118 changes: 45 additions & 73 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,45 @@ namespace transform {

using tvm::IRPrinter;

/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*/
class OptPassLevel {
public:
/*!
* \brief Get level for an optimization pass
*
* \param key pass name
* \return int level
*/
int operator[](const std::string& key) const {
const auto data = CreateMap();
auto it = data.find(key);
if (it == data.end()) {
return -1;
}
return it->second;
}

private:
static const std::unordered_map<std::string, int> CreateMap() {
const std::unordered_map<std::string, int> m = {
{"SimplifyInference", 0},
{"OpFusion", 1},
{"FoldConstant", 2},
{"CombineParallelConv2D", 3},
{"FoldScaleAxis", 3},
{"AlterOpLayout", 3},
{"CanonicalizeOps", 3},
{"EliminateCommonSubexpr", 3}
};
return m;
}
};



PassContext::PassContext(int opt_level, int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass) {
Expand Down Expand Up @@ -118,15 +157,6 @@ class ModulePassNode : public PassNode {
v->Visit("pass_info", &pass_info);
}

/*!
* \brief Run a module pass on a certain module.
*
* \param mod The module that an optimization pass runs on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;

/*!
* \brief Run a module pass on given pass context.
*
Expand Down Expand Up @@ -181,15 +211,6 @@ class FunctionPassNode : public PassNode {
v->Visit("pass_info", &pass_info);
}

/*!
* \brief Run a function pass on a certain module.
*
* \param mod The module that an optimization pass runs on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;

/*!
* \brief Run a function pass on given pass context.
*
Expand Down Expand Up @@ -293,23 +314,15 @@ class SequentialNode : public PassNode {

std::unordered_set<std::string> RequiredPasses(
const Array<tvm::Expr>& disabled) const;

/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
* be overloaded to focus on different metrics, i.e. performance,
* memory footprint, etc.
*
* \param mod The module that an optimization pass runs on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;

/*!
* \brief Run a series of passes on given pass context.
*
* \param mod The module that these passes are applied on.
* \param mod The context that these passes execute on.
* \param pass_ctx The context that these passes execute on.
*
* \return Return the updated module.
*/
Expand Down Expand Up @@ -338,20 +351,7 @@ ModulePass ModulePassNode::make(
}

// Module -> Module optimizations.
// TODO(zhiics) 1. Check and handle the required passes.
// 2. Probably use CoW for all places that use module instead of
// returning the updated one.
Module ModulePassNode::operator()(const Module& mod) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
PassContext ctx = PassContext::Current();
auto updated_mod = pass_func(mod, ctx);
CHECK(updated_mod.defined());
return updated_mod;
}

// TODO(zhiics) Check and handle the required passes.
Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
Expand All @@ -375,24 +375,6 @@ FunctionPass FunctionPassNode::make(

// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module FunctionPassNode::operator()(const Module& mod) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
Module new_mod = ModuleNode::make({}, mod->type_definitions);
PassContext ctx = PassContext::Current();

// Execute the pass function and return a new module.
for (const auto& it : mod->functions) {
auto updated_func =
SkipFunction(it.second) ? it.second : pass_func(it.second, ctx);
new_mod->Add(it.first, updated_func);
}

return new_mod;
}

Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
Expand Down Expand Up @@ -430,19 +412,6 @@ const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(this->node_.get());
}

// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
Module SequentialNode::operator()(const Module& module) const {
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const auto* pn = pass.operator->();
mod = (*pn)(mod);
}
return mod;
}

void SequentialNode::ResolveDependency(const Module& mod) {
// TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass.
Expand Down Expand Up @@ -491,6 +460,9 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
return ctx_node->opt_level >= opt_pass_level[pass_name];
}

// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
const auto* ctx_node = pass_ctx.operator->();
Expand Down
4 changes: 2 additions & 2 deletions tests/python/frontend/coreml/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

def get_tvm_output(func, x, params, target, ctx,
out_shape=(1, 1000), input_name='image', dtype='float32'):
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
Expand Down Expand Up @@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
dtype_dict = {input_name: input_data.dtype}

func, params = relay.frontend.from_coreml(coreml_model, shape_dict)
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)

from tvm.contrib import graph_runtime
Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_keras_output(xs, dtype='float32'):
def get_tvm_output(xs, target, ctx, dtype='float32'):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
func, params = relay.frontend.from_keras(keras_model, shape_dict)
with relay.build_module.build_config(opt_level=2):
with relay.transform.build_config(opt_level=2):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
for name, x in zip(keras_model.input_names, xs):
Expand Down

0 comments on commit 02e8dfb

Please sign in to comment.