diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index b2a4ca3ccf13..ed425992c8f9 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -16,13 +16,17 @@ # under the License. # pylint: disable=redefined-builtin, wildcard-import """TVM: Low level DSL/IR stack for tensor computation.""" -from __future__ import absolute_import as _abs - import multiprocessing import sys import traceback -from . import _pyversion +# import ffi related features +from ._ffi.base import TVMError, __version__ +from ._ffi.runtime_ctypes import TypeCode, TVMType +from ._ffi.ndarray import TVMContext +from ._ffi.packed_func import PackedFunc as Function +from ._ffi.registry import register_object, register_func, register_extension +from ._ffi.object import Object from . import tensor from . import arith @@ -34,7 +38,6 @@ from . import container from . import schedule from . import module -from . import object from . import attrs from . import ir_builder from . import target @@ -48,15 +51,9 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev -from ._ffi.runtime_ctypes import TypeCode, TVMType -from ._ffi.ndarray import TVMContext -from ._ffi.function import Function -from ._ffi.base import TVMError, __version__ from .api import * from .intrin import * from .tensor_intrin import decl_tensor_intrin -from .object import register_object -from .ndarray import register_extension from .schedule import create_schedule from .build_module import build, lower, build_config from .tag import tag_scope diff --git a/python/tvm/_ffi/__init__.py b/python/tvm/_ffi/__init__.py index f19851c2407a..1b2fc58d2927 100644 --- a/python/tvm/_ffi/__init__.py +++ b/python/tvm/_ffi/__init__.py @@ -24,3 +24,7 @@ Some performance critical functions are implemented by cython and have a ctypes fallback implementation. """ +from . import _pyversion +from .base import register_error +from .registry import register_object, register_func, register_extension +from .registry import _init_api, get_global_func diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index c572947c8d19..949cc8b1a987 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -16,8 +16,6 @@ # under the License. # pylint: disable=invalid-name """Runtime NDArray api""" -from __future__ import absolute_import - import ctypes from ..base import _LIB, check_call, c_str from ..runtime_ctypes import TVMArrayHandle diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 8a2fb1b5363e..907b7ddef616 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -16,8 +16,6 @@ # under the License. # pylint: disable=invalid-name """Runtime Object api""" -from __future__ import absolute_import - import ctypes from ..base import _LIB, check_call from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/packed_func.py similarity index 85% rename from python/tvm/_ffi/_ctypes/function.py rename to python/tvm/_ffi/_ctypes/packed_func.py index ee3deada7ce5..5eaa73836b0c 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -17,15 +17,12 @@ # coding: utf-8 # pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import """Function configuration API.""" -from __future__ import absolute_import - import ctypes import traceback from numbers import Number, Integral -from ..base import _LIB, get_last_ffi_error, py2cerror +from ..base import _LIB, get_last_ffi_error, py2cerror, check_call from ..base import c_str, string_types -from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array @@ -35,7 +32,7 @@ from .object import ObjectBase, _set_class_object from . import object as _object -FunctionHandle = ctypes.c_void_p +PackedFuncHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p ObjectHandle = ctypes.c_void_p TVMRetValueHandle = ctypes.c_void_p @@ -49,6 +46,15 @@ def _ctypes_free_resource(rhandle): TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource) ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ)) + +def _make_packed_func(handle, is_global): + """Make a packed function class""" + obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC) + obj.is_global = is_global + obj.handle = handle + return obj + + def convert_to_tvm_func(pyfunc): """Convert a python function to TVM function @@ -89,7 +95,7 @@ def cfun(args, type_codes, num_args, ret, _): _ = rv return 0 - handle = FunctionHandle() + handle = PackedFuncHandle() f = TVMPackedCFunc(cfun) # NOTE: We will need to use python-api to increase ref count of the f # TVM_FREE_PYOBJ will be called after it is no longer needed. @@ -98,7 +104,7 @@ def cfun(args, type_codes, num_args, ret, _): if _LIB.TVMFuncCreateFromCFunc( f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0: raise get_last_ffi_error() - return _CLASS_FUNCTION(handle, False) + return _make_packed_func(handle, False) def _make_tvm_args(args, temp_args): @@ -144,15 +150,15 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, string_types): values[i].v_str = c_str(arg) type_codes[i] = TypeCode.STR - elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): - arg = convert_to_object(arg) + elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): + arg = _FUNC_CONVERT_TO_OBJECT(arg) values[i].v_handle = arg.handle type_codes[i] = TypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle type_codes[i] = TypeCode.MODULE_HANDLE - elif isinstance(arg, FunctionBase): + elif isinstance(arg, PackedFuncBase): values[i].v_handle = arg.handle type_codes[i] = TypeCode.PACKED_FUNC_HANDLE elif isinstance(arg, ctypes.c_void_p): @@ -168,7 +174,7 @@ def _make_tvm_args(args, temp_args): return values, type_codes, num_args -class FunctionBase(object): +class PackedFuncBase(object): """Function base.""" __slots__ = ["handle", "is_global"] # pylint: disable=no-member @@ -177,7 +183,7 @@ def __init__(self, handle, is_global): Parameters ---------- - handle : FunctionHandle + handle : PackedFuncHandle the handle to the underlying function. is_global : bool @@ -238,9 +244,22 @@ def _return_module(x): def _handle_return_func(x): """Return function""" handle = x.v_handle - if not isinstance(handle, FunctionHandle): - handle = FunctionHandle(handle) - return _CLASS_FUNCTION(handle, False) + if not isinstance(handle, PackedFuncHandle): + handle = PackedFuncHandle(handle) + return _CLASS_PACKED_FUNC(handle, False) + + +def _get_global_func(name, allow_missing=False): + handle = PackedFuncHandle() + check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) + + if handle.value: + return _make_packed_func(handle, False) + + if allow_missing: + return None + + raise ValueError("Cannot find global function %s" % name) # setup return handle for function type _object.__init_by_constructor__ = __init_handle_by_constructor__ @@ -255,13 +274,22 @@ def _handle_return_func(x): C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) _CLASS_MODULE = None -_CLASS_FUNCTION = None +_CLASS_PACKED_FUNC = None +_CLASS_OBJECT_GENERIC = None +_FUNC_CONVERT_TO_OBJECT = None + def _set_class_module(module_class): """Initialize the module.""" global _CLASS_MODULE _CLASS_MODULE = module_class -def _set_class_function(func_class): - global _CLASS_FUNCTION - _CLASS_FUNCTION = func_class +def _set_class_packed_func(packed_func_class): + global _CLASS_PACKED_FUNC + _CLASS_PACKED_FUNC = packed_func_class + +def _set_class_object_generic(object_generic_class, func_convert_to_object): + global _CLASS_OBJECT_GENERIC + global _FUNC_CONVERT_TO_OBJECT + _CLASS_OBJECT_GENERIC = object_generic_class + _FUNC_CONVERT_TO_OBJECT = func_convert_to_object diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 31c4786b858f..f45748fdd4de 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -16,8 +16,6 @@ # under the License. """The C Types used in API.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs - import ctypes import struct from ..base import py_str, check_call, _LIB diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 420ec6221ad9..ad281d7512d9 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -75,7 +75,7 @@ ctypedef int64_t tvm_index_t ctypedef DLTensor* DLTensorHandle ctypedef void* TVMStreamHandle ctypedef void* TVMRetValueHandle -ctypedef void* TVMFunctionHandle +ctypedef void* TVMPackedFuncHandle ctypedef void* ObjectHandle ctypedef struct TVMObject: @@ -96,13 +96,15 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle) cdef extern from "tvm/runtime/c_runtime_api.h": void TVMAPISetLastError(const char* msg) const char *TVMGetLastError() - int TVMFuncCall(TVMFunctionHandle func, + int TVMFuncGetGlobal(const char* name, + TVMPackedFuncHandle* out); + int TVMFuncCall(TVMPackedFuncHandle func, TVMValue* arg_values, int* type_codes, int num_args, TVMValue* ret_val, int* ret_type_code) - int TVMFuncFree(TVMFunctionHandle func) + int TVMFuncFree(TVMPackedFuncHandle func) int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, @@ -110,7 +112,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin, - TVMFunctionHandle *out) + TVMPackedFuncHandle *out) int TVMCbArgToReturn(TVMValue* value, int code) int TVMArrayAlloc(tvm_index_t* shape, tvm_index_t ndim, diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index cbf9d5859046..730f8fc13345 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -17,7 +17,5 @@ include "./base.pxi" include "./object.pxi" -# include "./node.pxi" -include "./function.pxi" +include "./packed_func.pxi" include "./ndarray.pxi" - diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 494c3ff47c8e..25a9c3fb70cf 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -96,6 +96,6 @@ cdef class ObjectBase: self.chandle = NULL cdef void* chandle ConstructorCall( - (fconstructor).chandle, + (fconstructor).chandle, kTVMObjectHandle, args, &chandle) self.chandle = chandle diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/packed_func.pxi similarity index 87% rename from python/tvm/_ffi/_cython/function.pxi rename to python/tvm/_ffi/_cython/packed_func.pxi index bde672f02168..5630d72e9ed9 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -20,7 +20,6 @@ import traceback from cpython cimport Py_INCREF, Py_DECREF from numbers import Number, Integral from ..base import string_types, py2cerror -from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray @@ -67,6 +66,13 @@ cdef int tvm_callback(TVMValue* args, return 0 +cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global): + obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC) + (obj).chandle = chandle + (obj).is_global = is_global + return obj + + def convert_to_tvm_func(object pyfunc): """Convert a python function to TVM function @@ -80,15 +86,13 @@ def convert_to_tvm_func(object pyfunc): tvmfunc: tvm.Function The converted tvm function. """ - cdef TVMFunctionHandle chandle + cdef TVMPackedFuncHandle chandle Py_INCREF(pyfunc) CALL(TVMFuncCreateFromCFunc(tvm_callback, (pyfunc), tvm_callback_finalize, &chandle)) - ret = _CLASS_FUNCTION(None, False) - (ret).chandle = chandle - return ret + return make_packed_func(chandle, False) cdef inline int make_arg(object arg, @@ -149,29 +153,30 @@ cdef inline int make_arg(object arg, value[0].v_str = tstr tcode[0] = kTVMStr temp_args.append(tstr) - elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): - arg = convert_to_object(arg) + elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): + arg = _FUNC_CONVERT_TO_OBJECT(arg) value[0].v_handle = (arg).chandle tcode[0] = kTVMObjectHandle temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): value[0].v_handle = c_handle(arg.handle) tcode[0] = kTVMModuleHandle - elif isinstance(arg, FunctionBase): - value[0].v_handle = (arg).chandle + elif isinstance(arg, PackedFuncBase): + value[0].v_handle = (arg).chandle tcode[0] = kTVMPackedFuncHandle elif isinstance(arg, ctypes.c_void_p): value[0].v_handle = c_handle(arg) tcode[0] = kTVMOpaqueHandle elif callable(arg): arg = convert_to_tvm_func(arg) - value[0].v_handle = (arg).chandle + value[0].v_handle = (arg).chandle tcode[0] = kTVMPackedFuncHandle temp_args.append(arg) else: raise TypeError("Don't know how to handle type %s" % type(arg)) return 0 + cdef inline bytearray make_ret_bytes(void* chandle): handle = ctypes_handle(chandle) arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] @@ -182,6 +187,7 @@ cdef inline bytearray make_ret_bytes(void* chandle): raise RuntimeError('memmove failed') return res + cdef inline object make_ret(TVMValue value, int tcode): """convert result to return value.""" if tcode == kTVMObjectHandle: @@ -205,9 +211,7 @@ cdef inline object make_ret(TVMValue value, int tcode): elif tcode == kTVMModuleHandle: return _CLASS_MODULE(ctypes_handle(value.v_handle)) elif tcode == kTVMPackedFuncHandle: - fobj = _CLASS_FUNCTION(None, False) - (fobj).chandle = value.v_handle - return fobj + return make_packed_func(value.v_handle, False) elif tcode in _TVM_EXT_RET: return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) @@ -264,8 +268,8 @@ cdef inline int ConstructorCall(void* constructor_handle, return 0 -cdef class FunctionBase: - cdef TVMFunctionHandle chandle +cdef class PackedFuncBase: + cdef TVMPackedFuncHandle chandle cdef int is_global cdef inline _set_handle(self, handle): @@ -305,19 +309,39 @@ cdef class FunctionBase: return make_ret(ret_val, ret_tcode) -_CLASS_FUNCTION = None +def _get_global_func(name, allow_missing): + cdef TVMPackedFuncHandle chandle + CALL(TVMFuncGetGlobal(c_str(name), &chandle)) + if chandle != NULL: + return make_packed_func(chandle, True) + + if allow_missing: + return None + + raise ValueError("Cannot find global function %s" % name) + + +_CLASS_PACKED_FUNC = None _CLASS_MODULE = None _CLASS_OBJECT = None +_CLASS_OBJECT_GENERIC = None +_FUNC_CONVERT_TO_OBJECT = None def _set_class_module(module_class): """Initialize the module.""" global _CLASS_MODULE _CLASS_MODULE = module_class -def _set_class_function(func_class): - global _CLASS_FUNCTION - _CLASS_FUNCTION = func_class +def _set_class_packed_func(func_class): + global _CLASS_PACKED_FUNC + _CLASS_PACKED_FUNC = func_class def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class + +def _set_class_object_generic(object_generic_class, func_convert_to_object): + global _CLASS_OBJECT_GENERIC + global _FUNC_CONVERT_TO_OBJECT + _CLASS_OBJECT_GENERIC = object_generic_class + _FUNC_CONVERT_TO_OBJECT = func_convert_to_object diff --git a/python/tvm/_pyversion.py b/python/tvm/_ffi/_pyversion.py similarity index 93% rename from python/tvm/_pyversion.py rename to python/tvm/_ffi/_pyversion.py index a46b22028387..67591b357ca8 100644 --- a/python/tvm/_pyversion.py +++ b/python/tvm/_ffi/_pyversion.py @@ -18,6 +18,9 @@ """ import sys +#---------------------------- +# Python3 version. +#---------------------------- if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 5): PY3STATEMENT = """TVM project proudly dropped support of Python2. The minimal Python requirement is Python 3.5 diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 36effa3800a8..ddc942ada40d 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -17,8 +17,6 @@ # coding: utf-8 # pylint: disable=invalid-name """Base library for TVM FFI.""" -from __future__ import absolute_import - import sys import os import ctypes @@ -28,27 +26,22 @@ #---------------------------- # library loading #---------------------------- -if sys.version_info[0] == 3: - string_types = (str,) - integer_types = (int, np.int32) - numeric_types = integer_types + (float, np.float32) - # this function is needed for python3 - # to convert ctypes.char_p .value back to python str - if sys.platform == "win32": - def _py_str(x): - try: - return x.decode('utf-8') - except UnicodeDecodeError: - encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP()) - return x.decode(encoding) - py_str = _py_str - else: - py_str = lambda x: x.decode('utf-8') +string_types = (str,) +integer_types = (int, np.int32) +numeric_types = integer_types + (float, np.float32) + +# this function is needed for python3 +# to convert ctypes.char_p .value back to python str +if sys.platform == "win32": + def _py_str(x): + try: + return x.decode('utf-8') + except UnicodeDecodeError: + encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP()) + return x.decode(encoding) + py_str = _py_str else: - string_types = (basestring,) - integer_types = (int, long, np.int32) - numeric_types = integer_types + (float, np.float32) - py_str = lambda x: x + py_str = lambda x: x.decode('utf-8') def _load_lib(): diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index e87a3368268e..c026a7afffe9 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Library information.""" -from __future__ import absolute_import import sys import os @@ -39,6 +38,7 @@ def split_env_var(env_var, split): return [p.strip() for p in os.environ[env_var].split(split)] return [] + def find_lib_path(name=None, search_path=None, optional=False): """Find dynamic library files. diff --git a/python/tvm/_ffi/module.py b/python/tvm/_ffi/module.py new file mode 100644 index 000000000000..d6c81b3cfe33 --- /dev/null +++ b/python/tvm/_ffi/module.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, unused-import +"""Runtime Module namespace.""" +import ctypes +from .base import _LIB, check_call, c_str, string_types +from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module + +class ModuleBase(object): + """Base class for module""" + __slots__ = ["handle", "_entry", "entry_name"] + + def __init__(self, handle): + self.handle = handle + self._entry = None + self.entry_name = "__tvm_main__" + + def __del__(self): + check_call(_LIB.TVMModFree(self.handle)) + + def __hash__(self): + return ctypes.cast(self.handle, ctypes.c_void_p).value + + @property + def entry_func(self): + """Get the entry function + + Returns + ------- + f : Function + The entry function if exist + """ + if self._entry: + return self._entry + self._entry = self.get_function(self.entry_name) + return self._entry + + def get_function(self, name, query_imports=False): + """Get function from the module. + + Parameters + ---------- + name : str + The name of the function + + query_imports : bool + Whether also query modules imported by this module. + + Returns + ------- + f : Function + The result function. + """ + ret_handle = PackedFuncHandle() + check_call(_LIB.TVMModGetFunction( + self.handle, c_str(name), + ctypes.c_int(query_imports), + ctypes.byref(ret_handle))) + if not ret_handle.value: + raise AttributeError( + "Module has no function '%s'" % name) + return PackedFunc(ret_handle, False) + + def import_module(self, module): + """Add module to the import list of current one. + + Parameters + ---------- + module : Module + The other module. + """ + check_call(_LIB.TVMModImport(self.handle, module.handle)) + + def __getitem__(self, name): + if not isinstance(name, string_types): + raise ValueError("Can only take string as function name") + return self.get_function(name) + + def __call__(self, *args): + if self._entry: + return self._entry(*args) + f = self.entry_func + return f(*args) diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 650f01dd5409..f526195a6306 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -16,35 +16,22 @@ # under the License. # pylint: disable=invalid-name, unused-import """Runtime NDArray api""" -from __future__ import absolute_import - -import sys import ctypes import numpy as np from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle from .runtime_ctypes import TypeCode, tvm_shape_index_t - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError - try: # pylint: disable=wrong-import-position if _FFI_MODE == "ctypes": raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack - from ._cy3.core import NDArrayBase as _NDArrayBase - from ._cy3.core import _reg_extension - else: - from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack - from ._cy2.core import NDArrayBase as _NDArrayBase - from ._cy2.core import _reg_extension -except IMPORT_EXCEPT: + from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack + from ._cy3.core import NDArrayBase as _NDArrayBase +except (RuntimeError, ImportError): # pylint: disable=wrong-import-position from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack from ._ctypes.ndarray import NDArrayBase as _NDArrayBase - from ._ctypes.ndarray import _reg_extension def context(dev_type, dev_id=0): @@ -297,59 +284,3 @@ def copyto(self, target): res = empty(self.shape, self.dtype, target) return self._copyto(res) raise ValueError("Unsupported target type %s" % str(type(target))) - - -def register_extension(cls, fcreate=None): - """Register a extension class to TVM. - - After the class is registered, the class will be able - to directly pass as Function argument generated by TVM. - - Parameters - ---------- - cls : class - The class object to be registered as extension. - - fcreate : function, optional - The creation function to create a class object given handle value. - - Note - ---- - The registered class is requires one property: _tvm_handle. - - If the registered class is a subclass of NDArray, - it is required to have a class attribute _array_type_code. - Otherwise, it is required to have a class attribute _tvm_tcode. - - - ```_tvm_handle``` returns integer represents the address of the handle. - - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type - code of the class. - - Returns - ------- - cls : class - The class being registered. - - Example - ------- - The following code registers user defined class - MyTensor to be DLTensor compatible. - - .. code-block:: python - - @tvm.register_extension - class MyTensor(object): - _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE - - def __init__(self): - self.handle = _LIB.NewDLTensor() - - @property - def _tvm_handle(self): - return self.handle.value - """ - assert hasattr(cls, "_tvm_tcode") - if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: - raise ValueError("Cannot register create when extension tcode is same as buildin") - _reg_extension(cls, fcreate) - return cls diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py index 83d4129a7140..a80858058cb6 100644 --- a/python/tvm/_ffi/object.py +++ b/python/tvm/_ffi/object.py @@ -16,33 +16,20 @@ # under the License. # pylint: disable=invalid-name, unused-import """Runtime Object API""" -from __future__ import absolute_import - -import sys import ctypes from .. import _api_internal from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str -from .object_generic import ObjectGeneric, convert_to_object, const - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError try: # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "ctypes": raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object - from ._cy3.core import ObjectBase as _ObjectBase - from ._cy3.core import _register_object - else: - from ._cy2.core import _set_class_object - from ._cy2.core import ObjectBase as _ObjectBase - from ._cy2.core import _register_object -except IMPORT_EXCEPT: + from ._cy3.core import _set_class_object, _set_class_object_generic + from ._cy3.core import ObjectBase +except (RuntimeError, ImportError): # pylint: disable=wrong-import-position,unused-import - from ._ctypes.function import _set_class_object - from ._ctypes.object import ObjectBase as _ObjectBase - from ._ctypes.object import _register_object + from ._ctypes.packed_func import _set_class_object, _set_class_object_generic + from ._ctypes.object import ObjectBase def _new_object(cls): @@ -50,7 +37,7 @@ def _new_object(cls): return cls.__new__(cls) -class Object(_ObjectBase): +class Object(ObjectBase): """Base class for all tvm's runtime objects.""" def __repr__(self): return _api_internal._format_str(self) @@ -104,52 +91,6 @@ def same_as(self, other): return self.__hash__() == other.__hash__() -def register_object(type_key=None): - """register object type. - - Parameters - ---------- - type_key : str or cls - The type key of the node - - Examples - -------- - The following code registers MyObject - using type key "test.MyObject" - - .. code-block:: python - - @tvm.register_object("test.MyObject") - class MyObject(Object): - pass - """ - object_name = type_key if isinstance(type_key, str) else type_key.__name__ - - def register(cls): - """internal register function""" - if hasattr(cls, "_type_index"): - tindex = cls._type_index - else: - tidx = ctypes.c_uint() - if not _RUNTIME_ONLY: - check_call(_LIB.TVMObjectTypeKey2Index( - c_str(object_name), ctypes.byref(tidx))) - else: - # directly skip unknown objects during runtime. - ret = _LIB.TVMObjectTypeKey2Index( - c_str(object_name), ctypes.byref(tidx)) - if ret != 0: - return cls - tindex = tidx.value - _register_object(tindex, cls) - return cls - - if isinstance(type_key, str): - return register - - return register(type_key) - - def getitem_helper(obj, elem_getter, length, idx): """Helper function to implement a pythonic getitem function. diff --git a/python/tvm/_ffi/object_generic.py b/python/tvm/_ffi/object_generic.py index 92e73ad79e88..cbbca4dd34a6 100644 --- a/python/tvm/_ffi/object_generic.py +++ b/python/tvm/_ffi/object_generic.py @@ -16,35 +16,14 @@ # under the License. """Common implementation of object generic related logic""" # pylint: disable=unused-import -from __future__ import absolute_import - from numbers import Number, Integral from .. import _api_internal -from .base import string_types - -# Object base class -_CLASS_OBJECTS = None - -def _set_class_objects(cls): - global _CLASS_OBJECTS - _CLASS_OBJECTS = cls - -def _scalar_type_inference(value): - if hasattr(value, 'dtype'): - dtype = str(value.dtype) - elif isinstance(value, bool): - dtype = 'bool' - elif isinstance(value, float): - # We intentionally convert the float to float32 since it's more common in DL. - dtype = 'float32' - elif isinstance(value, int): - # We intentionally convert the python int to int32 since it's more common in DL. - dtype = 'int32' - else: - raise NotImplementedError('Cannot automatically inference the type.' - ' value={}'.format(value)) - return dtype +from .base import string_types +from .object import ObjectBase, _set_class_object_generic +from .ndarray import NDArrayBase +from .packed_func import PackedFuncBase, convert_to_tvm_func +from .module import ModuleBase class ObjectGeneric(object): @@ -54,6 +33,9 @@ def asobject(self): raise NotImplementedError() +_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase) + + def convert_to_object(value): """Convert a python value to corresponding object type. @@ -95,22 +77,65 @@ def convert_to_object(value): raise ValueError("don't know how to convert type %s to object" % type(value)) +def convert(value): + """Convert value to TVM object or function. + + Parameters + ---------- + value : python value + + Returns + ------- + tvm_val : Object or Function + Converted value in TVM + """ + if isinstance(value, (PackedFuncBase, ObjectBase)): + return value + + if callable(value): + return convert_to_tvm_func(value) + + return convert_to_object(value) + + +def _scalar_type_inference(value): + if hasattr(value, 'dtype'): + dtype = str(value.dtype) + elif isinstance(value, bool): + dtype = 'bool' + elif isinstance(value, float): + # We intentionally convert the float to float32 since it's more common in DL. + dtype = 'float32' + elif isinstance(value, int): + # We intentionally convert the python int to int32 since it's more common in DL. + dtype = 'int32' + else: + raise NotImplementedError('Cannot automatically inference the type.' + ' value={}'.format(value)) + return dtype + def const(value, dtype=None): - """Construct a constant value for a given type. + """construct a constant Parameters ---------- - value : int or float - The input value + value : number + The content of the constant number. dtype : str or None, optional The data type. Returns ------- - expr : Expr - Constant expression corresponds to the value. + const_val: tvm.Expr + The result expression. """ if dtype is None: dtype = _scalar_type_inference(value) + if dtype == "uint64" and value >= (1 << 63): + return _api_internal._LargeUIntImm( + dtype, value & ((1 << 32) - 1), value >> 32) return _api_internal._const(value, dtype) + + +_set_class_object_generic(ObjectGeneric, convert_to_object) diff --git a/python/tvm/_ffi/packed_func.py b/python/tvm/_ffi/packed_func.py new file mode 100644 index 000000000000..d0917a8d2965 --- /dev/null +++ b/python/tvm/_ffi/packed_func.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, unused-import +"""Packed Function namespace.""" +import ctypes +from .base import _LIB, check_call, c_str, string_types, _FFI_MODE + +try: + # pylint: disable=wrong-import-position + if _FFI_MODE == "ctypes": + raise ImportError() + from ._cy3.core import _set_class_packed_func, _set_class_module + from ._cy3.core import PackedFuncBase + from ._cy3.core import convert_to_tvm_func +except (RuntimeError, ImportError): + # pylint: disable=wrong-import-position + from ._ctypes.packed_func import _set_class_packed_func, _set_class_module + from ._ctypes.packed_func import PackedFuncBase + from ._ctypes.packed_func import convert_to_tvm_func + + +PackedFuncHandle = ctypes.c_void_p + +class PackedFunc(PackedFuncBase): + """The PackedFunc object used in TVM. + + Function plays an key role to bridge front and backend in TVM. + Function provide a type-erased interface, you can call function with positional arguments. + + The compiled module returns Function. + TVM backend also registers and exposes its API as Functions. + For example, the developer function exposed in tvm.ir_pass are actually + C++ functions that are registered as PackedFunc + + The following are list of common usage scenario of tvm.Function. + + - Automatic exposure of C++ API into python + - To call PackedFunc from python side + - To call python callbacks to inspect results in generated code + - Bring python hook into C++ backend + + See Also + -------- + tvm.register_func: How to register global function. + tvm.get_global_func: How to get global function. + """ + +_set_class_packed_func(PackedFunc) diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/registry.py similarity index 53% rename from python/tvm/_ffi/function.py rename to python/tvm/_ffi/registry.py index 22e03563976b..be1578550a3b 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/registry.py @@ -16,142 +16,127 @@ # under the License. # pylint: disable=invalid-name, unused-import -"""Function namespace.""" -from __future__ import absolute_import - +"""FFI registry to register function and objects.""" import sys import ctypes -from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE -from .object_generic import _set_class_objects +from .. import _api_internal -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError +from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY try: - # pylint: disable=wrong-import-position + # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "ctypes": raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_function, _set_class_module - from ._cy3.core import FunctionBase as _FunctionBase - from ._cy3.core import NDArrayBase as _NDArrayBase - from ._cy3.core import ObjectBase as _ObjectBase - from ._cy3.core import convert_to_tvm_func - else: - from ._cy2.core import _set_class_function, _set_class_module - from ._cy2.core import FunctionBase as _FunctionBase - from ._cy2.core import NDArrayBase as _NDArrayBase - from ._cy2.core import ObjectBase as _ObjectBase - from ._cy2.core import convert_to_tvm_func -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_function, _set_class_module - from ._ctypes.function import FunctionBase as _FunctionBase - from ._ctypes.ndarray import NDArrayBase as _NDArrayBase - from ._ctypes.object import ObjectBase as _ObjectBase - from ._ctypes.function import convert_to_tvm_func - -FunctionHandle = ctypes.c_void_p - -class Function(_FunctionBase): - """The PackedFunc object used in TVM. - - Function plays an key role to bridge front and backend in TVM. - Function provide a type-erased interface, you can call function with positional arguments. - - The compiled module returns Function. - TVM backend also registers and exposes its API as Functions. - For example, the developer function exposed in tvm.ir_pass are actually - C++ functions that are registered as PackedFunc - - The following are list of common usage scenario of tvm.Function. - - - Automatic exposure of C++ API into python - - To call PackedFunc from python side - - To call python callbacks to inspect results in generated code - - Bring python hook into C++ backend - - See Also + from ._cy3.core import _register_object + from ._cy3.core import _reg_extension + from ._cy3.core import convert_to_tvm_func, _get_global_func, PackedFuncBase +except (RuntimeError, ImportError): + # pylint: disable=wrong-import-position,unused-import + from ._ctypes.object import _register_object + from ._ctypes.ndarray import _reg_extension + from ._ctypes.packed_func import convert_to_tvm_func, _get_global_func, PackedFuncBase + + +def register_object(type_key=None): + """register object type. + + Parameters + ---------- + type_key : str or cls + The type key of the node + + Examples -------- - tvm.register_func: How to register global function. - tvm.get_global_func: How to get global function. + The following code registers MyObject + using type key "test.MyObject" + + .. code-block:: python + + @tvm.register_object("test.MyObject") + class MyObject(Object): + pass """ + object_name = type_key if isinstance(type_key, str) else type_key.__name__ + + def register(cls): + """internal register function""" + if hasattr(cls, "_type_index"): + tindex = cls._type_index + else: + tidx = ctypes.c_uint() + if not _RUNTIME_ONLY: + check_call(_LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + else: + # directly skip unknown objects during runtime. + ret = _LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx)) + if ret != 0: + return cls + tindex = tidx.value + _register_object(tindex, cls) + return cls + + if isinstance(type_key, str): + return register + + return register(type_key) + + +def register_extension(cls, fcreate=None): + """Register a extension class to TVM. + + After the class is registered, the class will be able + to directly pass as Function argument generated by TVM. + Parameters + ---------- + cls : class + The class object to be registered as extension. + + fcreate : function, optional + The creation function to create a class object given handle value. -class ModuleBase(object): - """Base class for module""" - __slots__ = ["handle", "_entry", "entry_name"] - - def __init__(self, handle): - self.handle = handle - self._entry = None - self.entry_name = "__tvm_main__" - - def __del__(self): - check_call(_LIB.TVMModFree(self.handle)) - - def __hash__(self): - return ctypes.cast(self.handle, ctypes.c_void_p).value - - @property - def entry_func(self): - """Get the entry function - - Returns - ------- - f : Function - The entry function if exist - """ - if self._entry: - return self._entry - self._entry = self.get_function(self.entry_name) - return self._entry - - def get_function(self, name, query_imports=False): - """Get function from the module. - - Parameters - ---------- - name : str - The name of the function - - query_imports : bool - Whether also query modules imported by this module. - - Returns - ------- - f : Function - The result function. - """ - ret_handle = FunctionHandle() - check_call(_LIB.TVMModGetFunction( - self.handle, c_str(name), - ctypes.c_int(query_imports), - ctypes.byref(ret_handle))) - if not ret_handle.value: - raise AttributeError( - "Module has no function '%s'" % name) - return Function(ret_handle, False) - - def import_module(self, module): - """Add module to the import list of current one. - - Parameters - ---------- - module : Module - The other module. - """ - check_call(_LIB.TVMModImport(self.handle, module.handle)) - - def __getitem__(self, name): - if not isinstance(name, string_types): - raise ValueError("Can only take string as function name") - return self.get_function(name) - - def __call__(self, *args): - if self._entry: - return self._entry(*args) - f = self.entry_func - return f(*args) + Note + ---- + The registered class is requires one property: _tvm_handle. + + If the registered class is a subclass of NDArray, + it is required to have a class attribute _array_type_code. + Otherwise, it is required to have a class attribute _tvm_tcode. + + - ```_tvm_handle``` returns integer represents the address of the handle. + - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type + code of the class. + + Returns + ------- + cls : class + The class being registered. + + Example + ------- + The following code registers user defined class + MyTensor to be DLTensor compatible. + + .. code-block:: python + + @tvm.register_extension + class MyTensor(object): + _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE + + def __init__(self): + self.handle = _LIB.NewDLTensor() + + @property + def _tvm_handle(self): + return self.handle.value + """ + assert hasattr(cls, "_tvm_tcode") + if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: + raise ValueError("Cannot register create when extension tcode is same as buildin") + _reg_extension(cls, fcreate) + return cls def register_func(func_name, f=None, override=False): @@ -189,7 +174,7 @@ def my_packed_func(*args): return 10 # Get it out from global function table f = tvm.get_global_func("my_packed_func") - assert isinstance(f, tvm.nd.Function) + assert isinstance(f, tvm.PackedFunc) y = f(*targs) assert y == 10 """ @@ -203,7 +188,7 @@ def my_packed_func(*args): ioverride = ctypes.c_int(override) def register(myf): """internal register function""" - if not isinstance(myf, Function): + if not isinstance(myf, PackedFuncBase): myf = convert_to_tvm_func(myf) check_call(_LIB.TVMFuncRegisterGlobal( c_str(func_name), myf.handle, ioverride)) @@ -226,19 +211,10 @@ def get_global_func(name, allow_missing=False): Returns ------- - func : tvm.Function + func : PackedFunc The function to be returned, None if function is missing. """ - handle = FunctionHandle() - check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) - if handle.value: - return Function(handle, False) - - if allow_missing: - return None - - raise ValueError("Cannot find global function %s" % name) - + return _get_global_func(name, allow_missing) def list_global_func_names(): @@ -290,6 +266,7 @@ def _get_api(f): flocal.is_global = True return flocal + def _init_api(namespace, target_module_name=None): """Initialize api for a given module name @@ -330,6 +307,3 @@ def _init_api_prefix(module_name, prefix): ff.__name__ = fname ff.__doc__ = ("TVM PackedFunc %s. " % fname) setattr(target_module, ff.__name__, ff) - -_set_class_function(Function) -_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase)) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 93c260a4f505..d6d9b3aac586 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -16,8 +16,6 @@ # under the License. """Common runtime ctypes.""" # pylint: disable=invalid-name -from __future__ import absolute_import - import ctypes import json import numpy as np diff --git a/python/tvm/api.py b/python/tvm/api.py index 46faae361a56..573732ece2f5 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -16,17 +16,13 @@ # under the License. """Functions defined in TVM.""" # pylint: disable=invalid-name,unused-import,redefined-builtin -from __future__ import absolute_import as _abs - from numbers import Integral as _Integral +import tvm._ffi + from ._ffi.base import string_types, TVMError -from ._ffi.object import register_object, Object -from ._ffi.object import convert_to_object as _convert_to_object -from ._ffi.object_generic import _scalar_type_inference -from ._ffi.function import Function -from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs -from ._ffi.function import convert_to_tvm_func as _convert_tvm_func +from ._ffi.object_generic import convert, const +from ._ffi.registry import register_func, get_global_func, extract_ext_funcs from ._ffi.runtime_ctypes import TVMType from . import _api_internal from . import make as _make @@ -75,30 +71,6 @@ def max_value(dtype): return _api_internal._max_value(dtype) -def const(value, dtype=None): - """construct a constant - - Parameters - ---------- - value : number - The content of the constant number. - - dtype : str or None, optional - The data type. - - Returns - ------- - const_val: tvm.Expr - The result expression. - """ - if dtype is None: - dtype = _scalar_type_inference(value) - if dtype == "uint64" and value >= (1 << 63): - return _api_internal._LargeUIntImm( - dtype, value & ((1 << 32) - 1), value >> 32) - return _api_internal._const(value, dtype) - - def get_env_func(name): """Get an EnvFunc by a global name. @@ -121,27 +93,6 @@ def get_env_func(name): return _api_internal._EnvFuncGet(name) -def convert(value): - """Convert value to TVM node or function. - - Parameters - ---------- - value : python value - - Returns - ------- - tvm_val : Object or Function - Converted value in TVM - """ - if isinstance(value, (Function, Object)): - return value - - if callable(value): - return _convert_tvm_func(value) - - return _convert_to_object(value) - - def load_json(json_str): """Load tvm object from json_str. @@ -1073,10 +1024,9 @@ def floormod(a, b): """ return _make._OpFloorMod(a, b) - -_init_api("tvm.api") - #pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min') max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max') + +tvm._ffi._init_api("tvm.api") diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 81f478c66b92..7434aee9c5f9 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -16,9 +16,9 @@ # under the License. """Arithmetic data structure and utility""" from __future__ import absolute_import as _abs +import tvm._ffi -from ._ffi.object import Object, register_object -from ._ffi.function import _init_api +from ._ffi.object import Object from . import _api_internal class IntSet(Object): @@ -32,7 +32,7 @@ def is_everything(self): return _api_internal._IntSetIsEverything(self) -@register_object("arith.IntervalSet") +@tvm._ffi.register_object("arith.IntervalSet") class IntervalSet(IntSet): """Represent set of continuous interval [min_value, max_value] @@ -49,7 +49,7 @@ def __init__(self, min_value, max_value): _make_IntervalSet, min_value, max_value) -@register_object("arith.ModularSet") +@tvm._ffi.register_object("arith.ModularSet") class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z """ def __init__(self, coeff, base): @@ -57,7 +57,7 @@ def __init__(self, coeff, base): _make_ModularSet, coeff, base) -@register_object("arith.ConstIntBound") +@tvm._ffi.register_object("arith.ConstIntBound") class ConstIntBound(Object): """Represent constant integer bound @@ -258,4 +258,4 @@ def update(self, var, info, override=False): "Do not know how to handle type {}".format(type(info))) -_init_api("tvm.arith") +tvm._ffi._init_api("tvm.arith") diff --git a/python/tvm/attrs.py b/python/tvm/attrs.py index 2963a0e21734..78d5b186b2e7 100644 --- a/python/tvm/attrs.py +++ b/python/tvm/attrs.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. """ TVM Attribute module, which is mainly used for defining attributes of operators""" -from ._ffi.object import Object, register_object -from ._ffi.function import _init_api +import tvm._ffi + +from ._ffi.object import Object from . import _api_internal -@register_object +@tvm._ffi.register_object class Attrs(Object): """Attribute node, which is mainly use for defining attributes of relay operators. @@ -92,4 +93,4 @@ def __getitem__(self, item): return self.__getattr__(item) -_init_api("tvm.attrs") +tvm._ffi._init_api("tvm.attrs") diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 85d2b8514779..c5097b23533a 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -19,11 +19,10 @@ This module provides the functions to transform schedule to LoweredFunc and compiled Module. """ -from __future__ import absolute_import as _abs import warnings +import tvm._ffi -from ._ffi.function import Function -from ._ffi.object import Object, register_object +from ._ffi.object import Object from . import api from . import _api_internal from . import tensor @@ -115,7 +114,7 @@ def exit(self): DumpIR.scope_level -= 1 -@register_object +@tvm._ffi.register_object class BuildConfig(Object): """Configuration scope to set a build config option. diff --git a/python/tvm/codegen.py b/python/tvm/codegen.py index 61ee1f78f139..7dc7bea90076 100644 --- a/python/tvm/codegen.py +++ b/python/tvm/codegen.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Code generation related functions.""" -from ._ffi.function import _init_api +import tvm._ffi def build_module(lowered_func, target): """Build lowered_func into Module. @@ -35,4 +35,4 @@ def build_module(lowered_func, target): """ return _Build(lowered_func, target) -_init_api("tvm.codegen") +tvm._ffi._init_api("tvm.codegen") diff --git a/python/tvm/container.py b/python/tvm/container.py index 673afb428987..b74cc04c75de 100644 --- a/python/tvm/container.py +++ b/python/tvm/container.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. """Container data structures used in TVM DSL.""" -from __future__ import absolute_import as _abs +import tvm._ffi + from tvm import ndarray as _nd from . import _api_internal -from ._ffi.object import Object, register_object, getitem_helper -from ._ffi.function import _init_api +from ._ffi.object import Object, getitem_helper + -@register_object +@tvm._ffi.register_object class Array(Object): """Array container of TVM. @@ -52,7 +53,7 @@ def __len__(self): return _api_internal._ArraySize(self) -@register_object +@tvm._ffi.register_object class EnvFunc(Object): """Environment function. @@ -66,7 +67,7 @@ def func(self): return _api_internal._EnvFuncGetPackedFunc(self) -@register_object +@tvm._ffi.register_object class Map(Object): """Map container of TVM. @@ -89,7 +90,7 @@ def __len__(self): return _api_internal._MapSize(self) -@register_object +@tvm._ffi.register_object class StrMap(Map): """A special map container that has str as key. @@ -101,7 +102,7 @@ def items(self): return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] -@register_object +@tvm._ffi.register_object class Range(Object): """Represent a range in TVM. @@ -110,7 +111,7 @@ class Range(Object): """ -@register_object +@tvm._ffi.register_object class LoweredFunc(Object): """Represent a LoweredFunc in TVM.""" MixedFunc = 0 @@ -118,7 +119,7 @@ class LoweredFunc(Object): DeviceFunc = 2 -@register_object("vm.ADT") +@tvm._ffi.register_object("vm.ADT") class ADT(Object): """Algebatic data type(ADT) object. @@ -168,4 +169,4 @@ def tuple_object(fields=None): return _Tuple(*fields) -_init_api("tvm.container") +tvm._ffi._init_api("tvm.container") diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 7d150c7c3d34..a5f6e3045491 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -19,8 +19,9 @@ import os import tempfile import shutil +import tvm._ffi + from tvm._ffi.base import string_types -from tvm._ffi.function import get_global_func from tvm.contrib import graph_runtime from tvm.ndarray import array from . import debug_result @@ -64,7 +65,7 @@ def create(graph_json_str, libmod, ctx, dump_root=None): fcreate = ctx[0]._rpc_sess.get_function( "tvm.graph_runtime_debug.create") else: - fcreate = get_global_func("tvm.graph_runtime_debug.create") + fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_debug.create") except ValueError: raise ValueError( "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 2c945d2fca95..6b7c099ff705 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -16,9 +16,9 @@ # under the License. """Minimum graph runtime that executes graph containing TVM PackedFunc.""" import numpy as np +import tvm._ffi from .._ffi.base import string_types -from .._ffi.function import get_global_func from .._ffi.runtime_ctypes import TVMContext from ..rpc import base as rpc_base @@ -54,7 +54,7 @@ def create(graph_json_str, libmod, ctx): if num_rpc_ctx == len(ctx): fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create") else: - fcreate = get_global_func("tvm.graph_runtime.create") + fcreate = tvm._ffi.get_global_func("tvm.graph_runtime.create") return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index aceab6dbfc89..3e2132eb5067 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. """External function interface to NNPACK libraries.""" -from __future__ import absolute_import as _abs +import tvm._ffi from .. import api as _api from .. import intrin as _intrin -from .._ffi.function import _init_api + def is_available(): """Check whether NNPACK is available, that is, `nnp_initialize()` @@ -202,4 +202,4 @@ def convolution_inference_weight_transform( "tvm.contrib.nnpack.convolution_inference_weight_transform", ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype) -_init_api("tvm.contrib.nnpack") +tvm._ffi._init_api("tvm.contrib.nnpack") diff --git a/python/tvm/contrib/random.py b/python/tvm/contrib/random.py index a57fac0cad68..059bf2344e6b 100644 --- a/python/tvm/contrib/random.py +++ b/python/tvm/contrib/random.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. """External function interface to random library.""" -from __future__ import absolute_import as _abs +import tvm._ffi from .. import api as _api from .. import intrin as _intrin -from .._ffi.function import _init_api def randint(low, high, size, dtype='int32'): @@ -96,4 +95,4 @@ def normal(loc, scale, size): "tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32') -_init_api("tvm.contrib.random") +tvm._ffi._init_api("tvm.contrib.random") diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 5ff30a121ff2..985c74787ed1 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """TFLite runtime that load and run tflite models.""" -from .._ffi.function import get_global_func +import tvm._ffi from ..rpc import base as rpc_base def create(tflite_model_bytes, ctx, runtime_target='cpu'): @@ -44,7 +44,7 @@ def create(tflite_model_bytes, ctx, runtime_target='cpu'): if device_type >= rpc_base.RPC_SESS_MASK: fcreate = ctx._rpc_sess.get_function(runtime_func) else: - fcreate = get_global_func(runtime_func) + fcreate = tvm._ffi.get_global_func(runtime_func) return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) diff --git a/python/tvm/datatype.py b/python/tvm/datatype.py index df3e3a62a510..809e43516adc 100644 --- a/python/tvm/datatype.py +++ b/python/tvm/datatype.py @@ -15,9 +15,8 @@ # specific language governing permissions and limitations # under the License. """Custom datatype functionality""" -from __future__ import absolute_import as _abs +import tvm._ffi -from ._ffi.function import register_func as _register_func from . import make as _make from .api import convert from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm @@ -111,7 +110,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None): else: lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ + type_name - _register_func(lower_func_name, lower_func) + tvm._ffi.register_func(lower_func_name, lower_func) def create_lower_func(extern_func_name): diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 20d9d89cd2ac..46a7eac40a65 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -32,7 +32,10 @@ """ # pylint: disable=missing-docstring from __future__ import absolute_import as _abs -from ._ffi.object import Object, register_object, ObjectGeneric +import tvm._ffi + +from ._ffi.object import Object +from ._ffi.object_generic import ObjectGeneric from ._ffi.runtime_ctypes import TVMType, TypeCode from . import make as _make from . import generic as _generic @@ -261,7 +264,7 @@ class CmpExpr(PrimExpr): class LogicalExpr(PrimExpr): pass -@register_object("Variable") +@tvm._ffi.register_object("Variable") class Var(PrimExpr): """Symbolic variable. @@ -278,7 +281,7 @@ def __init__(self, name, dtype): _api_internal._Var, name, dtype) -@register_object +@tvm._ffi.register_object class SizeVar(Var): """Symbolic variable to represent a tensor index size which is greater or equal to zero @@ -297,7 +300,7 @@ def __init__(self, name, dtype): _api_internal._SizeVar, name, dtype) -@register_object +@tvm._ffi.register_object class Reduce(PrimExpr): """Reduce node. @@ -324,7 +327,7 @@ def __init__(self, combiner, src, rdom, condition, value_index): condition, value_index) -@register_object +@tvm._ffi.register_object class FloatImm(ConstExpr): """Float constant. @@ -340,7 +343,7 @@ def __init__(self, dtype, value): self.__init_handle_by_constructor__( _make.FloatImm, dtype, value) -@register_object +@tvm._ffi.register_object class IntImm(ConstExpr): """Int constant. @@ -360,7 +363,7 @@ def __int__(self): return self.value -@register_object +@tvm._ffi.register_object class StringImm(ConstExpr): """String constant. @@ -384,7 +387,7 @@ def __ne__(self, other): return self.value != other -@register_object +@tvm._ffi.register_object class Cast(PrimExpr): """Cast expression. @@ -401,7 +404,7 @@ def __init__(self, dtype, value): _make.Cast, dtype, value) -@register_object +@tvm._ffi.register_object class Add(BinaryOpExpr): """Add node. @@ -418,7 +421,7 @@ def __init__(self, a, b): _make.Add, a, b) -@register_object +@tvm._ffi.register_object class Sub(BinaryOpExpr): """Sub node. @@ -435,7 +438,7 @@ def __init__(self, a, b): _make.Sub, a, b) -@register_object +@tvm._ffi.register_object class Mul(BinaryOpExpr): """Mul node. @@ -452,7 +455,7 @@ def __init__(self, a, b): _make.Mul, a, b) -@register_object +@tvm._ffi.register_object class Div(BinaryOpExpr): """Div node. @@ -469,7 +472,7 @@ def __init__(self, a, b): _make.Div, a, b) -@register_object +@tvm._ffi.register_object class Mod(BinaryOpExpr): """Mod node. @@ -486,7 +489,7 @@ def __init__(self, a, b): _make.Mod, a, b) -@register_object +@tvm._ffi.register_object class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -503,7 +506,7 @@ def __init__(self, a, b): _make.FloorDiv, a, b) -@register_object +@tvm._ffi.register_object class FloorMod(BinaryOpExpr): """FloorMod node. @@ -520,7 +523,7 @@ def __init__(self, a, b): _make.FloorMod, a, b) -@register_object +@tvm._ffi.register_object class Min(BinaryOpExpr): """Min node. @@ -537,7 +540,7 @@ def __init__(self, a, b): _make.Min, a, b) -@register_object +@tvm._ffi.register_object class Max(BinaryOpExpr): """Max node. @@ -554,7 +557,7 @@ def __init__(self, a, b): _make.Max, a, b) -@register_object +@tvm._ffi.register_object class EQ(CmpExpr): """EQ node. @@ -571,7 +574,7 @@ def __init__(self, a, b): _make.EQ, a, b) -@register_object +@tvm._ffi.register_object class NE(CmpExpr): """NE node. @@ -588,7 +591,7 @@ def __init__(self, a, b): _make.NE, a, b) -@register_object +@tvm._ffi.register_object class LT(CmpExpr): """LT node. @@ -605,7 +608,7 @@ def __init__(self, a, b): _make.LT, a, b) -@register_object +@tvm._ffi.register_object class LE(CmpExpr): """LE node. @@ -622,7 +625,7 @@ def __init__(self, a, b): _make.LE, a, b) -@register_object +@tvm._ffi.register_object class GT(CmpExpr): """GT node. @@ -639,7 +642,7 @@ def __init__(self, a, b): _make.GT, a, b) -@register_object +@tvm._ffi.register_object class GE(CmpExpr): """GE node. @@ -656,7 +659,7 @@ def __init__(self, a, b): _make.GE, a, b) -@register_object +@tvm._ffi.register_object class And(LogicalExpr): """And node. @@ -673,7 +676,7 @@ def __init__(self, a, b): _make.And, a, b) -@register_object +@tvm._ffi.register_object class Or(LogicalExpr): """Or node. @@ -690,7 +693,7 @@ def __init__(self, a, b): _make.Or, a, b) -@register_object +@tvm._ffi.register_object class Not(LogicalExpr): """Not node. @@ -704,7 +707,7 @@ def __init__(self, a): _make.Not, a) -@register_object +@tvm._ffi.register_object class Select(PrimExpr): """Select node. @@ -732,7 +735,7 @@ def __init__(self, condition, true_value, false_value): _make.Select, condition, true_value, false_value) -@register_object +@tvm._ffi.register_object class Load(PrimExpr): """Load node. @@ -755,7 +758,7 @@ def __init__(self, dtype, buffer_var, index, predicate): _make.Load, dtype, buffer_var, index, predicate) -@register_object +@tvm._ffi.register_object class Ramp(PrimExpr): """Ramp node. @@ -775,7 +778,7 @@ def __init__(self, base, stride, lanes): _make.Ramp, base, stride, lanes) -@register_object +@tvm._ffi.register_object class Broadcast(PrimExpr): """Broadcast node. @@ -792,7 +795,7 @@ def __init__(self, value, lanes): _make.Broadcast, value, lanes) -@register_object +@tvm._ffi.register_object class Shuffle(PrimExpr): """Shuffle node. @@ -809,7 +812,7 @@ def __init__(self, vectors, indices): _make.Shuffle, vectors, indices) -@register_object +@tvm._ffi.register_object class Call(PrimExpr): """Call node. @@ -844,7 +847,7 @@ def __init__(self, dtype, name, args, call_type, func, value_index): _make.Call, dtype, name, args, call_type, func, value_index) -@register_object +@tvm._ffi.register_object class Let(PrimExpr): """Let node. diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/hybrid/__init__.py index 11ecbc8f7b60..55c33e5e317f 100644 --- a/python/tvm/hybrid/__init__.py +++ b/python/tvm/hybrid/__init__.py @@ -28,13 +28,10 @@ # TODO(@were): Make this module more complete. # 1. Support HalideIR dumping to Hybrid Script # 2. Support multi-level HalideIR - -from __future__ import absolute_import as _abs - import inspect +import tvm._ffi from .._ffi.base import decorate -from .._ffi.function import _init_api from ..build_module import form_body from .module import HybridModule @@ -97,4 +94,4 @@ def build(sch, inputs, outputs, name="hybrid_func"): return HybridModule(src, name) -_init_api("tvm.hybrid") +tvm._ffi._init_api("tvm.hybrid") diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index fd7131e5c92f..6146a7189318 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -16,9 +16,9 @@ # under the License. """Expression Intrinsics and math functions in TVM.""" # pylint: disable=redefined-builtin -from __future__ import absolute_import as _abs +import tvm._ffi +import tvm.codegen -from ._ffi.function import register_func as _register_func from . import make as _make from .api import convert, const from .expr import Call as _Call @@ -189,7 +189,6 @@ def call_llvm_intrin(dtype, name, *args): call : Expr The call expression. """ - import tvm llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) @@ -596,7 +595,7 @@ def register_intrin_rule(target, intrin, f=None, override=False): register_intrin_rule("opencl", "exp", my_exp_rule, override=True) """ - return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override) + return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override) def _rule_float_suffix(op): @@ -650,7 +649,7 @@ def _rule_float_direct(op): return call_pure_extern(op.dtype, op.name, *op.args) return None -@_register_func("tvm.default_trace_action") +@tvm._ffi.register_func("tvm.default_trace_action") def _tvm_default_trace_action(*args): print(list(args)) diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index ede17a154285..8bd58923623b 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -24,7 +24,7 @@ from . import ir_pass as _pass from . import container as _container from ._ffi.base import string_types -from ._ffi.object import ObjectGeneric +from ._ffi.object_generic import ObjectGeneric from ._ffi.runtime_ctypes import TVMType from .expr import Call as _Call diff --git a/python/tvm/ir_pass.py b/python/tvm/ir_pass.py index 59354e2eb890..9d7f340310e7 100644 --- a/python/tvm/ir_pass.py +++ b/python/tvm/ir_pass.py @@ -23,6 +23,6 @@ You can read "include/tvm/tir/ir_pass.h" for the function signature and "src/api/api_pass.cc" for the PackedFunc's body of these functions. """ -from ._ffi.function import _init_api +import tvm._ffi -_init_api("tvm.ir_pass") +tvm._ffi._init_api("tvm.ir_pass") diff --git a/python/tvm/make.py b/python/tvm/make.py index 241edd6b0948..7f94d1031d9a 100644 --- a/python/tvm/make.py +++ b/python/tvm/make.py @@ -22,8 +22,7 @@ Each api is a PackedFunc that can be called in a positional argument manner. You can use make function to build the IR node. """ -from __future__ import absolute_import as _abs -from ._ffi.function import _init_api +import tvm._ffi def range_by_min_extent(min_value, extent): @@ -85,4 +84,4 @@ def node(type_key, **kwargs): return _Node(*args) -_init_api("tvm.make") +tvm._ffi._init_api("tvm.make") diff --git a/python/tvm/micro/base.py b/python/tvm/micro/base.py index e2e1329cb36c..a46d1bb99619 100644 --- a/python/tvm/micro/base.py +++ b/python/tvm/micro/base.py @@ -23,9 +23,11 @@ from enum import Enum import tvm +import tvm._ffi + from tvm.contrib import util as _util from tvm.contrib import cc as _cc -from .._ffi.function import _init_api + class LibType(Enum): """Enumeration of library types that can be compiled and loaded onto a device""" @@ -222,4 +224,4 @@ def get_micro_device_dir(): return micro_device_dir -_init_api("tvm.micro", "tvm.micro.base") +tvm._ffi._init_api("tvm.micro", "tvm.micro.base") diff --git a/python/tvm/module.py b/python/tvm/module.py index 9e6d8b14b7e7..9b98a0ff3cae 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -21,9 +21,9 @@ import logging import os from collections import namedtuple +import tvm._ffi -from ._ffi.function import ModuleBase, _set_class_module -from ._ffi.function import _init_api +from ._ffi.module import ModuleBase, _set_class_module from ._ffi.libinfo import find_include_path from .contrib import cc as _cc, tar as _tar, util as _util @@ -361,5 +361,5 @@ def enabled(target): return _Enabled(target) -_init_api("tvm.module") +tvm._ffi._init_api("tvm.module") _set_class_module(Module) diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index b19db6627ac6..096227e382fd 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -20,18 +20,16 @@ the correctness of the program. """ # pylint: disable=invalid-name,unused-import -from __future__ import absolute_import as _abs +import tvm._ffi import numpy as _np from ._ffi.function import register_func from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import _set_class_ndarray -from ._ffi.ndarray import register_extension -from ._ffi.object import register_object -@register_object +@tvm._ffi.register_object class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. diff --git a/python/tvm/object.py b/python/tvm/object.py deleted file mode 100644 index 9659d3c89067..000000000000 --- a/python/tvm/object.py +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Node is the base class of all TVM AST. - -Normally user do not need to touch this api. -""" -# pylint: disable=unused-import -from __future__ import absolute_import as _abs -from ._ffi.object import Object, register_object diff --git a/python/tvm/relay/_analysis.py b/python/tvm/relay/_analysis.py index 32a7324ae29f..050fcce2fb17 100644 --- a/python/tvm/relay/_analysis.py +++ b/python/tvm/relay/_analysis.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI exposing the passes for Relay program analysis.""" +import tvm._ffi -from tvm._ffi.function import _init_api - -_init_api("relay._analysis", __name__) +tvm._ffi._init_api("relay._analysis", __name__) diff --git a/python/tvm/relay/_base.py b/python/tvm/relay/_base.py index d7ecaa84b45c..f86aa70353dc 100644 --- a/python/tvm/relay/_base.py +++ b/python/tvm/relay/_base.py @@ -16,6 +16,6 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable """The interface of expr function exposed from C++.""" -from tvm._ffi.function import _init_api +import tvm._ffi -_init_api("relay._base", __name__) +tvm._ffi._init_api("relay._base", __name__) diff --git a/python/tvm/relay/_build_module.py b/python/tvm/relay/_build_module.py index bdbcbefff523..9ee92e0035fa 100644 --- a/python/tvm/relay/_build_module.py +++ b/python/tvm/relay/_build_module.py @@ -16,6 +16,6 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable """The interface for building Relay functions exposed from C++.""" -from tvm._ffi.function import _init_api +import tvm._ffi -_init_api("relay.build_module", __name__) +tvm._ffi._init_api("relay.build_module", __name__) diff --git a/python/tvm/relay/_expr.py b/python/tvm/relay/_expr.py index 07ef7e0588d4..70c13ce4eaa8 100644 --- a/python/tvm/relay/_expr.py +++ b/python/tvm/relay/_expr.py @@ -16,6 +16,6 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable """The interface of expr function exposed from C++.""" -from tvm._ffi.function import _init_api +import tvm._ffi -_init_api("relay._expr", __name__) +tvm._ffi._init_api("relay._expr", __name__) diff --git a/python/tvm/relay/_make.py b/python/tvm/relay/_make.py index 6081b2664ca8..351f7c6575ce 100644 --- a/python/tvm/relay/_make.py +++ b/python/tvm/relay/_make.py @@ -20,6 +20,6 @@ This module includes MyPy type signatures for all of the exposed modules. """ -from .._ffi.function import _init_api +import tvm._ffi -_init_api("relay._make", __name__) +tvm._ffi._init_api("relay._make", __name__) diff --git a/python/tvm/relay/_module.py b/python/tvm/relay/_module.py index 365c82736eec..aedb74a05486 100644 --- a/python/tvm/relay/_module.py +++ b/python/tvm/relay/_module.py @@ -16,6 +16,6 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable """The interface to the Module exposed from C++.""" -from tvm._ffi.function import _init_api +import tvm._ffi -_init_api("relay._module", __name__) +tvm._ffi._init_api("relay._module", __name__) diff --git a/python/tvm/relay/_transform.py b/python/tvm/relay/_transform.py index 273d97e0962a..a4168dfb5c0c 100644 --- a/python/tvm/relay/_transform.py +++ b/python/tvm/relay/_transform.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI exposing the Relay type inference and checking.""" +import tvm._ffi -from tvm._ffi.function import _init_api - -_init_api("relay._transform", __name__) +tvm._ffi._init_api("relay._transform", __name__) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 860788a4e5d0..1db70c3e2dc2 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -15,14 +15,13 @@ # specific language governing permissions and limitations # under the License. """The interface of expr function exposed from C++.""" -from __future__ import absolute_import +import tvm._ffi from ... import build_module as _build from ... import container as _container -from ..._ffi.function import _init_api, register_func -@register_func("relay.backend.lower") +@tvm._ffi.register_func("relay.backend.lower") def lower(sch, inputs, func_name, source_func): """Backend function for lowering. @@ -61,7 +60,7 @@ def lower(sch, inputs, func_name, source_func): f, (_container.Array, tuple, list)) else [f] -@register_func("relay.backend.build") +@tvm._ffi.register_func("relay.backend.build") def build(funcs, target, target_host=None): """Backend build function. @@ -88,14 +87,14 @@ def build(funcs, target, target_host=None): return _build.build(funcs, target=target, target_host=target_host) -@register_func("relay._tensor_value_repr") +@tvm._ffi.register_func("relay._tensor_value_repr") def _tensor_value_repr(tvalue): return str(tvalue.data.asnumpy()) -@register_func("relay._constant_repr") +@tvm._ffi.register_func("relay._constant_repr") def _tensor_constant_repr(tvalue): return str(tvalue.data.asnumpy()) -_init_api("relay.backend", __name__) +tvm._ffi._init_api("relay.backend", __name__) diff --git a/python/tvm/relay/backend/_vm.py b/python/tvm/relay/backend/_vm.py index e88f02a5a7c8..cffbbdccde5a 100644 --- a/python/tvm/relay/backend/_vm.py +++ b/python/tvm/relay/backend/_vm.py @@ -16,6 +16,6 @@ # under the License. """The Relay virtual machine FFI namespace. """ -from tvm._ffi.function import _init_api +import tvm._ffi -_init_api("relay._vm", __name__) +tvm._ffi._init_api("relay._vm", __name__) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index f1cdefc3ed3a..606e0ccf5f3b 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -25,7 +25,7 @@ import tvm import tvm.ndarray as _nd from tvm import autotvm, container -from tvm.object import Object +from tvm._ffi.object import Object from tvm.relay import expr as _expr from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi import base as _base diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index d389803bfeea..a723eda3a9db 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck """The base node types for the Relay language.""" -from __future__ import absolute_import as _abs -from .._ffi.object import register_object as _register_tvm_node +import tvm._ffi + from .._ffi.object import Object from . import _make from . import _expr @@ -34,9 +34,9 @@ def register_relay_node(type_key=None): The type key of the node. """ if not isinstance(type_key, str): - return _register_tvm_node( + return tvm._ffi.register_object( "relay." + type_key.__name__)(type_key) - return _register_tvm_node(type_key) + return tvm._ffi.register_object(type_key) def register_relay_attr_node(type_key=None): @@ -48,9 +48,9 @@ def register_relay_attr_node(type_key=None): The type key of the node. """ if not isinstance(type_key, str): - return _register_tvm_node( + return tvm._ffi.register_object( "relay.attrs." + type_key.__name__)(type_key) - return _register_tvm_node(type_key) + return tvm._ffi.register_object(type_key) class RelayNode(Object): diff --git a/python/tvm/relay/op/_make.py b/python/tvm/relay/op/_make.py index d51fee717804..85c2368fad4a 100644 --- a/python/tvm/relay/op/_make.py +++ b/python/tvm/relay/op/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ..._ffi.function import _init_api +import tvm._ffi -_init_api("relay.op._make", __name__) +tvm._ffi._init_api("relay.op._make", __name__) diff --git a/python/tvm/relay/op/annotation/_make.py b/python/tvm/relay/op/annotation/_make.py index ae909eb8af3c..12ece522c854 100644 --- a/python/tvm/relay/op/annotation/_make.py +++ b/python/tvm/relay/op/annotation/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ...._ffi.function import _init_api +import tvm._ffi -_init_api("relay.op.annotation._make", __name__) +tvm._ffi._init_api("relay.op.annotation._make", __name__) diff --git a/python/tvm/relay/op/contrib/_make.py b/python/tvm/relay/op/contrib/_make.py index 42d71755abb8..9d3369ebe7b2 100644 --- a/python/tvm/relay/op/contrib/_make.py +++ b/python/tvm/relay/op/contrib/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ...._ffi.function import _init_api +import tvm._ffi -_init_api("relay.op.contrib._make", __name__) +tvm._ffi._init_api("relay.op.contrib._make", __name__) diff --git a/python/tvm/relay/op/image/_make.py b/python/tvm/relay/op/image/_make.py index 747684b63ed4..1d5e02848a46 100644 --- a/python/tvm/relay/op/image/_make.py +++ b/python/tvm/relay/op/image/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ...._ffi.function import _init_api +import tvm._ffi -_init_api("relay.op.image._make", __name__) +tvm._ffi._init_api("relay.op.image._make", __name__) diff --git a/python/tvm/relay/op/memory/_make.py b/python/tvm/relay/op/memory/_make.py index cdf2dcc2cd0b..52a3777a3785 100644 --- a/python/tvm/relay/op/memory/_make.py +++ b/python/tvm/relay/op/memory/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ...._ffi.function import _init_api +import tvm._ffi -_init_api("relay.op.memory._make", __name__) +tvm._ffi._init_api("relay.op.memory._make", __name__) diff --git a/python/tvm/relay/op/nn/_make.py b/python/tvm/relay/op/nn/_make.py index 72496859d918..15ae43b35cb0 100644 --- a/python/tvm/relay/op/nn/_make.py +++ b/python/tvm/relay/op/nn/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ...._ffi.function import _init_api +import tvm._ffi -_init_api("relay.op.nn._make", __name__) +tvm._ffi._init_api("relay.op.nn._make", __name__) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 382f667b86a9..f9bc853282bb 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -17,8 +17,7 @@ #pylint: disable=unused-argument """The base node types for the Relay language.""" import topi - -from ..._ffi.function import _init_api +import tvm._ffi from ..base import register_relay_node from ..expr import Expr @@ -283,8 +282,6 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): get(op_name).set_attr("TShapeDataDependant", data_dependant, level) return register(op_name, "FShapeFunc", shape_func, level) -_init_api("relay.op", __name__) - @register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) @@ -320,3 +317,5 @@ def debug(expr, debug_func=None): name = '' return _make.debug(expr, name) + +tvm._ffi._init_api("relay.op", __name__) diff --git a/python/tvm/relay/op/vision/_make.py b/python/tvm/relay/op/vision/_make.py index f0e31709194d..eddca15c19b5 100644 --- a/python/tvm/relay/op/vision/_make.py +++ b/python/tvm/relay/op/vision/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ...._ffi.function import _init_api +import tvm._ffi -_init_api("relay.op.vision._make", __name__) +tvm._ffi._init_api("relay.op.vision._make", __name__) diff --git a/python/tvm/relay/qnn/op/_make.py b/python/tvm/relay/qnn/op/_make.py index 07b3dd154760..4472bc77c6cd 100644 --- a/python/tvm/relay/qnn/op/_make.py +++ b/python/tvm/relay/qnn/op/_make.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -from ...._ffi.function import _init_api +import tvm._ffi -_init_api("relay.qnn.op._make", __name__) +tvm._ffi._init_api("relay.qnn.op._make", __name__) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index ab98f3c369ab..ba100d8d03e4 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -16,11 +16,10 @@ # under the License. #pylint: disable=unused-argument,inconsistent-return-statements """Internal module for registering attribute for annotation.""" -from __future__ import absolute_import import warnings - import topi -from ..._ffi.function import register_func +import tvm._ffi + from .. import expr as _expr from .. import analysis as _analysis from .. import op as _op @@ -144,7 +143,8 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): qctx.qnode_map[key] = qnode return qnode -register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) +tvm._ffi.register_func( + "relay.quantize.attach_simulated_quantize", attach_simulated_quantize) @register_annotate_function("nn.contrib_conv2d_NCHWc") diff --git a/python/tvm/relay/quantize/_quantize.py b/python/tvm/relay/quantize/_quantize.py index 6f5c75f7b418..7b27b7a55208 100644 --- a/python/tvm/relay/quantize/_quantize.py +++ b/python/tvm/relay/quantize/_quantize.py @@ -16,7 +16,6 @@ # under the License. #pylint: disable=unused-argument """Internal module for quantization.""" -from __future__ import absolute_import -from tvm._ffi.function import _init_api +import tvm._ffi -_init_api("relay._quantize", __name__) +tvm._ffi._init_api("relay._quantize", __name__) diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index a1e837cd0c1f..bc81534a12d9 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -26,8 +26,8 @@ import struct import random import logging +import tvm._ffi -from .._ffi.function import _init_api from .._ffi.base import py_str # Magic header for RPC data plane @@ -179,4 +179,4 @@ def connect_with_retry(addr, timeout=60, retry_period=5): # Still use tvm.rpc for the foreign functions -_init_api("tvm.rpc", "tvm.rpc.base") +tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base") diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 9c0dea5b0863..314f1ab6e8f9 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -21,11 +21,11 @@ import socket import struct import time +import tvm._ffi from . import base from ..contrib import util from .._ffi.base import TVMError -from .._ffi import function from .._ffi import ndarray as nd from ..module import load as _load_module @@ -185,7 +185,7 @@ class LocalSession(RPCSession): def __init__(self): # pylint: disable=super-init-not-called self.context = nd.context - self.get_function = function.get_global_func + self.get_function = tvm._ffi.get_global_func self._temp = util.tempdir() def upload(self, data, target=None): diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 9e03097e89a7..efebe8b395ed 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -25,9 +25,6 @@ - {server|client}:device-type[:random-key] [-timeout=timeout] """ # pylint: disable=invalid-name - -from __future__ import absolute_import - import os import ctypes import socket @@ -39,8 +36,8 @@ import time import sys import signal +import tvm._ffi -from .._ffi.function import register_func from .._ffi.base import py_str from .._ffi.libinfo import find_lib_path from ..module import load as _load_module @@ -58,11 +55,11 @@ def _server_env(load_library, work_path=None): temp = util.tempdir() # pylint: disable=unused-variable - @register_func("tvm.rpc.server.workpath") + @tvm._ffi.register_func("tvm.rpc.server.workpath") def get_workpath(path): return temp.relpath(path) - @register_func("tvm.rpc.server.load_module", override=True) + @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True) def load_module(file_name): """Load module from remote side.""" path = temp.relpath(file_name) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index c8fcd7cbd52d..bf4e75f14966 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -15,38 +15,19 @@ # specific language governing permissions and limitations # under the License. """The computation schedule api of TVM.""" -from __future__ import absolute_import as _abs +import tvm._ffi + from ._ffi.base import string_types -from ._ffi.object import Object, register_object -from ._ffi.object import convert_to_object as _convert_to_object -from ._ffi.function import _init_api, Function -from ._ffi.function import convert_to_tvm_func as _convert_tvm_func +from ._ffi.object import Object +from ._ffi.object_generic import convert + from . import _api_internal from . import tensor as _tensor from . import expr as _expr from . import container as _container -def convert(value): - """Convert value to TVM object or function. - - Parameters - ---------- - value : python value - - Returns - ------- - tvm_val : Object or Function - Converted value in TVM - """ - if isinstance(value, (Function, Object)): - return value - - if callable(value): - return _convert_tvm_func(value) - - return _convert_to_object(value) -@register_object +@tvm._ffi.register_object class Buffer(Object): """Symbolic data buffer in TVM. @@ -156,22 +137,22 @@ def vstore(self, begin, value): return _api_internal._BufferVStore(self, begin, value) -@register_object +@tvm._ffi.register_object class Split(Object): """Split operation on axis.""" -@register_object +@tvm._ffi.register_object class Fuse(Object): """Fuse operation on axis.""" -@register_object +@tvm._ffi.register_object class Singleton(Object): """Singleton axis.""" -@register_object +@tvm._ffi.register_object class IterVar(Object, _expr.ExprOp): """Represent iteration variable. @@ -214,7 +195,7 @@ def create_schedule(ops): return _api_internal._CreateSchedule(ops) -@register_object +@tvm._ffi.register_object class Schedule(Object): """Schedule for all the stages.""" def __getitem__(self, k): @@ -348,7 +329,7 @@ def rfactor(self, tensor, axis, factor_axis=0): return factored[0] if len(factored) == 1 else factored -@register_object +@tvm._ffi.register_object class Stage(Object): """A Stage represents schedule for one operation.""" def split(self, parent, factor=None, nparts=None): @@ -670,4 +651,4 @@ def opengl(self): """ _api_internal._StageOpenGL(self) -_init_api("tvm.schedule") +tvm._ffi._init_api("tvm.schedule") diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index 6b87fcb1b885..59089347edb3 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -29,15 +29,15 @@ assert isinstance(st, tvm.stmt.Store) assert(st.buffer_var == a) """ -from __future__ import absolute_import as _abs -from ._ffi.object import Object, register_object +import tvm._ffi +from ._ffi.object import Object from . import make as _make class Stmt(Object): pass -@register_object +@tvm._ffi.register_object class LetStmt(Stmt): """LetStmt node. @@ -57,7 +57,7 @@ def __init__(self, var, value, body): _make.LetStmt, var, value, body) -@register_object +@tvm._ffi.register_object class AssertStmt(Stmt): """AssertStmt node. @@ -77,7 +77,7 @@ def __init__(self, condition, message, body): _make.AssertStmt, condition, message, body) -@register_object +@tvm._ffi.register_object class ProducerConsumer(Stmt): """ProducerConsumer node. @@ -97,7 +97,7 @@ def __init__(self, func, is_producer, body): _make.ProducerConsumer, func, is_producer, body) -@register_object +@tvm._ffi.register_object class For(Stmt): """For node. @@ -137,7 +137,7 @@ def __init__(self, for_type, device_api, body) -@register_object +@tvm._ffi.register_object class Store(Stmt): """Store node. @@ -160,7 +160,7 @@ def __init__(self, buffer_var, value, index, predicate): _make.Store, buffer_var, value, index, predicate) -@register_object +@tvm._ffi.register_object class Provide(Stmt): """Provide node. @@ -183,7 +183,7 @@ def __init__(self, func, value_index, value, args): _make.Provide, func, value_index, value, args) -@register_object +@tvm._ffi.register_object class Allocate(Stmt): """Allocate node. @@ -215,7 +215,7 @@ def __init__(self, extents, condition, body) -@register_object +@tvm._ffi.register_object class AttrStmt(Stmt): """AttrStmt node. @@ -238,7 +238,7 @@ def __init__(self, node, attr_key, value, body): _make.AttrStmt, node, attr_key, value, body) -@register_object +@tvm._ffi.register_object class Free(Stmt): """Free node. @@ -252,7 +252,7 @@ def __init__(self, buffer_var): _make.Free, buffer_var) -@register_object +@tvm._ffi.register_object class Realize(Stmt): """Realize node. @@ -288,7 +288,7 @@ def __init__(self, bounds, condition, body) -@register_object +@tvm._ffi.register_object class SeqStmt(Stmt): """Sequence of statements. @@ -308,7 +308,7 @@ def __len__(self): return len(self.seq) -@register_object +@tvm._ffi.register_object class IfThenElse(Stmt): """IfThenElse node. @@ -328,7 +328,7 @@ def __init__(self, condition, then_case, else_case): _make.IfThenElse, condition, then_case, else_case) -@register_object +@tvm._ffi.register_object class Evaluate(Stmt): """Evaluate node. @@ -342,7 +342,7 @@ def __init__(self, value): _make.Evaluate, value) -@register_object +@tvm._ffi.register_object class Prefetch(Stmt): """Prefetch node. diff --git a/python/tvm/target.py b/python/tvm/target.py index c2d37529040b..45dbf5fcdfca 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -54,12 +54,11 @@ We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string. We can also use other specific function in this module to create specific targets. """ -from __future__ import absolute_import - import warnings +import tvm._ffi from ._ffi.base import _LIB_NAME -from ._ffi.object import Object, register_object +from ._ffi.object import Object from . import _api_internal try: @@ -80,7 +79,7 @@ def _merge_opts(opts, new_opts): return opts -@register_object +@tvm._ffi.register_object class Target(Object): """Target device information, use through TVM API. @@ -146,7 +145,7 @@ def __exit__(self, ptype, value, trace): _api_internal._ExitTargetScope(self) -@register_object +@tvm._ffi.register_object class GenericFunc(Object): """GenericFunc node reference. This represents a generic function that may be specialized for different targets. When this object is diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index e4c36c11120b..522e901c3208 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -16,9 +16,11 @@ # under the License. """Tensor and Operation class for computation declaration.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs -from ._ffi.object import Object, register_object, ObjectGeneric, \ - convert_to_object +import tvm._ffi + +from ._ffi.object import Object +from ._ffi.object_generic import ObjectGeneric, convert_to_object + from . import _api_internal from . import make as _make from . import expr as _expr @@ -47,7 +49,7 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype -@register_object +@tvm._ffi.register_object class TensorIntrinCall(Object): """Intermediate structure for calling a tensor intrinsic.""" @@ -55,7 +57,7 @@ class TensorIntrinCall(Object): itervar_cls = None -@register_object +@tvm._ffi.register_object class Tensor(Object, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" @@ -157,12 +159,12 @@ def input_tensors(self): return _api_internal._OpInputTensors(self) -@register_object +@tvm._ffi.register_object class PlaceholderOp(Operation): """Placeholder operation.""" -@register_object +@tvm._ffi.register_object class BaseComputeOp(Operation): """Compute operation.""" @property @@ -176,18 +178,18 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") -@register_object +@tvm._ffi.register_object class ComputeOp(BaseComputeOp): """Scalar operation.""" pass -@register_object +@tvm._ffi.register_object class TensorComputeOp(BaseComputeOp): """Tensor operation.""" -@register_object +@tvm._ffi.register_object class ScanOp(Operation): """Scan operation.""" @property @@ -196,12 +198,12 @@ def scan_axis(self): return self.__getattr__("scan_axis") -@register_object +@tvm._ffi.register_object class ExternOp(Operation): """External operation.""" -@register_object +@tvm._ffi.register_object class HybridOp(Operation): """Hybrid operation.""" @property @@ -210,7 +212,7 @@ def axis(self): return self.__getattr__("axis") -@register_object +@tvm._ffi.register_object class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and @@ -270,7 +272,7 @@ def factor_of(self, axis): return _api_internal._LayoutFactorOf(self, axis) -@register_object +@tvm._ffi.register_object class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 4665ccfd6204..0b88af1c92f3 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. """Tensor intrinsics""" -from __future__ import absolute_import as _abs +import tvm._ffi + from . import _api_internal from . import api as _api from . import expr as _expr @@ -24,7 +25,7 @@ from . import tensor as _tensor from . import schedule as _schedule from .build_module import current_build_config -from ._ffi.object import Object, register_object +from ._ffi.object import Object def _get_region(tslice): @@ -41,7 +42,7 @@ def _get_region(tslice): region.append(_make.range_by_min_extent(begin, 1)) return region -@register_object +@tvm._ffi.register_object class TensorIntrin(Object): """Tensor intrinsic functions for certain computation. diff --git a/topi/python/topi/cpp/cuda.py b/topi/python/topi/cpp/cuda.py index 920b2717437d..efc31e82e519 100644 --- a/topi/python/topi/cpp/cuda.py +++ b/topi/python/topi/cpp/cuda.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for CUDA TOPI ops and schedules""" +import tvm._ffi -from tvm._ffi.function import _init_api_prefix - -_init_api_prefix("topi.cpp.cuda", "topi.cuda") +tvm._ffi._init_api("topi.cuda", "topi.cpp.cuda") diff --git a/topi/python/topi/cpp/generic.py b/topi/python/topi/cpp/generic.py index a8a71656c1aa..e6bf250cb85c 100644 --- a/topi/python/topi/cpp/generic.py +++ b/topi/python/topi/cpp/generic.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for generic TOPI ops and schedules""" +import tvm._ffi -from tvm._ffi.function import _init_api_prefix - -_init_api_prefix("topi.cpp.generic", "topi.generic") +tvm._ffi._init_api("topi.generic", "topi.cpp.generic") diff --git a/topi/python/topi/cpp/impl.py b/topi/python/topi/cpp/impl.py index 9ae407d4bd12..1081baa716b7 100644 --- a/topi/python/topi/cpp/impl.py +++ b/topi/python/topi/cpp/impl.py @@ -18,8 +18,8 @@ import sys import os import ctypes +import tvm._ffi -from tvm._ffi.function import _init_api_prefix from tvm._ffi import libinfo def _get_lib_names(): @@ -41,4 +41,4 @@ def _load_lib(): _LIB, _LIB_NAME = _load_lib() -_init_api_prefix("topi.cpp", "topi") +tvm._ffi._init_api("topi", "topi.cpp") diff --git a/topi/python/topi/cpp/nn.py b/topi/python/topi/cpp/nn.py index 59bf1477501d..d11aa27b2c84 100644 --- a/topi/python/topi/cpp/nn.py +++ b/topi/python/topi/cpp/nn.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for NN TOPI ops and schedules""" +import tvm._ffi -from tvm._ffi.function import _init_api_prefix - -_init_api_prefix("topi.cpp.nn", "topi.nn") +tvm._ffi._init_api("topi.nn", "topi.cpp.nn") diff --git a/topi/python/topi/cpp/rocm.py b/topi/python/topi/cpp/rocm.py index d57ce3e3cae1..c001a61d1ea5 100644 --- a/topi/python/topi/cpp/rocm.py +++ b/topi/python/topi/cpp/rocm.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for Rocm TOPI ops and schedules""" +import tvm._ffi -from tvm._ffi.function import _init_api_prefix - -_init_api_prefix("topi.cpp.rocm", "topi.rocm") +tvm._ffi._init_api("topi.rocm", "topi.cpp.rocm") diff --git a/topi/python/topi/cpp/util.py b/topi/python/topi/cpp/util.py index 90264bc89170..cc76dd9339c6 100644 --- a/topi/python/topi/cpp/util.py +++ b/topi/python/topi/cpp/util.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for TOPI utility functions""" +import tvm._ffi -from tvm._ffi.function import _init_api_prefix - -_init_api_prefix("topi.cpp.util", "topi.util") +tvm._ffi._init_api("topi.util", "topi.cpp.util") diff --git a/topi/python/topi/cpp/vision/__init__.py b/topi/python/topi/cpp/vision/__init__.py index bcdfc8c186b7..6034e271bc0e 100644 --- a/topi/python/topi/cpp/vision/__init__.py +++ b/topi/python/topi/cpp/vision/__init__.py @@ -16,9 +16,8 @@ # under the License. """FFI for vision TOPI ops and schedules""" - -from tvm._ffi.function import _init_api_prefix +import tvm._ffi from . import yolo -_init_api_prefix("topi.cpp.vision", "topi.vision") +tvm._ffi._init_api("topi.vision", "topi.cpp.vision") diff --git a/topi/python/topi/cpp/vision/yolo.py b/topi/python/topi/cpp/vision/yolo.py index 072ab29ff524..ff12498057d9 100644 --- a/topi/python/topi/cpp/vision/yolo.py +++ b/topi/python/topi/cpp/vision/yolo.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for Yolo TOPI ops and schedules""" +import tvm._ffi -from tvm._ffi.function import _init_api_prefix - -_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo") +tvm._ffi._init_api("topi.vision.yolo", "topi.cpp.vision.yolo") diff --git a/topi/python/topi/cpp/x86.py b/topi/python/topi/cpp/x86.py index a6db26e336bb..0681ffed2ff5 100644 --- a/topi/python/topi/cpp/x86.py +++ b/topi/python/topi/cpp/x86.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for x86 TOPI ops and schedules""" +import tvm._ffi -from tvm._ffi.function import _init_api_prefix - -_init_api_prefix("topi.cpp.x86", "topi.x86") +tvm._ffi._init_api("topi.x86", "topi.cpp.x86")