From bdbbdf2590783d5e6e54b6c13cb5a725e0b5c0db Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 8 Aug 2016 16:44:10 -0700 Subject: [PATCH] Update Symbol and C API (#22) * Update tuple to be compatible with mshadow * Move set error message to C API * simplify with using * updates to shape inference * Add unnamed namespace to the implementations * [SYMBOL] Enable inference of Auxiliary data, rename list_arguments to list_inputs --- nnvm/README.md | 12 +- nnvm/include/nnvm/base.h | 34 +---- nnvm/include/nnvm/c_api.h | 27 ++-- nnvm/include/nnvm/op.h | 4 + nnvm/include/nnvm/pass_functions.h | 4 +- nnvm/include/nnvm/symbolic.h | 20 ++- nnvm/include/nnvm/tuple.h | 207 ++++++++++++++++++++++++++--- nnvm/python/nnvm/symbol.py | 25 +++- nnvm/src/c_api/c_api_common.h | 5 - nnvm/src/c_api/c_api_symbolic.cc | 21 +-- nnvm/src/core/pass.cc | 1 + nnvm/src/core/symbolic.cc | 38 +++++- nnvm/src/pass/infer_shape_type.cc | 20 +-- nnvm/src/pass/order_mutation.cc | 2 + nnvm/src/pass/place_device.cc | 2 + nnvm/src/pass/plan_memory.cc | 4 +- nnvm/src/pass/saveload_json.cc | 10 +- nnvm/tests/python/test_graph.py | 11 ++ nnvm/tests/python/test_symbol.py | 4 +- 19 files changed, 346 insertions(+), 105 deletions(-) diff --git a/nnvm/README.md b/nnvm/README.md index 7acbf88093a8..1a9757b40153 100644 --- a/nnvm/README.md +++ b/nnvm/README.md @@ -1,7 +1,7 @@ # NNVM: Build deep learning system by parts -NNVM is not a deep learning library. It is a modular, decentralized and lightweight library to -help build deep learning libraries efficiently. +NNVM is not a deep learning library. It is a modular, decentralized and lightweight part to +help build deep learning libraries. ## What is it @@ -9,14 +9,14 @@ While most deep learning systems offer end to end solutions, it is interesting to ask if we can actually assemble a deep learning system by parts. The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about. We believe that the decentralized modular system is an interesting direction. + The hope is that effective parts can be assembled together just like you assemble your own desktops. So the customized deep learning solution can be minimax, minimum in terms of dependencies, while maxiziming the users' need. -NNVM offers one such part, it provides a generic to do generic -computation graph optimization such as memory reduction, device allocation, -operator fusion while being agnostic to the operator -interface defintion and how operators are executed. +NNVM offers one such part, it provides a generic way to do +computation graph optimization such as memory reduction, device allocation and more +while being agnostic to the operator interface defintion and how operators are executed. NNVM is inspired by LLVM, aiming to be an intermediate representation library for neural nets and computation graphs generation and optimizations. diff --git a/nnvm/include/nnvm/base.h b/nnvm/include/nnvm/base.h index 31e53dc2fa2f..94fa35ce3048 100644 --- a/nnvm/include/nnvm/base.h +++ b/nnvm/include/nnvm/base.h @@ -16,37 +16,13 @@ namespace nnvm { /*! \brief any type */ -using any = dmlc::any; +using dmlc::any; -/*! - * \brief array_veiw type - * \tparam ValueType The value content of array view. - */ -template -using array_view = dmlc::array_view; - -/*! - * \brief get reference of type T stored in src. - * \param src The source container - * \return the reference to the type. - * \tparam T The type to be fetched. - */ -template -inline T& get(any& src) { // NOLINT(*) - return dmlc::get(src); -} - -/*! - * \brief get const reference of type T stored in src. - * \param src The source container - * \return the reference to the type. - * \tparam T The type to be fetched. - */ +/*! \brief array_veiw type */ +using dmlc::array_view; -template -inline const T& get(const any& src) { - return dmlc::get(src); -} +/*!\brief getter function of any type */ +using dmlc::get; } // namespace nnvm diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index 3b352e3d8df6..1c6943f9c681 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -35,6 +35,12 @@ typedef void *SymbolHandle; /*! \brief handle to Graph */ typedef void *GraphHandle; +/*! + * \brief Set the last error message needed by C API + * \param msg The error message to set. + */ +NNVM_DLL void NNAPISetLastError(const char* msg); + /*! * \brief return str message of the last error * all function in this file will return 0 when success @@ -171,25 +177,30 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, nn_uint *out_size, const char*** out); /*! - * \brief List arguments in the symbol. + * \brief List inputs in the symbol. * \param symbol the symbol + * \param option The option to list the inputs + * option=0 means list all arguments. + * option=1 means list arguments that are readed only by the graph. + * option=2 means list arguments that are mutated by the graph. * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListArguments(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); +NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, + int option, + nn_uint *out_size, + const char ***out_str_array); /*! - * \brief List returns in the symbol. + * \brief List returns names in the symbol. * \param symbol the symbol * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListOutputs(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); +NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array); /*! * \brief Get a symbol that contains all the internals. * \param symbol The symbol diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index 5f499b7377e1..e32d56642a73 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -289,7 +289,9 @@ template inline const OpMap& Op::GetAttr(const std::string& key) { const any* ref = GetAttrMap(key); if (ref == nullptr) { + // update the attribute map of the key by creating new empty OpMap UpdateAttrMap(key, [key](any* pmap) { + // use callback so it is in lockscope if (pmap->empty()) { OpMap pm; pm.attr_name_ = key; @@ -304,7 +306,9 @@ inline const OpMap& Op::GetAttr(const std::string& key) { template inline Op& Op::attr( // NOLINT(*) const std::string& attr_name, const ValueType& value) { + // update the attribute map of the key by creating new empty if needed. UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) { + // the callback is in lockscope so is threadsafe. if (pmap->empty()) { OpMap pm; pm.attr_name_ = attr_name; diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index d622ce47dd55..742c95d1a1cf 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -83,10 +83,10 @@ inline Graph InferType(Graph graph, DTypeVector type_args = {}, std::string type_attr_key = "") { if (type_args.size() != 0) { - graph.attrs["type_args"] = std::make_shared(std::move(type_args)); + graph.attrs["dtype_args"] = std::make_shared(std::move(type_args)); } if (type_attr_key.length() != 0) { - graph.attrs["type_attr_key"] = std::make_shared(std::move(type_attr_key)); + graph.attrs["dtype_attr_key"] = std::make_shared(std::move(type_attr_key)); } return ApplyPass(std::move(graph), {"InferType"}); } diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index c14534f3c524..8bca4cb3103c 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -30,6 +30,18 @@ class Symbol { /*! \brief only list attributes in current node */ kShallow = 1 }; + /*! \brief option passed to ListInputNames */ + enum ListInputOption { + /*! \brief list all the arguments */ + kAll = 0, + /*! \brief list only read only arguments */ + kReadOnlyArgs = 1, + /*! + * \brief List auxiliary states that can be mutated by the graph. + * This excludes the ReadOnly arguments + */ + kAuxiliaryStates = 2 + }; /*! \brief output entries contained in the symbol */ std::vector outputs; @@ -51,18 +63,20 @@ class Symbol { */ Symbol operator[] (size_t index) const; /*! - * \brief List the arguments names. + * \brief List the input names. + * \param option The options to list the arguments. * * The position of the returned list also corresponds to calling position in operator() * \return the arguments list of this symbol, they can be either named or unnamed (empty string). + * \sa ListInputOption */ - std::vector ListArguments() const; + std::vector ListInputNames(ListInputOption option) const; /*! * \brief List the names of outputs for this symbol. * For normal operators, it is usually symbol node name + "_output" * \return get the descriptions of outputs for this symbol. */ - std::vector ListOutputs() const; + std::vector ListOutputNames() const; /*! * \brief Compose the symbol with arguments, this changes the current symbol. * The kwargs passed in can be in-complete, diff --git a/nnvm/include/nnvm/tuple.h b/nnvm/include/nnvm/tuple.h index dbae458773ac..fefb7ce5739d 100644 --- a/nnvm/include/nnvm/tuple.h +++ b/nnvm/include/nnvm/tuple.h @@ -58,17 +58,9 @@ class Tuple { * \brief move constructor from Tuple * \param src the source shape */ - inline Tuple(Tuple&& src) { - this->swap(src); - } - /*! - * \param ndim the number of dimension of the Tuple - * \param v The value to fill. - */ - inline Tuple(index_t ndim, ValueType v) { - this->SetDim(ndim); - std::fill_n(begin(), ndim, v); + inline Tuple(Tuple&& src) { // NOLINT(*) + this->swap(src); } /*! * \brief construct the Tuple from content of iterator @@ -97,7 +89,7 @@ class Tuple { * \brief Swap current object with other * \param other another object to be swapped. */ - inline void swap(Tuple& other) noexcept { // NOLINT(*) + inline void swap(Tuple& other) { // NOLINT(*) std::swap(ndim_, other.ndim_); std::swap(num_heap_allocated_, other.num_heap_allocated_); std::swap(data_stack_, other.data_stack_); @@ -275,7 +267,7 @@ class Tuple { return is; } - private: + protected: // stack cache size static const uint32_t kStackCache = 4; /*! \brief number of dimension of the tuple */ @@ -303,16 +295,30 @@ class Tuple { */ class TShape : public Tuple { public: - // inheritate other constructors from Tuple - using Tuple::Tuple; /*! \brief default constructor */ TShape() = default; + /*! + * constructor to construct a shape with all 1. + * \param ndim the number of dimension + */ + inline TShape(index_t ndim) { // NOLINT(*) + this->SetDim(ndim); + std::fill_n(begin(), ndim, 1); + } /*! * \brief copy constructor of TShape * \param s source shape. */ - inline TShape(const Tuple& s) // NOLINT(*) - : Tuple(s) {} + inline TShape(const Tuple& s) { // NOLINT(*) + this->assign(s.begin(), s.end()); + } + /*! + * \brief constructor from initializer list + * \param init the initializer_list + */ + inline TShape(std::initializer_list init) { + this->assign(init.begin(), init.end()); + } /*! * \brief move constructor. * \param s source shape. @@ -320,6 +326,17 @@ class TShape : public Tuple { inline TShape(Tuple&& s) { // NOLINT(*) this->swap(s); } + /*! + * \brief construct the Tuple from content of iterator + * \param begin the beginning of iterator + * \param end end the end of the iterator + * \tparam RandomAccessIterator iterator type + */ + template + inline TShape(RandomAccessIterator begin, + RandomAccessIterator end) { + this->assign(begin, end); + } /*! * \brief assignment function from tshape * \param src source shape. @@ -347,6 +364,164 @@ class TShape : public Tuple { } return size; } + /*! + * \return product shape in [dimstart,dimend) + * \param dimstart start dimension + * \param dimend end dimension + */ + inline index_t ProdShape(int dimstart, int dimend) const { + index_t num = 1; + const index_t *d = this->data(); + for (int i = dimstart; i < dimend; ++i) { + num *= d[i]; + } + return num; + } + /*! \return the begin data pointer to content of the tuple */ + inline const index_t *data() const { + return begin(); + } + /*! \return the begin data pointer to content of the tuple */ + inline index_t *data() { + return begin(); + } +#ifdef MSHADOW_XINLINE + template + inline TShape(mshadow::Shape &&s) {// NOLINT(*) + this->assign(s.shape_, s.shape_ + dim); + } + /*! + * \brief assignment from shape + * \param shape source shape + * \tparam dim shape dimension + * \return reference of self + */ + template + inline TShape &operator=(const mshadow::Shape &shape) { + this->assign(shape.shape_, shape.shape_ + dim); + return *this; + } + /*! + * \brief get the shape of tensor specifying dim + * \return the shape requested + * \tparam dim dimension of the tensor + */ + template + inline mshadow::Shape get() const { + CHECK_EQ(dim, ndim()) + << "dimension do not match target dimension " << dim << " vs " << ndim(); + const index_t *d = this->data(); + mshadow::Shape s; + for (int i = 0; i < dim; ++i) { + s[i] = d[i]; + } + return s; + } + /*! + * flatten the higher dimension to second dimension, return a 2D shape + * \return the flat 2d shape + */ + inline mshadow::Shape<2> FlatTo2D(void) const { + mshadow::Shape<2> s; + if (ndim() == 0) return mshadow::Shape2(0, 0); + const index_t *d = this->data(); + s.shape_[1] = d[ndim() - 1]; + index_t ymax = 1; + for (index_t i = 1; i < ndim(); ++i) { + ymax *= d[i - 1]; + } + s.shape_[0] = ymax; + return s; + } + /*! + * flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim) + * \param axis_begin The beginning axis specified. + * \param axis_end The ending axis specified. + * \return the flat 3d shape + */ + inline mshadow::Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const { + CHECK(axis_end >= axis_begin); + mshadow::Shape<3> s; + if (ndim() == 0) return mshadow::Shape3(0, 0, 0); + const index_t *d = this->data(); + s.shape_[0] = 1; + s.shape_[1] = 1; + s.shape_[2] = 1; + + for (index_t i = 0; i < axis_begin; ++i) { + s.shape_[0] *= d[i]; + } + for (index_t i = axis_begin; i <= axis_end; ++i) { + s.shape_[1] *= d[i]; + } + for (index_t i = axis_end + 1; i < ndim(); ++i) { + s.shape_[2] *= d[i]; + } + return s; + } + /*! + * flatten the axis before and after the specified axis, so it becomes 3D tensor + * \param axis The axis specified. + * \return the flat 3d shape + */ + inline mshadow::Shape<3> FlatTo3D(index_t axis) const { + return FlatTo3D(axis, axis); + } + inline bool operator==(const TShape &s) const { + if (ndim() != s.ndim()) return false; + return std::equal(begin(), end(), s.begin()); + } + inline bool operator!=(const TShape &s) const { + return !(*this == s); + } + /*! + * \return whether two shape equals + * \param s the shape to compare against + * \tparam dim dimension of the shape + */ + template + inline bool operator==(const mshadow::Shape &s) const { + if (ndim_ != dim) return false; + const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_; + for (index_t i = 0; i < dim; ++i) { + if (d[i] != s.shape_[i]) return false; + } + return true; + } + /*! + * \return whether two shape not equals + * \param s the shape to compare against + * \tparam dim dimension of the shape + */ + template + inline bool operator!=(const mshadow::Shape &s) const { + return !(*this == s); + } + /*! + * \brief save the content into binary stream + * \param strm the output stream + * \tparam TStream any stream type that have write + */ + template + inline void Save(TStream *strm) const { + strm->Write(&ndim_, sizeof(ndim_)); + strm->Write(data(), sizeof(index_t) * ndim_); + } + /*! + * \brief load the content from binary stream + * \param strm the output stream + * \tparam TStream any stream type that have write + * \return whether the load is successful + */ + template + inline bool Load(TStream *strm) { + if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; + this->SetDim(ndim_); + size_t nread = sizeof(index_t) * ndim_; + if (strm->Read(data(), nread) != nread) return false; + return true; + } +#endif }; } // namespace nnvm diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index 31f1660fc899..e753299c7433 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -176,9 +176,16 @@ def get_internals(self): self.handle, _ctypes.byref(handle))) return Symbol(handle=handle) - def list_arguments(self): - """List all the arguments in the symbol. + def list_inputs(self, option='all'): + """List all the inputs in the symbol. + Parameters + ---------- + option : {'all', 'read_only', 'aux_state'}, optional + The listing option + - 'all' will list all the arguments. + - 'read_only' lists arguments that are readed by the graph. + - 'aux_state' lists arguments that are mutated by the graph as state. Returns ------- args : list of string @@ -186,8 +193,16 @@ def list_arguments(self): """ size = _ctypes.c_uint() sarr = _ctypes.POINTER(_ctypes.c_char_p)() - _check_call(_LIB.NNSymbolListArguments( - self.handle, _ctypes.byref(size), _ctypes.byref(sarr))) + if option == 'all': + copt = _ctypes.c_int(0) + elif option == 'read_only': + copt = _ctypes.c_int(1) + elif option == 'aux_state': + copt = _ctypes.c_int(2) + else: + raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}") + _check_call(_LIB.NNSymbolListInputNames( + self.handle, copt, _ctypes.byref(size), _ctypes.byref(sarr))) return [_base.py_str(sarr[i]) for i in range(size.value)] def list_outputs(self): @@ -200,7 +215,7 @@ def list_outputs(self): """ size = _ctypes.c_uint() sarr = _ctypes.POINTER(_ctypes.c_char_p)() - _check_call(_LIB.NNSymbolListOutputs( + _check_call(_LIB.NNSymbolListOutputNames( self.handle, _ctypes.byref(size), _ctypes.byref(sarr))) return [_base.py_str(sarr[i]) for i in range(size.value)] diff --git a/nnvm/src/c_api/c_api_common.h b/nnvm/src/c_api/c_api_common.h index 40c91d8ffc11..d1d6c4316f79 100644 --- a/nnvm/src/c_api/c_api_common.h +++ b/nnvm/src/c_api/c_api_common.h @@ -44,11 +44,6 @@ struct NNAPIThreadLocalEntry { /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore NNAPIThreadLocalStore; -/*! - * \brief Set the last error message needed by C API - * \param msg The error message to set. - */ -void NNAPISetLastError(const char* msg); /*! * \brief handle exception throwed out * \param e the exception diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index bc6eed5c742a..aabfca4795d2 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -19,7 +19,6 @@ int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, API_END(); } - int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char **name, const char **description, @@ -37,7 +36,6 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, API_END(); } - int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, nn_uint num_param, const char **keys, @@ -179,13 +177,15 @@ int NNSymbolListAttrs(SymbolHandle symbol, API_END(); } -int NNSymbolListArguments(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array) { +int NNSymbolListInputNames(SymbolHandle symbol, + int option, + nn_uint *out_size, + const char ***out_str_array) { Symbol *s = static_cast(symbol); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = std::move(s->ListArguments()); + ret->ret_vec_str = std::move( + s->ListInputNames(Symbol::ListInputOption(option))); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); @@ -195,13 +195,13 @@ int NNSymbolListArguments(SymbolHandle symbol, API_END(); } -int NNSymbolListOutputs(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array) { +int NNSymbolListOutputNames(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array) { Symbol *s = static_cast(symbol); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = std::move(s->ListOutputs()); + ret->ret_vec_str = std::move(s->ListOutputNames()); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); @@ -221,6 +221,7 @@ int NNSymbolCompose(SymbolHandle sym, std::string& s_name = ret->ret_str; std::unordered_map& kwargs = ret->kwarg_symbol; + kwargs.clear(); if (name != nullptr) { s_name = name; } else { diff --git a/nnvm/src/core/pass.cc b/nnvm/src/core/pass.cc index f58c8039989b..5c4aeb2e0232 100644 --- a/nnvm/src/core/pass.cc +++ b/nnvm/src/core/pass.cc @@ -48,6 +48,7 @@ Graph ApplyPass(Graph g, } g = r->body(std::move(g)); } + return g; } diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 8c90ccfd6bf7..1a04bf40b05a 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -181,17 +181,41 @@ Symbol Symbol::operator[] (size_t index) const { } } -std::vector Symbol::ListArguments() const { +std::vector Symbol::ListInputNames(ListInputOption option) const { std::vector ret; - DFSVisit(this->outputs, [&ret](const NodePtr &node) { - if (node->is_variable()) { + if (option == kAll) { + DFSVisit(this->outputs, [&ret](const NodePtr &node) { + if (node->is_variable()) { + ret.push_back(node->attrs.name); + } + }); + } else { + std::unordered_set mutable_set; + std::vector vlist; + static auto& fmutate_inputs = Op::GetAttr("FMutateInput"); + DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) { + if (node->is_variable()) { + vlist.push_back(node.get()); + } else if (fmutate_inputs.count(node->op)) { + FMutateInput fmutate = fmutate_inputs[node->op]; + for (uint32_t i = 0; i < node->inputs.size(); ++i) { + if (fmutate(node->attrs, i)) { + mutable_set.insert(node->inputs[i].node.get()); + } + } + } + }); + for (Node* node : vlist) { + if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) || + (option == kAuxiliaryStates && mutable_set.count(node) != 0)) { ret.push_back(node->attrs.name); } - }); + } + } return ret; } -std::vector Symbol::ListOutputs() const { +std::vector Symbol::ListOutputNames() const { static auto& flist_ouputs = Op::GetAttr("FListOutputNames"); std::vector ret; for (auto &head : outputs) { @@ -345,10 +369,10 @@ void Symbol::Compose(const array_view& args, } } else { std::vector keys = GetKeys(kwargs); - std::vector arg_names = ListArguments(); + std::vector arg_names = ListInputNames(kAll); array_view view(dmlc::BeginPtr(arg_names) + arg_counter, dmlc::BeginPtr(arg_names) + arg_names.size()); - KeywordArgumentMismatch("Symbol.Compose", keys, ListArguments()); + KeywordArgumentMismatch("Symbol.Compose", keys, arg_names); } } } diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 9eb6c9bca704..25e62d563a3f 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -9,6 +9,7 @@ namespace nnvm { namespace pass { +namespace { template Graph InferAttr(Graph &&ret, @@ -17,7 +18,7 @@ Graph InferAttr(Graph &&ret, const char* arg_name, const char* attr_key_name, const char* attr_name, - const char* known_name, + const char* unknown_name, IsNone fis_none) { using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); @@ -48,11 +49,11 @@ Graph InferAttr(Graph &&ret, // temp space for shape inference. std::vector ishape, oshape; // number of completed nodes - size_t num_known = 0; + size_t num_unknown = 0; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) { - if (shape_attr_key.length() != 0) { + if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) { auto it = inode.source->attrs.dict.find(shape_attr_key); if (it != inode.source->attrs.dict.end()) { CHECK_EQ(inode.source->num_outputs(), 1); @@ -71,8 +72,8 @@ Graph InferAttr(Graph &&ret, oshape[i] = &rshape[idx.entry_id(nid, i)]; } if (finfer_shape.count(inode.source->op)) { - num_known += - finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape); + num_unknown += + !(finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape)); } else if (is_backward.get(inode.source->op, false)) { // backward operator inference. CHECK_GE(inode.control_deps.size(), 1) @@ -85,13 +86,13 @@ Graph InferAttr(Graph &&ret, *oshape[i] = rshape[idx.entry_id(fnode.inputs[i])]; if (fis_none(*oshape[i])) known = false; } - num_known += known; + num_unknown += !known; } } // set the shapes ret.attrs[attr_name] = std::make_shared(std::move(rshape)); // number of nodes who knows the shape. - ret.attrs[known_name] = std::make_shared(num_known); + ret.attrs[unknown_name] = std::make_shared(num_unknown); return ret; } @@ -101,7 +102,7 @@ NNVM_REGISTER_PASS(InferShape) return InferAttr( std::move(ret), TShape(), "FInferShape", "shape_args", "shape_attr_key", - "shape", "shape_num_known_nodes", + "shape", "shape_num_unknown_nodes", [](const TShape& s) { return s.ndim() == 0; }); }) .set_change_graph(false) @@ -113,7 +114,7 @@ NNVM_REGISTER_PASS(InferType) return InferAttr( std::move(ret), 0, "FInferType", "dtype_args", "dtype_attr_key", - "dtype", "dtype_num_known_nodes", + "dtype", "dtype_num_unknown_nodes", [](const int t) { return t == -1; }); }) .set_change_graph(false) @@ -123,5 +124,6 @@ DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape); DMLC_JSON_ENABLE_ANY(DTypeVector, list_int); DMLC_JSON_ENABLE_ANY(size_t, size_t); +} // namespace } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index ce615fccaad4..775621982794 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -10,6 +10,7 @@ namespace nnvm { namespace pass { +namespace { template inline T get_with_default(const std::unordered_map &map, @@ -140,5 +141,6 @@ NNVM_REGISTER_PASS(OrderMutation) .set_body(OrderMutation) .set_change_graph(true); +} // namespace } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index 787f47b80ca7..6a6e877a9f87 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -10,6 +10,7 @@ namespace nnvm { namespace pass { +namespace { // simply logic to place device according to device_group hint // insert copy node when there is @@ -176,5 +177,6 @@ NNVM_REGISTER_PASS(PlaceDevice) DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int); +} // namespace } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 3ba4aee16821..2d57b5c78f6b 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -12,6 +12,7 @@ namespace nnvm { namespace pass { +namespace { // simple graph based allocator. class GraphAllocator { @@ -91,7 +92,7 @@ class GraphAllocator { if ((*idx_)[nid].source->is_variable()) continue; importance[nid] = 1; } - num_match_color_ = ColorNodeGroup( + num_match_color_ = pass::ColorNodeGroup( *idx_, importance, num_match_color_, &node_color_); } } @@ -223,5 +224,6 @@ NNVM_REGISTER_PASS(PlanMemory) .depend_graph_attr("shape") .provide_graph_attr("storage_id"); +} // namespace } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 4a3c97e25b77..bd22d807e1fe 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -4,6 +4,7 @@ * \brief Save and load graph to/from JSON file. */ #include +#include #include #include @@ -26,6 +27,7 @@ struct Handler > { namespace nnvm { namespace pass { +namespace { // auxiliary node structure for serialization. struct JSONNode { @@ -35,7 +37,7 @@ struct JSONNode { uint32_t index; uint32_t version; void Save(dmlc::JSONWriter *writer) const { - writer->BeginArray(); + writer->BeginArray(false); writer->WriteArrayItem(node_id); writer->WriteArrayItem(index); writer->WriteArrayItem(version); @@ -74,7 +76,10 @@ struct JSONNode { } writer->WriteObjectKeyValue("name", node->attrs.name); if (node->attrs.dict.size() != 0) { - writer->WriteObjectKeyValue("attr", node->attrs.dict); + // write attributes in order; + std::map dict( + node->attrs.dict.begin(), node->attrs.dict.end()); + writer->WriteObjectKeyValue("attr", dict); } writer->WriteObjectKeyValue("inputs", inputs); if (control_deps.size() != 0) { @@ -247,5 +252,6 @@ NNVM_REGISTER_PASS(SaveJSON) DMLC_JSON_ENABLE_ANY(std::string, str); DMLC_JSON_ENABLE_ANY(std::vector, list_int); +} // namespace } // namespace pass } // namespace nnvm diff --git a/nnvm/tests/python/test_graph.py b/nnvm/tests/python/test_graph.py index 53714b1ae018..b958cee10be0 100644 --- a/nnvm/tests/python/test_graph.py +++ b/nnvm/tests/python/test_graph.py @@ -35,6 +35,16 @@ def test_order_mutation_pass(): assert nindex['add1'] in jnodes[nindex['assign']]['control_deps'] assert jnodes[nindex['assign']]['inputs'][0][2] == 1 +def test_list_args(): + x = sym.Variable('x') + z = sym.Variable('z') + y = sym.conv2d(data=x, name='conv', dev='gpu') + y = sym.add(y, z, name='add1') + # write after read + z = sym.assign(x, y, name='assign') + assert z.list_inputs('read_only') == ['conv_weight', 'z'] + assert z.list_inputs('aux_state') == ['x'] + def test_infer_shape(): x = sym.Variable('x', shape=(4, 2)) y = sym.add(x, x, name='add1') @@ -109,3 +119,4 @@ def test_plan_memory(): test_infer_type() test_place_device() test_plan_memory() + test_list_args() diff --git a/nnvm/tests/python/test_symbol.py b/nnvm/tests/python/test_symbol.py index adc9099adc13..a754f0f60fde 100644 --- a/nnvm/tests/python/test_symbol.py +++ b/nnvm/tests/python/test_symbol.py @@ -7,7 +7,7 @@ def test_compose(): y = sym.exp(sym.add(x, x, name='add', gpu=2), name='exp', gpu=1, attr={"kk": "1"}) - assert y.list_arguments() == ['x'] + assert y.list_inputs() == ['x'] assert y.list_outputs() == ["exp_output"] assert y.list_attr()['gpu'] == '1' z = y.get_internals() @@ -17,7 +17,7 @@ def test_compose(): def test_default_input(): x = sym.Variable('x') y = sym.conv2d(data=x, name='conv') - assert y.list_arguments() == ['x', 'conv_weight'] + assert y.list_inputs() == ['x', 'conv_weight'] try: z = sym.add(x) assert False