Skip to content

Commit

Permalink
Make cython compatible with python3 (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent b639a63 commit ad0ab0a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
4 changes: 3 additions & 1 deletion nnvm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loop
-Iinclude -Idmlc-core/include -fPIC

# specify tensor path
.PHONY: clean all test lint doc cython cython3
.PHONY: clean all test lint doc cython cython3 cyclean

all: lib/libnnvm.so lib/libnnvm.a cli_test

Expand Down Expand Up @@ -37,6 +37,8 @@ cython:
cython3:
cd python; python3 setup.py build_ext --inplace

cyclean:
rm -rf python/nnvm/*/*.so python/nnvm/*/*.cpp

lint:
python2 dmlc-core/scripts/lint.py nnvm cpp include src
Expand Down
22 changes: 22 additions & 0 deletions nnvm/python/nnvm/cython/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@ cdef py_str(const char* x):
return x.decode("utf-8")


cdef c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string

Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef CALL(int ret):
if ret != 0:
raise NNVMError(NNGetLastError())
Expand All @@ -20,6 +35,13 @@ cdef const char** CBeginPtr(vector[const char*]& vec):
else:
return NULL
cdef vector[const char*] SVec2Ptr(vector[string]& vec):
cdef vector[const char*] svec
svec.resize(vec.size())
for i in range(vec.size()):
svec[i] = vec[i].c_str()
return svec
cdef BuildDoc(nn_uint num_args,
const char** arg_names,
Expand Down
39 changes: 26 additions & 13 deletions nnvm/python/nnvm/cython/symbol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from .._base import NNVMError
from ..name import NameManager
from ..attribute import AttrScope
from libcpp.vector cimport vector
from libcpp.string cimport string
from cpython.version cimport PY_MAJOR_VERSION

include "./base.pyi"
Expand Down Expand Up @@ -110,7 +111,7 @@ cdef class Symbol:
CALL(NNSymbolGetOutput(self.handle, c_index, &handle))
return NewSymbol(handle)

def attr(self, const char* key):
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
Expand All @@ -125,6 +126,8 @@ cdef class Symbol:
"""
cdef const char* ret
cdef int success
key = c_str(key)

CALL(NNSymbolGetAttr(
self.handle, key, &ret, &success))
if success != 0:
Expand Down Expand Up @@ -203,16 +206,19 @@ cdef class Symbol:
def debug_str(self):
cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str))
return str(out_str)
return py_str(out_str)


cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef vector[const char*] param_keys
cdef vector[const char*] param_vals
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef nn_uint num_args
for k, v in kwargs.items():
param_keys.push_back(k)
param_vals.push_back(str(v))
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))
# keep strings in vector
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
num_args = param_keys.size()
CALL(NNSymbolSetAttrs(
handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
Expand All @@ -225,7 +231,7 @@ cdef NewSymbol(SymbolHandle handle):
return sym


def Variable(const char* name, **kwargs):
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
Expand All @@ -241,6 +247,7 @@ def Variable(const char* name, **kwargs):
The created variable symbol.
"""
cdef SymbolHandle handle
name = c_str(name)
CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle)

Expand Down Expand Up @@ -274,10 +281,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
func_hint = func_name.lower()

def creator(*args, **kwargs):
cdef vector[const char*] param_keys
cdef vector[const char*] param_vals
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef vector[SymbolHandle] symbol_args
cdef vector[const char*] symbol_keys
cdef vector[string] ssymbol_keys
cdef SymbolHandle ret_handle

name = kwargs.pop("name", None)
Expand All @@ -286,11 +293,11 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
if len(kwargs) != 0:
for k, v in kwargs.items():
if isinstance(v, Symbol):
symbol_keys.push_back(k)
ssymbol_keys.push_back(c_str(k))
symbol_args.push_back((<Symbol>v).handle)
else:
param_keys.push_back(k)
param_vals.push_back(str(v))
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))

if len(args) != 0:
if symbol_args.size() != 0:
Expand All @@ -301,6 +308,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<Symbol>v).handle)

cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys)

CALL(NNSymbolCreateAtomicSymbol(
handle,
<nn_uint>param_keys.size(),
Expand All @@ -315,7 +326,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
name = NameManager.current.get(name, func_hint)

cdef const char* c_name = NULL

if name:
name = c_str(name)
c_name = name

CALL(NNSymbolCompose(
Expand Down

0 comments on commit ad0ab0a

Please sign in to comment.