Skip to content

Commit

Permalink
Enable Alias, refactor C API to reflect Op Semantics (#41)
Browse files Browse the repository at this point in the history
* Enable Alias, refactor C API to reflect Op Semantics

* add alias example
  • Loading branch information
tqchen committed May 29, 2018
1 parent a58feb2 commit 038945f
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 95 deletions.
1 change: 1 addition & 0 deletions nnvm/example/src/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ NNVM_REGISTER_OP(identity)
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.add_alias("__add_symbol__")
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
Expand Down
41 changes: 34 additions & 7 deletions nnvm/include/dmlc/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,19 @@ namespace dmlc {
template<typename EntryType>
class Registry {
public:
/*! \return list of functions in the registry */
inline static const std::vector<const EntryType*> &List() {
return Get()->entry_list_;
/*! \return list of entries in the registry(excluding alias) */
inline static const std::vector<const EntryType*>& List() {
return Get()->const_list_;
}
/*! \return list all names registered in the registry, including alias */
inline static std::vector<std::string> ListAllNames() {
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
typename std::map<std::string, EntryType*>::const_iterator p;
std::vector<std::string> names;
for (p = fmap.begin(); p !=fmap.end(); ++p) {
names.push_back(p->first);
}
return names;
}
/*!
* \brief Find the entry with corresponding name.
Expand All @@ -44,6 +54,21 @@ class Registry {
return NULL;
}
}
/*!
* \brief Add alias to the key_name
* \param key_name The original entry key
* \param alias The alias key.
*/
inline void AddAlias(const std::string& key_name,
const std::string& alias) {
EntryType* e = fmap_.at(key_name);
if (fmap_.count(alias)) {
CHECK_EQ(e, fmap_.at(alias))
<< "Entry " << e->name << " already registered under different entry";
} else {
fmap_[alias] = e;
}
}
/*!
* \brief Internal function to register a name function under name.
* \param name name of the function
Expand All @@ -55,6 +80,7 @@ class Registry {
EntryType *e = new EntryType();
e->name = name;
fmap_[name] = e;
const_list_.push_back(e);
entry_list_.push_back(e);
return *e;
}
Expand All @@ -79,16 +105,17 @@ class Registry {

private:
/*! \brief list of entry types */
std::vector<const EntryType*> entry_list_;
std::vector<EntryType*> entry_list_;
/*! \brief list of entry types */
std::vector<const EntryType*> const_list_;
/*! \brief map of name->function */
std::map<std::string, EntryType*> fmap_;
/*! \brief constructor */
Registry() {}
/*! \brief destructor */
~Registry() {
for (typename std::map<std::string, EntryType*>::iterator p = fmap_.begin();
p != fmap_.end(); ++p) {
delete p->second;
for (size_t i = 0; i < entry_list_.size(); ++i) {
delete entry_list_[i];
}
}
};
Expand Down
54 changes: 38 additions & 16 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
typedef unsigned int nn_uint;

/*! \brief handle to a function that takes param and creates symbol */
typedef void *AtomicSymbolCreator;
typedef void *OpHandle;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to Graph */
Expand All @@ -53,17 +53,39 @@ NNVM_DLL void NNAPISetLastError(const char* msg);
NNVM_DLL const char *NNGetLastError(void);

/*!
* \brief list all the available AtomicSymbolEntry
* \brief list all the available operator names, include entries.
* \param out_size the size of returned array
* \param out_array the output operator name array.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNListAllOpNames(nn_uint *out_size,
const char*** out_array);

/*!
* \brief Get operator handle given name.
* \param op_name The name of the operator.
* \param op_out The returnning op handle.
*/
NNVM_DLL int NNGetOpHandle(const char* op_name,
OpHandle* op_out);

/*!
* \brief list all the available operators.
* This won't include the alias, use ListAllNames
* instead to get all alias names.
*
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
NNVM_DLL int NNListUniqueOps(nn_uint *out_size,
OpHandle **out_array);

/*!
* \brief Get the detailed information about atomic symbol.
* \param creator the AtomicSymbolCreator.
* \param name The returned name of the creator.
* \param op The operator handle.
* \param real_name The returned name of the creator.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args
Expand All @@ -72,24 +94,24 @@ NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
NNVM_DLL int NNGetOpInfo(OpHandle op,
const char **real_name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
/*!
* \brief Create an AtomicSymbol functor.
* \param creator the AtomicSymbolCreator
* \param op The operator handle
* \param num_param the number of parameters
* \param keys the keys to the params
* \param vals the vals of the params
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param,
const char **keys,
const char **vals,
Expand Down
7 changes: 7 additions & 0 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ class Op {
* \return reference to self.
*/
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator.
* \return reference to self.
*/
Op& add_alias(const std::string& alias); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _load_lib():

# type definitions
nn_uint = ctypes.c_uint
SymbolCreatorHandle = ctypes.c_void_p
OpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from .._base import _LIB
from .._base import c_array, c_str, nn_uint, py_str, string_types
from .._base import SymbolHandle
from .._base import SymbolHandle, OpHandle
from .._base import check_call, ctypes2docstring
from ..name import NameManager
from ..attribute import AttrScope
Expand Down Expand Up @@ -114,25 +114,25 @@ def _set_symbol_class(cls):
_symbol_cls = cls


def _make_atomic_symbol_function(handle):
def _make_atomic_symbol_function(handle, name):
"""Create an atomic symbol function by handle and funciton name."""
name = ctypes.c_char_p()
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = nn_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p()

check_call(_LIB.NNSymbolGetAtomicSymbolInfo(
handle, ctypes.byref(name), ctypes.byref(desc),
check_call(_LIB.NNGetOpInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(ret_type)))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name.value)
func_name = name
desc = py_str(desc.value)

doc_str = ('%s\n\n' +
Expand Down Expand Up @@ -199,22 +199,25 @@ def creator(*args, **kwargs):
return creator


def _init_symbol_module():
def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
plist = ctypes.POINTER(ctypes.c_void_p)()
_set_symbol_class(symbol_class)
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
module_obj = sys.modules["nnvm.symbol"]
module_internal = sys.modules["nnvm._symbol_internal"]
check_call(_LIB.NNListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
hdl = SymbolHandle(plist[i])
function = _make_atomic_symbol_function(hdl)
op_names.append(py_str(plist[i]))

module_obj = sys.modules["%s.symbol" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)

# Initialize the atomic symbol in startups
_init_symbol_module()
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/cython/base.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ctypedef void* SymbolHandle
ctypedef void* AtomicSymbolCreator
ctypedef void* OpHandle
ctypedef unsigned nn_uint

cdef py_str(const char* x):
Expand Down
6 changes: 0 additions & 6 deletions nnvm/python/nnvm/cython/symbol.pyd

This file was deleted.

62 changes: 35 additions & 27 deletions nnvm/python/nnvm/cython/symbol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,25 @@ include "./base.pyi"

cdef extern from "nnvm/c_api.h":
const char* NNGetLastError();
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
int NNListAllOpNames(nn_uint *out_size,
const char ***out_array);
int NNGetOpHandle(const char *op_name,
OpHandle *handle);
int NNGetOpInfo(OpHandle op,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNListOpNames(nn_uint *out_size,
const char ***out_array);
int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNSymbolFree(SymbolHandle symbol);
int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
Expand Down Expand Up @@ -88,7 +92,7 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):

_symbol_cls = SymbolBase

def _set_symbol_class(cls):
cdef _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls

Expand All @@ -98,23 +102,24 @@ cdef NewSymbol(SymbolHandle handle):
(<SymbolBase>sym).handle = handle
return sym

cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
cdef _make_atomic_symbol_function(OpHandle handle, string name):
"""Create an atomic symbol function by handle and funciton name."""
cdef const char *name
cdef const char *real_name
cdef const char *desc
cdef nn_uint num_args
cdef const char** arg_names
cdef const char** arg_types
cdef const char** arg_descs
cdef const char* return_type

CALL(NNSymbolGetAtomicSymbolInfo(
handle, &name, &desc,
CALL(NNGetOpInfo(
handle, &real_name, &desc,
&num_args, &arg_names,
&arg_types, &arg_descs,
&return_type))

param_str = BuildDoc(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name)
func_name = py_str(name.c_str())
doc_str = ('%s\n\n' +
'%s\n' +
'name : string, optional.\n' +
Expand Down Expand Up @@ -190,20 +195,23 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
return creator


def _init_symbol_module():
def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
cdef AtomicSymbolCreator* plist
cdef const char** op_name_ptrs
cdef nn_uint size
CALL(NNSymbolListAtomicSymbolCreators(&size, &plist))
module_obj = _sys.modules["nnvm.symbol"]
module_internal = _sys.modules["nnvm._symbol_internal"]
for i in range(size):
function = _make_atomic_symbol_function(plist[i])
cdef vector[string] op_names
cdef OpHandle handle

_set_symbol_class(symbol_class)
CALL(NNListAllOpNames(&size, &op_name_ptrs))
for i in range(size):
op_names.push_back(string(op_name_ptrs[i]));
module_obj = _sys.modules["%s.symbol" % root_namespace]
module_internal = _sys.modules["%s._symbol_internal" % root_namespace]
for i in range(op_names.size()):
CALL(NNGetOpHandle(op_names[i].c_str(), &handle))
function = _make_atomic_symbol_function(handle, op_names[i])
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)

# Initialize the atomic symbol in startups
_init_symbol_module()
Loading

0 comments on commit 038945f

Please sign in to comment.