From c1e48e1a5b018a31cb3aa03b42375de4d8b10d0b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 17 Jul 2016 11:12:38 -0700 Subject: [PATCH] [CYTHON] Make speedup component minimum (#13) --- nnvm/python/nnvm/_symbol_internal.py | 1 + nnvm/python/nnvm/ctypes/symbol.py | 201 +++----------------------- nnvm/python/nnvm/cython/symbol.pyx | 189 +++--------------------- nnvm/python/nnvm/symbol.py | 209 ++++++++++++++++++++++++++- 4 files changed, 241 insertions(+), 359 deletions(-) create mode 100644 nnvm/python/nnvm/_symbol_internal.py diff --git a/nnvm/python/nnvm/_symbol_internal.py b/nnvm/python/nnvm/_symbol_internal.py new file mode 100644 index 000000000000..6fadaf89c9d9 --- /dev/null +++ b/nnvm/python/nnvm/_symbol_internal.py @@ -0,0 +1 @@ +"""Module space to register internal functions. Leave empty""" diff --git a/nnvm/python/nnvm/ctypes/symbol.py b/nnvm/python/nnvm/ctypes/symbol.py index 503fc09bde74..3bd5e65d4e25 100644 --- a/nnvm/python/nnvm/ctypes/symbol.py +++ b/nnvm/python/nnvm/ctypes/symbol.py @@ -13,11 +13,9 @@ from ..name import NameManager from ..attribute import AttrScope -__all__ = ["Symbol", "Variable"] - -class Symbol(object): +class SymbolBase(object): """Symbol is symbolic graph.""" - + __slots__ = ["handle"] # pylint: disable=no-member def __init__(self, handle): """Initialize the function with handle @@ -32,15 +30,6 @@ def __init__(self, handle): def __del__(self): check_call(_LIB.NNSymbolFree(self.handle)) - def __copy__(self): - return copy.deepcopy(self) - - def __deepcopy__(self, _): - handle = SymbolHandle() - check_call(_LIB.NNSymbolCopy(self.handle, - ctypes.byref(handle))) - return Symbol(handle) - def __call__(self, *args, **kwargs): """Invoke symbol as function on inputs. @@ -85,10 +74,10 @@ def _compose(self, *args, **kwargs): either as positional or keyword arguments, not both') for arg in args: - if not isinstance(arg, Symbol): + if not isinstance(arg, SymbolBase): raise TypeError('Compose expect `Symbol` as arguments') for val in kwargs.values(): - if not isinstance(val, Symbol): + if not isinstance(val, SymbolBase): raise TypeError('Compose expect `Symbol` as arguments') num_args = len(args) + len(kwargs) @@ -101,65 +90,6 @@ def _compose(self, *args, **kwargs): check_call(_LIB.NNSymbolCompose( self.handle, name, num_args, keys, args)) - def __getitem__(self, index): - if isinstance(index, string_types): - idx = None - for i, name in enumerate(self.list_outputs()): - if name == index: - if idx is not None: - raise ValueError('There are multiple outputs with name \"%s\"' % index) - idx = i - if idx is None: - raise ValueError('Cannot find output that matches name \"%s\"' % index) - index = idx - if not isinstance(index, int): - raise TypeError('Symbol only support integer index to fetch i-th output') - handle = SymbolHandle() - check_call(_LIB.NNSymbolGetOutput( - self.handle, nn_uint(index), ctypes.byref(handle))) - return Symbol(handle=handle) - - def attr(self, key): - """Get attribute string from the symbol, this function only works for non-grouped symbol. - - Parameters - ---------- - key : str - The key to get attribute from. - - Returns - ------- - value : str - The attribute value of the key, returns None if attribute do not exist. - """ - ret = ctypes.c_char_p() - success = ctypes.c_int() - check_call(_LIB.NNSymbolGetAttr( - self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) - if success.value != 0: - return py_str(ret.value) - else: - return None - - def list_attr(self, recursive=False): - """Get all attributes from the symbol. - - Parameters - ---------- - recursive : bool - Default `False`. When `recursive` is `True`, list recursively all the - attributes in the descendents. The attribute names are pre-pended with - the symbol names to avoid conflicts. If `False`, then only attributes - that belongs to this symbol is returned, and the attribute names will - **not** be pre-pended with the symbol name. - """ - size = nn_uint() - pairs = ctypes.POINTER(ctypes.c_char_p)() - option = ctypes.c_int(0) if recursive else ctypes.c_int(1) - check_call(_LIB.NNSymbolListAttrs( - self.handle, option, ctypes.byref(size), ctypes.byref(pairs))) - return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)} - def _set_attr(self, **kwargs): """Set the attribute of the symbol. @@ -168,116 +98,20 @@ def _set_attr(self, **kwargs): **kwargs The attributes to set """ - keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) - vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) - num_args = nn_uint(len(kwargs)) - check_call(_LIB.NNSymbolSetAttrs( + keys = _base.c_array(_ctypes.c_char_p, + [_base.c_str(key) for key in kwargs.keys()]) + vals = _base.c_array(_ctypes.c_char_p, + [_base.c_str(str(val)) for val in kwargs.values()]) + num_args = _base.nn_uint(len(kwargs)) + _check_call(_LIB.NNSymbolSetAttrs( self.handle, num_args, keys, vals)) - def get_internals(self): - """Get a new grouped symbol whose output contains all the internal outputs of this symbol. - Returns - ------- - sgroup : Symbol - The internal of the symbol. - """ - handle = SymbolHandle() - check_call(_LIB.NNSymbolGetInternals( - self.handle, ctypes.byref(handle))) - return Symbol(handle=handle) - - def list_arguments(self): - """List all the arguments in the symbol. - - Returns - ------- - args : list of string - List of all the arguments. - """ - size = ctypes.c_uint() - sarr = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.NNSymbolListArguments( - self.handle, ctypes.byref(size), ctypes.byref(sarr))) - return [py_str(sarr[i]) for i in range(size.value)] - - def list_outputs(self): - """List all outputs in the symbol. +_symbol_cls = SymbolBase - Returns - ------- - returns : list of string - List of all the outputs. - """ - size = ctypes.c_uint() - sarr = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.NNSymbolListOutputs( - self.handle, ctypes.byref(size), ctypes.byref(sarr))) - return [py_str(sarr[i]) for i in range(size.value)] - - def debug_str(self): - """Get a debug string. - - Returns - ------- - debug_str : string - Debug string of the symbol. - """ - debug_str = ctypes.c_char_p() - check_call(_LIB.NNSymbolPrint( - self.handle, ctypes.byref(debug_str))) - return py_str(debug_str.value) - - -def Variable(name, **kwargs): - """Create a symbolic variable with specified name. - - Parameters - ---------- - name : str - Name of the variable. - kwargs : dict of string -> string - Additional attributes to set on the variable. - - Returns - ------- - variable : Symbol - The created variable symbol. - """ - if not isinstance(name, string_types): - raise TypeError('Expect a string for variable `name`') - handle = SymbolHandle() - check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle))) - ret = Symbol(handle) - attr = AttrScope.current.get(kwargs) - if attr: - ret._set_attr(**attr) - return ret - - -def Group(symbols): - """Create a symbol that groups symbols together. - - Parameters - ---------- - symbols : list - List of symbols to be grouped. - - Returns - ------- - sym : Symbol - The created group symbol. - """ - ihandles = [] - for sym in symbols: - if not isinstance(sym, Symbol): - raise TypeError('Expect Symbols in the list input') - ihandles.append(sym.handle) - handle = SymbolHandle() - check_call(_LIB.NNSymbolCreateGroup( - nn_uint(len(ihandles)), - c_array(SymbolHandle, ihandles), ctypes.byref(handle))) - return Symbol(handle) +def _set_symbol_class(cls): + global _symbol_cls + _symbol_cls = cls def _make_atomic_symbol_function(handle): @@ -332,7 +166,7 @@ def creator(*args, **kwargs): attr = kwargs.pop('attr', None) for k, v in kwargs.items(): - if isinstance(v, Symbol): + if isinstance(v, SymbolBase): symbol_kwargs[k] = v else: param_keys.append(c_str(k)) @@ -351,7 +185,7 @@ def creator(*args, **kwargs): raise TypeError( '%s can only accept input' 'Symbols either as positional or keyword arguments, not both' % func_name) - s = Symbol(sym_handle) + s = _symbol_cls(sym_handle) attr = AttrScope.current.get(attr) if attr: s._set_attr(**attr) @@ -373,11 +207,12 @@ def _init_symbol_module(): check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size), ctypes.byref(plist))) module_obj = sys.modules["nnvm.symbol"] + module_internal = sys.modules["nnvm._symbol_internal"] for i in range(size.value): hdl = SymbolHandle(plist[i]) function = _make_atomic_symbol_function(hdl) if function.__name__.startswith('_'): - setattr(Symbol, function.__name__, staticmethod(function)) + setattr(module_internal, function.__name__, function) else: setattr(module_obj, function.__name__, function) diff --git a/nnvm/python/nnvm/cython/symbol.pyx b/nnvm/python/nnvm/cython/symbol.pyx index eeec2c430e89..7b1435381e99 100644 --- a/nnvm/python/nnvm/cython/symbol.pyx +++ b/nnvm/python/nnvm/cython/symbol.pyx @@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs import sys as _sys import ctypes as _ctypes +from numbers import Number as _Number from .._base import NNVMError from ..name import NameManager from ..attribute import AttrScope @@ -64,8 +65,7 @@ cdef extern from "nnvm/c_api.h": const char** keys, SymbolHandle* args); - -cdef class Symbol: +cdef class SymbolBase: """Symbol is symbolic graph.""" # handle for symbolic operator. cdef SymbolHandle handle @@ -85,76 +85,6 @@ cdef class Symbol: def handle(self): return _ctypes.cast(self.handle, _ctypes.c_void_p) - def __copy__(self): - return self.__deepcopy__() - - def __deepcopy__(self, _ = None): - cdef SymbolHandle handle - CALL(NNSymbolCopy(self.handle, &handle)) - return NewSymbol(handle) - - def __getitem__(self, index): - if isinstance(index, str): - idx = None - for i, name in enumerate(self.list_outputs()): - if name == index: - if idx is not None: - raise ValueError('There are multiple outputs with name \"%s\"' % index) - idx = i - if idx is None: - raise ValueError('Cannot find output that matches name \"%s\"' % index) - index = idx - if not isinstance(index, int): - raise TypeError('Symbol only support integer index to fetch i-th output') - cdef SymbolHandle handle - cdef nn_uint c_index = index - CALL(NNSymbolGetOutput(self.handle, c_index, &handle)) - return NewSymbol(handle) - - def attr(self, key): - """Get attribute string from the symbol, this function only works for non-grouped symbol. - - Parameters - ---------- - key : str - The key to get attribute from. - - Returns - ------- - value : str - The attribute value of the key, returns None if attribute do not exist. - """ - cdef const char* ret - cdef int success - key = c_str(key) - - CALL(NNSymbolGetAttr( - self.handle, key, &ret, &success)) - if success != 0: - return py_str(ret.value) - else: - return None - - def list_attr(self, recursive=False): - """Get all attributes from the symbol. - - Parameters - ---------- - recursive : bool - Default `False`. When `recursive` is `True`, list recursively all the - attributes in the descendents. The attribute names are pre-pended with - the symbol names to avoid conflicts. If `False`, then only attributes - that belongs to this symbol is returned, and the attribute names will - **not** be pre-pended with the symbol name. - """ - cdef nn_uint size - cdef const char** pairs - cdef int option - option = 0 if recursive else 1 - CALL(NNSymbolListAttrs( - self.handle, option, &size, &pairs)) - return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size)} - def _set_attr(self, **kwargs): """Set the attribute of the symbol. @@ -165,49 +95,6 @@ cdef class Symbol: """ SymbolSetAttr(self.handle, kwargs) - def get_internals(self): - """Get a new grouped symbol whose output contains all the internal outputs of this symbol. - - Returns - ------- - sgroup : Symbol - The internal of the symbol. - """ - cdef SymbolHandle handle - CALL(NNSymbolGetInternals(self.handle, &handle)) - return NewSymbol(handle) - - def list_arguments(self): - """List all the arguments in the symbol. - - Returns - ------- - args : list of string - List of all the arguments. - """ - cdef nn_uint size - cdef const char ** sarr - CALL(NNSymbolListArguments(self.handle, &size, &sarr)) - return [py_str(sarr[i]) for i in range(size)] - - def list_outputs(self): - """List all outputs in the symbol. - - Returns - ------- - returns : list of string - List of all the outputs. - """ - cdef nn_uint size - cdef const char ** sarr - CALL(NNSymbolListOutputs(self.handle, &size, &sarr)) - return [py_str(sarr[i]) for i in range(size)] - - def debug_str(self): - cdef const char* out_str - CALL(NNSymbolPrint(self.handle, &out_str)) - return py_str(out_str) - cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): cdef vector[string] sparam_keys @@ -224,34 +111,18 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals))) +_symbol_cls = SymbolBase + +def _set_symbol_class(cls): + global _symbol_cls + _symbol_cls = cls + cdef NewSymbol(SymbolHandle handle): """Create a new symbol given handle""" - sym = Symbol(None) - sym.handle = handle + sym = _symbol_cls(None) + (sym).handle = handle return sym - -def Variable(name, **kwargs): - """Create a symbolic variable with specified name. - - Parameters - ---------- - name : str - Name of the variable. - kwargs : dict of string -> string - Additional attributes to set on the variable. - - Returns - ------- - variable : Symbol - The created variable symbol. - """ - cdef SymbolHandle handle - name = c_str(name) - CALL(NNSymbolCreateVariable(name, &handle)) - return NewSymbol(handle) - - cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): """Create an atomic symbol function by handle and funciton name.""" cdef const char *name @@ -292,9 +163,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): if len(kwargs) != 0: for k, v in kwargs.items(): - if isinstance(v, Symbol): + if isinstance(v, SymbolBase): ssymbol_keys.push_back(c_str(k)) - symbol_args.push_back((v).handle) + symbol_args.push_back((v).handle) else: sparam_keys.push_back(c_str(k)) sparam_vals.push_back(c_str(str(v))) @@ -304,9 +175,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): raise TypeError("compose only accept input Symbols\ either as positional or keyword arguments, not both") for v in args: - if not isinstance(v, Symbol): + if not isinstance(v, SymbolBase): raise TypeError('Compose expect `Symbol` as arguments') - symbol_args.push_back((v).handle) + symbol_args.push_back((v).handle) cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys) cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals) @@ -344,46 +215,20 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): return creator -def Group(symbols): - """Create a symbol that groups symbols together. - - Parameters - ---------- - symbols : list - List of symbols to be grouped. - - Returns - ------- - sym : Symbol - The created group symbol. - """ - cdef vector[SymbolHandle] ihandles - cdef SymbolHandle handle - - for sym in symbols: - if not isinstance(sym, Symbol): - raise TypeError("Expect Symbols in the list input") - ihandles.push_back((sym).handle) - if ihandles.size() == 0: - raise ValueError("expect at least one element in the input") - CALL(NNSymbolCreateGroup(ihandles.size(), - &ihandles[0], &handle)) - return NewSymbol(handle) - - def _init_symbol_module(): """List and add all the atomic symbol functions to current module.""" cdef AtomicSymbolCreator* plist cdef nn_uint size CALL(NNSymbolListAtomicSymbolCreators(&size, &plist)) module_obj = _sys.modules["nnvm.symbol"] + module_internal = _sys.modules["nnvm._symbol_internal"] for i in range(size): function = _make_atomic_symbol_function(plist[i]) + if function.__name__.startswith('_'): - setattr(Symbol, function.__name__, staticmethod(function)) + setattr(module_internal, function.__name__, function) else: setattr(module_obj, function.__name__, function) - # Initialize the atomic symbol in startups _init_symbol_module() diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index ac1dee39a08b..7eab9f6045f7 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -2,13 +2,214 @@ from __future__ import absolute_import as _abs import sys as _sys import os as _os +import ctypes as _ctypes +from numbers import Number as _Number +from . import _base +from ._base import _LIB, check_call as _check_call +from . import _symbol_internal as _internal +from .attribute import AttrScope + +# Use different verison of SymbolBase +# When possible, use cython to speedup part of computation. try: if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0: - from .ctypes.symbol import Symbol, Variable + from .ctypes.symbol import SymbolBase, _set_symbol_class elif _sys.version_info >= (3, 0): - from ._cy3.symbol import Symbol, Variable, Group + from ._cy3.symbol import SymbolBase, _set_symbol_class else: - from ._cy2.symbol import Symbol, Variable, Group + from ._cy2.symbol import SymbolBase, _set_symbol_class except: - from .ctypes.symbol import Symbol, Variable, Group + from .ctypes.symbol import SymbolBase, _set_symbol_class + + +class Symbol(SymbolBase): + """Symbol is basic operation unit for symbolic graph compostion.""" + # disable dictionary storage, also do not have parent type. + __slots__ = [] + + def __add__(self, other): + if isinstance(other, Symbol): + return _internal.__add__symbol__(self, other) + elif isinstance(other, _Number): + return _internal.__add__scalar__(self, scalar=other) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __copy__(self): + return self.__deepcopy__() + + def __deepcopy__(self, _=None): + handle = _base.SymbolHandle() + _base.check_call(_LIB.NNSymbolCopy(self.handle, + _ctypes.byref(handle))) + return Symbol(handle) + + def __getitem__(self, index): + if isinstance(index, _base.string_types): + idx = None + for i, name in enumerate(self.list_outputs()): + if name == index: + if idx is not None: + raise ValueError('There are multiple outputs with name \"%s\"' % index) + idx = i + if idx is None: + raise ValueError('Cannot find output that matches name \"%s\"' % index) + index = idx + if not isinstance(index, int): + raise TypeError('Symbol only support integer index to fetch i-th output') + handle = _base.SymbolHandle() + _check_call(_LIB.NNSymbolGetOutput( + self.handle, _base.nn_uint(index), _ctypes.byref(handle))) + return Symbol(handle=handle) + + def attr(self, key): + """Get attribute string from the symbol, this function only works for non-grouped symbol. + + Parameters + ---------- + key : str + The key to get attribute from. + + Returns + ------- + value : str + The attribute value of the key, returns None if attribute do not exist. + """ + ret = _ctypes.c_char_p() + success = _ctypes.c_int() + _check_call(_LIB.NNSymbolGetAttr( + self.handle, c_str(key), _ctypes.byref(ret), _ctypes.byref(success))) + if success.value != 0: + return _base.py_str(ret.value) + else: + return None + + def list_attr(self, recursive=False): + """Get all attributes from the symbol. + + Parameters + ---------- + recursive : bool + Default `False`. When `recursive` is `True`, list recursively all the + attributes in the descendents. The attribute names are pre-pended with + the symbol names to avoid conflicts. If `False`, then only attributes + that belongs to this symbol is returned, and the attribute names will + **not** be pre-pended with the symbol name. + """ + size = _base.nn_uint() + pairs = _ctypes.POINTER(_ctypes.c_char_p)() + option = _ctypes.c_int(0) if recursive else _ctypes.c_int(1) + _check_call(_LIB.NNSymbolListAttrs( + self.handle, option, _ctypes.byref(size), _ctypes.byref(pairs))) + return {_base.py_str(pairs[i*2]): _base.py_str(pairs[i*2+1]) for i in range(size.value)} + + def get_internals(self): + """Get a new grouped symbol whose output contains all the internal outputs of this symbol. + + Returns + ------- + sgroup : Symbol + The internal of the symbol. + """ + handle = _base.SymbolHandle() + _check_call(_LIB.NNSymbolGetInternals( + self.handle, _ctypes.byref(handle))) + return Symbol(handle=handle) + + def list_arguments(self): + """List all the arguments in the symbol. + + Returns + ------- + args : list of string + List of all the arguments. + """ + size = _ctypes.c_uint() + sarr = _ctypes.POINTER(_ctypes.c_char_p)() + _check_call(_LIB.NNSymbolListArguments( + self.handle, _ctypes.byref(size), _ctypes.byref(sarr))) + return [_base.py_str(sarr[i]) for i in range(size.value)] + + def list_outputs(self): + """List all outputs in the symbol. + + Returns + ------- + returns : list of string + List of all the outputs. + """ + size = _ctypes.c_uint() + sarr = _ctypes.POINTER(_ctypes.c_char_p)() + _check_call(_LIB.NNSymbolListOutputs( + self.handle, _ctypes.byref(size), _ctypes.byref(sarr))) + return [_base.py_str(sarr[i]) for i in range(size.value)] + + def debug_str(self): + """Get a debug string. + + Returns + ------- + debug_str : string + Debug string of the symbol. + """ + debug_str = _ctypes.c_char_p() + _check_call(_LIB.NNSymbolPrint( + self.handle, _ctypes.byref(debug_str))) + return _base.py_str(debug_str.value) + + +def Variable(name, **kwargs): + """Create a symbolic variable with specified name. + + Parameters + ---------- + name : str + Name of the variable. + kwargs : dict of string -> string + Additional attributes to set on the variable. + + Returns + ------- + variable : Symbol + The created variable symbol. + """ + if not isinstance(name, _base.string_types): + raise TypeError('Expect a string for variable `name`') + handle = _base.SymbolHandle() + _base.check_call(_LIB.NNSymbolCreateVariable( + _base.c_str(name), _ctypes.byref(handle))) + ret = Symbol(handle) + attr = AttrScope.current.get(kwargs) + if attr: + ret._set_attr(**attr) + return ret + + +def Group(symbols): + """Create a symbol that groups symbols together. + + Parameters + ---------- + symbols : list + List of symbols to be grouped. + + Returns + ------- + sym : Symbol + The created group symbol. + """ + ihandles = [] + for sym in symbols: + if not isinstance(sym, Symbol): + raise TypeError('Expect Symbols in the list input') + ihandles.append(sym.handle) + handle = _base.SymbolHandle() + _check_call(_LIB.NNSymbolCreateGroup( + _base.nn_uint(len(ihandles)), + _base.c_array(_base.SymbolHandle, ihandles), + _ctypes.byref(handle))) + return Symbol(handle) + +# Set the real symbol class to Symbol +_set_symbol_class(Symbol)