Skip to content

Commit

Permalink
[CYTHON] Make speedup component minimum (apache#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 31f9fc0 commit c1e48e1
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 359 deletions.
1 change: 1 addition & 0 deletions nnvm/python/nnvm/_symbol_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module space to register internal functions. Leave empty"""
201 changes: 18 additions & 183 deletions nnvm/python/nnvm/ctypes/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
from ..name import NameManager
from ..attribute import AttrScope

__all__ = ["Symbol", "Variable"]

class Symbol(object):
class SymbolBase(object):
"""Symbol is symbolic graph."""

__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Expand All @@ -32,15 +30,6 @@ def __init__(self, handle):
def __del__(self):
check_call(_LIB.NNSymbolFree(self.handle))

def __copy__(self):
return copy.deepcopy(self)

def __deepcopy__(self, _):
handle = SymbolHandle()
check_call(_LIB.NNSymbolCopy(self.handle,
ctypes.byref(handle)))
return Symbol(handle)

def __call__(self, *args, **kwargs):
"""Invoke symbol as function on inputs.
Expand Down Expand Up @@ -85,10 +74,10 @@ def _compose(self, *args, **kwargs):
either as positional or keyword arguments, not both')

for arg in args:
if not isinstance(arg, Symbol):
if not isinstance(arg, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
for val in kwargs.values():
if not isinstance(val, Symbol):
if not isinstance(val, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')

num_args = len(args) + len(kwargs)
Expand All @@ -101,65 +90,6 @@ def _compose(self, *args, **kwargs):
check_call(_LIB.NNSymbolCompose(
self.handle, name, num_args, keys, args))

def __getitem__(self, index):
if isinstance(index, string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetOutput(
self.handle, nn_uint(index), ctypes.byref(handle)))
return Symbol(handle=handle)

def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNSymbolGetAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None

def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size = nn_uint()
pairs = ctypes.POINTER(ctypes.c_char_p)()
option = ctypes.c_int(0) if recursive else ctypes.c_int(1)
check_call(_LIB.NNSymbolListAttrs(
self.handle, option, ctypes.byref(size), ctypes.byref(pairs)))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)}

def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
Expand All @@ -168,116 +98,20 @@ def _set_attr(self, **kwargs):
**kwargs
The attributes to set
"""
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()])
num_args = nn_uint(len(kwargs))
check_call(_LIB.NNSymbolSetAttrs(
keys = _base.c_array(_ctypes.c_char_p,
[_base.c_str(key) for key in kwargs.keys()])
vals = _base.c_array(_ctypes.c_char_p,
[_base.c_str(str(val)) for val in kwargs.values()])
num_args = _base.nn_uint(len(kwargs))
_check_call(_LIB.NNSymbolSetAttrs(
self.handle, num_args, keys, vals))

def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.

Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetInternals(
self.handle, ctypes.byref(handle)))
return Symbol(handle=handle)

def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListArguments(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def list_outputs(self):
"""List all outputs in the symbol.
_symbol_cls = SymbolBase

Returns
-------
returns : list of string
List of all the outputs.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListOutputs(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def debug_str(self):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str = ctypes.c_char_p()
check_call(_LIB.NNSymbolPrint(
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)


def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, string_types):
raise TypeError('Expect a string for variable `name`')
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle)))
ret = Symbol(handle)
attr = AttrScope.current.get(kwargs)
if attr:
ret._set_attr(**attr)
return ret


def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateGroup(
nn_uint(len(ihandles)),
c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
return Symbol(handle)
def _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls


def _make_atomic_symbol_function(handle):
Expand Down Expand Up @@ -332,7 +166,7 @@ def creator(*args, **kwargs):
attr = kwargs.pop('attr', None)

for k, v in kwargs.items():
if isinstance(v, Symbol):
if isinstance(v, SymbolBase):
symbol_kwargs[k] = v
else:
param_keys.append(c_str(k))
Expand All @@ -351,7 +185,7 @@ def creator(*args, **kwargs):
raise TypeError(
'%s can only accept input'
'Symbols either as positional or keyword arguments, not both' % func_name)
s = Symbol(sym_handle)
s = _symbol_cls(sym_handle)
attr = AttrScope.current.get(attr)
if attr:
s._set_attr(**attr)
Expand All @@ -373,11 +207,12 @@ def _init_symbol_module():
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
module_obj = sys.modules["nnvm.symbol"]
module_internal = sys.modules["nnvm._symbol_internal"]
for i in range(size.value):
hdl = SymbolHandle(plist[i])
function = _make_atomic_symbol_function(hdl)
if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function))
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)

Expand Down
Loading

0 comments on commit c1e48e1

Please sign in to comment.