Skip to content

Commit

Permalink
[Pass] Finish infershape testcase (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 9135bc0 commit 83c0745
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 14 deletions.
4 changes: 2 additions & 2 deletions nnvm/include/nnvm/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ namespace nnvm {
* \param src The graph to be transformed.
* \return The generated graph.
*/
typedef std::function<Graph (const Graph& src)> PassFunction;
typedef std::function<Graph (Graph src)> PassFunction;

/*!
* \brief Apply a series of pass transformations on g.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied.
* \return The transformed graph
*/
Graph ApplyPass(const Graph& src,
Graph ApplyPass(Graph src,
const std::vector<std::string>& pass);

/*!
Expand Down
76 changes: 76 additions & 0 deletions nnvm/include/nnvm/pass_functions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*!
* Copyright (c) 2016 by Contributors
* \file pass_functions.h
* \brief Pass functions that simply redirect the calls to ApplyPass
*
* This file serves as documentation on how to use functions implemented in "src/pass".
* It is totally optional to add these functions when you add a new pass, since
* ApplyPass can be directly called.
*/
#ifndef NNVM_PASS_FUNCTIONS_H_
#define NNVM_PASS_FUNCTIONS_H_

#include <string>
#include <memory>
#include "./base.h"
#include "./pass.h"
#include "./graph_attr_types.h"

namespace nnvm {
namespace pass {

/*!
* \brief Load a graph from JSON string, redirects to "LoadJSON" pass.
* \param json_str The json string.
* \return Loaded graph.
*/
inline Graph LoadJSON(const std::string& json_str) {
Graph ret;
ret.attrs["json"] = std::make_shared<any>(json_str);
return ApplyPass(ret, {"LoadJSON"});
}

/*!
* \brief Save a graph to json, redirects to "SaveJSON" pass.
* \param graph The to be saved.
* \return The json string.
*/
inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"});
return ret.GetAttr<std::string>("json");
}

/*!
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
* \param src source graph
* \return A graph that added control flow dependencies.
*/
inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), {"OrderMutation"});
}

/*!
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline Graph InferShape(Graph graph,
ShapeVector shape_args = {},
std::string shape_attr_key = "") {
if (shape_args.size() != 0) {
graph.attrs["shape_args"] = std::make_shared<any>(std::move(shape_args));
}
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
}
return ApplyPass(std::move(graph), {"InferShape"});
}

} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
9 changes: 3 additions & 6 deletions nnvm/src/core/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return nullptr;
}

Graph ApplyPass(const Graph& src,
Graph ApplyPass(Graph g,
const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) {
Expand All @@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src,
fpass.push_back(reg);
}

Graph g;
const Graph* s = &src;
for (auto r : fpass) {
for (auto& dep : r->graph_attr_dependency) {
if (s->attrs.count(dep) == 0) {
if (g.attrs.count(dep) == 0) {
auto* pass_dep = FindPassDep(dep);
std::string msg;
if (pass_dep != nullptr) {
Expand All @@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src,
<< msg;
}
}
g = r->body(*s);
s = &g;
g = r->body(std::move(g));
}
return g;
}
Expand Down
26 changes: 26 additions & 0 deletions nnvm/src/example/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <nnvm/base.h>
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/node.h>
#include <nnvm/graph_attr_types.h>
#include <utility>

Expand All @@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs,
return true;
}

// simple demonstration of reshape.
NNVM_REGISTER_OP(reshape)
.describe("reshape source to target shape")
.set_num_inputs(1)
.set_attr_parser(
[](NodeAttrs* attrs) {
// parse attr parser to get target attribute
TShape target;
std::istringstream is(attrs->dict.at("target"));
CHECK(is >> target);
attrs->parsed = std::move(target);
})
.attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs,
array_view<TShape*> ishape,
array_view<TShape*> oshape) {
// get parsed attribute
const TShape& target = nnvm::get<TShape>(attrs.parsed);
*oshape[0] = target;
if (ishape[0]->ndim() == 0) return false;
CHECK_EQ(ishape[0]->Size(), target.Size())
<< "Reshape op: source target shape mismatch";
return true;
});

NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
Expand Down
37 changes: 34 additions & 3 deletions nnvm/src/pass/infer_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,42 @@
namespace nnvm {
namespace pass {

Graph InferShape(const Graph& src) {
Graph ret = src;
Graph InferShape(Graph ret) {
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
// reshape shape vector
ShapeVector rshape(idx.num_node_entries());

if (ret.attrs.count("shape_args") != 0) {
const ShapeVector& shape_args = ret.GetAttr<ShapeVector>("shape_args");
CHECK_LE(shape_args.size(), idx.arg_nodes().size())
<< "shape args is more than number of arguments";
for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i];
}
}
std::string shape_attr_key;
if (ret.attrs.count("shape_attr_key") != 0) {
shape_attr_key = ret.GetAttr<std::string>("shape_attr_key");
}

// temp space for shape inference.
std::vector<TShape*> ishape, oshape;
// number of completed nodes
size_t num_known = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1);
std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid shape attribute";
}
}
continue;
}
ishape.resize(inode.inputs.size());
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
Expand All @@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) {
return ret;
}

NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body(InferShape)
.set_change_graph(false)
.provide_graph_attr("shape");

DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);

} // namespace pass
} // namespace nnvm
2 changes: 1 addition & 1 deletion nnvm/src/pass/order_mutation.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \file order_mutation.cc
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
Expand Down
4 changes: 2 additions & 2 deletions nnvm/src/pass/saveload_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ struct JSONGraph {
};

// Load a graph from JSON file.
Graph LoadJSON(const Graph& src) {
Graph LoadJSON(Graph src) {
CHECK_NE(src.attrs.count("json"), 0)
<< "Load JSON require json to be presented.";
const std::string &json_str =
Expand Down Expand Up @@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) {
}

// save a graph to json
Graph SaveJSON(const Graph& src) {
Graph SaveJSON(Graph src) {
JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index;
jgraph.node_row_ptr.push_back(0);
Expand Down
1 change: 1 addition & 0 deletions nnvm/src/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <nnvm/tuple.h>
#include <nnvm/c_api.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass_functions.h>
#include <dmlc/timer.h>
#include <string>

Expand Down
16 changes: 16 additions & 0 deletions nnvm/tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,23 @@ def test_order_mutation_pass():
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
assert jnodes[nindex['assign']]['inputs'][0][2] == 1

def test_infer_shape():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
g._set_json_attr("shape_attr_key", "shape")
g = g.apply('InferShape')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]


if __name__ == "__main__":
test_order_mutation_pass()
test_graph_json_attr()
test_json_pass()
test_infer_shape()

0 comments on commit 83c0745

Please sign in to comment.