diff --git a/.travis.yml b/.travis.yml index 31d6e49f3dd1..e7110ecbacac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ language: cpp os: - linux - # - osx + - osx env: # code analysis diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index b7bc41cabbfd..41014e867da4 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -25,8 +25,13 @@ class ArgVariant(ctypes.Union): kStr = 3 kNodeHandle = 4 +NODE_TYPE = { +} -def _type_key(handle): +def _return_node(x): + handle = x.v_handle + if not isinstance(handle, ctypes.c_void_p): + handle = ctypes.c_void_p(handle) ret_val = ArgVariant() ret_typeid = ctypes.c_int() ret_success = ctypes.c_int() @@ -35,17 +40,14 @@ def _type_key(handle): ctypes.byref(ret_val), ctypes.byref(ret_typeid), ctypes.byref(ret_success))) - return py_str(ret_val.v_str) - -NODE_TYPE = { -} + return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle) RET_SWITCH = { kNull: lambda x: None, kLong: lambda x: x.v_long, kDouble: lambda x: x.v_double, kStr: lambda x: py_str(x.v_str), - kNodeHandle: lambda x: NODE_TYPE.get(_type_key(x), NodeBase)(x.v_handle) + kNodeHandle: lambda x: _return_node(x) } class SliceBase(object):