diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 0e46e23f25e9..603ff3d05fb2 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -128,6 +128,33 @@ class Node { static NodePtr Create(); }; +/*! + * \brief Quick utilities make node. + * \param op_name The name of operator + * \param node_name The name of the node + * \param inputs The input entries + * \param attrs The attributes + * \return The created node entry. + */ +inline NodeEntry MakeNode( + const char* op_name, + std::string node_name, + std::vector inputs, + std::unordered_map attrs = + std::unordered_map()) { + NodePtr p = Node::Create(); + p->attrs.op = nnvm::Op::Get(op_name); + p->attrs.name = std::move(node_name); + if (attrs.size() != 0) { + p->attrs.dict = attrs; + if (p->attrs.op->attr_parser) { + p->attrs.op->attr_parser(&(p->attrs)); + } + } + p->inputs = std::move(inputs); + return NodeEntry{p, 0, 0}; +} + // implementation of functions. inline const Op* Node::op() const { return this->attrs.op; diff --git a/nnvm/python/nnvm/compiler/graph_attr.py b/nnvm/python/nnvm/compiler/graph_attr.py index 0b5d19898221..3787eca68707 100644 --- a/nnvm/python/nnvm/compiler/graph_attr.py +++ b/nnvm/python/nnvm/compiler/graph_attr.py @@ -83,6 +83,5 @@ def set_layout_inputs(g, layout): g._set_json_attr("layout_inputs", list_shape, 'list_str') return g - -_move_out_module = tvm.get_global_func("nnvm.graph_attr._move_module") -_move_out_graph = tvm.get_global_func("nnvm.graph_attr._move_graph") +_move_out_module = tvm.get_global_func("nnvm.graph._move_module") +_move_out_graph = tvm.get_global_func("nnvm.graph._move_graph") diff --git a/nnvm/python/nnvm/compiler/graph_pass.py b/nnvm/python/nnvm/compiler/graph_pass.py index 7d25a055777d..3e98615d8ff2 100644 --- a/nnvm/python/nnvm/compiler/graph_pass.py +++ b/nnvm/python/nnvm/compiler/graph_pass.py @@ -7,6 +7,7 @@ """ from __future__ import absolute_import as _abs +import tvm from . import graph_attr @@ -60,3 +61,26 @@ def infer_dtype(graph, **dtype): output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]] for x in index.output_entries] return input_dtype, output_dtype + + +_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare") + +def check_graph_equal(grapha, graphb): + """Check if two graphs have equal structure. + + Parameters + ---------- + grapha : Graph + The first graph + + graphb : Graph + The second graph + + Raises + ------ + ValueError + ValueError is raised with error message when graph not equal + """ + err = _deep_compare(grapha, graphb) + if err: + raise ValueError("Graph compare error: " + err) diff --git a/nnvm/python/nnvm/graph.py b/nnvm/python/nnvm/graph.py index 309603d06264..b8cb655c9b65 100644 --- a/nnvm/python/nnvm/graph.py +++ b/nnvm/python/nnvm/graph.py @@ -177,6 +177,10 @@ def index(self): self._index = GraphIndex(self) return self._index + def graphir(self): + """Get text form of graph ir.""" + return self.apply("PrintGraphIR").json_attr("graphir") + def apply(self, passes): """Apply passes to the graph diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 69ca49e6603e..b4c0c44f0edc 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -10,7 +10,7 @@ # relu @reg.register_compute("relu") -def compute_relu(attrs, inputs): +def compute_relu(_, inputs): """Compute definition of relu""" return topi.nn.relu(inputs[0]) @@ -72,8 +72,7 @@ def schedule_conv2d(attrs, outs, target): if target == "cuda": if groups == 1: return topi.cuda.schedule_conv2d_nchw(outs) - else: - return topi.cuda.schedule_depthwise_conv2d_nchw(outs) + return topi.cuda.schedule_depthwise_conv2d_nchw(outs) # naive schedule return tvm.create_schedule([x.op for x in outs]) diff --git a/nnvm/src/compiler/graph_deep_compare.cc b/nnvm/src/compiler/graph_deep_compare.cc new file mode 100644 index 000000000000..dd64f0e3b062 --- /dev/null +++ b/nnvm/src/compiler/graph_deep_compare.cc @@ -0,0 +1,113 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file graph_deep_compare.cc + * \brief Deep compare two graph structure + */ +#include +#include +#include +#include +#include "./node_attr.h" + +namespace nnvm { +namespace compiler { + +// deep compare the graph structure +// not considering the graph attributes +// return non-empty error message if the graph mismatch. +// the comparator won't match name of intermediate node. +std::string DeepCompare(Graph a, Graph b) { + const IndexedGraph& idxa = a.indexed_graph(); + const IndexedGraph& idxb = b.indexed_graph(); + std::ostringstream err; + if (idxa.num_nodes() != idxb.num_nodes()) { + err << "Number of nodes mismatch"; + return err.str(); + } + if (idxa.num_node_entries() != idxb.num_node_entries()) { + err << "Number of node entry mismatch"; + return err.str(); + } + if (idxa.outputs().size() != idxb.outputs().size()) { + err << "Number of outputs mismatch"; + return err.str(); + } + for (size_t i = 0; i < idxa.outputs().size(); ++i) { + if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id || + idxa.outputs()[i].index != idxb.outputs()[i].index) { + err << "Output entry mismatch"; + return err.str(); + } + } + if (idxa.input_nodes().size() != idxb.input_nodes().size()) { + err << "Number of inputs mismatch"; + return err.str(); + } + + for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) { + const IndexedGraph::Node& anode = idxa[nid]; + const IndexedGraph::Node& bnode = idxb[nid]; + if (anode.source->op() != bnode.source->op()) { + err << "Node mismatch "; + return err.str(); + } + AttrDict adict = GetAttrDict(anode.source->attrs); + AttrDict bdict = GetAttrDict(bnode.source->attrs); + + auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) { + for (const auto& kv : adict) { + auto it = bdict.find(kv.first); + if (it != bdict.end()) { + if (it->second != kv.second) { + err << "Node attr mismatch, op=" << anode.source->attrs.name + << " attr_key=" << kv.first << " " << it->second + << " v.s. " << kv.second; + return false; + } + } else { + err << "One attr_key=" << kv.first << " is missing in another " + << "op=" << anode.source->attrs.name; + return false; + } + } + return true; + }; + if (!fmatch(adict, bdict)) return err.str(); + if (adict.size() != bdict.size()) { + CHECK(!fmatch(bdict, adict)); + return err.str(); + } + if (anode.inputs.size() != bnode.inputs.size()) { + err << "Node input mismatch, op=" << anode.source->attrs.name; + return err.str(); + } + if (anode.control_deps.size() != bnode.control_deps.size()) { + err << "Node control_deps mistach, op=" << anode.source->attrs.name; + return err.str(); + } + for (size_t i = 0; i < anode.inputs.size(); ++i) { + const IndexedGraph::NodeEntry& ae = anode.inputs[i]; + const IndexedGraph::NodeEntry& be = bnode.inputs[i]; + if (ae.node_id != be.node_id || + ae.index != be.index || + ae.version != be.version) { + err << "Node input mismatch on, op=" << anode.source->attrs.name; + return err.str(); + } + } + for (size_t i = 0; i < anode.control_deps.size(); ++i) { + if (anode.control_deps[i] != bnode.control_deps[i]) { + err << "Node control_dep mismatch on, op=" << anode.source->attrs.name; + return err.str(); + } + } + } + return ""; +} + +TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare") +.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { + *rv = DeepCompare(args[0], args[1]); + }); +} // namespace compiler +} // namespace nnvm diff --git a/nnvm/src/compiler/pass/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc similarity index 99% rename from nnvm/src/compiler/pass/graph_fuse.cc rename to nnvm/src/compiler/graph_fuse.cc index 0880a90b9dad..9496e110ba64 100644 --- a/nnvm/src/compiler/pass/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -13,7 +13,7 @@ #include #include #include -#include "../../runtime/graph_executor.h" +#include "../runtime/graph_executor.h" namespace nnvm { namespace compiler { diff --git a/nnvm/src/compiler/graph_transform.h b/nnvm/src/compiler/graph_transform.h new file mode 100644 index 000000000000..2099809115aa --- /dev/null +++ b/nnvm/src/compiler/graph_transform.h @@ -0,0 +1,125 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file graph_transform.h + * \brief A mutator class that does local pattern matching and mutates a node. +*/ +#ifndef NNVM_COMPILER_GRAPH_TRANSFORM_H_ +#define NNVM_COMPILER_GRAPH_TRANSFORM_H_ + +#include +#include + +namespace nnvm { +namespace compiler { + +/*! + * \brief Transform the graph to build a new Graph, in post DFS order. + * + * Automatically copies node when some of its children or control_deps changed. + * This function won't be called in Variable. + * + * \param graph The original graph + * + * \param ftransform Function of (int nid, const Node* node, std::vector* out) -> bool + * + * If empty vector is returned, it means original entries should be kept. + * + * \tparam FTransform The transformation function. + */ +template +Graph GraphTransform(Graph graph, FTransform ftransform) { + const IndexedGraph& idx = graph.indexed_graph(); + // new nodes + std::vector new_entry_map(idx.num_node_entries()); + std::vector updated(idx.num_node_entries(), false); + + // setup inputs and placeholder. + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + bool need_copy = false; + for (const IndexedGraph::NodeEntry& e : inode.inputs) { + if (updated[idx.entry_id(e)]) { + need_copy = true; break; + } + } + if (!need_copy) { + for (const uint32_t cid : inode.control_deps) { + const auto& cnode = idx[cid]; + for (uint32_t i = 0 ; i < cnode.source->num_outputs(); ++i) { + if (updated[idx.entry_id(cid, i)]) { + need_copy = true; + } + } + if (need_copy) break; + } + } + + if (!need_copy) { + std::vector ret; + if (ftransform(nid, inode.source, &ret)) { + CHECK_EQ(ret.size(), static_cast(inode.source->num_outputs())); + for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { + updated[idx.entry_id(nid, i)] = true; + new_entry_map[idx.entry_id(nid, i)] = ret[i]; + } + } + } else { + NodePtr node = Node::Create(); + node->attrs = inode.source->attrs; + for (size_t i = 0; i < inode.inputs.size(); ++i) { + const IndexedGraph::NodeEntry& e = inode.inputs[i]; + if (updated[idx.entry_id(e)]) { + node->inputs.push_back(new_entry_map[idx.entry_id(e)]); + } else { + node->inputs.push_back(inode.source->inputs[i]); + } + } + for (size_t i = 0; i < inode.control_deps.size(); ++i) { + const uint32_t cid = inode.control_deps[i]; + const auto& cnode = idx[cid]; + CHECK_NE(cnode.source->num_outputs(), 0U); + NodePtr selected_ptr; + for (uint32_t j = 0 ; j < cnode.source->num_outputs(); ++j) { + NodePtr cptr = updated[idx.entry_id(cid, j)] ? + new_entry_map[idx.entry_id(cid, j)].node : inode.source->control_deps[i]; + if (selected_ptr == nullptr) { + selected_ptr = std::move(cptr); + } else { + CHECK(selected_ptr.get() == cptr.get()) + << "Control dependency node changed to more than one node"; + } + } + node->control_deps.push_back(selected_ptr); + } + std::vector ret; + if (ftransform(nid, node.get(), &ret)) { + CHECK_EQ(ret.size(), static_cast(inode.source->num_outputs())); + for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { + updated[idx.entry_id(nid, i)] = true; + new_entry_map[idx.entry_id(nid, i)] = ret[i]; + } + } else { + for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { + updated[idx.entry_id(nid, i)] = true; + new_entry_map[idx.entry_id(nid, i)] = NodeEntry{node, i, 0}; + } + } + } + } + Graph ret; + for (size_t i = 0; i < idx.outputs().size(); ++i) { + const IndexedGraph::NodeEntry& e = idx.outputs()[i]; + if (updated[idx.entry_id(e)]) { + ret.outputs.push_back(new_entry_map[idx.entry_id(e)]); + } else { + ret.outputs.push_back(graph.outputs[i]); + } + } + return ret; +} + +} // namespace compiler +} // namespace nnvm + +#endif // NNVM_COMPILER_GRAPH_TRANSFORM_H_ diff --git a/nnvm/src/compiler/pass/layout_transform.cc b/nnvm/src/compiler/layout_transform.cc similarity index 100% rename from nnvm/src/compiler/pass/layout_transform.cc rename to nnvm/src/compiler/layout_transform.cc diff --git a/nnvm/src/compiler/node_attr.h b/nnvm/src/compiler/node_attr.h new file mode 100644 index 000000000000..c4395ad98b69 --- /dev/null +++ b/nnvm/src/compiler/node_attr.h @@ -0,0 +1,34 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file node_attr.h + * \brief utility to access node attributes +*/ +#ifndef NNVM_COMPILER_NODE_ATTR_H_ +#define NNVM_COMPILER_NODE_ATTR_H_ + +#include +#include +#include +#include + +namespace nnvm { +namespace compiler { + +using AttrDict = std::unordered_map; +/*! + * \brief Get canonicalized attr dict from node + * \param attrs The node attrs + * \return The attribute dict + */ +inline AttrDict GetAttrDict(const NodeAttrs& attrs) { + static auto& fgetdict = nnvm::Op::GetAttr("FGetAttrDict"); + if (fgetdict.count(attrs.op)) { + return fgetdict[attrs.op](attrs); + } else { + return attrs.dict; + } +} + +} // namespace compiler +} // namespace nnvm +#endif // NNVM_COMPILER_NODE_ATTR_H_ diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index 54d0f627c7a5..5444ece9c110 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -8,6 +8,7 @@ #include #include #include +#include "./node_attr.h" namespace tvm { namespace runtime { @@ -19,7 +20,6 @@ TVM_REGISTER_EXT_TYPE(nnvm::compiler::AttrDict); } // namespace runtime } // namespace tvm - namespace nnvm { namespace compiler { @@ -58,17 +58,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys") }); // custom version of TVM compute -inline std::unordered_map -GetAttrDict(const NodeAttrs& attrs) { - static auto& fgetdict = nnvm::Op::GetAttr("FGetAttrDict"); - if (fgetdict.count(attrs.op)) { - return fgetdict[attrs.op](attrs); - } else { - return attrs.dict; - } -} - - TVM_REGISTER_GLOBAL("nnvm._register_compute") .set_body([](TVMArgs args, TVMRetValue *rv) { // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown @@ -105,14 +94,14 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern") op.set_attr("TOpPattern", args[1].operator int(), args[2]); }); -TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_module") +TVM_REGISTER_GLOBAL("nnvm.graph._move_module") .set_body([](TVMArgs args, TVMRetValue *rv) { const nnvm::Graph& g = args[0].AsExtension(); *rv = const_cast(&g)-> MoveCopyAttr(args[1]); }); -TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_graph") +TVM_REGISTER_GLOBAL("nnvm.graph._move_graph") .set_body([](TVMArgs args, TVMRetValue *rv) { const nnvm::Graph& g = args[0].AsExtension(); *rv = const_cast(&g)-> diff --git a/nnvm/src/compiler/pass/precompute_prune.cc b/nnvm/src/compiler/precompute_prune.cc similarity index 100% rename from nnvm/src/compiler/pass/precompute_prune.cc rename to nnvm/src/compiler/precompute_prune.cc diff --git a/nnvm/src/compiler/simplify_batch_norm.cc b/nnvm/src/compiler/simplify_batch_norm.cc new file mode 100644 index 000000000000..16d7557f29a5 --- /dev/null +++ b/nnvm/src/compiler/simplify_batch_norm.cc @@ -0,0 +1,114 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file simplify_batch_norm.cc + * \author Ziheng Jiang +*/ +#include +#include +#include +#include +#include +#include +#include "./graph_transform.h" + +namespace nnvm { +namespace compiler { + +std::vector +BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, + nnvm::NodeEntry data, + nnvm::NodeEntry gamma, + nnvm::NodeEntry beta, + nnvm::NodeEntry moving_mean, + nnvm::NodeEntry moving_var, + int data_dim) { + CHECK(attrs.op); + static const Op* bn_op = Op::Get("batch_norm"); + CHECK(attrs.op == bn_op); + const auto& param = nnvm::get(attrs.parsed); + std::string bn_name = attrs.name; + + // transform batch_norm(data) to scale * data + shift + NodeEntry var_add_eps = MakeNode( + "__add_scalar__", bn_name + "_add_eps", + {moving_var}, {{"scalar", std::to_string(param.epsilon)}}); + + NodeEntry sqrt = MakeNode( + "sqrt", bn_name + "_sqrt", {var_add_eps}); + + NodeEntry scale = MakeNode( + "__rdiv_scalar__", bn_name + "_div", + {sqrt}, {{"scalar", "1"}}); + + if (param.scale) { + scale = MakeNode( + "elemwise_mul", bn_name + "_gamma_mul_div", + {scale, gamma}); + } + + NodeEntry neg_mean = MakeNode( + "negative", bn_name + "_neg_mean", {moving_mean}); + + NodeEntry shift = MakeNode( + "elemwise_mul", bn_name + "_neg_mean_mul_a", + {neg_mean, scale}); + + if (param.center) { + shift = MakeNode( + "elemwise_add", bn_name + "_add_beta", {shift, beta}); + } + // reshape to nhwc + std::ostringstream oshape; + oshape << "("; + for (int i = 0; i < data_dim; ++i) { + if (i != 0) oshape << ", "; + if (i == param.axis) { + oshape << "-1"; + } else { + oshape << "1"; + } + } + oshape << ")"; + + scale = MakeNode("reshape", bn_name + "_sc_reshape", + {scale}, {{"shape", oshape.str()}}); + shift = MakeNode("reshape", bn_name + "_sh_reshape", + {shift}, {{"shape", oshape.str()}}); + NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", + {data, scale}); + out = MakeNode("broadcast_add", bn_name + "_out", + {out, shift}); + // It is invalid to ref the other values of BN after infernece transform. + NodeEntry undef = MakeNode("__undef__", "undef", {}); + return {out, undef, undef}; +} + +Graph SimplifyBatchNormInference(nnvm::Graph src) { + // Get attributes from the graph + const IndexedGraph& idx = src.indexed_graph(); + const ShapeVector& shape_vec = src.GetAttr("shape"); + auto transform = [&](uint32_t nid, const Node* n, std::vector* ret) { + if (n->is_variable()) return false; + static const Op* bn_op = Op::Get("batch_norm"); + if (n->op() == bn_op) { + *ret = BatchNormToInferUnpack( + n->attrs, + n->inputs[0], + n->inputs[1], + n->inputs[2], + n->inputs[3], + n->inputs[4], + shape_vec[idx.entry_id(nid, 0)].ndim()); + return true; + } else { + return false; + } + }; + return GraphTransform(src, transform); +} + +NNVM_REGISTER_PASS(SimplifyBatchNormInference) +.set_body(SimplifyBatchNormInference); + +} // namespace compiler +} // namespace nnvm diff --git a/nnvm/src/pass/print_graph_ir.cc b/nnvm/src/pass/print_graph_ir.cc new file mode 100644 index 000000000000..a29ee922b644 --- /dev/null +++ b/nnvm/src/pass/print_graph_ir.cc @@ -0,0 +1,127 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file print_graph_ir.cc + * \brief Print the graph IR in LLVM style human readable format. + */ +#include +#include +#include + +namespace nnvm { +namespace pass { + +// print the graph ir in readable format +void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) + const IndexedGraph& idx = src.indexed_graph(); + os << "Graph("; + if (idx.input_nodes().size() < 4) { + for (size_t i = 0; i < idx.input_nodes().size(); ++i) { + uint32_t nid = idx.input_nodes()[i]; + if (i != 0) { + os << ", "; + } + os << '%' << idx[nid].source->attrs.name; + } + } else { + for (size_t i = 0; i < idx.input_nodes().size(); ++i) { + uint32_t nid = idx.input_nodes()[i]; + if (i != 0) { + os << ",\n "; + } + os << '%' << idx[nid].source->attrs.name; + } + } + os << ") {\n"; + + auto print_entry = [&](const IndexedGraph::NodeEntry& e) { + if (idx[e.node_id].source->is_variable()) { + os << '%' << idx[e.node_id].source->attrs.name; + } else if (idx[e.node_id].source->num_outputs() == 1) { + os << '%' << e.node_id; + } else { + os << '%' << e.node_id << "." << e.index; + } + }; + + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + os << " " << "%" << nid << " = " + << inode.source->op()->name << "("; + bool first = true; + for (const IndexedGraph::NodeEntry& e : inode.inputs) { + if (first) { + first = false; + } else { + os << ", "; + } + print_entry(e); + } + for (const auto& kv : inode.source->attrs.dict) { + if (first) { + first = false; + } else { + os << ", "; + } + os << kv.first << "=\'" << kv.second << "\'"; + } + os << ")"; + if (inode.control_deps.size() != 0) { + os << ", control_deps=["; + for (size_t i = 0; i < inode.control_deps.size(); ++i) { + if (i != 0) os << ", "; + uint32_t cid = inode.control_deps[i]; + if (idx[cid].source->is_variable()) { + os << '%' << idx[cid].source->attrs.name; + } else { + os << '%' << cid; + } + } + os << "]"; + } + os << "\n"; + } + os << " ret "; + { + bool first = true; + for (const IndexedGraph::NodeEntry& e : idx.outputs()) { + if (first) { + first = false; + } else { + os << ", "; + } + print_entry(e); + } + } + os << "\n}"; + if (src.attrs.size() != 0) { + os << "\ngraph_attr_keys = ["; + bool first = true; + for (const auto& kv : src.attrs) { + if (first) { + first = false; + } else { + os << ", "; + } + os << kv.first; + } + os << "]\n"; + } +} + +// save a graph to json +Graph PrintGraphIR(Graph src) { + std::ostringstream os; + PrintGraphIR_(src, os); + Graph ret; + ret.attrs["graphir"] = std::make_shared(os.str()); + return ret; +} + +// register pass +NNVM_REGISTER_PASS(PrintGraphIR) +.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") +.set_body(PrintGraphIR); + +} // namespace pass +} // namespace nnvm diff --git a/nnvm/src/top/nn/nn_common.h b/nnvm/src/top/nn/nn_common.h index a4ce6b1b78d0..a75077b87ca4 100644 --- a/nnvm/src/top/nn/nn_common.h +++ b/nnvm/src/top/nn/nn_common.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace nnvm { diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index 825b01b018e2..17e457e56b3c 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -12,6 +12,16 @@ namespace nnvm { namespace top { +// undefined op +NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__) +.describe(R"code(undefined op. + +Used to produce invalide node during optimization. + +)code" NNVM_ADD_FILELINE) +.set_num_outputs(1) +.set_num_inputs(0); + // sigmoid NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid) .describe(R"code(Computes sigmoid. @@ -52,6 +62,16 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(log) )code" NNVM_ADD_FILELINE) .set_support_level(1); +// sqrt +NNVM_REGISTER_ELEMWISE_UNARY_OP(sqrt) +.describe(R"code(Returns the sqrt input array, computed element-wise. + +.. math:: + \sqrt(x) + +)code" NNVM_ADD_FILELINE) +.set_support_level(1); + // binary ops NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add) diff --git a/nnvm/tests/python/compiler/test_graph_pass.py b/nnvm/tests/python/compiler/test_graph_pass.py index 110184c8f07c..9f2c36a1d3f9 100644 --- a/nnvm/tests/python/compiler/test_graph_pass.py +++ b/nnvm/tests/python/compiler/test_graph_pass.py @@ -1,10 +1,11 @@ """Unittest cases for graph pass""" import nnvm import nnvm.compiler -from nnvm.compiler import graph_pass +from nnvm import symbol as sym +from nnvm.compiler import graph_pass, graph_attr def test_infer_attr(): - x = nnvm.symbol.Variable("x") + x = sym.Variable("x") y = x * 2 g = nnvm.graph.create(y) ishape, oshape = graph_pass.infer_shape(g, x=(10,20)) @@ -13,6 +14,5 @@ def test_infer_attr(): itype, otype = graph_pass.infer_dtype(g, x="float32") assert otype[0] == "float32" - if __name__ == "__main__": test_infer_attr() diff --git a/nnvm/tests/python/compiler/test_simplify_batchnorm.py b/nnvm/tests/python/compiler/test_simplify_batchnorm.py new file mode 100644 index 000000000000..54cb8bc0bc57 --- /dev/null +++ b/nnvm/tests/python/compiler/test_simplify_batchnorm.py @@ -0,0 +1,49 @@ +"""Unittest cases for simplify batch_norm""" +import nnvm +from nnvm import symbol as sym +from nnvm.compiler import graph_pass, graph_attr + +def test_simplify_batchnorm(): + def simple_bn(x, gamma, beta, moving_mean, moving_var, + axis=1, epsilon=1e-5, dim=2): + # expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta + scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma) + shift = sym.elemwise_add( + sym.elemwise_mul(sym.negative(moving_mean), scale), beta) + # for 2D + shape = tuple(1 if i != axis else -1 for i in range(dim)) + scale = sym.reshape(scale, shape=shape) + shift = sym.reshape(shift, shape=shape) + return x * scale + shift + + + # Before simplify + def check(dim, axis, nstep): + eps = 0.01 + x = sym.Variable("x") + 1 + beta = sym.Variable("beta") + gamma = sym.Variable("gamma") + moving_var = sym.Variable("moving_var") + moving_mean = sym.Variable("moving_mean") + y1, y2 = x, x + + for i in range(nstep): + y1 = sym.batch_norm( + y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) + y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var, + epsilon=eps, axis=axis, dim=dim) + g = nnvm.graph.create(y1) + g2 = nnvm.graph.create(y2) + ishape = {"x": tuple(10 for i in range(dim))} + graph_attr.set_shape_inputs(g, ishape) + g1 = g.apply("InferShape").apply("SimplifyBatchNormInference") + # Some prints for debug + # print(g1.graphir()) + # assert graph equals as expected + graph_pass.check_graph_equal(g1, g2) + + check(2, 1, 1) + check(4, 0, 3) + +if __name__ == "__main__": + test_simplify_batchnorm() diff --git a/nnvm/tests/python/unittest/test_top_level1.py b/nnvm/tests/python/unittest/test_top_level1.py index 4c5c28005f88..6b75cf1d83f1 100644 --- a/nnvm/tests/python/unittest/test_top_level1.py +++ b/nnvm/tests/python/unittest/test_top_level1.py @@ -1,4 +1,5 @@ import nnvm.symbol as sym +import nnvm.graph as graph def test_dense(): x = sym.Variable('x')