Skip to content

Commit

Permalink
[PASS] PrintGraphIR, SimplifyBatchNormInference (apache#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 26, 2018
1 parent 147a13f commit a185e85
Show file tree
Hide file tree
Showing 19 changed files with 650 additions and 24 deletions.
27 changes: 27 additions & 0 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,33 @@ class Node {
static NodePtr Create();
};

/*!
* \brief Quick utilities make node.
* \param op_name The name of operator
* \param node_name The name of the node
* \param inputs The input entries
* \param attrs The attributes
* \return The created node entry.
*/
inline NodeEntry MakeNode(
const char* op_name,
std::string node_name,
std::vector<NodeEntry> inputs,
std::unordered_map<std::string, std::string> attrs =
std::unordered_map<std::string, std::string>()) {
NodePtr p = Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
if (attrs.size() != 0) {
p->attrs.dict = attrs;
if (p->attrs.op->attr_parser) {
p->attrs.op->attr_parser(&(p->attrs));
}
}
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
}

// implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
Expand Down
5 changes: 2 additions & 3 deletions nnvm/python/nnvm/compiler/graph_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,5 @@ def set_layout_inputs(g, layout):
g._set_json_attr("layout_inputs", list_shape, 'list_str')
return g


_move_out_module = tvm.get_global_func("nnvm.graph_attr._move_module")
_move_out_graph = tvm.get_global_func("nnvm.graph_attr._move_graph")
_move_out_module = tvm.get_global_func("nnvm.graph._move_module")
_move_out_graph = tvm.get_global_func("nnvm.graph._move_graph")
24 changes: 24 additions & 0 deletions nnvm/python/nnvm/compiler/graph_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from __future__ import absolute_import as _abs

import tvm
from . import graph_attr


Expand Down Expand Up @@ -60,3 +61,26 @@ def infer_dtype(graph, **dtype):
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype


_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")

def check_graph_equal(grapha, graphb):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
if err:
raise ValueError("Graph compare error: " + err)
4 changes: 4 additions & 0 deletions nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def index(self):
self._index = GraphIndex(self)
return self._index

def graphir(self):
"""Get text form of graph ir."""
return self.apply("PrintGraphIR").json_attr("graphir")

def apply(self, passes):
"""Apply passes to the graph
Expand Down
5 changes: 2 additions & 3 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# relu
@reg.register_compute("relu")
def compute_relu(attrs, inputs):
def compute_relu(_, inputs):
"""Compute definition of relu"""
return topi.nn.relu(inputs[0])

Expand Down Expand Up @@ -72,8 +72,7 @@ def schedule_conv2d(attrs, outs, target):
if target == "cuda":
if groups == 1:
return topi.cuda.schedule_conv2d_nchw(outs)
else:
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

Expand Down
113 changes: 113 additions & 0 deletions nnvm/src/compiler/graph_deep_compare.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*!
* Copyright (c) 2017 by Contributors
* \file graph_deep_compare.cc
* \brief Deep compare two graph structure
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include "./node_attr.h"

namespace nnvm {
namespace compiler {

// deep compare the graph structure
// not considering the graph attributes
// return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node.
std::string DeepCompare(Graph a, Graph b) {
const IndexedGraph& idxa = a.indexed_graph();
const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err;
if (idxa.num_nodes() != idxb.num_nodes()) {
err << "Number of nodes mismatch";
return err.str();
}
if (idxa.num_node_entries() != idxb.num_node_entries()) {
err << "Number of node entry mismatch";
return err.str();
}
if (idxa.outputs().size() != idxb.outputs().size()) {
err << "Number of outputs mismatch";
return err.str();
}
for (size_t i = 0; i < idxa.outputs().size(); ++i) {
if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
idxa.outputs()[i].index != idxb.outputs()[i].index) {
err << "Output entry mismatch";
return err.str();
}
}
if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
err << "Number of inputs mismatch";
return err.str();
}

for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
const IndexedGraph::Node& anode = idxa[nid];
const IndexedGraph::Node& bnode = idxb[nid];
if (anode.source->op() != bnode.source->op()) {
err << "Node mismatch ";
return err.str();
}
AttrDict adict = GetAttrDict(anode.source->attrs);
AttrDict bdict = GetAttrDict(bnode.source->attrs);

auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
for (const auto& kv : adict) {
auto it = bdict.find(kv.first);
if (it != bdict.end()) {
if (it->second != kv.second) {
err << "Node attr mismatch, op=" << anode.source->attrs.name
<< " attr_key=" << kv.first << " " << it->second
<< " v.s. " << kv.second;
return false;
}
} else {
err << "One attr_key=" << kv.first << " is missing in another "
<< "op=" << anode.source->attrs.name;
return false;
}
}
return true;
};
if (!fmatch(adict, bdict)) return err.str();
if (adict.size() != bdict.size()) {
CHECK(!fmatch(bdict, adict));
return err.str();
}
if (anode.inputs.size() != bnode.inputs.size()) {
err << "Node input mismatch, op=" << anode.source->attrs.name;
return err.str();
}
if (anode.control_deps.size() != bnode.control_deps.size()) {
err << "Node control_deps mistach, op=" << anode.source->attrs.name;
return err.str();
}
for (size_t i = 0; i < anode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& ae = anode.inputs[i];
const IndexedGraph::NodeEntry& be = bnode.inputs[i];
if (ae.node_id != be.node_id ||
ae.index != be.index ||
ae.version != be.version) {
err << "Node input mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
for (size_t i = 0; i < anode.control_deps.size(); ++i) {
if (anode.control_deps[i] != bnode.control_deps[i]) {
err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
}
return "";
}

TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = DeepCompare(args[0], args[1]);
});
} // namespace compiler
} // namespace nnvm
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include "../../runtime/graph_executor.h"
#include "../runtime/graph_executor.h"

namespace nnvm {
namespace compiler {
Expand Down
125 changes: 125 additions & 0 deletions nnvm/src/compiler/graph_transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*!
* Copyright (c) 2017 by Contributors
* \file graph_transform.h
* \brief A mutator class that does local pattern matching and mutates a node.
*/
#ifndef NNVM_COMPILER_GRAPH_TRANSFORM_H_
#define NNVM_COMPILER_GRAPH_TRANSFORM_H_

#include <nnvm/graph.h>
#include <vector>

namespace nnvm {
namespace compiler {

/*!
* \brief Transform the graph to build a new Graph, in post DFS order.
*
* Automatically copies node when some of its children or control_deps changed.
* This function won't be called in Variable.
*
* \param graph The original graph
*
* \param ftransform Function of (int nid, const Node* node, std::vector<NodeEntry>* out) -> bool
*
* If empty vector is returned, it means original entries should be kept.
*
* \tparam FTransform The transformation function.
*/
template<typename FTransform>
Graph GraphTransform(Graph graph, FTransform ftransform) {
const IndexedGraph& idx = graph.indexed_graph();
// new nodes
std::vector<NodeEntry> new_entry_map(idx.num_node_entries());
std::vector<bool> updated(idx.num_node_entries(), false);

// setup inputs and placeholder.
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
bool need_copy = false;
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (updated[idx.entry_id(e)]) {
need_copy = true; break;
}
}
if (!need_copy) {
for (const uint32_t cid : inode.control_deps) {
const auto& cnode = idx[cid];
for (uint32_t i = 0 ; i < cnode.source->num_outputs(); ++i) {
if (updated[idx.entry_id(cid, i)]) {
need_copy = true;
}
}
if (need_copy) break;
}
}

if (!need_copy) {
std::vector<NodeEntry> ret;
if (ftransform(nid, inode.source, &ret)) {
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true;
new_entry_map[idx.entry_id(nid, i)] = ret[i];
}
}
} else {
NodePtr node = Node::Create();
node->attrs = inode.source->attrs;
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[i];
if (updated[idx.entry_id(e)]) {
node->inputs.push_back(new_entry_map[idx.entry_id(e)]);
} else {
node->inputs.push_back(inode.source->inputs[i]);
}
}
for (size_t i = 0; i < inode.control_deps.size(); ++i) {
const uint32_t cid = inode.control_deps[i];
const auto& cnode = idx[cid];
CHECK_NE(cnode.source->num_outputs(), 0U);
NodePtr selected_ptr;
for (uint32_t j = 0 ; j < cnode.source->num_outputs(); ++j) {
NodePtr cptr = updated[idx.entry_id(cid, j)] ?
new_entry_map[idx.entry_id(cid, j)].node : inode.source->control_deps[i];
if (selected_ptr == nullptr) {
selected_ptr = std::move(cptr);
} else {
CHECK(selected_ptr.get() == cptr.get())
<< "Control dependency node changed to more than one node";
}
}
node->control_deps.push_back(selected_ptr);
}
std::vector<NodeEntry> ret;
if (ftransform(nid, node.get(), &ret)) {
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true;
new_entry_map[idx.entry_id(nid, i)] = ret[i];
}
} else {
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
updated[idx.entry_id(nid, i)] = true;
new_entry_map[idx.entry_id(nid, i)] = NodeEntry{node, i, 0};
}
}
}
}
Graph ret;
for (size_t i = 0; i < idx.outputs().size(); ++i) {
const IndexedGraph::NodeEntry& e = idx.outputs()[i];
if (updated[idx.entry_id(e)]) {
ret.outputs.push_back(new_entry_map[idx.entry_id(e)]);
} else {
ret.outputs.push_back(graph.outputs[i]);
}
}
return ret;
}

} // namespace compiler
} // namespace nnvm

#endif // NNVM_COMPILER_GRAPH_TRANSFORM_H_
File renamed without changes.
34 changes: 34 additions & 0 deletions nnvm/src/compiler/node_attr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*!
* Copyright (c) 2017 by Contributors
* \file node_attr.h
* \brief utility to access node attributes
*/
#ifndef NNVM_COMPILER_NODE_ATTR_H_
#define NNVM_COMPILER_NODE_ATTR_H_

#include <nnvm/op.h>
#include <nnvm/compiler/op_attr_types.h>
#include <unordered_map>
#include <string>

namespace nnvm {
namespace compiler {

using AttrDict = std::unordered_map<std::string, std::string>;
/*!
* \brief Get canonicalized attr dict from node
* \param attrs The node attrs
* \return The attribute dict
*/
inline AttrDict GetAttrDict(const NodeAttrs& attrs) {
static auto& fgetdict = nnvm::Op::GetAttr<FGetAttrDict>("FGetAttrDict");
if (fgetdict.count(attrs.op)) {
return fgetdict[attrs.op](attrs);
} else {
return attrs.dict;
}
}

} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_NODE_ATTR_H_
Loading

0 comments on commit a185e85

Please sign in to comment.