diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index da1c952ac9b37..b09d9c61ab2e9 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -18,6 +18,8 @@ namespace nnvm { +class IndexedGraph; + /*! * \brief Symbolic computation graph. * This is the intermediate representation for optimization pass. @@ -32,6 +34,145 @@ class Graph { * and can be shared across multiple Instance of graph */ std::unordered_map > attrs; + /*! + * \brief Get the attribute from attrs. + * \param attr_name the name of the attribute + * \return the reference to corresponding attribute + * \tparam T the type of the attribute. + */ + template + inline const T& GetAttr(const std::string& attr_name); + /*! + * \brief get a indexed graph of current graph, if not exist, create it on demand + * \return The indexed graph. + * \sa IndexedGraph + */ + const IndexedGraph& indexed_graph(); + + private: + // internal structure of indexed graph + std::shared_ptr indexed_graph_; +}; + +/*! + * \brief Auxililary data structure to index a graph. + * It maps Nodes in the graph to consecutive integers node_id. + * It also maps IndexedGraph::NodeEntry to consecutive integer entry_id. + * This allows storing properties of Node and NodeEntry into + * compact vector and quickly access them without resorting to hashmap. + * + * The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass. + */ +class IndexedGraph { + public: + /*! \brief represents a data in the graph */ + struct NodeEntry { + /*! \brief the source node id in the computation graph */ + uint32_t node_id; + /*! \brief index of output from the source. */ + uint32_t index; + /*! + * \brief compare equality + * \param other the other entry to compare + * \return whether two entries equals to each other + */ + inline bool operator==(const NodeEntry& other) const { + return node_id == other.node_id && index == other.index; + } + }; + /*! \brief Node data structure in IndexedGraph */ + struct Node { + /*! \brief pointer to the source node */ + const nnvm::Node* source; + /*! \brief inputs to the node */ + array_view inputs; + /*! \brief control flow dependencies to the node */ + array_view control_deps; + }; + /*! \return number of nodes in the graph */ + inline size_t num_nodes() const { + return nodes_.size(); + } + /*! \return total number of NodeEntry in the graph */ + inline size_t num_node_entries() const { + return entry_rptr_.back(); + } + /*! + * \brief Get a unique entry id between 0 to num_node_entries() + * for a given IndexedGraph::NodeEntry + * \param node_id The node index + * \param index the output index + * \return the unique index. + */ + inline uint32_t entry_id(uint32_t node_id, uint32_t index) const { + return entry_rptr_[node_id] + index; + } + /*! + * \brief Get a unique entry id between 0 to num_node_entries() + * for a given IndexedGraph::NodeEntry + * \param e The entry to query for index. + * \return the unique index. + */ + inline uint32_t entry_id(const NodeEntry& e) const { + return entry_rptr_[e.node_id] + e.index; + } + /*! + * \brief Get a unique entry id between 0 to num_node_entries() + * for a given NodeEntry. + * \param e The entry to query for index. + * \return the unique index. + */ + inline uint32_t entry_id(const nnvm::NodeEntry& e) const { + return entry_rptr_[node_id(e.node.get())] + e.index; + } + /*! + * \brief Get the corresponding node id for a given Node in the IndexedGraph. + * \param node The Node to query for index. + * \return the node index. + */ + inline uint32_t node_id(const nnvm::Node* node) const { + return node2index_.at(node); + } + /*! + * \brief Get the corresponding Node structure for a given node_id. + * \param node_id The node id + * \return const reference to the corresponding IndexedGraph::Node + */ + inline const Node& operator[](uint32_t node_id) const { + return nodes_[node_id]; + } + /*! + * \brief Get the corresponding Node structure + * \param node The pointer to the Node structure + * \return const reference to the corresponding IndexedGraph::Node + */ + inline const Node& operator[](const nnvm::Node* node) const { + return nodes_[node_id(node)]; + } + /*! \return list of argument nodes */ + inline const std::vector& arg_nodes() const { + return arg_nodes_; + } + + private: + friend class Graph; + /*! + * \brief Constructor an IndexedGraph from normal Graph + * \param other The source graph. + */ + explicit IndexedGraph(const Graph& other); + // node pointers in CSR structure. + std::vector nodes_; + // index to argument nodes + std::vector arg_nodes_; + // mapping from node to index. + std::unordered_map node2index_; + // CSR pointer of node entries + std::vector entry_rptr_; + // space to store input entries of each + std::vector input_entries_; + // control flow dependencies + std::vector control_deps_; }; /*! @@ -45,6 +186,14 @@ template inline void DFSVisit(const std::vector& heads, FVisit fvisit); // inline function implementations +template +inline const T& Graph::GetAttr(const std::string& attr_name) { + auto it = attrs.find(attr_name); + CHECK(it != attrs.end()) + << "Cannot find attribute " << attr_name << " in the graph"; + return nnvm::get(*it->second); +} + template diff --git a/nnvm/include/nnvm/graph_attr_types.h b/nnvm/include/nnvm/graph_attr_types.h index ab14185724f1c..4af9fabaf62da 100644 --- a/nnvm/include/nnvm/graph_attr_types.h +++ b/nnvm/include/nnvm/graph_attr_types.h @@ -7,120 +7,37 @@ #define NNVM_GRAPH_ATTR_TYPES_H_ #include -#include -#include "./graph.h" +#include +#include "./tuple.h" namespace nnvm { /*! - * \brief Auxililary data structure to index a graph. - * It maps Nodes in the graph to consecutive integers node_id. - * It also maps IndexedGraph::NodeEntry to consecutive integer entry_id. - * This allows storing properties of Node and NodeEntry into - * compact vector and quickly access them without resorting to hashmap. + * \brief The result holder of JSON serializer + * + * \note Stored under ret.attrs["json"], provided by Pass "SaveJSON" + + * \code + * Graph ret = ApplyPass(src_graph, {"SaveJSON"}); + * const JSONString& json = ret.GetAttr("shape"); + * \endcode */ -struct IndexedGraph { - public: - /*! \brief represents a data in the graph */ - struct NodeEntry { - /*! \brief the source node id in the computation graph */ - uint32_t node_id; - /*! \brief index of output from the source. */ - uint32_t index; - /*! - * \brief compare equality - * \param other the other entry to compare - * \return whether two entries equals to each other - */ - inline bool operator==(const NodeEntry& other) const { - return node_id == other.node_id && index == other.index; - } - }; - /*! \brief Node data structure in IndexedGraph */ - struct Node { - /*! \brief pointer to the source node */ - const nnvm::Node* source; - /*! \brief inputs to the node */ - array_view inputs; - /*! \brief control flow dependencies to the node */ - array_view control_deps; - }; - /*! \return number of nodes in the graph */ - inline size_t num_nodes() const { - return nodes_.size(); - } - /*! \return total number of NodeEntry in the graph */ - inline size_t num_node_entries() const { - return entry_rptr_.back(); - } - /*! - * \brief Get a unique entry id between 0 to num_node_entries() - * for a given IndexedGraph::NodeEntry - * \param e The entry to query for index. - * \return the unique index. - */ - inline uint32_t entry_id(const NodeEntry& e) const { - return entry_rptr_[e.node_id] + e.index; - } - /*! - * \brief Get a unique entry id between 0 to num_node_entries() - * for a given NodeEntry. - * \param e The entry to query for index. - * \return the unique index. - */ - inline uint32_t entry_id(const nnvm::NodeEntry& e) const { - return entry_rptr_[node_id(e.node.get())] + e.index; - } - /*! - * \brief Get the corresponding node id for a given Node in the IndexedGraph. - * \param node The Node to query for index. - * \return the node index. - */ - inline uint32_t node_id(const nnvm::Node* node) const { - return node2index_.at(node); - } - /*! - * \brief Get the corresponding Node structure for a given node_id. - * \param node_id The node id - * \return const reference to the corresponding IndexedGraph::Node - */ - inline const Node& operator[](uint32_t node_id) const { - return nodes_[node_id]; - } - /*! - * \brief Get the corresponding Node structure - * \param node The pointer to the Node structure - * \return const reference to the corresponding IndexedGraph::Node - */ - inline const Node& operator[](const nnvm::Node* node) const { - return nodes_[node_id(node)]; - } - /*! \return list of argument nodes */ - inline const std::vector& arg_nodes() const { - return arg_nodes_; - } - /*! - * \brief Constructor an IndexedGraph from normal Graph - * \param other The source graph. - */ - explicit IndexedGraph(const Graph& other); - // disallow copy assign - IndexedGraph(const IndexedGraph& other) = delete; +using JSONString = std::string; - private: - // node pointers in CSR structure. - std::vector nodes_; - // index to argument nodes - std::vector arg_nodes_; - // mapping from node to index. - std::unordered_map node2index_; - // CSR pointer of node entries - std::vector entry_rptr_; - // space to store input entries of each - std::vector input_entries_; - // control flow dependencies - std::vector control_deps_; -}; +/*! + * \brief The result holder of shape of each NodeEntry in the graph. + * \note Stored under graph.attrs["shape"], provided by Pass "InferShape" + * + * \code + * Graph g = ApplyPass(src_graph, {"InferShape"}); + * const ShapeVector& shapes = g.GetAttr("shape"); + * // get shape by entry id + * TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)]; + * \endcode + * + * \sa FInferShape + */ +using ShapeVector = std::vector; } // namespace nnvm diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 615fca24e07cb..9049f39632e7a 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -9,6 +9,8 @@ #include #include #include +#include "./base.h" +#include "./tuple.h" namespace nnvm { @@ -39,6 +41,7 @@ using FListOutputNames = std::function (const NodeAttrs /*! * \brief Check whether operator will mutate k-th input. + * \param attrs The attributes of the node. * \param index The input index * \return Whether this operator will mutate index-th input. * @@ -47,6 +50,26 @@ using FListOutputNames = std::function (const NodeAttrs */ using FMutateInput = std::function; +/*! + * \brief Shape inference function. + * Update the shapes given the input shape information. + * TShape.ndim() == 0 means the shape is still unknown. + * + * \param attrs The attributes of the node. + * \param in_shapes Array of shapes from the inputs. + * \param out_shapes Array of shapes from the outputs. + * + * \return Whether all the shapes are known. + * + * \note Register under "FInferShape", + * by default do not update any shapes. + * + * FInferShape is needed by shape inference + */ +using FInferShape = std::function in_shapes, + array_view out_shapes)>; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/nnvm/include/nnvm/tuple.h b/nnvm/include/nnvm/tuple.h index 8898fc054ca74..755f2720fa780 100644 --- a/nnvm/include/nnvm/tuple.h +++ b/nnvm/include/nnvm/tuple.h @@ -10,6 +10,7 @@ #include #include #include +#include "./base.h" namespace nnvm { @@ -179,7 +180,23 @@ class Tuple { inline const ValueType& operator[](index_t i) const { return begin()[i]; } - + /*! + * \brief Save Tuple to JSON. + * \param writer JSONWriter + */ + inline void Save(dmlc::JSONWriter* writer) const { + std::vector tmp(begin(), end()); + writer->Write(tmp); + } + /*! + * \brief Load Tuple from JSON. + * \param reader JSONReader + */ + inline void Load(dmlc::JSONReader* reader) { + std::vector tmp; + reader->Read(&tmp); + this->assign(tmp.begin(), tmp.end()); + } /*! * \brief allow output string of tuple to ostream * \param os the output stream @@ -287,6 +304,8 @@ class TShape : public Tuple { public: // inheritate other constructors from Tuple using Tuple::Tuple; + /*! \brief default constructor */ + TShape() = default; /*! * \brief copy constructor of TShape * \param s source shape. diff --git a/nnvm/src/core/graph_attr_types.cc b/nnvm/src/core/graph.cc similarity index 91% rename from nnvm/src/core/graph_attr_types.cc rename to nnvm/src/core/graph.cc index 745a909bdf909..601f84d0a5768 100644 --- a/nnvm/src/core/graph_attr_types.cc +++ b/nnvm/src/core/graph.cc @@ -3,11 +3,18 @@ * \file graph_attr_types.cc * \brief Graph node data structure. */ -#include +#include #include namespace nnvm { +const IndexedGraph& Graph::indexed_graph() { + if (indexed_graph_ == nullptr) { + indexed_graph_.reset(new IndexedGraph(*this)); + } + return *indexed_graph_; +} + // implement constructor from graph IndexedGraph::IndexedGraph(const Graph &g) { entry_rptr_.push_back(0); diff --git a/nnvm/src/example/operator.cc b/nnvm/src/example/operator.cc index 9078c314b119e..1b2fb1e10510d 100644 --- a/nnvm/src/example/operator.cc +++ b/nnvm/src/example/operator.cc @@ -1,17 +1,39 @@ // Copyright (c) 2016 by Contributors // This is an example on how we can register operator information to NNVM +#include #include #include +#include #include +namespace myproject { + using nnvm::FListInputNames; using nnvm::FMutateInput; +using nnvm::FInferShape; using nnvm::NodeAttrs; +using nnvm::TShape; +using nnvm::array_view; + +// simply return the shape as same +inline bool SameShape(const NodeAttrs& attrs, + array_view ishape, + array_view oshape) { + if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false; + for (TShape* pshape : oshape) { + *pshape = *ishape[0]; + } + for (TShape* pshape : ishape) { + *pshape = *ishape[0]; + } + return true; +} NNVM_REGISTER_OP(add) .describe("add two data together") -.set_num_inputs(2); +.set_num_inputs(2) +.attr("FInferShape", SameShape); NNVM_REGISTER_OP(__add_symbol__) .describe("Alias of add") @@ -20,7 +42,8 @@ NNVM_REGISTER_OP(__add_symbol__) NNVM_REGISTER_OP(exp) .describe("take exponmential") .set_num_inputs(1) -.attr("inplace_pair", std::make_pair(0, 0)); +.attr("inplace_pair", std::make_pair(0, 0)) +.attr("FInferShape", SameShape); NNVM_REGISTER_OP(conv2d) @@ -39,3 +62,5 @@ NNVM_REGISTER_OP(assign) .attr("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) { return index == 0; }); + +} // namespace myproject diff --git a/nnvm/src/pass/infer_shape.cc b/nnvm/src/pass/infer_shape.cc new file mode 100644 index 0000000000000..5788cebfa89df --- /dev/null +++ b/nnvm/src/pass/infer_shape.cc @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file infer_shape.cc + * \brief Inference the shapes given + */ +#include +#include +#include + +namespace nnvm { +namespace pass { + +Graph InferShape(const Graph& src) { + Graph ret = src; + const IndexedGraph& idx = ret.indexed_graph(); + static auto& finfer_shape = Op::GetAttr("FInferShape"); + // reshape shape vector + ShapeVector rshape(idx.num_node_entries()); + // temp space for shape inference. + std::vector ishape, oshape; + // number of completed nodes + size_t num_known = 0; + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + ishape.resize(inode.inputs.size()); + for (uint32_t i = 0; i < ishape.size(); ++i) { + ishape[i] = &rshape[idx.entry_id(inode.inputs[i])]; + } + oshape.resize(inode.source->num_outputs()); + for (uint32_t i = 0; i < oshape.size(); ++i) { + oshape[i] = &rshape[idx.entry_id(nid, i)]; + } + if (finfer_shape.count(inode.source->op)) { + num_known += + finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape); + } + } + // set the shapes + ret.attrs["shape"] = std::make_shared(std::move(rshape)); + // number of nodes who knows the shape. + ret.attrs["shape_num_known_nodes"] = std::make_shared(num_known); + return ret; +} + +} // namespace pass +} // namespace nnvm diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index ecaeadeaadb43..a498660cfd1f6 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -9,6 +9,7 @@ #include namespace nnvm { +namespace pass { template inline T get_with_default(const std::unordered_map &map, @@ -139,4 +140,5 @@ NNVM_REGISTER_PASS(OrderMutation) .set_body(OrderMutation) .set_change_graph(true); +} // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 6ba7ac23f50f0..0fe1d1896db6d 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -120,6 +120,7 @@ struct JSONNode { struct JSONGraph { std::vector nodes; std::vector arg_nodes; + std::vector node_row_ptr; std::vector heads; std::unordered_map > attrs; @@ -127,6 +128,7 @@ struct JSONGraph { writer->BeginObject(); writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr); writer->WriteObjectKeyValue("heads", heads); if (attrs.size() != 0) { writer->WriteObjectKeyValue("attrs", attrs); @@ -140,6 +142,7 @@ struct JSONGraph { helper.DeclareField("nodes", &nodes); helper.DeclareField("arg_nodes", &arg_nodes); helper.DeclareField("heads", &heads); + helper.DeclareOptionalField("node_row_ptr", &node_row_ptr); helper.DeclareOptionalField("attrs", &attrs); helper.ReadAllFields(reader); } @@ -188,6 +191,7 @@ Graph LoadJSON(const Graph& src) { Graph SaveJSON(const Graph& src) { JSONGraph jgraph; std::unordered_map node2index; + jgraph.node_row_ptr.push_back(0); DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) { uint32_t nid = static_cast(jgraph.nodes.size()); node2index[n.get()] = nid; @@ -204,6 +208,8 @@ Graph SaveJSON(const Graph& src) { for (const NodePtr& c : n->control_deps) { jnode.control_deps.push_back(node2index.at(c.get())); } + jgraph.node_row_ptr.push_back( + jgraph.node_row_ptr.back() + n->num_outputs()); jgraph.nodes.emplace_back(std::move(jnode)); });