Skip to content

Commit

Permalink
[SYMBOLIC] Add symbolic API (#2)
Browse files Browse the repository at this point in the history
* [SYMBOLIC] Add symbolic API

* Update Testcase to nnvm
  • Loading branch information
tqchen committed May 26, 2018
1 parent 625ab2c commit aa36823
Show file tree
Hide file tree
Showing 12 changed files with 680 additions and 59 deletions.
27 changes: 16 additions & 11 deletions nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
#include <unordered_set>
#include "./base.h"
#include "./node.h"
#include "./symbolic.h"

namespace nnvm {

/*!
* \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass.
*/
class Graph {
public:
Expand All @@ -30,16 +32,18 @@ class Graph {
* and can be shared across multiple Instance of graph
*/
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(FVisit fvisit) const;
};

/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param heads The heads in the graph.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);

// inline function implementations
template <typename GNode, typename HashType,
typename FVisit, typename HashFunc,
Expand Down Expand Up @@ -75,10 +79,11 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
}

template<typename FVisit>
inline void Graph::DFSVisit(FVisit fvisit) const {
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const std::shared_ptr<Node>* GNode;
std::vector<GNode> head_nodes(outputs.size());
std::transform(outputs.begin(), outputs.end(), head_nodes.begin(),
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
return &e.node;
});
Expand Down
17 changes: 14 additions & 3 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class Node {
inline bool is_variable() const;
/*! \return number of outputs from this node */
inline uint32_t num_outputs() const;
/*! \return number of inputs from this node */
inline uint32_t num_inputs() const;
/*!
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
Expand All @@ -86,10 +88,19 @@ inline bool Node::is_variable() const {

inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
if (this->op->num_outputs >= 0) {
return static_cast<uint32_t>(this->op->num_outputs);
if (this->op->get_num_outputs == nullptr) {
return this->op->num_outputs;
} else {
return this->op->get_num_outputs(*this);
return this->op->get_num_outputs(this->attrs);
}
}

inline uint32_t Node::num_inputs() const {
if (is_variable()) return 1;
if (this->op->get_num_inputs == nullptr) {
return this->op->num_inputs;
} else {
return this->op->get_num_inputs(this->attrs);
}
}

Expand Down
85 changes: 66 additions & 19 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <vector>
#include <utility>
#include <typeinfo>
#include <limits>
#include <functional>
#include "./base.h"

Expand All @@ -22,8 +23,8 @@ template<typename ValueType>
class OpMap;
class OpRegistryEntry;

/*! \brief constant to indicate variable length inout and output */
static const int kVarg = -1;
/*! \brief constant to indicate it take any length of positional inputs */
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();

/*!
* \brief Operator structure.
Expand Down Expand Up @@ -79,23 +80,31 @@ class Op {
/*!
* \brief number of inputs to the operator,
* -1 means it is variable length
* When get_num_inputs is presented,
* the number will be decided by get_num_inputs instead.
* \sa get_num_inputs
*/
int num_inputs = 0;
uint32_t num_inputs = 1;
/*!
* \brief number of outputs of the operator
* -1 means it is variable length
* When get_num_outputs is presented.
* The number of outputs will be decided by
* get_num_outputs function
* \sa get_num_outputs
*/
int num_outputs = 1;
uint32_t num_outputs = 1;
/*!
* \brief get number of outputs given information about the node.
* This is only valid when num_outputs == -1.
* \param node The constructed node.
* \param attrs The attribute of the node
* \return number of outputs.
*/
int (*get_num_outputs)(const Node& node) = nullptr;
uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \return number of inputs
*/
uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
Expand Down Expand Up @@ -143,19 +152,25 @@ class Op {
* \param n The number of inputs to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(int n); // NOLINT(*)
inline Op& set_num_inputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int n); // NOLINT(*)
inline Op& set_num_outputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int (*fn)(const Node& node)); // NOLINT(*)
inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
Expand All @@ -180,6 +195,7 @@ class Op {
static const Op* Get(const std::string& op_name);
/*!
* \brief Get additional registered attribute about operators.
* If nothing has been registered, an empty OpMap will be returned.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
Expand All @@ -197,7 +213,7 @@ class Op {
// internal constructor
Op();
// get const reference to certain attribute
static const any& GetAttrMap(const std::string& key);
static const any* GetAttrMap(const std::string& key);
// update the attribute OpMap
static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater);
Expand All @@ -217,6 +233,13 @@ class OpMap {
* \return the const reference to the content value.
*/
inline const ValueType& operator[](const Op* op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
inline const ValueType& get(const Op* op, const ValueType& def_value) const;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
Expand Down Expand Up @@ -262,8 +285,18 @@ class OpMap {
// member function of Op
template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any& ref = GetAttrMap(key);
return nnvm::get<OpMap<ValueType> >(ref);
const any* ref = GetAttrMap(key);
if (ref == nullptr) {
UpdateAttrMap(key, [key](any* pmap) {
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = key;
*pmap = std::move(pm);
}
});
ref = GetAttrMap(key);
}
return nnvm::get<OpMap<ValueType> >(*ref);
}

template<typename ValueType>
Expand All @@ -273,7 +306,7 @@ inline Op& Op::attr( // NOLINT(*)
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = attr_name;
*pmap = pm;
*pmap = std::move(pm);
}
CHECK_EQ(pmap->type(), typeid(OpMap<ValueType>))
<< "Attribute " << attr_name
Expand Down Expand Up @@ -301,18 +334,22 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
return *this;
}

inline Op& Op::set_num_inputs(int n) { // NOLINT(*)
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
this->num_inputs = n;
return *this;
}

inline Op& Op::set_num_outputs(int n) { // NOLINT(*)
inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}

inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
this->num_outputs = n;
return *this;
}

inline Op& Op::set_num_outputs(int (*fn)(const Node& node)) { // NOLINT(*)
this->num_outputs = kVarg;
inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}
Expand All @@ -338,6 +375,16 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
return data_[idx].first;
}

template<typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second) {
return data_[idx].first;
} else {
return def_value;
}
}

} // namespace nnvm

#endif // NNVM_OP_H_
42 changes: 42 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*!
* Copyright (c) 2016 by Contributors
* \file op_attr_types.h
* \brief Data structures that can appear in operator attributes.
*/
#ifndef NNVM_OP_ATTR_TYPES_H_
#define NNVM_OP_ATTR_TYPES_H_

#include <vector>
#include <string>
#include <functional>

namespace nnvm {

// These types are optional attributes in each op
// Some of them are needed for certain pass.

/*!
* \brief Return list of input arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;

/*!
* \brief Return list of output arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
Loading

0 comments on commit aa36823

Please sign in to comment.