diff --git a/nnvm/include/nnvm/pass.h b/nnvm/include/nnvm/pass.h index c1936941b331..438226f5c93f 100644 --- a/nnvm/include/nnvm/pass.h +++ b/nnvm/include/nnvm/pass.h @@ -23,7 +23,7 @@ namespace nnvm { * \param src The graph to be transformed. * \return The generated graph. */ -typedef std::function PassFunction; +typedef std::function PassFunction; /*! * \brief Apply a series of pass transformations on g. @@ -31,7 +31,7 @@ typedef std::function PassFunction; * \param pass The name of pass to be applied. * \return The transformed graph */ -Graph ApplyPass(const Graph& src, +Graph ApplyPass(Graph src, const std::vector& pass); /*! diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h new file mode 100644 index 000000000000..e87a74fff27e --- /dev/null +++ b/nnvm/include/nnvm/pass_functions.h @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file pass_functions.h + * \brief Pass functions that simply redirect the calls to ApplyPass + * + * This file serves as documentation on how to use functions implemented in "src/pass". + * It is totally optional to add these functions when you add a new pass, since + * ApplyPass can be directly called. + */ +#ifndef NNVM_PASS_FUNCTIONS_H_ +#define NNVM_PASS_FUNCTIONS_H_ + +#include +#include +#include "./base.h" +#include "./pass.h" +#include "./graph_attr_types.h" + +namespace nnvm { +namespace pass { + +/*! + * \brief Load a graph from JSON string, redirects to "LoadJSON" pass. + * \param json_str The json string. + * \return Loaded graph. + */ +inline Graph LoadJSON(const std::string& json_str) { + Graph ret; + ret.attrs["json"] = std::make_shared(json_str); + return ApplyPass(ret, {"LoadJSON"}); +} + +/*! + * \brief Save a graph to json, redirects to "SaveJSON" pass. + * \param graph The to be saved. + * \return The json string. + */ +inline std::string SaveJSON(Graph graph) { + Graph ret = ApplyPass(std::move(graph), {"SaveJSON"}); + return ret.GetAttr("json"); +} + +/*! + * \brief Add control flow dependencies between nodes + * To correctly order mutation and read to resolve + * write after read problem and read after write problems. + * \param src source graph + * \return A graph that added control flow dependencies. + */ +inline Graph OrderMutation(Graph src) { + return ApplyPass(std::move(src), {"OrderMutation"}); +} + +/*! + * \brief Infer shapes in the graph given the information. + * \param graph source graph + * \param shape_args The shapes of aruguments to the graph. + * \param shape_attr_key The key to the node attribute that can indicate shape. + * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id + */ +inline Graph InferShape(Graph graph, + ShapeVector shape_args = {}, + std::string shape_attr_key = "") { + if (shape_args.size() != 0) { + graph.attrs["shape_args"] = std::make_shared(std::move(shape_args)); + } + if (shape_attr_key.length() != 0) { + graph.attrs["shape_attr_key"] = std::make_shared(std::move(shape_attr_key)); + } + return ApplyPass(std::move(graph), {"InferShape"}); +} + +} // namespace pass +} // namespace nnvm +#endif // NNVM_PASS_FUNCTIONS_H_ diff --git a/nnvm/src/core/pass.cc b/nnvm/src/core/pass.cc index f94acf50c8d3..f58c8039989b 100644 --- a/nnvm/src/core/pass.cc +++ b/nnvm/src/core/pass.cc @@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) { return nullptr; } -Graph ApplyPass(const Graph& src, +Graph ApplyPass(Graph g, const std::vector& pass) { std::vector fpass; for (auto& name : pass) { @@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src, fpass.push_back(reg); } - Graph g; - const Graph* s = &src; for (auto r : fpass) { for (auto& dep : r->graph_attr_dependency) { - if (s->attrs.count(dep) == 0) { + if (g.attrs.count(dep) == 0) { auto* pass_dep = FindPassDep(dep); std::string msg; if (pass_dep != nullptr) { @@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src, << msg; } } - g = r->body(*s); - s = &g; + g = r->body(std::move(g)); } return g; } diff --git a/nnvm/src/example/operator.cc b/nnvm/src/example/operator.cc index 1b2fb1e10510..f286716f51c4 100644 --- a/nnvm/src/example/operator.cc +++ b/nnvm/src/example/operator.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs, return true; } +// simple demonstration of reshape. +NNVM_REGISTER_OP(reshape) +.describe("reshape source to target shape") +.set_num_inputs(1) +.set_attr_parser( + [](NodeAttrs* attrs) { + // parse attr parser to get target attribute + TShape target; + std::istringstream is(attrs->dict.at("target")); + CHECK(is >> target); + attrs->parsed = std::move(target); + }) +.attr( + "FInferShape", [] (const NodeAttrs& attrs, + array_view ishape, + array_view oshape) { + // get parsed attribute + const TShape& target = nnvm::get(attrs.parsed); + *oshape[0] = target; + if (ishape[0]->ndim() == 0) return false; + CHECK_EQ(ishape[0]->Size(), target.Size()) + << "Reshape op: source target shape mismatch"; + return true; + }); + NNVM_REGISTER_OP(add) .describe("add two data together") .set_num_inputs(2) diff --git a/nnvm/src/pass/infer_shape.cc b/nnvm/src/pass/infer_shape.cc index 5788cebfa89d..a2d9abf38291 100644 --- a/nnvm/src/pass/infer_shape.cc +++ b/nnvm/src/pass/infer_shape.cc @@ -10,19 +10,42 @@ namespace nnvm { namespace pass { -Graph InferShape(const Graph& src) { - Graph ret = src; +Graph InferShape(Graph ret) { const IndexedGraph& idx = ret.indexed_graph(); static auto& finfer_shape = Op::GetAttr("FInferShape"); // reshape shape vector ShapeVector rshape(idx.num_node_entries()); + + if (ret.attrs.count("shape_args") != 0) { + const ShapeVector& shape_args = ret.GetAttr("shape_args"); + CHECK_LE(shape_args.size(), idx.arg_nodes().size()) + << "shape args is more than number of arguments"; + for (size_t i = 0; i < shape_args.size(); ++i) { + rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i]; + } + } + std::string shape_attr_key; + if (ret.attrs.count("shape_attr_key") != 0) { + shape_attr_key = ret.GetAttr("shape_attr_key"); + } + // temp space for shape inference. std::vector ishape, oshape; // number of completed nodes size_t num_known = 0; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; - if (inode.source->is_variable()) continue; + if (inode.source->is_variable()) { + if (shape_attr_key.length() != 0) { + auto it = inode.source->attrs.dict.find(shape_attr_key); + if (it != inode.source->attrs.dict.end()) { + CHECK_EQ(inode.source->num_outputs(), 1); + std::istringstream is(it->second); + CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid shape attribute"; + } + } + continue; + } ishape.resize(inode.inputs.size()); for (uint32_t i = 0; i < ishape.size(); ++i) { ishape[i] = &rshape[idx.entry_id(inode.inputs[i])]; @@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) { return ret; } +NNVM_REGISTER_PASS(InferShape) +.describe("Infer the shape of each node entries.") +.set_body(InferShape) +.set_change_graph(false) +.provide_graph_attr("shape"); + +DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape); + } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index a498660cfd1f..ce615fccaad4 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file saveload_json.cc + * \file order_mutation.cc * \brief Add control flow dependencies between nodes * To correctly order mutation and read to resolve * write after read problem and read after write problems. diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 0fe1d1896db6..c2a22a46b611 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -149,7 +149,7 @@ struct JSONGraph { }; // Load a graph from JSON file. -Graph LoadJSON(const Graph& src) { +Graph LoadJSON(Graph src) { CHECK_NE(src.attrs.count("json"), 0) << "Load JSON require json to be presented."; const std::string &json_str = @@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) { } // save a graph to json -Graph SaveJSON(const Graph& src) { +Graph SaveJSON(Graph src) { JSONGraph jgraph; std::unordered_map node2index; jgraph.node_row_ptr.push_back(0); diff --git a/nnvm/src/test_main.cc b/nnvm/src/test_main.cc index 4d14abced5ab..139fde8f4f0d 100644 --- a/nnvm/src/test_main.cc +++ b/nnvm/src/test_main.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include diff --git a/nnvm/tests/python/test_graph.py b/nnvm/tests/python/test_graph.py index 0af641932a5d..88859250956e 100644 --- a/nnvm/tests/python/test_graph.py +++ b/nnvm/tests/python/test_graph.py @@ -35,7 +35,23 @@ def test_order_mutation_pass(): assert nindex['add1'] in jnodes[nindex['assign']]['control_deps'] assert jnodes[nindex['assign']]['inputs'][0][2] == 1 +def test_infer_shape(): + x = sym.Variable('x', shape=(4, 2)) + y = sym.add(x, x, name='add1') + y = sym.reshape(y, target=(2, 4), name="reshape1") + g = graph.create(y) + g._set_json_attr("shape_attr_key", "shape") + g = g.apply('InferShape') + jgraph = json.loads(g.apply('SaveJSON').json_attr('json')) + jnodes = jgraph['nodes'] + jnode_row_ptr = jgraph['node_row_ptr'] + nindex = {n['name']: i for i, n in enumerate(jnodes)} + assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4] + assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2] + + if __name__ == "__main__": test_order_mutation_pass() test_graph_json_attr() test_json_pass() + test_infer_shape()