diff --git a/apps/ios_rpc/tests/ios_rpc_test.py b/apps/ios_rpc/tests/ios_rpc_test.py index 67b9cd22aeba..df850812e527 100644 --- a/apps/ios_rpc/tests/ios_rpc_test.py +++ b/apps/ios_rpc/tests/ios_rpc_test.py @@ -39,7 +39,7 @@ # override metal compiler to compile to iphone -@tvm.register_func("tvm_callback_metal_compile") +@tvm.register_global_func("tvm_callback_metal_compile") def compile_metal(src, target): return xcode.compile_metal(src, sdk=sdk) diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst index 6015c4351076..6a80418be798 100644 --- a/docs/arch/device_target_interactions.rst +++ b/docs/arch/device_target_interactions.rst @@ -169,7 +169,7 @@ then be registered with the following steps. enum value to a string representation. This string representation should match the name given to ``GlobalDef().def``. -#. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of +#. Add entries to the ``_DEVICE_TYPE_TO_NAME`` and ``_DEVICE_NAME_TO_TYPE`` dictionaries of :py:class:`tvm.runtime.Device` for the new enum value. diff --git a/docs/get_started/tutorials/quick_start.py b/docs/get_started/tutorials/quick_start.py index 753acbf0a475..8762564c02bd 100644 --- a/docs/get_started/tutorials/quick_start.py +++ b/docs/get_started/tutorials/quick_start.py @@ -164,9 +164,9 @@ def forward(self, x): # .. code-block:: Python # # # Convert PyTorch tensor to TVM Tensor -# x_tvm = tvm.runtime.from_dlpack(x_torch.to_dlpack()) +# x_tvm = tvm.runtime.from_dlpack(x_torch) # # Convert TVM Tensor to PyTorch tensor -# x_torch = torch.from_dlpack(x_tvm.to_dlpack()) +# x_torch = torch.from_dlpack(x_tvm) # # - TVM runtime works in non-python environments, so it works on settings such as mobile # diff --git a/ffi/docs/get_started/quick_start.md b/ffi/docs/get_started/quick_start.md index 7eb3b97727b1..c7cb007c7815 100644 --- a/ffi/docs/get_started/quick_start.md +++ b/ffi/docs/get_started/quick_start.md @@ -144,8 +144,9 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA); ### Working with PyTorch Atfer build, we will create library such as `build/add_one_cuda.so`, that can be loaded by -with api `tvm_ffi.load_module`. Then the function will become available as property of the loaded module. -The tensor arguments in the ffi functions automatically consumes torch.Tensor. The following code shows how +with api {py:func}`tvm_ffi.load_module` that returns a {py:class}`tvm_ffi.Module` +Then the function will become available as property of the loaded module. +The tensor arguments in the ffi functions automatically consumes `torch.Tensor`. The following code shows how to use the function in torch. ```python diff --git a/ffi/docs/guides/packaging.md b/ffi/docs/guides/packaging.md index 544a45e52d60..1ae9bc673010 100644 --- a/ffi/docs/guides/packaging.md +++ b/ffi/docs/guides/packaging.md @@ -204,7 +204,7 @@ _LIB = _load_lib() Effectively, it leverages the `tvm_ffi.load_module` call to load the library extension DLL shipped along with the package. The `_ffi_api.py` contains a function -call to `tvm_ffi._init_api` that registers all global functions prefixed +call to `tvm_ffi.init_ffi_api` that registers all global functions prefixed with `my_ffi_extension` into the module. ```python @@ -214,7 +214,7 @@ from .base import _LIB # Register all global functions prefixed with 'my_ffi_extension.' # This makes functions registered via TVM_FFI_STATIC_INIT_BLOCK available -tvm_ffi._init_api("my_ffi_extension", __name__) +tvm_ffi.init_ffi_api("my_ffi_extension", __name__) ``` Then we can redirect the calls to the related functions. diff --git a/ffi/docs/guides/python_guide.md b/ffi/docs/guides/python_guide.md index 5ac7f318be25..b993c3c756b8 100644 --- a/ffi/docs/guides/python_guide.md +++ b/ffi/docs/guides/python_guide.md @@ -47,7 +47,7 @@ y = np.empty_like(x) mod.add_one_cpu(x, y) ``` -In this case, `tvm_ffi.load_module` will return a `tvm_ffi.Module` class that contains +In this case, {py:func}`tvm_ffi.load_module` will return a {py:class}`tvm_ffi.Module` class that contains the exported functions. You can access the functions by their names. ## Tensor @@ -67,12 +67,12 @@ np_result = np.from_dlpack(tvm_array) In most cases, however, you do not have to explicitly create Tensors. The Python interface can take in `torch.Tensor` and `numpy.ndarray` objects -and automatically convert them to `tvm_ffi.Tensor`. +and automatically convert them to {py:class}`tvm_ffi.Tensor`. ## Functions and Callbacks -`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the C++. -You can retrieve globally registered functions via `tvm_ffi.get_global_func()`. +{py:class}`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the C++. +You can retrieve globally registered functions via {py:func}`tvm_ffi.get_global_func`. ```python import tvm_ffi @@ -84,8 +84,8 @@ assert fecho(1) == 1 ``` You can pass a Python function as an argument to another FFI function as callbacks. -Under the hood, `tvm_ffi.convert` is called to convert the Python function into a -`tvm_ffi.Function`. +Under the hood, {py:func}`tvm_ffi.convert` is called to convert the Python function into a +{py:class}`tvm_ffi.Function`. ```python import tvm_ffi @@ -103,7 +103,7 @@ You can also register a Python callback as a global function. ```python import tvm_ffi -@tvm_ffi.register_func("example.add_one") +@tvm_ffi.register_global_func("example.add_one") def add_one(a): return a + 1 @@ -112,7 +112,7 @@ assert tvm_ffi.get_global_func("example.add_one")(1) == 2 ## Container Types -When an FFI function takes arguments from lists/tuples, they will be converted into `tvm_ffi.Array`. +When an FFI function takes arguments from lists/tuples, they will be converted into {py:class}`tvm_ffi.Array`. ```python import tvm_ffi @@ -124,7 +124,7 @@ assert len(arr) == 4 assert arr[0] == 1 ``` -Dictionaries will be converted to `tvm_ffi.Map` +Dictionaries will be converted to {py:class}`tvm_ffi.Map` ```python import tvm_ffi @@ -167,7 +167,7 @@ File "src/ffi/extra/testing.cc", line 60, in void tvm::ffi::TestRaiseError(tvm:: throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0)); ``` -We register common error kinds. You can also register extra error dispatch via the `tvm_ffi.register_error` function. +We register common error kinds. You can also register extra error dispatch via the {py:func}`tvm_ffi.register_error` function. ## Advanced: Register Your Own Object @@ -239,5 +239,5 @@ assert test_int_pair.b == 2 Under the hood, we leverage the information registered through the reflection registry to generate efficient field accessors and methods for each class. -Importantly, when you have multiple inheritance, you need to call `tvm_ffi.register_object` +Importantly, when you have multiple inheritance, you need to call {py:func}`tvm_ffi.register_object` on both the base class and the child class. diff --git a/ffi/docs/index.rst b/ffi/docs/index.rst index c3f0b3ea5128..0739f8c2eebd 100644 --- a/ffi/docs/index.rst +++ b/ffi/docs/index.rst @@ -39,3 +39,9 @@ Apache TVM FFI Documentation :caption: Concepts concepts/abi_overview.md + +.. toctree:: + :maxdepth: 1 + :caption: Reference + + reference/python/index.rst diff --git a/ffi/docs/reference/python/index.rst b/ffi/docs/reference/python/index.rst new file mode 100644 index 000000000000..13008089f3a9 --- /dev/null +++ b/ffi/docs/reference/python/index.rst @@ -0,0 +1,69 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Python API +========== + +.. automodule:: tvm_ffi + :no-members: + +.. currentmodule:: tvm_ffi + +Object +------ +.. autosummary:: + :toctree: generated/ + + Object + register_object + + +Function and Module +------------------- +.. autosummary:: + :toctree: generated/ + + + Function + Module + register_global_func + get_global_func + system_lib + load_module + init_ffi_api + register_error + convert + + +Tensor +------ +.. autosummary:: + :toctree: generated/ + + Shape + Tensor + Device + from_dlpack + + +Containers +---------- +.. autosummary:: + :toctree: generated/ + + Array + Map diff --git a/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py b/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py index 79c269ab0ac3..616b1ee8e80c 100644 --- a/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py +++ b/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py @@ -21,4 +21,4 @@ # this is a short cut to register all the global functions # prefixed by `my_ffi_extension.` to this module -tvm_ffi._init_api("my_ffi_extension", __name__) +tvm_ffi.init_ffi_api("my_ffi_extension", __name__) diff --git a/ffi/python/tvm_ffi/__init__.py b/ffi/python/tvm_ffi/__init__.py index 807dc56a9181..b0ff88c6c8e1 100644 --- a/ffi/python/tvm_ffi/__init__.py +++ b/ffi/python/tvm_ffi/__init__.py @@ -20,17 +20,21 @@ from . import libinfo # package init part -from .registry import register_object, register_func, get_global_func, _init_api -from .dtype import dtype, DataTypeCode -from .core import String, Bytes -from .core import Object, ObjectGeneric, Function -from .convert import convert +from .registry import ( + register_object, + register_global_func, + get_global_func, + remove_global_func, + init_ffi_api, +) +from ._dtype import dtype +from .core import Object, ObjectConvertible, Function +from ._convert import convert from .error import register_error -from .tensor import Device, device -from .tensor import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu -from .tensor import from_dlpack, Tensor, Shape +from ._tensor import Device, device, DLDeviceType +from ._tensor import from_dlpack, Tensor, Shape from .container import Array, Map -from .module import Module, ModulePropertyMask, system_lib, load_module +from .module import Module, system_lib, load_module from . import serialization from . import access_path from . import testing @@ -38,32 +42,21 @@ __all__ = [ "dtype", - "DataTypeCode", "Device", "Object", "register_object", - "register_func", + "register_global_func", "get_global_func", - "_init_api", + "remove_global_func", + "init_ffi_api", "Object", - "ObjectGeneric", + "ObjectConvertible", "Function", "convert", - "String", - "Bytes", "register_error", "Device", "device", - "cpu", - "cuda", - "rocm", - "opencl", - "metal", - "vpi", - "vulkan", - "ext_dev", - "hexagon", - "webgpu", + "DLDeviceType", "from_dlpack", "Tensor", "Shape", @@ -73,7 +66,6 @@ "access_path", "serialization", "Module", - "ModulePropertyMask", "system_lib", "load_module", ] diff --git a/ffi/python/tvm_ffi/convert.py b/ffi/python/tvm_ffi/_convert.py similarity index 91% rename from ffi/python/tvm_ffi/convert.py rename to ffi/python/tvm_ffi/_convert.py index 94c82991101b..168dd15b531b 100644 --- a/ffi/python/tvm_ffi/convert.py +++ b/ffi/python/tvm_ffi/_convert.py @@ -33,6 +33,12 @@ def convert(value: Any) -> Any: ------- ffi_obj : Any The converted TVM FFI object. + + Note + ---- + Function arguments to ffi function calls are + automatically converted. So this function is mainly + only used in internal or testing scenarios. """ if isinstance(value, core.Object): return value @@ -48,7 +54,7 @@ def convert(value: Any) -> Any: return core.String(value) elif isinstance(value, (bytes, bytearray)): return core.Bytes(value) - elif isinstance(value, core.ObjectGeneric): + elif isinstance(value, core.ObjectConvertible): return value.asobject() elif callable(value): return core._convert_to_ffi_func(value) diff --git a/ffi/python/tvm_ffi/dtype.py b/ffi/python/tvm_ffi/_dtype.py similarity index 70% rename from ffi/python/tvm_ffi/dtype.py rename to ffi/python/tvm_ffi/_dtype.py index cd9561695503..30409e41d1cf 100644 --- a/ffi/python/tvm_ffi/dtype.py +++ b/ffi/python/tvm_ffi/_dtype.py @@ -22,7 +22,7 @@ class DataTypeCode(IntEnum): - """DataType code in DLTensor.""" + """DLDataTypeCode code in DLTensor.""" INT = 0 UINT = 1 @@ -57,7 +57,7 @@ class dtype(str): __slots__ = ["__tvm_ffi_dtype__"] - NUMPY_DTYPE_TO_STR = {} + _NUMPY_DTYPE_TO_STR = {} def __new__(cls, content): content = str(content) @@ -111,30 +111,30 @@ def lanes(self): # although almost in all cases we want numpy import numpy as np - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" if hasattr(np, "float_"): - dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64" except ImportError: pass try: import ml_dtypes - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2" + dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" except ImportError: pass diff --git a/ffi/python/tvm_ffi/_ffi_api.py b/ffi/python/tvm_ffi/_ffi_api.py index 60bd2463e9ac..1c2326c0fefd 100644 --- a/ffi/python/tvm_ffi/_ffi_api.py +++ b/ffi/python/tvm_ffi/_ffi_api.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI API.""" -from .registry import _init_api +from . import registry - -_init_api("ffi", __name__) +registry.init_ffi_api("ffi", __name__) diff --git a/ffi/python/tvm_ffi/_tensor.py b/ffi/python/tvm_ffi/_tensor.py new file mode 100644 index 000000000000..c0c9a20731f4 --- /dev/null +++ b/ffi/python/tvm_ffi/_tensor.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tensor related objects and functions.""" +# we name it as _tensor.py to avoid potential future case +# if we also want to expose a tensor function in the root namespace + +from numbers import Integral +from . import core +from .core import Device, DLDeviceType, Tensor, from_dlpack +from . import registry +from . import _ffi_api + + +@registry.register_object("ffi.Shape") +class Shape(tuple, core.PyNativeObject): + """Shape tuple that represents `ffi::Shape` returned by a ffi call. + + Note + ---- + This class subclasses `tuple` so it can be used in most places where + tuple is used in python array apis. + """ + + def __new__(cls, content): + if any(not isinstance(x, Integral) for x in content): + raise ValueError("Shape must be a tuple of integers") + val = tuple.__new__(cls, content) + val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content) + return val + + # pylint: disable=no-self-argument + def __from_tvm_ffi_object__(cls, obj): + """Construct from a given tvm object.""" + content = core._shape_obj_get_py_tuple(obj) + val = tuple.__new__(cls, content) + val.__tvm_ffi_object__ = obj + return val + + +def device(device_type, index=None): + """Construct a TVM FFI device with given device type and index + + Parameters + ---------- + device_type: str or int + The device type or name. + + index: int, optional + The device index. + + Returns + ------- + device: tvm_ffi.Device + + Examples + -------- + Device can be used to create reflection of device by + string representation of the device type. + + .. code-block:: python + + assert tvm_ffi.device("cuda:0") == tvm_ffi.device("cuda", 0) + assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0) + """ + return core._CLASS_DEVICE(device_type, index) + + +__all__ = [ + "from_dlpack", + "Tensor", + "device", + "Device", + "DLDeviceType", +] diff --git a/ffi/python/tvm_ffi/container.py b/ffi/python/tvm_ffi/container.py index 157840ba9d46..fedc0a281ba8 100644 --- a/ffi/python/tvm_ffi/container.py +++ b/ffi/python/tvm_ffi/container.py @@ -66,7 +66,29 @@ def getitem_helper(obj, elem_getter, length, idx): @register_object("ffi.Array") class Array(core.Object, collections.abc.Sequence): - """Array container""" + """Array container that represents a sequence of values in ffi. + + {py:func}`tvm_ffi.convert` will map python list/tuple to this class. + + Parameters + ---------- + input_list : Sequence[Any] + The list of values to be stored in the array. + + See Also + -------- + {py:func}`tvm_ffi.convert` + + Examples + -------- + .. code-block:: python + + import tvm_ffi + + a = tvm_ffi.convert([1, 2, 3]) + assert isinstance(a, tvm_ffi.Array) + assert len(a) == 3 + """ def __init__(self, input_list: Sequence[Any]): self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) @@ -150,7 +172,31 @@ def __iter__(self): @register_object("ffi.Map") class Map(core.Object, collections.abc.Mapping): - """Map container.""" + """Map container. + + {py:func}`tvm_ffi.convert` will map python dict to this class. + + Parameters + ---------- + input_dict : Mapping[Any, Any] + The dictionary of values to be stored in the map. + + See Also + -------- + {py:func}`tvm_ffi.convert` + + Examples + -------- + .. code-block:: python + + import tvm_ffi + + amap = tvm_ffi.convert({"a": 1, "b": 2}) + assert isinstance(amap, tvm_ffi.Map) + assert len(amap) == 2 + assert amap["a"] == 1 + assert amap["b"] == 2 + """ def __init__(self, input_dict: Mapping[Any, Any]): list_kvs = [] diff --git a/ffi/python/tvm_ffi/cython/device.pxi b/ffi/python/tvm_ffi/cython/device.pxi index 90d641c44ffa..85740a067a63 100644 --- a/ffi/python/tvm_ffi/cython/device.pxi +++ b/ffi/python/tvm_ffi/cython/device.pxi @@ -16,6 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from enum import IntEnum _CLASS_DEVICE = None @@ -31,19 +32,8 @@ def _create_device_from_tuple(cls, device_type, device_id): return ret -cdef class Device: - """Device is a wrapper around DLDevice. - - Parameters - ---------- - device_type_or_name : Union[str, int] - The string representation of the device type - - device_id : int - The device id - """ - cdef DLDevice cdevice - +class DLDeviceType(IntEnum): + """The enum that maps to DLDeviceType.""" kDLCPU = 1 kDLCUDA = 2 kDLCUDAHost = 3 @@ -59,62 +49,88 @@ cdef class Device: kDLWebGPU = 15 kDLHexagon = 16 - DEVICE_TYPE_TO_NAME = { - kDLCPU: "cpu", - kDLCUDA: "cuda", - kDLCUDAHost: "cuda_host", - kDLCUDAManaged: "cuda_managed", - kDLOpenCL: "opencl", - kDLVulkan: "vulkan", - kDLMetal: "metal", - kDLVPI: "vpi", - kDLROCM: "rocm", - kDLROCMHost: "rocm_host", - kDLExtDev: "ext_dev", - kDLOneAPI: "oneapi", - kDLWebGPU: "webgpu", - kDLHexagon: "hexagon", + +cdef class Device: + """Device represents a device in the ffi system. + + Device is a thin wrapper around DLDevice in DLPack standard. + + Parameters + ---------- + device_type : Union[str, int] + The string representation of the device type + + index : int + The device id + + Examples + -------- + You can use `tvm_ffi.device` function to create a `Device`. + + .. code-block:: python + + assert tvm_ffi.device("cuda:0") == tvm_ffi.device("cuda", 0) + assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0) + """ + cdef DLDevice cdevice + + _DEVICE_TYPE_TO_NAME = { + DLDeviceType.kDLCPU: "cpu", + DLDeviceType.kDLCUDA: "cuda", + DLDeviceType.kDLCUDAHost: "cuda_host", + DLDeviceType.kDLCUDAManaged: "cuda_managed", + DLDeviceType.kDLOpenCL: "opencl", + DLDeviceType.kDLVulkan: "vulkan", + DLDeviceType.kDLMetal: "metal", + DLDeviceType.kDLVPI: "vpi", + DLDeviceType.kDLROCM: "rocm", + DLDeviceType.kDLROCMHost: "rocm_host", + DLDeviceType.kDLExtDev: "ext_dev", + DLDeviceType.kDLOneAPI: "oneapi", + DLDeviceType.kDLWebGPU: "webgpu", + DLDeviceType.kDLHexagon: "hexagon", } - DEVICE_NAME_TO_TYPE = { - "llvm": kDLCPU, - "cpu": kDLCPU, - "c": kDLCPU, - "test": kDLCPU, - "hybrid": kDLCPU, - "composite": kDLCPU, - "cuda": kDLCUDA, - "nvptx": kDLCUDA, - "cl": kDLOpenCL, - "opencl": kDLOpenCL, - "vulkan": kDLVulkan, - "metal": kDLMetal, - "vpi": kDLVPI, - "rocm": kDLROCM, - "ext_dev": kDLExtDev, - "hexagon": kDLHexagon, - "webgpu": kDLWebGPU, + _DEVICE_NAME_TO_TYPE = { + "llvm": DLDeviceType.kDLCPU, + "cpu": DLDeviceType.kDLCPU, + "c": DLDeviceType.kDLCPU, + "test": DLDeviceType.kDLCPU, + "cuda": DLDeviceType.kDLCUDA, + "nvptx": DLDeviceType.kDLCUDA, + "cl": DLDeviceType.kDLOpenCL, + "opencl": DLDeviceType.kDLOpenCL, + "vulkan": DLDeviceType.kDLVulkan, + "metal": DLDeviceType.kDLMetal, + "vpi": DLDeviceType.kDLVPI, + "rocm": DLDeviceType.kDLROCM, + "ext_dev": DLDeviceType.kDLExtDev, + "hexagon": DLDeviceType.kDLHexagon, + "webgpu": DLDeviceType.kDLWebGPU, } - def __init__(self, device_type_or_name, device_id = None): + def __init__(self, device_type, index = None): + device_type_or_name = device_type + index = index if index is not None else 0 if isinstance(device_type_or_name, str): + # skip suffix annotations + device_type_or_name = device_type_or_name.split(" ")[0] parts = device_type_or_name.split(":") if len(parts) < 1 or len(parts) > 2: raise ValueError(f"Invalid device: {device_type_or_name}") - if parts[0] not in self.DEVICE_NAME_TO_TYPE: + if parts[0] not in self._DEVICE_NAME_TO_TYPE: raise ValueError(f"Unknown device: {parts[0]}") - device_type = self.DEVICE_NAME_TO_TYPE[parts[0]] + device_type = self._DEVICE_NAME_TO_TYPE[parts[0]] if len(parts) == 2: try: - device_id = int(parts[1]) + index = int(parts[1]) except ValueError: - raise ValueError(f"Invalid device id: {parts[1]}") + raise ValueError(f"Invalid device index: {parts[1]}") else: device_type = device_type_or_name - device_id = device_id if device_id is not None else 0 - if not isinstance(device_id, int): - raise TypeError(f"Invalid device id: {device_id}") - self.cdevice = TVMFFIDLDeviceFromIntPair(device_type, device_id) + if not isinstance(index, int): + raise TypeError(f"Invalid device index: {index}") + self.cdevice = TVMFFIDLDeviceFromIntPair(device_type, index) def __reduce__(self): cls = type(self) @@ -131,9 +147,6 @@ cdef class Device: def __ne__(self, other): return not self.__eq__(other) - def __device_type_name__(self): - return self.DEVICE_TYPE_TO_NAME[self.cdevice.device_type] - def __str__(self): cdef int dev_type = self.cdevice.device_type name = self.__device_type_name__() @@ -149,14 +162,25 @@ cdef class Device: def __hash__(self): return hash((self.cdevice.device_type, self.cdevice.device_id)) + + def __device_type_name__(self): + return self._DEVICE_TYPE_TO_NAME[self.cdevice.device_type] + @property - def device_type(self): - return self.cdevice.device_type + def type(self): + """String representation of the device type.""" + return self.__device_type_name__() @property - def device_id(self): + def index(self): + """The device index.""" return self.cdevice.device_id + def dlpack_device_type(self): + """The device type int code used in the DLPack specification. + """ + return self.cdevice.device_type + cdef inline object make_ret_device(TVMFFIAny result): ret = _CLASS_DEVICE.__new__(_CLASS_DEVICE) diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index ea10356077da..0161ec4292ab 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -167,7 +167,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, out[i].v_int64 = 0 out[i].v_ptr = (arg).cptr() temp_args.append(arg) - elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): + elif isinstance(arg, (list, tuple, dict, ObjectConvertible)): arg = _FUNC_CONVERT_TO_OBJECT(arg) out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) out[i].v_ptr = (arg).chandle @@ -277,11 +277,11 @@ cdef inline int ConstructorCall(void* constructor_handle, class Function(Object): - """The Function object used in TVM FFI. + """Python class that wraps a function with tvm-ffi ABI. See Also -------- - tvm_ffi.register_func: How to register global function. + tvm_ffi.register_global_func: How to register global function. tvm_ffi.get_global_func: How to get global function. """ def __call__(self, *args): diff --git a/ffi/python/tvm_ffi/cython/object.pxi b/ffi/python/tvm_ffi/cython/object.pxi index fda7f56b23be..2a306e01ee68 100644 --- a/ffi/python/tvm_ffi/cython/object.pxi +++ b/ffi/python/tvm_ffi/cython/object.pxi @@ -43,7 +43,7 @@ _OBJECT_FROM_JSON_GRAPH_STR = None _OBJECT_TO_JSON_GRAPH_STR = None -class ObjectGeneric: +class ObjectConvertible: """Base class for all classes that can be converted to object.""" def asobject(self): @@ -195,7 +195,13 @@ cdef class Object: cdef class OpaquePyObject(Object): - """Opaque PyObject container""" + """Opaque PyObject container + + This is a helper class to store opaque python objects + that will be passed to the ffi functions. + + Users do not need to directly create this class. + """ def pyobject(self): """Get the underlying python object""" cdef object obj diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi index 5544359c9e02..b09ac42eb99c 100644 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ b/ffi/python/tvm_ffi/cython/tensor.pxi @@ -101,6 +101,11 @@ def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): required_contiguous : bool Whether to check for contiguous memory. + + Returns + ------- + tensor : :py:class:`tvm_ffi.Tensor` + The converted tensor. """ cdef TVMFFIObjectHandle chandle # as of most frameworks do not yet support v1.1 @@ -157,14 +162,10 @@ def _shape_obj_get_py_tuple(obj): cdef class Tensor(Object): - """N-dimensional array that is compatible with DLPack. + """Tensor object that represents a managed n-dimensional array. """ cdef DLTensor* cdltensor - @property - def is_view(self): - return self.cdltensor != NULL and self.chandle == NULL - @property def shape(self): """Shape of this array""" @@ -179,22 +180,12 @@ cdef class Tensor(Object): @property def device(self): - """Device of this array""" + """Device of this Tensor""" cdef TVMFFIAny device_any device_any.v_device = self.cdltensor.device return make_ret_device(device_any) - def to_dlpack(self): - """Produce an array from a DLPack Tensor without copying memory - - Returns - ------- - dlpack : DLPack tensor view of the array data - - Note - ---- - This is an old style legacy API, consider use new dlpack api instead. - """ + def _to_dlpack(self): cdef DLManagedTensor* dltensor cdef int c_api_ret_code @@ -248,7 +239,7 @@ cdef class Tensor(Object): # Keep and use the DLPack 0.X implementation # Note: from March 2025 onwards (but ideally as late as # possible), it's okay to raise BufferError here - return self.to_dlpack() + return self._to_dlpack() else: # We get to produce `DLManagedTensorVersioned` now. Note that # our_own_dlpack_version is the max version that the *producer* @@ -261,7 +252,7 @@ cdef class Tensor(Object): raise BufferError("copy not yet supported") return self._to_dlpack_versioned() elif max_version[0] < 1: - return self.to_dlpack() + return self.__ctypes_handle__to_dlpack() else: raise BufferError(f"Unsupported max_version {max_version}") diff --git a/ffi/python/tvm_ffi/module.py b/ffi/python/tvm_ffi/module.py index 684018416e62..56c2a9385517 100644 --- a/ffi/python/tvm_ffi/module.py +++ b/ffi/python/tvm_ffi/module.py @@ -36,7 +36,23 @@ class ModulePropertyMask(IntEnum): @register_object("ffi.Module") class Module(core.Object): - """Runtime Module.""" + """Module container for dynamically loaded Module. + + Example + ------- + .. code-block:: python + + import tvm_ffi + + # load the module from a tvm-ffi shared library + mod : tvm_ffi.Module = tvm_ffi.load_module("path/to/library.so") + # you can use mod.func_name to call the exported function + mod.func_name(*args) + + See Also + -------- + :py:func:`tvm_ffi.load_module` + """ # constant for entry function name entry_name = "main" @@ -242,7 +258,18 @@ def load_module(path): Returns ------- - module : ffi.Module + module : :py:class:`tvm_ffi.Module` The loaded module + + Examples + -------- + .. code-block:: python + + mod = tvm_ffi.load_module("path/to/module.so") + mod.func_name(*args) + + See Also + -------- + :py:class:`tvm_ffi.Module` """ return _ffi_api.ModuleLoadFromFile(path) diff --git a/ffi/python/tvm_ffi/registry.py b/ffi/python/tvm_ffi/registry.py index e2455c3d3384..b43e0dc6bb6b 100644 --- a/ffi/python/tvm_ffi/registry.py +++ b/ffi/python/tvm_ffi/registry.py @@ -60,7 +60,7 @@ def register(cls): return register(type_key) -def register_func(func_name, f=None, override=False): +def register_global_func(func_name, f=None, override=False): """Register global function Parameters @@ -78,6 +78,30 @@ def register_func(func_name, f=None, override=False): ------- fregister : function Register function if f is not specified. + + Examples + -------- + .. code-block:: python + + import tvm_ffi + + # we can use decorator to register a function + @tvm_ffi.register_global_func("mytest.echo") + def echo(x): + return x + # After registering, we can get the function by its name + f = tvm_ffi.get_global_func("mytest.echo") + assert f(1) == 1 + + # we can also directly register a function + tvm_ffi.register_global_func("mytest.add_one", lambda x: x + 1) + f = tvm_ffi.get_global_func("mytest.add_one") + assert f(1) == 2 + + See Also + -------- + :py:func:`tvm_ffi.get_global_func` + :py:func:`tvm_ffi.remove_global_func` """ if callable(func_name): f = func_name @@ -110,6 +134,10 @@ def get_global_func(name, allow_missing=False): ------- func : Function The function to be returned, None if function is missing. + + See Also + -------- + :py:func:`tvm_ffi.register_global_func` """ return core._get_global_func(name, allow_missing) @@ -138,14 +166,33 @@ def remove_global_func(name): get_global_func("ffi.FunctionRemoveGlobal")(name) -def _init_api(namespace, target_module_name=None): - """Initialize api for a given module name +def init_ffi_api(namespace, target_module_name=None): + """Initialize register ffi api functions into a given module + Parameters + ---------- namespace : str The namespace of the source registry target_module_name : str The target module name if different from namespace + + Examples + -------- + + A typical usage pattern is to create a _ffi_api.py file to register + the functions under a given module. The following + code populates all registered global functions + prefixed with ``mypackage.`` into the current module, + then we can call the function through ``_ffi_api.func_name(*args)`` + which will call into the registered global function "mypackage.func_name". + + .. code-block:: python + + # _ffi_api.py + import tvm_ffi + + tvm_ffi.init_ffi_api("mypackage", __name__) """ target_module_name = target_module_name if target_module_name else namespace @@ -171,9 +218,9 @@ def _init_api(namespace, target_module_name=None): __all__ = [ "register_object", - "register_func", + "register_global_func", "get_global_func", "list_global_func_names", "remove_global_func", - "_init_api", + "init_ffi_api", ] diff --git a/ffi/python/tvm_ffi/tensor.py b/ffi/python/tvm_ffi/tensor.py deleted file mode 100644 index 97240c6a499f..000000000000 --- a/ffi/python/tvm_ffi/tensor.py +++ /dev/null @@ -1,255 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Tensor related objects and functions.""" - -from numbers import Integral -from . import core -from .core import Device, Tensor, from_dlpack -from . import registry -from . import _ffi_api - - -@registry.register_object("ffi.Shape") -class Shape(tuple, core.PyNativeObject): - """Shape object that is possibly returned by FFI call.""" - - def __new__(cls, content): - if any(not isinstance(x, Integral) for x in content): - raise ValueError("Shape must be a tuple of integers") - val = tuple.__new__(cls, content) - val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content) - return val - - # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj): - """Construct from a given tvm object.""" - content = core._shape_obj_get_py_tuple(obj) - val = tuple.__new__(cls, content) - val.__tvm_ffi_object__ = obj - return val - - -def device(dev_type, dev_id=0): - """Construct a TVM FFIdevice with given device type and id. - - Parameters - ---------- - dev_type: int or str - The device type mask or name of the device. - - dev_id : int, optional - The integer device id - - Returns - ------- - dev: tvm_ffi.Device - - Examples - -------- - Device can be used to create reflection of device by - string representation of the device type. - - .. code-block:: python - - assert tvm_ffi.device("cuda:0") == tvm_ffi.cuda(1) - assert tvm_ffi.device("cpu", 0) == tvm_ffi.cpu(0) - """ - if isinstance(dev_type, str): - dev_type = dev_type.split(" ")[0] - return core._CLASS_DEVICE(dev_type, dev_id) - - -def cpu(dev_id=0): - """Construct a CPU device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLCPU, dev_id) - - -def cuda(dev_id=0): - """Construct a CUDA GPU device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLCUDA, dev_id) - - -def rocm(dev_id=0): - """Construct a ROCM device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLROCM, dev_id) - - -def opencl(dev_id=0): - """Construct a OpenCL device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLOpenCL, dev_id) - - -def metal(dev_id=0): - """Construct a metal device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLMetal, dev_id) - - -def vpi(dev_id=0): - """Construct a VPI simulated device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLVPI, dev_id) - - -def vulkan(dev_id=0): - """Construct a Vulkan device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLVulkan, dev_id) - - -def ext_dev(dev_id=0): - """Construct a extension device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - - Note - ---- - This API is reserved for quick testing of new - device by plugin device API as ext_dev. - """ - return device(Device.kDLExtDev, dev_id) - - -def hexagon(dev_id=0): - """Construct a Hexagon device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLHexagon, dev_id) - - -def webgpu(dev_id=0): - """Construct a webgpu device. - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLWebGPU, dev_id) - - -__all__ = [ - "from_dlpack", - "Tensor", - "device", - "cpu", - "cuda", - "rocm", - "opencl", - "metal", - "vpi", - "vulkan", - "ext_dev", - "hexagon", - "webgpu", -] diff --git a/ffi/tests/python/test_device.py b/ffi/tests/python/test_device.py index 645738710f30..849f45b8f97d 100644 --- a/ffi/tests/python/test_device.py +++ b/ffi/tests/python/test_device.py @@ -17,22 +17,22 @@ import pytest import pickle -from tvm_ffi import Device +from tvm_ffi import Device, DLDeviceType import tvm_ffi def test_device(): device = tvm_ffi.Device("cuda", 0) - assert device.device_type == tvm_ffi.Device.kDLCUDA - assert device.device_id == 0 + assert device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCUDA + assert device.index == 0 assert str(device) == "cuda:0" assert device.__repr__() == "device(type='cuda', index=0)" def test_device_from_str(): device = tvm_ffi.device("ext_dev:0") - assert device.device_type == tvm_ffi.Device.kDLExtDev - assert device.device_id == 0 + assert device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLExtDev + assert device.index == 0 assert str(device) == "ext_dev:0" assert device.__repr__() == "device(type='ext_dev', index=0)" @@ -40,33 +40,33 @@ def test_device_from_str(): @pytest.mark.parametrize( "dev_str, expected_device_type, expect_device_id", [ - ("cpu", Device.kDLCPU, 0), - ("cuda", Device.kDLCUDA, 0), - ("cuda:0", Device.kDLCUDA, 0), - ("cuda:3", Device.kDLCUDA, 3), - ("metal:2", Device.kDLMetal, 2), + ("cpu", DLDeviceType.kDLCPU, 0), + ("cuda", DLDeviceType.kDLCUDA, 0), + ("cuda:0", DLDeviceType.kDLCUDA, 0), + ("cuda:3", DLDeviceType.kDLCUDA, 3), + ("metal:2", DLDeviceType.kDLMetal, 2), ], ) def test_device(dev_str, expected_device_type, expect_device_id): dev = tvm_ffi.device(dev_str) - assert dev.device_type == expected_device_type - assert dev.device_id == expect_device_id + assert dev.dlpack_device_type() == expected_device_type + assert dev.index == expect_device_id @pytest.mark.parametrize( "dev_type, dev_id, expected_device_type, expect_device_id", [ - ("cpu", 0, Device.kDLCPU, 0), - ("cuda", 0, Device.kDLCUDA, 0), - (Device.kDLCUDA, 0, Device.kDLCUDA, 0), - ("cuda", 3, Device.kDLCUDA, 3), - (Device.kDLMetal, 2, Device.kDLMetal, 2), + ("cpu", 0, DLDeviceType.kDLCPU, 0), + ("cuda", 0, DLDeviceType.kDLCUDA, 0), + (DLDeviceType.kDLCUDA, 0, DLDeviceType.kDLCUDA, 0), + ("cuda", 3, DLDeviceType.kDLCUDA, 3), + (DLDeviceType.kDLMetal, 2, DLDeviceType.kDLMetal, 2), ], ) def test_device_with_dev_id(dev_type, dev_id, expected_device_type, expect_device_id): - dev = tvm_ffi.device(dev_type=dev_type, dev_id=dev_id) - assert dev.device_type == expected_device_type - assert dev.device_id == expect_device_id + dev = tvm_ffi.device(dev_type, dev_id) + assert dev.dlpack_device_type() == expected_device_type + assert dev.index == expect_device_id @pytest.mark.parametrize( @@ -79,16 +79,16 @@ def test_device_with_dev_id(dev_type, dev_id, expected_device_type, expect_devic ) def test_deive_type_error(dev_type, dev_id): with pytest.raises(ValueError): - dev = tvm_ffi.device(dev_type=dev_type, dev_id=dev_id) + dev = tvm_ffi.device(dev_type, dev_id) def test_deive_id_error(): with pytest.raises(TypeError): - dev = tvm_ffi.device(dev_type="cpu", dev_id="?") + dev = tvm_ffi.device("cpu", "?") def test_device_pickle(): device = tvm_ffi.device("cuda", 0) device_pickled = pickle.loads(pickle.dumps(device)) - assert device_pickled.device_type == device.device_type - assert device_pickled.device_id == device.device_id + assert device_pickled.dlpack_device_type() == device.dlpack_device_type() + assert device_pickled.index == device.index diff --git a/ffi/tests/python/test_examples.py b/ffi/tests/python/test_examples.py new file mode 100644 index 000000000000..f8a94636a284 --- /dev/null +++ b/ffi/tests/python/test_examples.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# testcases appearing in example docstrings +import tvm_ffi + + +def test_register_global_func(): + # we can use decorator to register a function + @tvm_ffi.register_global_func("example.echo") + def echo(x): + return x + + # After registering, we can get the function by its name + f = tvm_ffi.get_global_func("example.echo") + assert f(1) == 1 + # we can also directly register a function + tvm_ffi.register_global_func("example.add_one", lambda x: x + 1) + f = tvm_ffi.get_global_func("example.add_one") + assert f(1) == 2 + + +def test_array(): + a = tvm_ffi.convert([1, 2, 3]) + assert isinstance(a, tvm_ffi.Array) + assert len(a) == 3 + + +def test_map(): + amap = tvm_ffi.convert({"a": 1, "b": 2}) + assert isinstance(amap, tvm_ffi.Map) + assert len(amap) == 2 + assert amap["a"] == 1 + assert amap["b"] == 2 diff --git a/ffi/tests/python/test_function.py b/ffi/tests/python/test_function.py index 0b45fe5583b3..dfe22a1bad80 100644 --- a/ffi/tests/python/test_function.py +++ b/ffi/tests/python/test_function.py @@ -58,8 +58,8 @@ def test_echo(): # test device device_result = fecho(tvm_ffi.device("cuda:1")) assert isinstance(device_result, tvm_ffi.Device) - assert device_result.device_type == tvm_ffi.Device.kDLCUDA - assert device_result.device_id == 1 + assert device_result.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCUDA + assert device_result.index == 1 assert str(device_result) == "cuda:1" assert device_result.__repr__() == "device(type='cuda', index=1)" @@ -85,8 +85,8 @@ def check_tensor(): assert isinstance(tensor_result, tvm_ffi.Tensor) assert tensor_result.shape == (10,) assert tensor_result.dtype == tvm_ffi.dtype("int32") - assert tensor_result.device.device_type == tvm_ffi.Device.kDLCPU - assert tensor_result.device.device_id == 0 + assert tensor_result.device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCPU + assert tensor_result.device.index == 0 check_tensor() @@ -113,7 +113,7 @@ def fapply(f, *args): def test_global_func(): - @tvm_ffi.register_func("mytest.echo") + @tvm_ffi.register_global_func("mytest.echo") def echo(x): return x diff --git a/ffi/tests/python/test_string.py b/ffi/tests/python/test_string.py index f334bc4fadba..feaa9584d2fc 100644 --- a/ffi/tests/python/test_string.py +++ b/ffi/tests/python/test_string.py @@ -21,7 +21,7 @@ def test_string(): fecho = tvm_ffi.get_global_func("testing.echo") - s = tvm_ffi.String("hello") + s = tvm_ffi.core.String("hello") s2 = fecho(s) assert s2 == "hello" s3 = tvm_ffi.convert("hello") @@ -36,19 +36,19 @@ def test_string(): def test_bytes(): fecho = tvm_ffi.get_global_func("testing.echo") - b = tvm_ffi.Bytes(b"hello") - assert isinstance(b, tvm_ffi.Bytes) + b = tvm_ffi.core.Bytes(b"hello") + assert isinstance(b, tvm_ffi.core.Bytes) b2 = fecho(b) assert b2 == b"hello" b3 = tvm_ffi.convert(b"hello") - assert isinstance(b3, tvm_ffi.Bytes) + assert isinstance(b3, tvm_ffi.core.Bytes) assert isinstance(b3, bytes) b4 = tvm_ffi.convert(bytearray(b"hello")) - assert isinstance(b4, tvm_ffi.Bytes) + assert isinstance(b4, tvm_ffi.core.Bytes) assert isinstance(b4, bytes) b5 = pickle.loads(pickle.dumps(b)) assert b5 == b"hello" - assert isinstance(b5, tvm_ffi.Bytes) + assert isinstance(b5, tvm_ffi.core.Bytes) diff --git a/ffi/tests/python/test_tensor.py b/ffi/tests/python/test_tensor.py index 2e2a99940017..aa2482f88852 100644 --- a/ffi/tests/python/test_tensor.py +++ b/ffi/tests/python/test_tensor.py @@ -33,8 +33,8 @@ def test_tensor_attributes(): assert isinstance(x, tvm_ffi.Tensor) assert x.shape == (10, 8, 4, 2) assert x.dtype == tvm_ffi.dtype("int16") - assert x.device.device_type == tvm_ffi.Device.kDLCPU - assert x.device.device_id == 0 + assert x.device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCPU + assert x.device.index == 0 x2 = np.from_dlpack(x) np.testing.assert_equal(x2, data) @@ -61,8 +61,8 @@ def check(x, y): assert isinstance(y, tvm_ffi.Tensor) assert y.shape == (128,) assert y.dtype == tvm_ffi.dtype("int64") - assert y.device.device_type == tvm_ffi.Device.kDLCPU - assert y.device.device_id == 0 + assert y.device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCPU + assert y.device.index == 0 x2 = torch.from_dlpack(y) np.testing.assert_equal(x2.numpy(), x.numpy()) diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 337f8dc4cbc2..8af3f77539fe 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -107,15 +107,15 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in hint_on_device */ struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { - int32_t dev_type; - int32_t dev_id; + int32_t device_type; + int32_t index; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("dev_type", &HintOnDeviceAttrs::dev_type, + .def_ro("device_type", &HintOnDeviceAttrs::device_type, "The device type where the data is supposed to be executed.") - .def_ro("dev_id", &HintOnDeviceAttrs::dev_id, "The device id."); + .def_ro("index", &HintOnDeviceAttrs::index, "The device id."); } static constexpr const char* _type_key = "relax.attrs.HintOnDeviceAttrs"; diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index c3c8c559c84f..55c78e43c07b 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -21,7 +21,7 @@ import os # ffi module must load first -from tvm_ffi import register_object, register_func, get_global_func +from tvm_ffi import register_object, register_global_func, get_global_func # top-level alias from .base import TVMError, __version__, _RUNTIME_ONLY diff --git a/python/tvm/arith/_ffi_api.py b/python/tvm/arith/_ffi_api.py index aa9883934995..519423aa4e1f 100644 --- a/python/tvm/arith/_ffi_api.py +++ b/python/tvm/arith/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("arith", __name__) +tvm_ffi.init_ffi_api("arith", __name__) diff --git a/python/tvm/contrib/coreml_runtime.py b/python/tvm/contrib/coreml_runtime.py index 34e0681d3162..1d185059f0bd 100644 --- a/python/tvm/contrib/coreml_runtime.py +++ b/python/tvm/contrib/coreml_runtime.py @@ -35,7 +35,7 @@ def create(symbol, compiled_model_path, device): coreml_runtime : CoreMLModule Runtime coreml module that can be used to execute the coreml model. """ - device_type = device.device_type + device_type = device.dlpack_device_type() runtime_func = "tvm.coreml_runtime.create" if device_type >= rpc_base.RPC_SESS_MASK: diff --git a/python/tvm/contrib/cutlass/_ffi_api.py b/python/tvm/contrib/cutlass/_ffi_api.py index 25393a8f99f8..d57825835b6b 100644 --- a/python/tvm/contrib/cutlass/_ffi_api.py +++ b/python/tvm/contrib/cutlass/_ffi_api.py @@ -17,4 +17,4 @@ """FFI API for CUTLASS BYOC.""" import tvm_ffi -tvm_ffi._init_api("contrib.cutlass", __name__) +tvm_ffi.init_ffi_api("contrib.cutlass", __name__) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 294ab36b2088..4b2a50a5f1d8 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -23,7 +23,7 @@ import os from functools import reduce from typing import Optional, Sequence -from tvm_ffi import register_func +from tvm_ffi import register_global_func import tvm from tvm import relax, runtime @@ -821,7 +821,7 @@ def visit_span(self, span): return span -@register_func("contrib.cutlass.tune_relax_function") +@register_global_func("contrib.cutlass.tune_relax_function") def profile_relax_function(functions, options): """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates.""" tmp_dir = options.get("tmp_dir", "./tmp") @@ -840,7 +840,7 @@ def profile_relax_function(functions, options): return annotated_functions -@register_func("contrib.cutlass.compile") +@register_global_func("contrib.cutlass.compile") def compile_cutlass_module(c_source_module, options): """Compile all CUTLASS kernels in the given C-source module. diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index e10abf113ea2..3a875ce220d0 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -461,7 +461,7 @@ def _get_optional_int_annotation(annotations, key, default=None): return int(value) -@tvm_ffi.register_func("contrib.cutlass.instantiate_template") +@tvm_ffi.register_global_func("contrib.cutlass.instantiate_template") def instantiate_template(func_name, annotations, func_args): """Return CUTLASS host code based on a template and the provided annotations. diff --git a/python/tvm/contrib/hexagon/tools.py b/python/tvm/contrib/hexagon/tools.py index d84c18aaf73e..f010461df082 100644 --- a/python/tvm/contrib/hexagon/tools.py +++ b/python/tvm/contrib/hexagon/tools.py @@ -29,7 +29,7 @@ import tvm import tvm.contrib.cc as cc -from tvm_ffi import register_func +from tvm_ffi import register_global_func # Linking Hexagon shared libraries. @@ -67,10 +67,10 @@ def register_linker(f): """Register a function that will return the path to the Hexagon linker.""" - return register_func("tvm.contrib.hexagon.hexagon_link", f, True) + return register_global_func("tvm.contrib.hexagon.hexagon_link", f, True) -@register_func("tvm.contrib.hexagon.hexagon_link") +@register_global_func("tvm.contrib.hexagon.hexagon_link") def hexagon_link() -> str: """Return path to the Hexagon linker.""" return str(HEXAGON_LINK_MAIN) @@ -112,7 +112,7 @@ def toolchain_version(toolchain=None) -> List[int]: raise RuntimeError("Cannot establish toolchain version") -@register_func("tvm.contrib.hexagon.link_shared") +@register_global_func("tvm.contrib.hexagon.link_shared") def link_shared(so_name, objs, extra_args=None): """Link shared library on Hexagon using the registered Hexagon linker. @@ -248,10 +248,10 @@ def __create_shared_mac(so_name, objs, **kwargs): return link_shared_macos(so_name, objs, kwargs) create_shared = __create_shared_mac - register_func("tvm.contrib.hexagon.link_shared", f=link_shared_macos, override=True) + register_global_func("tvm.contrib.hexagon.link_shared", f=link_shared_macos, override=True) else: # Linux and Win32 create_shared = cc.create_shared - register_func("tvm.contrib.hexagon.link_shared", f=link_shared, override=True) + register_global_func("tvm.contrib.hexagon.link_shared", f=link_shared, override=True) def create_aot_shared(so_name: Union[str, pathlib.Path], files, hexagon_arch: str, options=None): diff --git a/python/tvm/contrib/mrvl.py b/python/tvm/contrib/mrvl.py index 2c67bcdaf55b..996f6f881882 100644 --- a/python/tvm/contrib/mrvl.py +++ b/python/tvm/contrib/mrvl.py @@ -26,7 +26,7 @@ import tvm_ffi -@tvm_ffi.register_func("tvm.mrvl.find_value_in_KV_pair") +@tvm_ffi.register_global_func("tvm.mrvl.find_value_in_KV_pair") def find_value_in_KV_pair(json_input: str, key_to_find: str) -> str: """This function takes the graph_json string and key to be searched in the json string, using json parser routine it loads the json string @@ -53,7 +53,7 @@ def find_value_in_KV_pair(json_input: str, key_to_find: str) -> str: return value -@tvm_ffi.register_func("tvm.mrvl.GetNodesJSONString") +@tvm_ffi.register_global_func("tvm.mrvl.GetNodesJSONString") def get_nodes_json_string(graph_json): """This takes the graph_json string from MrvlJSONSerializer and adds / modifies the json string to a form suitable for the Marvell Backend. @@ -205,7 +205,7 @@ def get_nodes_json_string(graph_json): return nodes_json_string -@tvm_ffi.register_func("tvm.mrvl.ModifyConstNames") +@tvm_ffi.register_global_func("tvm.mrvl.ModifyConstNames") def modify_const_names(nodes_json_str, consts_json_str): """This takes the graph module returned by build an generates nodes and constant meta data suitable for compilation by the back end. @@ -328,7 +328,7 @@ def get_working_dir(): return os.getcwd() -@tvm_ffi.register_func("tvm.mrvl.WriteJsonFile") +@tvm_ffi.register_global_func("tvm.mrvl.WriteJsonFile") def write_json_file(json_string, json_filename): """Generate json file under working directory""" working_dir = get_working_dir() @@ -350,7 +350,7 @@ def delete_temp_files(symbol_name): shutil.rmtree(bin_folder) -@tvm_ffi.register_func("tvm.mrvl.CompileModel") +@tvm_ffi.register_global_func("tvm.mrvl.CompileModel") def compile_model( symbol_name, nodes_json_string, @@ -413,7 +413,7 @@ def compile_model( raise RuntimeError(error_msg) -@tvm_ffi.register_func("tvm.mrvl.CleanUpSim") +@tvm_ffi.register_global_func("tvm.mrvl.CleanUpSim") def clean_up_sim(bin_file, input_json, input_bin, out_bin_prefix, num_outputs): os.remove(bin_file) os.remove(input_json) @@ -423,7 +423,7 @@ def clean_up_sim(bin_file, input_json, input_bin, out_bin_prefix, num_outputs): os.remove(out_bin) -@tvm_ffi.register_func("tvm.mrvl.SearchPath") +@tvm_ffi.register_global_func("tvm.mrvl.SearchPath") def search_path(file_name): path = shutil.which(file_name) if path is None: @@ -431,7 +431,7 @@ def search_path(file_name): return os.path.dirname(path) -@tvm_ffi.register_func("tvm.mrvl.JsonToBin") +@tvm_ffi.register_global_func("tvm.mrvl.JsonToBin") def convert_json_to_bin(json_file, input_bin_file): with open(json_file) as input_json: data = json.load(input_json) @@ -441,7 +441,7 @@ def convert_json_to_bin(json_file, input_bin_file): f.write(data_b) -@tvm_ffi.register_func("tvm.mrvl.RunSim") +@tvm_ffi.register_global_func("tvm.mrvl.RunSim") def run_simulation(run_command, sim_directory): cwd_path = get_working_dir() os.mkdir(sim_directory) @@ -451,6 +451,6 @@ def run_simulation(run_command, sim_directory): shutil.rmtree(sim_directory) -@tvm_ffi.register_func("tvm.mrvl.TempDir") +@tvm_ffi.register_global_func("tvm.mrvl.TempDir") def get_temp_dir(): return tempfile.gettempdir() diff --git a/python/tvm/contrib/msc/core/_ffi_api.py b/python/tvm/contrib/msc/core/_ffi_api.py index a8f36146397d..ff027a0dec8e 100644 --- a/python/tvm/contrib/msc/core/_ffi_api.py +++ b/python/tvm/contrib/msc/core/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("msc.core", __name__) +tvm_ffi.init_ffi_api("msc.core", __name__) diff --git a/python/tvm/contrib/msc/core/tools/execute.py b/python/tvm/contrib/msc/core/tools/execute.py index 2a47d755619e..dce9b1f1316f 100644 --- a/python/tvm/contrib/msc/core/tools/execute.py +++ b/python/tvm/contrib/msc/core/tools/execute.py @@ -214,7 +214,7 @@ def process_tensor(tensor: Any, name: str, consumer: str, scope: str, tag: str = return tensor -@tvm.register_func("msc_tool.codegen_tensor") +@tvm.register_global_func("msc_tool.codegen_tensor") def codegen_tensor( tensor_ctx: Dict[str, str], name: str, consumer: str, scope: str, tag: str = "main" ) -> List[str]: @@ -356,7 +356,7 @@ def _execute_step_with_context( return step_ctx -@tvm.register_func("msc_tool.codegen_step") +@tvm.register_global_func("msc_tool.codegen_step") def codegen_step( step_ctx: Dict[str, str], step: str, graph_name: str, tag: str = "main" ) -> List[str]: @@ -384,7 +384,7 @@ def codegen_step( return step_ctx["processed"] -@tvm.register_func("msc_tool.callback_step") +@tvm.register_global_func("msc_tool.callback_step") def callback_step(step_ctx: Dict[str, Any], step: str, graph_name: str = "main", tag: str = "main"): """Execute tools for a step diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index b4301beeb53e..65ed51f80f4c 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -47,9 +47,9 @@ def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: if isinstance(data, np.ndarray): return MSCFramework.MSC, "tensor", "cpu" if isinstance(data, tvm.runtime.Tensor): - device = tvm.runtime.Device.DEVICE_TYPE_TO_NAME[data.device.device_type] - if data.device.device_id: - device += ":{}".format(data.device.device_id) + device = tvm.runtime.Device._DEVICE_TYPE_TO_NAME[data.device.dlpack_device_type()] + if data.device.index: + device += ":{}".format(data.device.index) return MSCFramework.TVM, "tensor", device if isinstance(data, tvm.relax.Var): return MSCFramework.TVM, "var", "cpu" diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py index be82e1d0907a..4f7dcc3688ef 100644 --- a/python/tvm/contrib/msc/core/utils/register.py +++ b/python/tvm/contrib/msc/core/utils/register.py @@ -58,7 +58,7 @@ def reset(cls): cls.REGISTERY = {} -def register_func(name: str, func: callable, framework: str = MSCFramework.MSC): +def register_global_func(name: str, func: callable, framework: str = MSCFramework.MSC): """Register a func for framework. Parameters diff --git a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py index fef10823decb..f7cd2ea43e3e 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("msc.framework.tensorflow", __name__) +tvm_ffi.init_ffi_api("msc.framework.tensorflow", __name__) diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py index eeee4635ab4e..49e231b7a524 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py @@ -195,7 +195,7 @@ def load_native(cls, model: Any, config: dict) -> Tuple[tf_v1.GraphDef, str, boo "Load native model {} with type {} is not supported".format(model, type(model)) ) device_protos = device_lib.list_local_devices() - if any(dev.device_type == "GPU" for dev in device_protos): + if any(dev.dlpack_device_type() == "GPU" for dev in device_protos): device = "cuda" else: device = "cpu" @@ -301,5 +301,5 @@ def support_device(cls, device: str) -> bool: return True if device.startswith("cuda"): device_protos = device_lib.list_local_devices() - return any(dev.device_type == "GPU" for dev in device_protos) + return any(dev.dlpack_device_type() == "GPU" for dev in device_protos) return False diff --git a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py index 4dc13bd24bb1..a09ab875fbed 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("msc.framework.tensorrt", __name__) +tvm_ffi.init_ffi_api("msc.framework.tensorrt", __name__) diff --git a/python/tvm/contrib/msc/framework/torch/_ffi_api.py b/python/tvm/contrib/msc/framework/torch/_ffi_api.py index 9ea5136048ce..d1f27a53bdcf 100644 --- a/python/tvm/contrib/msc/framework/torch/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/torch/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("msc.framework.torch", __name__) +tvm_ffi.init_ffi_api("msc.framework.torch", __name__) diff --git a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py index dc75eed41883..c9f63e21eaef 100644 --- a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("msc.framework.tvm", __name__) +tvm_ffi.init_ffi_api("msc.framework.tvm", __name__) diff --git a/python/tvm/contrib/msc/plugin/_ffi_api.py b/python/tvm/contrib/msc/plugin/_ffi_api.py index 8bb42c8c029f..88f9204f3a02 100644 --- a/python/tvm/contrib/msc/plugin/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("msc.plugin", __name__) +tvm_ffi.init_ffi_api("msc.plugin", __name__) diff --git a/python/tvm/contrib/msc/plugin/op/_ffi_api.py b/python/tvm/contrib/msc/plugin/op/_ffi_api.py index 68704bb1785f..8ca5071cdaf6 100644 --- a/python/tvm/contrib/msc/plugin/op/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/op/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("msc.plugin.op", __name__) +tvm_ffi.init_ffi_api("msc.plugin.op", __name__) diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py index 743f911b48c8..f3a23e55db0c 100644 --- a/python/tvm/contrib/ndk.py +++ b/python/tvm/contrib/ndk.py @@ -25,7 +25,7 @@ import tempfile from pathlib import Path -from tvm_ffi import register_func +from tvm_ffi import register_global_func from ..base import py_str from . import utils as _utils, tar as _tar, cc as _cc from .cc import get_target_by_dump_machine @@ -157,7 +157,7 @@ def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]: return _cc.get_global_symbol_section_map(path, nm=nm) -@register_func("meta_schedule.builder.export_ndk") +@register_global_func("meta_schedule.builder.export_ndk") def _ndk_export(mod): tmp_dir = tempfile.mkdtemp() binary_name = "tmp_binary.so" diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 1f1077bf41c1..a0aba75b019b 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -232,4 +232,4 @@ def convolution_inference_weight_transform( ) -tvm_ffi._init_api("tvm.contrib.nnpack") +tvm_ffi.init_ffi_api("tvm.contrib.nnpack") diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index cbc88f0ab4f1..e20eb37daed4 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -312,14 +312,14 @@ def find_nvshmem_paths() -> Tuple[str, str]: raise RuntimeError("\n".join(error_message)) -@tvm_ffi.register_func +@tvm_ffi.register_global_func def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument """use nvcc to generate fatbin code for better optimization""" ptx = compile_cuda(code, target_format="fatbin") return ptx -@tvm_ffi.register_func("tvm_callback_libdevice_path") +@tvm_ffi.register_global_func("tvm_callback_libdevice_path") def find_libdevice_path(arch): """Utility function to find libdevice @@ -384,7 +384,7 @@ def callback_libdevice_path(arch): return "" -@tvm_ffi.register_func("tvm.contrib.nvcc.get_compute_version") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version") def get_target_compute_version(target=None): """Utility function to get compute capability of compilation target. @@ -529,7 +529,7 @@ def have_cudagraph(): return False -@tvm_ffi.register_func("tvm.contrib.nvcc.supports_bf16") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16") def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not @@ -545,7 +545,7 @@ def have_bf16(compute_version): return False -@tvm_ffi.register_func("tvm.contrib.nvcc.supports_fp8") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8") def have_fp8(compute_version): """Whether fp8 support is provided in the specified compute capability or not @@ -563,7 +563,7 @@ def have_fp8(compute_version): return False -@tvm_ffi.register_func("tvm.contrib.nvcc.supports_fp4") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp4") def have_fp4(compute_version): """Whether fp4 support is provided in the specified compute capability or not diff --git a/python/tvm/contrib/random.py b/python/tvm/contrib/random.py index 48263992515d..681978ff7132 100644 --- a/python/tvm/contrib/random.py +++ b/python/tvm/contrib/random.py @@ -112,4 +112,4 @@ def normal(loc, scale, size): ) -tvm_ffi._init_api("tvm.contrib.random") +tvm_ffi.init_ffi_api("tvm.contrib.random") diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index ee9f9e9b79a4..38e74b660c51 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -99,7 +99,7 @@ def rocm_link(in_file, out_file, lld=None): raise RuntimeError(msg) -@tvm_ffi.register_func("tvm_callback_rocm_link") +@tvm_ffi.register_global_func("tvm_callback_rocm_link") def callback_rocm_link(obj_bin): """Links object file generated from LLVM to HSA Code Object @@ -123,7 +123,7 @@ def callback_rocm_link(obj_bin): return cobj_bin -@tvm_ffi.register_func("tvm_callback_rocm_bitcode_path") +@tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path") def callback_rocm_bitcode_path(rocdl_dir=None): """Utility function to find ROCm device library bitcodes @@ -227,7 +227,7 @@ def have_matrixcore(compute_version=None): return False -@tvm_ffi.register_func("tvm_callback_rocm_get_arch") +@tvm_ffi.register_global_func("tvm_callback_rocm_get_arch") def get_rocm_arch(rocm_path=None): """Utility function to get the AMD GPU architecture diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 076946214678..f3f5bf4c21fa 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -35,7 +35,7 @@ def create(tflite_model_bytes, device, runtime_target="cpu"): tflite_runtime : TFLiteModule Runtime tflite module that can be used to execute the tflite model. """ - device_type = device.device_type + device_type = device.dlpack_device_type() if runtime_target == "edge_tpu": runtime_func = "tvm.edgetpu_runtime.create" diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index a72eafd2bf75..a40c0cfbb07e 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -279,8 +279,8 @@ def dump_tensor_cache( # prefer to preserve original dtype, especially if the format was bfloat16 dtype = origin_v.dtype if isinstance(origin_v, tvm.runtime.Tensor) else v.dtype - if dtype in DataType.NUMPY_DTYPE_TO_STR: - dtype = DataType.NUMPY_DTYPE_TO_STR[dtype] + if dtype in DataType._NUMPY_DTYPE_TO_STR: + dtype = DataType._NUMPY_DTYPE_TO_STR[dtype] else: dtype = str(dtype) diff --git a/python/tvm/dlight/benchmark/bench.py b/python/tvm/dlight/benchmark/bench.py index ea9f4299b24f..b600e7efb783 100644 --- a/python/tvm/dlight/benchmark/bench.py +++ b/python/tvm/dlight/benchmark/bench.py @@ -143,7 +143,7 @@ def benchmark( _, profile_result = rpc_run( rt_mod, - device_type=dev.DEVICE_TYPE_TO_NAME[dev.device_type], + device_type=dev._DEVICE_TYPE_TO_NAME[dev.dlpack_device_type()], args=[w.numpy() if isinstance(w, tvm.runtime.Tensor) else w for w in input_tensors], rpc_config=rpc_config, evaluator_config=evaluator_config, diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index b3853345f0a3..e56426fd5182 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.driver""" import tvm_ffi -tvm_ffi._init_api("driver", __name__) +tvm_ffi.init_ffi_api("driver", __name__) diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py index 9c47627548ab..5b20480decd4 100644 --- a/python/tvm/exec/disco_worker.py +++ b/python/tvm/exec/disco_worker.py @@ -22,44 +22,44 @@ from typing import Callable import tvm -from tvm_ffi import get_global_func, register_func +from tvm_ffi import get_global_func, register_global_func from tvm.runtime import Tensor, ShapeTuple, String from tvm.runtime.tensor import tensor -@register_func("tests.disco.add_one", override=True) +@register_global_func("tests.disco.add_one", override=True) def _add_one(x: int) -> int: return x + 1 -@register_func("tests.disco.add_one_float", override=True) +@register_global_func("tests.disco.add_one_float", override=True) def _add_one_float(x: float): return x + 0.5 -@register_func("tests.disco.add_one_tensor", override=True) +@register_global_func("tests.disco.add_one_tensor", override=True) def _add_one_tensor(x: Tensor) -> Tensor: return tensor(x.numpy() + 1) -@register_func("tests.disco.str", override=True) +@register_global_func("tests.disco.str", override=True) def _str_func(x: str): return x + "_suffix" -@register_func("tests.disco.str_obj", override=True) +@register_global_func("tests.disco.str_obj", override=True) def _str_obj_func(x: str): assert isinstance(x, str) return String(x + "_suffix") -@register_func("tests.disco.shape_tuple", override=True) +@register_global_func("tests.disco.shape_tuple", override=True) def _shape_tuple_func(x: ShapeTuple): assert isinstance(x, ShapeTuple) return ShapeTuple(list(x) + [4, 5]) -@register_func("tests.disco.test_callback", override=True) +@register_global_func("tests.disco.test_callback", override=True) def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], Tensor]: """For use in tests/python/disco/test_callback.py diff --git a/python/tvm/ir/_ffi_analysis_api.py b/python/tvm/ir/_ffi_analysis_api.py index 6ba65fe2649e..9d7c12332c18 100644 --- a/python/tvm/ir/_ffi_analysis_api.py +++ b/python/tvm/ir/_ffi_analysis_api.py @@ -19,4 +19,4 @@ import tvm_ffi -tvm_ffi._init_api("ir.analysis", __name__) +tvm_ffi.init_ffi_api("ir.analysis", __name__) diff --git a/python/tvm/ir/_ffi_api.py b/python/tvm/ir/_ffi_api.py index 6165d5ea0b18..798e69fca507 100644 --- a/python/tvm/ir/_ffi_api.py +++ b/python/tvm/ir/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("ir", __name__) +tvm_ffi.init_ffi_api("ir", __name__) diff --git a/python/tvm/ir/_ffi_instrument_api.py b/python/tvm/ir/_ffi_instrument_api.py index af0a0ea3ebd5..18aea5cf8a2f 100644 --- a/python/tvm/ir/_ffi_instrument_api.py +++ b/python/tvm/ir/_ffi_instrument_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.instrument""" import tvm_ffi -tvm_ffi._init_api("instrument", __name__) +tvm_ffi.init_ffi_api("instrument", __name__) diff --git a/python/tvm/ir/_ffi_transform_api.py b/python/tvm/ir/_ffi_transform_api.py index eda8d5354b23..8a2f517e2145 100644 --- a/python/tvm/ir/_ffi_transform_api.py +++ b/python/tvm/ir/_ffi_transform_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("transform", __name__) +tvm_ffi.init_ffi_api("transform", __name__) diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index 3a131b2a14c0..4a521dfa587e 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -24,7 +24,7 @@ import enum import tvm_ffi from . import _ffi_api -from ... import get_global_func, register_func, Object +from ... import get_global_func, register_global_func, Object def get_renderer(): @@ -38,7 +38,7 @@ def get_renderer(): return _ffi_api.GetRenderer() -@tvm_ffi.register_func("diagnostics.override_renderer") +@tvm_ffi.register_global_func("diagnostics.override_renderer") def override_renderer(render_func): """ Sets a custom renderer for diagnostics. @@ -54,7 +54,7 @@ def override_renderer(render_func): def _render_factory(): return DiagnosticRenderer(render_func) - register_func("diagnostics.OverrideRenderer", _render_factory, override=True) + register_global_func("diagnostics.OverrideRenderer", _render_factory, override=True) else: _ffi_api.ClearRenderer() diff --git a/python/tvm/ir/diagnostics/_ffi_api.py b/python/tvm/ir/diagnostics/_ffi_api.py index 0232cac91462..65fb2cc896f3 100644 --- a/python/tvm/ir/diagnostics/_ffi_api.py +++ b/python/tvm/ir/diagnostics/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("diagnostics", __name__) +tvm_ffi.init_ffi_api("diagnostics", __name__) diff --git a/python/tvm/meta_schedule/_ffi_api.py b/python/tvm/meta_schedule/_ffi_api.py index bb07a225735c..1a06aef5a482 100644 --- a/python/tvm/meta_schedule/_ffi_api.py +++ b/python/tvm/meta_schedule/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.meta_schedule""" -from tvm_ffi import _init_api +import tvm_ffi -_init_api("meta_schedule", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("meta_schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index cda8d21838cb..6bd8f10ed810 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -19,7 +19,7 @@ import tempfile from typing import Callable, Dict, List, Optional, Union -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.ir import IRModule from tvm.runtime import Module, Tensor, load_param_dict, save_param_dict from tvm.target import Target @@ -234,7 +234,7 @@ def _worker_func( return artifact_path -@register_func("meta_schedule.builder.default_build") +@register_global_func("meta_schedule.builder.default_build") def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, Tensor]]) -> Module: """Default build function. @@ -261,7 +261,7 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, Ten return tvm_build(mod, target=target) -@register_func("meta_schedule.builder.default_export") +@register_global_func("meta_schedule.builder.default_export") def default_export(mod: Module) -> str: """Default export function. @@ -282,7 +282,7 @@ def default_export(mod: Module) -> str: return artifact_path -@register_func("meta_schedule.builder.get_local_builder") +@register_global_func("meta_schedule.builder.get_local_builder") def get_local_builder() -> LocalBuilder: """Get the local builder. diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index 92e0e24a4cc3..dc78d2400a74 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -23,7 +23,7 @@ # isort: on -from tvm_ffi import get_global_func, register_func +from tvm_ffi import get_global_func, register_global_func from tvm.ir import IRModule from tvm.ir.transform import PassContext from tvm.runtime import Tensor @@ -269,7 +269,7 @@ def tune_relax( ) -@register_func("tvm.meta_schedule.tune_relax") +@register_global_func("tvm.meta_schedule.tune_relax") def _tune_relax( mod: Union[IRModule, "relax.Function"], params: Dict[str, Tensor], diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index 7ff1065a191f..b35e47c94dda 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -148,7 +148,7 @@ def resource_handler(): rt_mod = tvm.runtime.load_module(artifact_path) # Step 2: Allocate input arguments with Profiler.timeit("LocalRunner/alloc_argument"): - device = tvm.runtime.device(dev_type=device_type, dev_id=0) + device = tvm.runtime.device(device_type, 0) repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( device, args_info, @@ -392,7 +392,7 @@ def default_cleanup() -> None: pass # pylint: disable=unnecessary-pass -@tvm.register_func("meta_schedule.runner.get_local_runner") +@tvm.register_global_func("meta_schedule.runner.get_local_runner") def get_local_builder() -> LocalRunner: """Get the local Runner. diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index b249be7ded74..9d61a7b0b4d6 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -384,7 +384,7 @@ def resource_handler(): # Step 1. Create session with Profiler.timeit("RPCRunner/create_session"): session = f_create_session(rpc_config) - device = session.device(dev_type=device_type, dev_id=0) + device = session.device(device_type, 0) # Step 2. Upload the module with Profiler.timeit("RPCRunner/upload_module"): _, remote_path = osp.split(artifact_path) diff --git a/python/tvm/meta_schedule/schedule/cuda/layout_transform.py b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py index 949ef915c9ff..58540839397d 100644 --- a/python/tvm/meta_schedule/schedule/cuda/layout_transform.py +++ b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py @@ -501,7 +501,7 @@ def get_max_tile_size() -> int: return max_tile_size -@tvm.register_func("meta_schedule.cuda.layout_transform") +@tvm.register_global_func("meta_schedule.cuda.layout_transform") def cuda_layout_transform_schedule_rule( sch: tvm.tir.Schedule, block: BlockRV, testing_tile_sizes: Optional[List[int]] = None ) -> List[tvm.tir.Schedule]: diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 2da672b40561..490929402dc7 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -48,6 +48,6 @@ def run_module_via_rpc( session.upload(filename) _, filename = os.path.split(filename) rt_mod = session.load_module(filename) - dev = session.device(dev_type=dev_type, dev_id=0) + dev = session.device(dev_type, 0) nd_args = {k: ndarray.array(v, dev) for k, v in args.items()} return continuation(rt_mod, dev, nd_args) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index e356e6c75358..8b5a87f61932 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -22,7 +22,7 @@ from statistics import mean from typing import Callable, Tuple, Union, List, Any import numpy as np # type: ignore -from tvm_ffi import get_global_func, register_func +from tvm_ffi import get_global_func, register_global_func import tvm @@ -203,7 +203,7 @@ def __hash__(self) -> int: def initializer() -> None: """Initializer function to register the functions on PopenWorker.""" - @register_func("tvm.meta_schedule.testing.default_check_metric") + @register_global_func("tvm.meta_schedule.testing.default_check_metric") def default_check_metric( # pylint: disable=unused-variable,unreachable-code lhs: List[tvm.runtime.Tensor], rhs: List[tvm.runtime.Tensor] ) -> bool: @@ -229,7 +229,7 @@ def default_check_metric( # pylint: disable=unused-variable,unreachable-code return True -@register_func("tvm.meta_schedule.testing.default_input_generator") +@register_global_func("tvm.meta_schedule.testing.default_input_generator") def default_input_generator( # pylint: disable=unused-variable mod: IRModule, ) -> List[tvm.runtime.Tensor]: diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index 7a9ccb404016..69a71ba3d6d9 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -19,7 +19,7 @@ # isort: off from typing_extensions import Literal -from tvm_ffi import register_func +from tvm_ffi import register_global_func # isort: on from tvm import ir, tir @@ -161,7 +161,7 @@ def tune_tir( # pylint: disable=too-many-locals ) -@register_func("tvm.meta_schedule.tune_tir") +@register_global_func("tvm.meta_schedule.tune_tir") def _tune_tir( mod: Union[ir.IRModule, tir.PrimFunc], target: Union[str, Target], diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 08faf86dc5c8..34527f409ec0 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -21,7 +21,7 @@ # isort: off from typing_extensions import Literal -from tvm_ffi import register_object, register_func +from tvm_ffi import register_object, register_global_func # isort: on @@ -42,7 +42,7 @@ from .space_generator import SpaceGenerator -@register_func("tvm.meta_schedule.normalize_mod") +@register_global_func("tvm.meta_schedule.normalize_mod") def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 76bac88983f0..385ddc30f9ab 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -22,7 +22,7 @@ import numpy as np # type: ignore import psutil # type: ignore -from tvm_ffi import get_global_func, register_func +from tvm_ffi import get_global_func, register_global_func from tvm.error import TVMError from tvm.ir import Array, IRModule, Map from tvm.rpc import RPCSession @@ -163,7 +163,7 @@ def __setattr__(self, name, value): return TVMDerivedObject -@register_func("meta_schedule.cpu_count") +@register_global_func("meta_schedule.cpu_count") def _cpu_count_impl(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system @@ -219,7 +219,7 @@ def cpu_count(logical: bool = True) -> int: return _cpu_count_impl(logical) -@register_func("meta_schedule.using_ipython") +@register_global_func("meta_schedule.using_ipython") def _using_ipython() -> bool: """Return whether the current process is running in an IPython shell. @@ -234,7 +234,7 @@ def _using_ipython() -> bool: return False -@register_func("meta_schedule.print_interactive_table") +@register_global_func("meta_schedule.print_interactive_table") def print_interactive_table(data: str) -> None: """Print the dataframe interactive table in notebook. @@ -327,7 +327,7 @@ def get_global_func_on_rpc_session( return result -@register_func("meta_schedule.remove_build_dir") +@register_global_func("meta_schedule.remove_build_dir") def remove_build_dir(artifact_path: str) -> None: """Clean up the build directory""" shutil.rmtree(os.path.dirname(artifact_path)) diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py index 947ddb089a3d..c5e98a22eaaf 100644 --- a/python/tvm/relax/_ffi_api.py +++ b/python/tvm/relax/_ffi_api.py @@ -17,4 +17,4 @@ """FFI API for Relax.""" import tvm_ffi -tvm_ffi._init_api("relax", __name__) +tvm_ffi.init_ffi_api("relax", __name__) diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py index d6adf9580583..0a230fbd8bb6 100644 --- a/python/tvm/relax/analysis/_ffi_api.py +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs""" import tvm_ffi -tvm_ffi._init_api("relax.analysis", __name__) +tvm_ffi.init_ffi_api("relax.analysis", __name__) diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py index fbab39429403..97a999788b93 100644 --- a/python/tvm/relax/backend/_ffi_api.py +++ b/python/tvm/relax/backend/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("relax.backend", __name__) +tvm_ffi.init_ffi_api("relax.backend", __name__) diff --git a/python/tvm/relax/backend/metal/coreml.py b/python/tvm/relax/backend/metal/coreml.py index 56b0eb3a6ce9..dfc891dc1f31 100644 --- a/python/tvm/relax/backend/metal/coreml.py +++ b/python/tvm/relax/backend/metal/coreml.py @@ -463,7 +463,7 @@ def compile(self, out_dir): compile_coreml(model, self.model_name, out_dir) -@tvm_ffi.register_func("relax.ext.coreml") +@tvm_ffi.register_global_func("relax.ext.coreml") def coreml_compiler(funcs, options, constant_names): """ Create a CoreML runtime from a Relax module. diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 688dc962f23f..796ab41a1470 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -333,8 +333,7 @@ def _convert_single_tvm_to_pytorch(self, tvm_array: Any) -> "torch.Tensor": if not isinstance(tvm_array, Tensor): return torch.tensor(tvm_array) try: - dlpack = tvm_array.to_dlpack() - return torch.from_dlpack(dlpack) + return torch.from_dlpack(tvm_array) # pylint: disable=broad-exception-caught except Exception as error: print(f"Warning: DLPack conversion from TVM failed ({error}), using numpy fallback") diff --git a/python/tvm/relax/distributed/_ffi_api.py b/python/tvm/relax/distributed/_ffi_api.py index 89a15a2bc33a..71185a1276da 100644 --- a/python/tvm/relax/distributed/_ffi_api.py +++ b/python/tvm/relax/distributed/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.relax.distributed""" import tvm_ffi -tvm_ffi._init_api("relax.distributed", __name__) +tvm_ffi.init_ffi_api("relax.distributed", __name__) diff --git a/python/tvm/relax/distributed/transform/_ffi_api.py b/python/tvm/relax/distributed/transform/_ffi_api.py index ffdb09715f68..35808cc2bc93 100644 --- a/python/tvm/relax/distributed/transform/_ffi_api.py +++ b/python/tvm/relax/distributed/transform/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs for tvm.relax.distributed.transform""" import tvm_ffi -tvm_ffi._init_api("relax.distributed.transform", __name__) +tvm_ffi.init_ffi_api("relax.distributed.transform", __name__) diff --git a/python/tvm/relax/dpl/_ffi.py b/python/tvm/relax/dpl/_ffi.py index 7097ec8c5282..b03e5800e8fc 100644 --- a/python/tvm/relax/dpl/_ffi.py +++ b/python/tvm/relax/dpl/_ffi.py @@ -17,4 +17,4 @@ """DataFlow Pattern Language FFI bindings.""" import tvm_ffi -tvm_ffi._init_api("relax.dpl", __name__) +tvm_ffi.init_ffi_api("relax.dpl", __name__) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 2b78996f2974..1a7a5c224add 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -307,7 +307,7 @@ def elem_offset(self) -> "Expr": return tvm.relax.Call(op, [self]) -class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric): +class _DLTensorDTypeProxy(tvm.runtime.ObjectConvertible): """A proxy object for unpacking DLDatatype from DLTensor Exposes accessors for `DLDataType` fields `type_code`, `lanes`, @@ -387,7 +387,7 @@ def bits(self) -> Expr: return tvm.relax.Call(op, [self.tensor]) -class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric): +class _DLTensorShapeProxy(tvm.runtime.ObjectConvertible): """A proxy object for unpacking the shape from DLTensor Exposes accessors for the `DLTensor::shape` field. Accessing @@ -457,7 +457,7 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: return tvm.relax.Call(op, [self.tensor, axis]) -class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric): +class _DLTensorStrideProxy(tvm.runtime.ObjectConvertible): """A proxy object for unpacking the strides from DLTensor Exposes accessors for the `DLTensor::strides` field. Accessing diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index b2904fe2a9be..8529dda00686 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -639,7 +639,7 @@ def _from_dlpack(tensor) -> tvm.runtime.Tensor: return tvm.runtime.tensor( tensor.numpy(), device=Device( - Device.DEVICE_NAME_TO_TYPE[device_type], + Device._DEVICE_NAME_TO_TYPE[device_type], device_id, ), ) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 1e42c862fee6..714ae9478250 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2087,7 +2087,7 @@ def extern( out: OutType, ) -> OutType: """Invoke an extern function during runtime. The extern function must be registered with the " - TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_func` (Python). + TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_global_func` (Python). Parameters ---------- @@ -2144,7 +2144,7 @@ def debug_func( .. code-block:: python - @tvm.register_func(name_of_debug_func) + @tvm.register_global_func(name_of_debug_func) def debug_func(lineno: str, arg_0, arg_1, ...) -> None: ... diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py index 693c9564d59c..867c43e4d85b 100644 --- a/python/tvm/relax/op/_ffi_api.py +++ b/python/tvm/relax/op/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs for tvm.relax.op""" import tvm_ffi -tvm_ffi._init_api("relax.op", __name__) +tvm_ffi.init_ffi_api("relax.op", __name__) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 4663e47020e0..e77920d8dea6 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -22,7 +22,7 @@ import tvm import tvm.runtime from tvm.runtime.object import Object -from tvm.runtime import ObjectGeneric +from tvm.runtime import ObjectConvertible from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var @@ -422,7 +422,7 @@ def render_object(val: tvm.Object) -> str: return str(val) -@tvm.register_func("relax.run.shape_to_tensor") +@tvm.register_global_func("relax.run.shape_to_tensor") def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) -> tvm.runtime.Tensor: """ Takes a ShapeTuple and convert it to Tensor. @@ -435,7 +435,7 @@ def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) -> tvm.runtime.Te return tvm.runtime.tensor([int(v) for v in shape_tuple]) -@tvm.register_func("relax.run.print") +@tvm.register_global_func("relax.run.print") def relax_print(format_str: str, *format_args: tvm.Object) -> None: """ Takes a list of values to print, formats with the given format string. @@ -483,7 +483,7 @@ def print(*values: List[Expr], format: Union[str, Expr] = "") -> Expr: return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member -@tvm.register_func("relax.run.assert_op") +@tvm.register_global_func("relax.run.assert_op") def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None: """ A variadic function. The first value serves as the assertion condition: @@ -744,7 +744,7 @@ def call_pure_packed( sinfo() if callable(sinfo) else sinfo.asobject() - if isinstance(sinfo, ObjectGeneric) + if isinstance(sinfo, ObjectConvertible) else sinfo for sinfo in sinfo_args ] diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py index 4ad011b447b1..0e5955f6e47d 100644 --- a/python/tvm/relax/op/builtin/_ffi_api.py +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs for tvm.relax.op.builtin""" import tvm_ffi -tvm_ffi._init_api("relax.op.builtin", __name__) +tvm_ffi.init_ffi_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/ccl/_ffi_api.py b/python/tvm/relax/op/ccl/_ffi_api.py index eab31a6463c5..f0796d3da318 100644 --- a/python/tvm/relax/op/ccl/_ffi_api.py +++ b/python/tvm/relax/op/ccl/_ffi_api.py @@ -17,4 +17,4 @@ """Operators serving for Collective Communications Library (CCL) operators""" import tvm_ffi -tvm_ffi._init_api("relax.op.ccl", __name__) +tvm_ffi.init_ffi_api("relax.op.ccl", __name__) diff --git a/python/tvm/relax/op/distributed/_ffi_api.py b/python/tvm/relax/op/distributed/_ffi_api.py index 03c4bcc988b3..fa1c163794b9 100644 --- a/python/tvm/relax/op/distributed/_ffi_api.py +++ b/python/tvm/relax/op/distributed/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.relax.op.distributed""" import tvm_ffi -tvm_ffi._init_api("relax.op.dist", __name__) +tvm_ffi.init_ffi_api("relax.op.dist", __name__) diff --git a/python/tvm/relax/op/grad/_ffi_api.py b/python/tvm/relax/op/grad/_ffi_api.py index d1f96a1d0299..1a8ebb09aa8d 100644 --- a/python/tvm/relax/op/grad/_ffi_api.py +++ b/python/tvm/relax/op/grad/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.relax.op.grad""" import tvm_ffi -tvm_ffi._init_api("relax.op.grad", __name__) +tvm_ffi.init_ffi_api("relax.op.grad", __name__) diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py index b00b26744b7b..8147a155cb76 100644 --- a/python/tvm/relax/op/image/_ffi_api.py +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -17,4 +17,4 @@ """Constructor APIs""" import tvm_ffi -tvm_ffi._init_api("relax.op.image", __name__) +tvm_ffi.init_ffi_api("relax.op.image", __name__) diff --git a/python/tvm/relax/op/memory/_ffi_api.py b/python/tvm/relax/op/memory/_ffi_api.py index f876c2c1e639..05dbf534c7f5 100644 --- a/python/tvm/relax/op/memory/_ffi_api.py +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs for tvm.relax.op.memory""" import tvm_ffi -tvm_ffi._init_api("relax.op.memory", __name__) +tvm_ffi.init_ffi_api("relax.op.memory", __name__) diff --git a/python/tvm/relax/op/nn/_ffi_api.py b/python/tvm/relax/op/nn/_ffi_api.py index fa8bf8f6d8cb..d58fa186fc7c 100644 --- a/python/tvm/relax/op/nn/_ffi_api.py +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -17,4 +17,4 @@ """Constructor APIs""" import tvm_ffi -tvm_ffi._init_api("relax.op.nn", __name__) +tvm_ffi.init_ffi_api("relax.op.nn", __name__) diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index 4d0fd3dd420f..87fd067e5d1e 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -84,7 +84,7 @@ def unique( ) -@tvm.register_func("relax.run.unique") +@tvm.register_global_func("relax.run.unique") def numpy_unique( x: tvm.runtime.tensor, sorted: int, @@ -143,7 +143,7 @@ def nonzero(x: Expr) -> Expr: return _ffi_api.nonzero(x) # type: ignore -@tvm.register_func("relax.run.nonzero") +@tvm.register_global_func("relax.run.nonzero") def numpy_nonzero(x: tvm.runtime.tensor) -> tvm.runtime.tensor: np_result = np.atleast_1d(x.numpy()).nonzero() return tvm.runtime.tensor(np.stack(np_result, axis=0)) diff --git a/python/tvm/relax/op/vm/_ffi_api.py b/python/tvm/relax/op/vm/_ffi_api.py index bd543ad1c9bd..eed64e53f036 100644 --- a/python/tvm/relax/op/vm/_ffi_api.py +++ b/python/tvm/relax/op/vm/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs for tvm.relax.op.vm""" import tvm_ffi -tvm_ffi._init_api("relax.op.vm", __name__) +tvm_ffi.init_ffi_api("relax.op.vm", __name__) diff --git a/python/tvm/relax/testing/vm.py b/python/tvm/relax/testing/vm.py index 737de13fc7f6..5516bac17cf7 100644 --- a/python/tvm/relax/testing/vm.py +++ b/python/tvm/relax/testing/vm.py @@ -24,53 +24,53 @@ from tvm.runtime.object import Object -@tvm.register_func("test.vm.move") +@tvm.register_global_func("test.vm.move") def move(src): return src -@tvm.register_func("test.vm.add") +@tvm.register_global_func("test.vm.add") def add(a, b): ret = a.numpy() + b.numpy() return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.mul") +@tvm.register_global_func("test.vm.mul") def mul(a, b): ret = a.numpy() * b.numpy() return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.equal_zero") +@tvm.register_global_func("test.vm.equal_zero") def equal_zero(a): ret = np.all((a.numpy() == 0)) return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.subtract_one") +@tvm.register_global_func("test.vm.subtract_one") def subtract_one(a): ret = np.subtract(a.numpy(), 1) return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.identity") +@tvm.register_global_func("test.vm.identity") def identity_packed(a, b): b[:] = tvm.runtime.tensor(a.numpy()) -@tvm.register_func("test.vm.tile") +@tvm.register_global_func("test.vm.tile") def tile_packed(a, b): b[:] = tvm.runtime.tensor(np.tile(a.numpy(), (1, 2))) -@tvm.register_func("test.vm.add_scalar") +@tvm.register_global_func("test.vm.add_scalar") def add_scalar(a, b): return a + b -@tvm.register_func("test.vm.get_device_id") +@tvm.register_global_func("test.vm.get_device_id") def get_device_id(device): - return device.device_id + return device.index def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any]) -> Object: @@ -85,6 +85,6 @@ def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any return res1 -@tvm.register_func("test.vm.check_if_defined") +@tvm.register_global_func("test.vm.check_if_defined") def check_if_defined(obj: tvm.Object) -> tvm.tir.IntImm: return tvm.runtime.convert(obj is not None) diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py index 84c117f9cbb3..25f395830341 100644 --- a/python/tvm/relax/training/_ffi_api.py +++ b/python/tvm/relax/training/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.relax.training""" import tvm_ffi -tvm_ffi._init_api("relax.training", __name__) +tvm_ffi.init_ffi_api("relax.training", __name__) diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py index b2300cf4706d..a3fc836fc0a4 100644 --- a/python/tvm/relax/training/utils.py +++ b/python/tvm/relax/training/utils.py @@ -18,7 +18,7 @@ """Utility functions for relax training.""" from typing import Optional, Callable -from tvm_ffi import register_func +from tvm_ffi import register_global_func import tvm from tvm import relax @@ -199,7 +199,7 @@ def handler( primfunc_name_hint=te_grad_name, ) - register_func(func_prefix + te_grad_name, handler) + register_global_func(func_prefix + te_grad_name, handler) return func return register(te_grad_func) if te_grad_func else register diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py index 6ae33aef830a..25d6ecd75385 100644 --- a/python/tvm/relax/transform/_ffi_api.py +++ b/python/tvm/relax/transform/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs for tvm.transform""" import tvm_ffi -tvm_ffi._init_api("relax.transform", __name__) +tvm_ffi.init_ffi_api("relax.transform", __name__) diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py index b1bc8af974e5..80fd79e31348 100644 --- a/python/tvm/rpc/_ffi_api.py +++ b/python/tvm/rpc/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("rpc", __name__) +tvm_ffi.init_ffi_api("rpc", __name__) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 37bc6b311745..73e9db3d5b60 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -23,10 +23,11 @@ import time import tvm_ffi +from tvm_ffi import DLDeviceType + import tvm.runtime from tvm.base import TVMError from tvm.contrib import utils -from tvm.runtime import Device from . import _ffi_api, base, server @@ -88,7 +89,7 @@ def device(self, dev_type, dev_id=0): """ dev = tvm.runtime.device(dev_type, dev_id) encode = (self._tbl_index + 1) * base.RPC_SESS_MASK - dev = tvm.runtime.device(dev.device_type + encode, dev.device_id) + dev = tvm.runtime.device(dev.dlpack_device_type() + encode, dev.index) dev._rpc_sess = self return dev @@ -216,39 +217,39 @@ def download_linked_module(self, path): def cpu(self, dev_id=0): """Construct CPU device.""" - return self.device(Device.kDLCPU, dev_id) + return self.device(DLDeviceType.kDLCPU, dev_id) def cuda(self, dev_id=0): """Construct CUDA GPU device.""" - return self.device(Device.kDLCUDA, dev_id) + return self.device(DLDeviceType.kDLCUDA, dev_id) def cl(self, dev_id=0): """Construct OpenCL device.""" - return self.device(Device.kDLOpenCL, dev_id) + return self.device(DLDeviceType.kDLOpenCL, dev_id) def vulkan(self, dev_id=0): """Construct Vulkan device.""" - return self.device(Device.kDLVulkan, dev_id) + return self.device(DLDeviceType.kDLVulkan, dev_id) def metal(self, dev_id=0): """Construct Metal device.""" - return self.device(Device.kDLMetal, dev_id) + return self.device(DLDeviceType.kDLMetal, dev_id) def rocm(self, dev_id=0): """Construct ROCm device.""" - return self.device(Device.kDLROCM, dev_id) + return self.device(DLDeviceType.kDLROCM, dev_id) def ext_dev(self, dev_id=0): """Construct extension device.""" - return self.device(Device.kDLExtDev, dev_id) + return self.device(DLDeviceType.kDLExtDev, dev_id) def hexagon(self, dev_id=0): """Construct Hexagon device.""" - return self.device(Device.kDLHexagon, dev_id) + return self.device(DLDeviceType.kDLHexagon, dev_id) def webgpu(self, dev_id=0): """Construct WebGPU device.""" - return self.device(Device.kDLWebGPU, dev_id) + return self.device(DLDeviceType.kDLWebGPU, dev_id) class LocalSession(RPCSession): @@ -263,7 +264,7 @@ def __init__(self): RPCSession.__init__(self, _ffi_api.LocalSession()) -@tvm_ffi.register_func("rpc.PopenSession") +@tvm_ffi.register_global_func("rpc.PopenSession") def _popen_session(binary): temp = utils.tempdir() diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 17b3f3652ec6..3ed512e9dd04 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -70,11 +70,11 @@ def _server_env(load_library, work_path=None): temp = utils.tempdir() # pylint: disable=unused-variable - @tvm_ffi.register_func("tvm.rpc.server.workpath", override=True) + @tvm_ffi.register_global_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) - @tvm_ffi.register_func("tvm.rpc.server.load_module", override=True) + @tvm_ffi.register_global_func("tvm.rpc.server.load_module", override=True) def load_module(file_name): """Load module from remote side.""" path = temp.relpath(file_name) @@ -82,7 +82,7 @@ def load_module(file_name): logger.info("load_module %s", path) return m - @tvm_ffi.register_func("tvm.rpc.server.download_linked_module", override=True) + @tvm_ffi.register_global_func("tvm.rpc.server.download_linked_module", override=True) def download_linked_module(file_name): """Load module from remote side.""" # pylint: disable=import-outside-toplevel @@ -488,7 +488,7 @@ def server_init_callback(): # must import mypackage here import mypackage - tvm.register_func("function", mypackage.func) + tvm.register_global_func("function", mypackage.func) server = rpc.Server(host, server_init_callback=server_init_callback) """ diff --git a/python/tvm/rpc/testing.py b/python/tvm/rpc/testing.py index d27485413814..e3f216563863 100644 --- a/python/tvm/rpc/testing.py +++ b/python/tvm/rpc/testing.py @@ -22,38 +22,38 @@ # RPC test functions to be registered for unit-tests purposes -@tvm.register_func("rpc.test.addone") +@tvm.register_global_func("rpc.test.addone") def _addone(x): return x + 1 -@tvm.register_func("rpc.test.strcat") +@tvm.register_global_func("rpc.test.strcat") def _strcat(name, x): return f"{name}:{x}" -@tvm.register_func("rpc.test.except") +@tvm.register_global_func("rpc.test.except") def _remotethrow(name): raise ValueError(f"{name}") -@tvm.register_func("rpc.test.runtime_str_concat") +@tvm.register_global_func("rpc.test.runtime_str_concat") def _strcat(x, y): return x + y -@tvm.register_func("rpc.test.remote_tensor_func") +@tvm.register_global_func("rpc.test.remote_tensor_func") def _remote_tensor_func(y): x = np.ones((3, 4)) np.testing.assert_equal(y.numpy(), x) -@tvm.register_func("rpc.test.add_to_lhs") +@tvm.register_global_func("rpc.test.add_to_lhs") def _add_to_lhs(x): return lambda y: x + y -@tvm.register_func("rpc.test.remote_return_nd") +@tvm.register_global_func("rpc.test.remote_return_nd") def _my_module(name): # Use closure to check the ref counter correctness nd = tvm.runtime.tensor(np.zeros(10).astype("float32")) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 57546dcff48b..4c61e2e06b3a 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -16,13 +16,14 @@ # under the License. """TVM runtime namespace.""" -from tvm_ffi import convert, dtype as DataType, DataTypeCode +from tvm_ffi import convert +from tvm_ffi._dtype import dtype as DataType, DataTypeCode # class exposures from .packed_func import PackedFunc from .object import Object from .script_printer import Scriptable -from .object_generic import ObjectGeneric +from .object_generic import ObjectConvertible from .device import Device from ._tensor import Tensor, tensor, empty from .module import Module diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py index 0357b280bd46..c713b379c384 100644 --- a/python/tvm/runtime/_ffi_api.py +++ b/python/tvm/runtime/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi # Exports functions registered in runtime namespace. -tvm_ffi._init_api("runtime", __name__) +tvm_ffi.init_ffi_api("runtime", __name__) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 2e47f6aa32f9..a4f74864aa2d 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -23,7 +23,7 @@ # The implementations below are default ones when the corresponding # functions are not available in the runtime only mode. -# They will be overriden via _init_api to the ones registered +# They will be overriden via tvm_ffi.init_ffi_api to the ones registered def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" @@ -37,4 +37,4 @@ def LoadJSON(json_str): # Exports functions registered in node namespace. -tvm_ffi._init_api("node", __name__) +tvm_ffi.init_ffi_api("node", __name__) diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py index 1d413272b2a3..fc176bf60097 100644 --- a/python/tvm/runtime/_tensor.py +++ b/python/tvm/runtime/_tensor.py @@ -28,19 +28,7 @@ ml_dtypes = None import tvm_ffi -from tvm_ffi import ( - device, - cpu, - cuda, - rocm, - opencl, - metal, - vpi, - vulkan, - ext_dev, - hexagon, - webgpu, -) +from tvm_ffi import device, DLDeviceType import tvm from tvm.runtime import Device @@ -134,7 +122,7 @@ def copyfrom(self, source_array): raise ValueError( f"array shape do not match the shape of Tensor {source_array.shape} vs {shape}" ) - numpy_str_map = tvm_ffi.dtype.NUMPY_DTYPE_TO_STR + numpy_str_map = tvm_ffi.dtype._NUMPY_DTYPE_TO_STR np_dtype_str = ( numpy_str_map[source_array.dtype] if source_array.dtype in numpy_str_map @@ -360,5 +348,170 @@ def tensor(arr, device=None, mem_scope=None): return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr) +def cpu(dev_id=0): + """Construct a CPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLCPU, dev_id) + + +def cuda(dev_id=0): + """Construct a CUDA GPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLCUDA, dev_id) + + +def rocm(dev_id=0): + """Construct a ROCM device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLROCM, dev_id) + + +def opencl(dev_id=0): + """Construct a OpenCL device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLOpenCL, dev_id) + + +def metal(dev_id=0): + """Construct a metal device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLMetal, dev_id) + + +def vpi(dev_id=0): + """Construct a VPI simulated device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLVPI, dev_id) + + +def vulkan(dev_id=0): + """Construct a Vulkan device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLVulkan, dev_id) + + +def ext_dev(dev_id=0): + """Construct a extension device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + + Note + ---- + This API is reserved for quick testing of new + device by plugin device API as ext_dev. + """ + return device(DLDeviceType.kDLExtDev, dev_id) + + +def hexagon(dev_id=0): + """Construct a Hexagon device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLHexagon, dev_id) + + +def webgpu(dev_id=0): + """Construct a webgpu device. + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLWebGPU, dev_id) + + # Register back to FFI tvm_ffi.core._set_class_tensor(Tensor) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 37d0d2116c55..f9ddb5e51206 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. """Runtime container structures.""" -from tvm_ffi import String, Shape as ShapeTuple +from tvm_ffi.core import String +from tvm_ffi import Shape as ShapeTuple __all__ = ["ShapeTuple", "String"] diff --git a/python/tvm/runtime/device.py b/python/tvm/runtime/device.py index d86e30605faa..b8a3db15f30e 100644 --- a/python/tvm/runtime/device.py +++ b/python/tvm/runtime/device.py @@ -48,7 +48,7 @@ def exist(self): True if the device exists """ - return self._GetDeviceAttr(self.device_type, self.device_id, 0) != 0 + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 0) != 0 @property def max_threads_per_block(self): @@ -64,7 +64,7 @@ def max_threads_per_block(self): The number of threads on each block """ - return self._GetDeviceAttr(self.device_type, self.device_id, 1) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 1) @property def warp_size(self): @@ -81,7 +81,7 @@ def warp_size(self): Number of threads that execute concurrently """ - return self._GetDeviceAttr(self.device_type, self.device_id, 2) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 2) @property def max_shared_memory_per_block(self): @@ -97,7 +97,7 @@ def max_shared_memory_per_block(self): Total amount of shared memory per block in bytes """ - return self._GetDeviceAttr(self.device_type, self.device_id, 3) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 3) @property def compute_version(self): @@ -116,7 +116,7 @@ def compute_version(self): The version string in `major.minor` format. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 4) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 4) @property def device_name(self): @@ -132,7 +132,7 @@ def device_name(self): The name of the device. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 5) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 5) @property def max_clock_rate(self): @@ -148,7 +148,7 @@ def max_clock_rate(self): The maximum clock frequency of the device (kHz) """ - return self._GetDeviceAttr(self.device_type, self.device_id, 6) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 6) @property def multi_processor_count(self): @@ -164,7 +164,7 @@ def multi_processor_count(self): Thee number of compute units in the device """ - return self._GetDeviceAttr(self.device_type, self.device_id, 7) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 7) @property def max_thread_dimensions(self): @@ -180,7 +180,7 @@ def max_thread_dimensions(self): The maximum length of threadIdx.x, threadIdx.y, threadIdx.z """ - return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) + return json.loads(self._GetDeviceAttr(self.dlpack_device_type(), self.index, 8)) @property def api_version(self): @@ -199,7 +199,7 @@ def api_version(self): The version of the SDK """ - return self._GetDeviceAttr(self.device_type, self.device_id, 11) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 11) @property def driver_version(self): @@ -218,7 +218,7 @@ def driver_version(self): The version string in `major.minor.patch` format. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 12) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 12) @property def l2_cache_size_bytes(self): @@ -236,7 +236,7 @@ def l2_cache_size_bytes(self): ---- The value returned by opencl's API is smaller than actual device L2 cache size. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 13) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 13) @property def total_global_memory(self): @@ -250,7 +250,7 @@ def total_global_memory(self): Return the total size of global memory on device in bytes. Return None if the device does not support this feature. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 14) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 14) @property def available_global_memory(self): @@ -264,7 +264,7 @@ def available_global_memory(self): Return the amount of unallocated global memory on device in bytes. Return None if the device does not support this feature. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 15) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 15) def texture_spatial_limit(self): """Returns limits for textures by spatial dimensions @@ -275,7 +275,7 @@ def texture_spatial_limit(self): Maximum size of the texture by spatial dimensions """ - return self._GetDeviceAttr(self.device_type, self.device_id, 12) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 12) def create_raw_stream(self): """Create a new runtime stream at the context. @@ -319,19 +319,12 @@ def sync(self, stream=None): """ _ffi_api.Device_StreamSync(self, stream or 0) - def _device_type_name_(self): - if self.device_type >= RPC_SESS_MASK: - tbl_id = self.device_type / RPC_SESS_MASK - 1 - dev_type = self.device_type % RPC_SESS_MASK - return f"remote[{tbl_id}]:{Device.DEVICE_TYPE_TO_NAME[dev_type]}" - return Device.DEVICE_TYPE_TO_NAME[self.device_type] - def __device_type_name__(self): - if self.device_type >= RPC_SESS_MASK: - tbl_id = self.device_type / RPC_SESS_MASK - 1 - dev_type = self.device_type % RPC_SESS_MASK - return f"remote[{tbl_id}]:{Device.DEVICE_TYPE_TO_NAME[dev_type]}" - return Device.DEVICE_TYPE_TO_NAME[self.device_type] + if self.dlpack_device_type() >= RPC_SESS_MASK: + tbl_id = self.dlpack_device_type() / RPC_SESS_MASK - 1 + dev_type = self.dlpack_device_type() % RPC_SESS_MASK + return f"remote[{tbl_id}]:{Device._DEVICE_TYPE_TO_NAME[dev_type]}" + return Device._DEVICE_TYPE_TO_NAME[self.dlpack_device_type()] tvm_ffi.core._set_class_device(Device) diff --git a/python/tvm/runtime/disco/_ffi_api.py b/python/tvm/runtime/disco/_ffi_api.py index 63a53d8b8540..2caeef293ea5 100644 --- a/python/tvm/runtime/disco/_ffi_api.py +++ b/python/tvm/runtime/disco/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs from C++""" -from tvm_ffi import _init_api +import tvm_ffi -_init_api("runtime.disco", __name__) +tvm_ffi.init_ffi_api("runtime.disco", __name__) diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py index ba9b512f04a3..975c26fb922f 100644 --- a/python/tvm/runtime/disco/process_pool.py +++ b/python/tvm/runtime/disco/process_pool.py @@ -20,7 +20,7 @@ import subprocess import sys -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.runtime import ShapeTuple @@ -177,7 +177,7 @@ def _kill_child_processes(pid): pass -@register_func("runtime.disco.create_process_pool") +@register_global_func("runtime.disco.create_process_pool") def _create_process_pool(num_workers: int, num_groups: int, entrypoint: str): """Create a process pool where the workers' are [1, num_workers).""" pool = [DiscoPopenWorker(i, num_workers, num_groups, entrypoint) for i in range(1, num_workers)] diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index ed4ce06a3766..f2c2dfc791ab 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -25,7 +25,7 @@ import numpy as np -from tvm_ffi import get_global_func, register_func, register_object +from tvm_ffi import get_global_func, register_global_func, register_object from ..device import Device from ..container import ShapeTuple from .._tensor import Tensor @@ -583,7 +583,7 @@ def _configure_structlog(self) -> None: func(config, os.getpid()) -@register_func("runtime.disco.create_socket_session_local_workers") +@register_global_func("runtime.disco.create_socket_session_local_workers") def _create_socket_session_local_workers(num_workers) -> Session: """Create the local session for each distributed node over socket session.""" return ProcessSession(num_workers) @@ -611,7 +611,7 @@ def __init__( ) -@register_func("runtime.disco._configure_structlog") +@register_global_func("runtime.disco._configure_structlog") def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: """Configure structlog for all disco workers @@ -646,7 +646,7 @@ def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: structlog.configure(**structlog_config) -@register_func("runtime.disco._import_python_module") +@register_global_func("runtime.disco._import_python_module") def _import_python_module(module_name: str) -> None: __import__(module_name) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index c725150c6e69..71b3bdd94b64 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -377,8 +377,8 @@ def time_evaluator( feval = _ffi_api.RPCTimeEvaluator( self, func_name, - dev.device_type, - dev.device_id, + dev.dlpack_device_type(), + dev.index, number, repeat, min_repeat_ms, diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index f5574e48023b..340df0fcea55 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -16,7 +16,7 @@ # under the License. """Common implementation of object generic related logic""" # pylint: disable=unused-import, invalid-name -from tvm_ffi import ObjectGeneric +from tvm_ffi import ObjectConvertible from . import _ffi_node_api diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 45189a008495..3ca831ac4200 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -266,7 +266,7 @@ def profile_function(mod, dev, collectors, func_name=None, warmup_iters=10): if func_name is None: func_name = mod.entry_name return _ffi_api.ProfileFunction( - mod, func_name, dev.device_type, dev.device_id, warmup_iters, collectors + mod, func_name, dev.dlpack_device_type(), dev.index, warmup_iters, collectors ) diff --git a/python/tvm/runtime/profiling/_ffi_api.py b/python/tvm/runtime/profiling/_ffi_api.py index 104aac90a551..883e3ca6e778 100644 --- a/python/tvm/runtime/profiling/_ffi_api.py +++ b/python/tvm/runtime/profiling/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for profiling""" -from tvm_ffi import _init_api +import tvm_ffi -_init_api("runtime.profiling", __name__) +tvm_ffi.init_ffi_api("runtime.profiling", __name__) diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index 99856b8d3b9d..4a2e9ef50847 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -23,7 +23,7 @@ import tvm_ffi -@tvm_ffi.register_func("tvm.runtime.regex_match") +@tvm_ffi.register_global_func("tvm.runtime.regex_match") def _regex_match(regex_pattern: str, match_against: str) -> bool: """Check if a pattern matches a regular expression diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 72fb13378896..b188c6ca70c7 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -23,7 +23,7 @@ import numpy as np # type: ignore import tvm -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.runtime import Device, Object, PackedFunc from tvm.runtime.profiling import Report @@ -99,7 +99,7 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) devs = [dev] # CPU is required for executing shape functions - if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type: + if devs[-1].dlpack_device_type() % RPC_SESS_MASK != tvm.cpu().dlpack_device_type(): devs.append(tvm.cpu()) default_alloc_type = VirtualMachine.POOLED_ALLOCATOR @@ -117,8 +117,8 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) ) init_args = [] for device in devs: - init_args.append(device.device_type % RPC_SESS_MASK) - init_args.append(device.device_id) + init_args.append(device.dlpack_device_type() % RPC_SESS_MASK) + init_args.append(device.index) alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type init_args.append(alloc_type) self.module["vm_initialization"](*init_args) @@ -499,6 +499,6 @@ def profile(self, func_name: str, *args): return Report.from_json(report_json) -@register_func("vm.builtin.debug_print") +@register_global_func("vm.builtin.debug_print") def _print(lineo: str, array) -> None: print(f"{lineo}: shape = {array.shape}, dtype = {array.dtype}, data =\n{array}") diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/_ffi_api.py index 28dcec06bbdd..1354d3f2ec2c 100644 --- a/python/tvm/script/_ffi_api.py +++ b/python/tvm/script/_ffi_api.py @@ -17,4 +17,4 @@ import tvm_ffi -tvm_ffi._init_api("script", __name__) +tvm_ffi.init_ffi_api("script", __name__) diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py index fdca5f75dce4..c8a9597d5292 100644 --- a/python/tvm/script/ir_builder/_ffi_api.py +++ b/python/tvm/script/ir_builder/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.ir_builder""" import tvm_ffi -tvm_ffi._init_api("script.ir_builder", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/ir/_ffi_api.py b/python/tvm/script/ir_builder/ir/_ffi_api.py index 23b92904cba1..e319c3d4612e 100644 --- a/python/tvm/script/ir_builder/ir/_ffi_api.py +++ b/python/tvm/script/ir_builder/ir/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs""" import tvm_ffi -tvm_ffi._init_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py index 251a24d4fa79..f6c53336ff4c 100644 --- a/python/tvm/script/ir_builder/relax/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.ir_builder.relax""" import tvm_ffi -tvm_ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py index a69d7f3e38d5..b82fa37e8f3f 100644 --- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py @@ -17,6 +17,6 @@ """FFI APIs for tvm.script.ir_builder.relax.distributed""" import tvm_ffi -tvm_ffi._init_api( +tvm_ffi.init_ffi_api( "script.ir_builder.relax.distributed", __name__ ) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index f045508bfcec..d28ff3430aaa 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -191,7 +191,7 @@ from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter, gen_call_tir_inputs from tvm.runtime import Object as tvm_Object -from tvm.runtime import ObjectGeneric +from tvm.runtime import ObjectConvertible from tvm.runtime._tensor import ( cpu, cuda, @@ -431,7 +431,7 @@ def call_packed( sinfo() if callable(sinfo) else sinfo.asobject() - if isinstance(sinfo, ObjectGeneric) + if isinstance(sinfo, ObjectConvertible) else sinfo ) for sinfo in sinfo_args @@ -462,7 +462,7 @@ def _convert_tensor_type(args): return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} if inspect.isfunction(args): args = args() - if isinstance(args, ObjectGeneric): + if isinstance(args, ObjectConvertible): args = args.asobject() return args diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tir/_ffi_api.py index 42893a0047cc..4385b2ec13d0 100644 --- a/python/tvm/script/ir_builder/tir/_ffi_api.py +++ b/python/tvm/script/ir_builder/tir/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs""" import tvm_ffi -tvm_ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 04a5f985643e..ec140e57ba60 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -35,7 +35,7 @@ TupleStructInfo, ) from tvm.relax.expr import Var -from tvm.runtime import ObjectGeneric +from tvm.runtime import ObjectConvertible from tvm.tir import PrimExpr from .._core import doc, parse, utils @@ -147,7 +147,7 @@ def wrapper(*args, **kwargs): ############################# Struct Info ############################## -class StructInfoProxy(ObjectGeneric): +class StructInfoProxy(ObjectConvertible): def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> StructInfo: raise NotImplementedError() diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py index e219c9dbf845..967d0d824ba2 100644 --- a/python/tvm/script/printer/_ffi_api.py +++ b/python/tvm/script/printer/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.printer""" import tvm_ffi -tvm_ffi._init_api("script.printer", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.printer", __name__) # pylint: disable=protected-access diff --git a/python/tvm/support.py b/python/tvm/support.py index 5266602fd168..d0b1540c0417 100644 --- a/python/tvm/support.py +++ b/python/tvm/support.py @@ -26,7 +26,7 @@ from .runtime.module import Module from . import get_global_func -tvm_ffi._init_api("support", __name__) +tvm_ffi.init_ffi_api("support", __name__) def libinfo(): diff --git a/python/tvm/target/_ffi_api.py b/python/tvm/target/_ffi_api.py index 7520482388ab..8b9f6c73bd4e 100644 --- a/python/tvm/target/_ffi_api.py +++ b/python/tvm/target/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("target", __name__) +tvm_ffi.init_ffi_api("target", __name__) diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index bd6a72e8df8a..e597c8d147be 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -18,7 +18,7 @@ TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist""" from tvm_ffi import get_global_func -from tvm_ffi import register_func as _register_func +from tvm_ffi import register_global_func as _register_global_func import tvm from tvm.runtime import convert, DataType @@ -216,7 +216,7 @@ class name (e.g. Add, LE, Cast, Call). ) else: lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." + src_type_name - tvm_ffi.register_func(lower_func_name, lower_func) + tvm_ffi.register_global_func(lower_func_name, lower_func) def register_min_func(func, type_name): @@ -245,7 +245,7 @@ def register_min_func(func, type_name): type_name : str The name of the custom datatype, e.g. posites2 (but not custom[posites2]32). """ - _register_func("tvm.datatype.min." + type_name, func) + _register_global_func("tvm.datatype.min." + type_name, func) def create_min_lower_func(extern_func_map, type_name): diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index 808c63cef16a..5c61de62e4e1 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -123,7 +123,7 @@ def detect_target_from_device(dev: Union[str, Device]) -> Target: """ if isinstance(dev, str): dev = device(dev) - device_type = Device.DEVICE_TYPE_TO_NAME[dev.device_type] + device_type = Device._DEVICE_TYPE_TO_NAME[dev.dlpack_device_type()] if device_type not in SUPPORT_DEVICE: raise ValueError( f"Auto detection for device `{device_type}` is not supported. " diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 64a7a893d808..a9191df773ec 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -21,7 +21,7 @@ from typing import Union import tvm_ffi -from tvm_ffi import register_func as _register_func +from tvm_ffi import register_global_func as _register_global_func from tvm.runtime import Device from tvm.runtime import Object, convert from tvm.runtime.container import String @@ -853,7 +853,7 @@ def create(target): return Target(target) -@_register_func("target._load_config_dict") +@_register_global_func("target._load_config_dict") def _load_config_dict(config_dict_str): try: config = json.loads(config_dict_str) diff --git a/python/tvm/target/virtual_device.py b/python/tvm/target/virtual_device.py index e73de85cd380..e509c5670750 100644 --- a/python/tvm/target/virtual_device.py +++ b/python/tvm/target/virtual_device.py @@ -34,6 +34,5 @@ def __init__(self, device=None, target=None, memory_scope="") -> None: _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope, device, target, memory_scope ) - @property - def device_type(self) -> int: + def dlpack_device_type(self) -> int: return self.device_type_int diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index 874975383ee1..e00dbb437440 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. """Common x86 related utilities""" -from tvm_ffi import register_func +from tvm_ffi import register_global_func from .codegen import target_has_features -@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") +@register_global_func("tvm.topi.x86.utils.get_simd_32bit_lanes") def get_simd_32bit_lanes(): """X86 SIMD optimal vector length lookup. Parameters diff --git a/python/tvm/te/_ffi_api.py b/python/tvm/te/_ffi_api.py index 8df8d5ff4754..172fff01d7ff 100644 --- a/python/tvm/te/_ffi_api.py +++ b/python/tvm/te/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("te", __name__) +tvm_ffi.init_ffi_api("te", __name__) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 61102085ef21..11084da0cc7f 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -18,13 +18,13 @@ # pylint: disable=invalid-name import tvm_ffi -from tvm.runtime import Object, ObjectGeneric +from tvm.runtime import Object, ObjectConvertible from tvm.tir import expr as _expr, DataProducer from . import _ffi_api -class TensorSlice(ObjectGeneric, _expr.ExprOp): +class TensorSlice(ObjectConvertible, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" def __init__(self, tensor, indices): diff --git a/python/tvm/testing/_ffi_api.py b/python/tvm/testing/_ffi_api.py index 4e57f4feafb7..6cb0b9bac495 100644 --- a/python/tvm/testing/_ffi_api.py +++ b/python/tvm/testing/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("testing", __name__) +tvm_ffi.init_ffi_api("testing", __name__) diff --git a/python/tvm/testing/popen_pool.py b/python/tvm/testing/popen_pool.py index c74829202bc3..8ff260a62f9c 100644 --- a/python/tvm/testing/popen_pool.py +++ b/python/tvm/testing/popen_pool.py @@ -36,13 +36,13 @@ def after_initializer(): return TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3 -@tvm_ffi.register_func("testing.identity_py") +@tvm_ffi.register_global_func("testing.identity_py") def identity_py(arg): return arg def register_ffi(): - @tvm_ffi.register_func("testing.nested_identity_py") + @tvm_ffi.register_global_func("testing.nested_identity_py") def _identity_py(arg): # pylint: disable=unused-variable return arg diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tir/_ffi_api.py index 2a004c9a83eb..4140cda741dd 100644 --- a/python/tvm/tir/_ffi_api.py +++ b/python/tvm/tir/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("tir", __name__) +tvm_ffi.init_ffi_api("tir", __name__) diff --git a/python/tvm/tir/analysis/_ffi_api.py b/python/tvm/tir/analysis/_ffi_api.py index f228e8b30cdd..9e5d094c1a82 100644 --- a/python/tvm/tir/analysis/_ffi_api.py +++ b/python/tvm/tir/analysis/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("tir.analysis", __name__) +tvm_ffi.init_ffi_api("tir.analysis", __name__) diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index beccb65b6359..5df2663fc20b 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -205,7 +205,9 @@ def build( if target is not None: if target.host is not None: target_host = target.host - elif tvm.device(target.kind.name, 0).device_type == tvm.cpu(0).device_type: + elif ( + tvm.device(target.kind.name, 0).dlpack_device_type() == tvm.cpu(0).dlpack_device_type() + ): target_host = target target_host = Target.canon_target(target_host) target_to_bind = target_to_bind.with_host(target_host) @@ -237,4 +239,4 @@ def build( return tir_to_runtime(host_mod, device_mod_dict, target_host) -tvm.register_func("tir.build", build) +tvm.register_global_func("tir.build", build) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 4fdee96a93b5..f5476230c19b 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -34,7 +34,7 @@ from tvm import ir from tvm.ir import Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import Object, ObjectGeneric, Scriptable, DataType, DataTypeCode, const +from tvm.runtime import Object, ObjectConvertible, Scriptable, DataType, DataTypeCode, const from . import _ffi_api from . import generic as _generic @@ -227,7 +227,7 @@ def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: return _generic.cast(self, dtype, span) -class EqualOp(ObjectGeneric, ExprOp): +class EqualOp(ObjectConvertible, ExprOp): """Deferred equal operator. This is used to support sugar that a == b can either @@ -264,7 +264,7 @@ def asobject(self) -> PrimExpr: return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore -class NotEqualOp(ObjectGeneric, ExprOp): +class NotEqualOp(ObjectConvertible, ExprOp): """Deferred NE operator. This is used to support sugar that a != b can either @@ -301,7 +301,7 @@ def asobject(self) -> PrimExpr: return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore -class IntImmEnum(ObjectGeneric): +class IntImmEnum(ObjectConvertible): """Lazily evaluate an IntImm in case the constructor is not available in runtime. diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 7a9708848ab4..d6466b09224d 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -16,7 +16,7 @@ # under the License. """Developer API of IR node builder make function.""" import tvm -from tvm.runtime import ObjectGeneric, const +from tvm.runtime import ObjectConvertible, const from tvm.ir import container as _container from . import stmt as _stmt @@ -39,7 +39,7 @@ def __exit__(self, ptype, value, trace): self._exit_cb() -class BufferVar(ObjectGeneric): +class BufferVar(ObjectConvertible): """Buffer variable with content type, makes load store easily. Do not create it directly, create use IRBuilder. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index d706a1a15023..fcbc47961625 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1953,7 +1953,7 @@ def all(*args, span=None): return val -@tvm_ffi.register_func("tvm.default_trace_action") +@tvm_ffi.register_global_func("tvm.default_trace_action") def _tvm_default_trace_action(*args): print(list(args)) diff --git a/python/tvm/tir/schedule/_ffi_api.py b/python/tvm/tir/schedule/_ffi_api.py index 99b831cdcda2..5087112b892a 100644 --- a/python/tvm/tir/schedule/_ffi_api.py +++ b/python/tvm/tir/schedule/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.tir.schedule""" import tvm_ffi -tvm_ffi._init_api("tir.schedule", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("tir.schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 104acf2f44c0..761654fc6906 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -18,7 +18,7 @@ """Intrinsics for tensorization on NVIDIA GPU.""" from typing import Dict, Literal, Optional, Tuple -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.runtime import convert from tvm.script import tir as T from tvm.tir import Cast, IntImm, TensorIntrin @@ -46,7 +46,7 @@ def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col -@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") +@register_global_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): i, j = ind[0], ind[1] thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j) @@ -1746,7 +1746,7 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: ) -@register_func("tir.index_map_m16n8k8.matrixC") +@register_global_func("tir.index_map_m16n8k8.matrixC") def index_map_m16n8k8_matrixC(ind): i, j = ind[0], ind[1] return convert([(i // 8) // 2, j // 8, (i // 8) % 2, (j % 8) % 2]) diff --git a/python/tvm/tir/transform/_ffi_api.py b/python/tvm/tir/transform/_ffi_api.py index 6a059ff0cf96..67896ec05dda 100644 --- a/python/tvm/tir/transform/_ffi_api.py +++ b/python/tvm/tir/transform/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi._init_api("tir.transform", __name__) +tvm_ffi.init_ffi_api("tir.transform", __name__) diff --git a/python/tvm/topi/cpp/cuda.py b/python/tvm/topi/cpp/cuda.py index d7d413fcf5aa..21cf554add3b 100644 --- a/python/tvm/topi/cpp/cuda.py +++ b/python/tvm/topi/cpp/cuda.py @@ -17,4 +17,4 @@ """FFI for CUDA TOPI ops and schedules""" import tvm_ffi -tvm_ffi._init_api("topi.cuda", "tvm.topi.cpp.cuda") +tvm_ffi.init_ffi_api("topi.cuda", "tvm.topi.cpp.cuda") diff --git a/python/tvm/topi/cpp/generic.py b/python/tvm/topi/cpp/generic.py index cafcdbcada60..77dfcab58a0f 100644 --- a/python/tvm/topi/cpp/generic.py +++ b/python/tvm/topi/cpp/generic.py @@ -17,4 +17,4 @@ """FFI for generic TOPI ops and schedules""" import tvm_ffi -tvm_ffi._init_api("topi.generic", "tvm.topi.cpp.generic") +tvm_ffi.init_ffi_api("topi.generic", "tvm.topi.cpp.generic") diff --git a/python/tvm/topi/cpp/impl.py b/python/tvm/topi/cpp/impl.py index f906fc16d24c..c1783067951a 100644 --- a/python/tvm/topi/cpp/impl.py +++ b/python/tvm/topi/cpp/impl.py @@ -17,4 +17,4 @@ """Load Lib for C++ TOPI ops and schedules""" import tvm_ffi -tvm_ffi._init_api("topi", "tvm.topi.cpp") +tvm_ffi.init_ffi_api("topi", "tvm.topi.cpp") diff --git a/python/tvm/topi/cpp/nn.py b/python/tvm/topi/cpp/nn.py index b40bf834e001..32c24dc1ed98 100644 --- a/python/tvm/topi/cpp/nn.py +++ b/python/tvm/topi/cpp/nn.py @@ -17,4 +17,4 @@ """FFI for NN TOPI ops and schedules""" import tvm_ffi -tvm_ffi._init_api("topi.nn", "tvm.topi.cpp.nn") +tvm_ffi.init_ffi_api("topi.nn", "tvm.topi.cpp.nn") diff --git a/python/tvm/topi/cpp/rocm.py b/python/tvm/topi/cpp/rocm.py index eb14b0c7dc2e..3eb83fe689c3 100644 --- a/python/tvm/topi/cpp/rocm.py +++ b/python/tvm/topi/cpp/rocm.py @@ -17,4 +17,4 @@ """FFI for Rocm TOPI ops and schedules""" import tvm_ffi -tvm_ffi._init_api("topi.rocm", "tvm.topi.cpp.rocm") +tvm_ffi.init_ffi_api("topi.rocm", "tvm.topi.cpp.rocm") diff --git a/python/tvm/topi/cpp/utils.py b/python/tvm/topi/cpp/utils.py index 3e73ce7a9bdb..ecf341fabd5f 100644 --- a/python/tvm/topi/cpp/utils.py +++ b/python/tvm/topi/cpp/utils.py @@ -17,4 +17,4 @@ """FFI for TOPI utility functions""" import tvm_ffi -tvm_ffi._init_api("topi.utils", "tvm.topi.cpp.utils") +tvm_ffi.init_ffi_api("topi.utils", "tvm.topi.cpp.utils") diff --git a/python/tvm/topi/cpp/vision/__init__.py b/python/tvm/topi/cpp/vision/__init__.py index f47a21db7886..8acbb3861067 100644 --- a/python/tvm/topi/cpp/vision/__init__.py +++ b/python/tvm/topi/cpp/vision/__init__.py @@ -20,4 +20,4 @@ from . import yolo -tvm_ffi._init_api("topi.vision", "tvm.topi.cpp.vision") +tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision") diff --git a/python/tvm/topi/cpp/vision/yolo.py b/python/tvm/topi/cpp/vision/yolo.py index a2eb47dadb47..f5aa6d2d0670 100644 --- a/python/tvm/topi/cpp/vision/yolo.py +++ b/python/tvm/topi/cpp/vision/yolo.py @@ -17,4 +17,4 @@ """FFI for Yolo TOPI ops and schedules""" import tvm_ffi -tvm_ffi._init_api("topi.vision.yolo", "tvm.topi.cpp.vision.yolo") +tvm_ffi.init_ffi_api("topi.vision.yolo", "tvm.topi.cpp.vision.yolo") diff --git a/python/tvm/topi/cpp/x86.py b/python/tvm/topi/cpp/x86.py index 343254607514..93cb6d96f6b8 100644 --- a/python/tvm/topi/cpp/x86.py +++ b/python/tvm/topi/cpp/x86.py @@ -17,4 +17,4 @@ """FFI for x86 TOPI ops and schedules""" import tvm_ffi -tvm_ffi._init_api("topi.x86", "tvm.topi.cpp.x86") +tvm_ffi.init_ffi_api("topi.x86", "tvm.topi.cpp.x86") diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 1e476eaf035a..49bf9ae3d93f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1444,8 +1444,8 @@ TVM_REGISTER_OP("relax.hint_on_device") Expr MakeHintOnDevice(Expr data, Device device) { static const Op& op = Op::Get("relax.hint_on_device"); ObjectPtr attrs = make_object(); - attrs->dev_type = static_cast(device.device_type); - attrs->dev_id = device.device_id; + attrs->device_type = static_cast(device.device_type); + attrs->index = device.device_id; return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 96885eb255ca..1034c2640f2a 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -54,8 +54,8 @@ class VDeviceLookup { VDevice operator()(Attrs hint_on_device_attrs) { auto attrs = hint_on_device_attrs.as(); ICHECK(attrs); - int32_t device_type = attrs->dev_type; - int32_t device_id = attrs->dev_id; + int32_t device_type = attrs->device_type; + int32_t device_id = attrs->index; CHECK(opt_vdevices_.defined()) << "ValueError: The target VDevice in the GlobalInfos was not found."; diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 9427d6805db5..bb07cbe44255 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -511,7 +511,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ String debug_func_name = args[1].cast(); const auto debug_func = tvm::ffi::Function::GetGlobal(debug_func_name); CHECK(debug_func.has_value()) << "ValueError: " << debug_func_name << " is not found. " - << "Use the decorator `@tvm.register_func(\"" + << "Use the decorator `@tvm.register_global_func(\"" << debug_func_name << "\")` to register it."; String line_info = args[2].cast(); std::vector call_args(num_args + 1); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 4d61c035fbe5..e284a75fefc3 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -436,8 +436,6 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev); -TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU); - TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break .add_attr_option>("devices"); diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index 404ca5d1d94d..7b868007a6b0 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -29,7 +29,7 @@ def test_get_global(): targs = (10, 10.0, "hello") # register into global function table - @tvm.register_func + @tvm.register_global_func def my_packed_func(*args): assert tuple(args) == targs return 10 @@ -50,7 +50,7 @@ def test(y): f2 = tvm.runtime.convert(test) # register into global function table - @tvm.register_func + @tvm.register_global_func def my_callback_with_node(y, f): assert y == x return f(y) @@ -112,7 +112,7 @@ def test_device_func(dev): x = test_device_func(tvm.cuda(7)) assert x == tvm.cpu(0) x = tvm.opencl(10) - x = tvm.testing.device_test(x, x.device_type, x.device_id) + x = tvm.testing.device_test(x, x.dlpack_device_type(), x.index) assert x == tvm.opencl(10) @@ -123,7 +123,6 @@ def test_numpy_scalar(): def test_tensor_args(): def check(arr): - assert not arr.is_view assert tvm.testing.object_use_count(arr) == 2 fcheck = tvm.runtime.convert(check) diff --git a/tests/python/codegen/test_gpu_codegen_allreduce.py b/tests/python/codegen/test_gpu_codegen_allreduce.py index fe6a9179f41c..09c9fa13386e 100644 --- a/tests/python/codegen/test_gpu_codegen_allreduce.py +++ b/tests/python/codegen/test_gpu_codegen_allreduce.py @@ -94,7 +94,7 @@ def optional_metal_compile_callback(define_metal_compile_callback): if define_metal_compile_callback: - @tvm.register_func(name, override=True) + @tvm.register_global_func(name, override=True) def compile_metal(src, target): return tvm.contrib.xcode.compile_metal(src, sdk="macosx") @@ -104,7 +104,7 @@ def compile_metal(src, target): if cached is None: tvm_ffi.registry.remove_global_func(name) else: - tvm.register_func(name, cached, override=True) + tvm.register_global_func(name, cached, override=True) @tvm.testing.requires_metal(support_required="compile-only") diff --git a/tests/python/codegen/test_target_codegen_extern.py b/tests/python/codegen/test_target_codegen_extern.py index f02a717747b4..06e0926005bf 100644 --- a/tests/python/codegen/test_target_codegen_extern.py +++ b/tests/python/codegen/test_target_codegen_extern.py @@ -97,7 +97,7 @@ def extern_generator(ins, outs): # Create IRModule directly mod = tvm.IRModule.from_expr(te.create_prim_func([A, C])) - @tvm.register_func + @tvm.register_global_func def my_extern_array_func1(aa, bb): aa.copyto(bb) @@ -143,7 +143,7 @@ def check_target(target): a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) - @tvm.register_func + @tvm.register_global_func def my_extern_array_func2(aa, bb): assert aa.shape == a.shape tvm.testing.assert_allclose(aa.numpy(), a.numpy() + 1) diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index 8f50ec829843..e938eb64d5a1 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -180,7 +180,7 @@ def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float3 vi = T.axis.spatial(16, i) B[vi] = A[vi] + x - @tvm.register_func("tvm_callback_metal_compile") + @tvm.register_global_func("tvm_callback_metal_compile") def compile_metal(src, target): return xcode.compile_metal(src) diff --git a/tests/python/codegen/test_target_codegen_static_init.py b/tests/python/codegen/test_target_codegen_static_init.py index ad3863abd13d..30161913360a 100644 --- a/tests/python/codegen/test_target_codegen_static_init.py +++ b/tests/python/codegen/test_target_codegen_static_init.py @@ -51,7 +51,7 @@ def test_static_init(): handle = tvm.tir.call_intrin("handle", "tir.tvm_static_handle") ib.emit(tvm.tir.call_packed("test_static_callback", handle, Ab)) - @tvm.register_func("test_static_callback") + @tvm.register_global_func("test_static_callback") def test_cb(sh, A): assert isinstance(sh, ctypes.c_void_p) return sh diff --git a/tests/python/contrib/test_dlpack.py b/tests/python/contrib/test_dlpack.py index 20992048b208..f0632f3ac7db 100644 --- a/tests/python/contrib/test_dlpack.py +++ b/tests/python/contrib/test_dlpack.py @@ -24,7 +24,7 @@ def verify_torch_dlpack(): a = np.random.randn(1337) tvm_a = tvm.runtime.tensor(a) - np.testing.assert_equal(tvm.runtime.from_dlpack(tvm_a.to_dlpack()).numpy(), a) + np.testing.assert_equal(tvm.runtime.from_dlpack(tvm_a).numpy(), a) try: import torch @@ -35,9 +35,7 @@ def verify_torch_dlpack(): np.testing.assert_equal(x.numpy(), tvm_x.numpy()) y = tvm.runtime.from_dlpack(tvm_x) np.testing.assert_equal(y.numpy(), tvm_x.numpy()) - np.testing.assert_equal( - torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.numpy() - ) + np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y).numpy(), tvm_x.numpy()) n = tvm.runtime.convert(137) xx = torch.rand(137, 137) diff --git a/tests/python/contrib/test_rpc_tracker.py b/tests/python/contrib/test_rpc_tracker.py index f6918db4e286..8dbc1c700412 100644 --- a/tests/python/contrib/test_rpc_tracker.py +++ b/tests/python/contrib/test_rpc_tracker.py @@ -31,7 +31,7 @@ def check_server_drop(): # pylint: disable=import-outside-toplevel from tvm.rpc.base import TrackerCode - @tvm.register_func("rpc.test2.addone") + @tvm.register_global_func("rpc.test2.addone") def addone(x): return x + 1 diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py index b41ff526f083..a68f53917603 100644 --- a/tests/python/disco/test_loader.py +++ b/tests/python/disco/test_loader.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import dlight as dl from tvm import relax as rx -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.contrib import tvmjs from tvm.runtime import ShapeTuple from tvm.runtime import disco as di @@ -35,19 +35,19 @@ from tvm.contrib import tvmjs -@register_func("tests.disco.shard_dim_0", override=True) +@register_global_func("tests.disco.shard_dim_0", override=True) def _shard_dim_0(src, num_shards, tgt): s_0, s_1 = src.shape tgt.copyfrom(src.numpy().reshape(num_shards, s_0 // num_shards, s_1)) -@register_func("tests.disco.shard_dim_1", override=True) +@register_global_func("tests.disco.shard_dim_1", override=True) def _shard_dim_1(src, num_shards, tgt): s_0, s_1 = src.shape tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards).transpose(1, 0, 2)) -@register_func("tests.disco.shard_qkv_0", override=True) +@register_global_func("tests.disco.shard_qkv_0", override=True) def _shard_qkv_0(src, num_shards, q_heads, kv_heads, tgt): total_dim, hidden_size = src.shape head_dim = total_dim // (q_heads + kv_heads + kv_heads) @@ -75,7 +75,7 @@ def _shard_qkv_0(src, num_shards, q_heads, kv_heads, tgt): tgt.copyfrom(w_qkv) -@register_func("tests.disco.shard_qkv_1", override=True) +@register_global_func("tests.disco.shard_qkv_1", override=True) def _shard_qkv_1(src, tgt): s, _, _, h = src.shape # pylint: disable=invalid-name tgt.copyfrom(src.numpy().reshape(s, -1, h)) diff --git a/tests/python/ir/test_node_reflection.py b/tests/python/ir/test_node_reflection.py index 2db0359b6d3a..52b2a29f59c0 100644 --- a/tests/python/ir/test_node_reflection.py +++ b/tests/python/ir/test_node_reflection.py @@ -94,7 +94,7 @@ def test_make_sum(): def test_env_func(): - @tvm.register_func("test.env_func") + @tvm.register_global_func("test.env_func") def test(x): return x + 1 diff --git a/tests/python/meta_schedule/test_meta_schedule_builder.py b/tests/python/meta_schedule/test_meta_schedule_builder.py index a21d5a91959f..6da0a089180c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_builder.py +++ b/tests/python/meta_schedule/test_meta_schedule_builder.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import script -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.meta_schedule.builder import ( BuilderInput, BuilderResult, @@ -163,7 +163,7 @@ def test_meta_schedule_error_handle_build_func(): """Test the error handing during building""" def initializer(): - @register_func("meta_schedule.builder.test_build") + @register_global_func("meta_schedule.builder.test_build") def test_build(mod: Module, target: Target, _) -> None: # pylint: disable=unused-variable raise ValueError("Builder intended Test Error (build func).") @@ -182,7 +182,7 @@ def test_meta_schedule_error_handle_export_func(): """Test the error handing during building""" def initializer(): - @register_func("meta_schedule.builder.test_export") + @register_global_func("meta_schedule.builder.test_export") def test_build(mod: Module) -> str: # pylint: disable=unused-variable raise ValueError("Builder intended Test Error (export func).") @@ -201,7 +201,7 @@ def test_meta_schedule_error_handle_time_out(): """Test the error handing time out during building""" def initializer(): - @register_func("meta_schedule.builder.test_time_out") + @register_global_func("meta_schedule.builder.test_time_out") def timeout_build(mod, target, _): # pylint: disable=unused-argument, unused-variable time.sleep(2) diff --git a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py index cbf2530eeffc..61888ed1a70e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import te from tvm.ir.module import IRModule -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.error import TVMError from tvm.meta_schedule import TuneContext from tvm.meta_schedule.schedule_rule import PyScheduleRule diff --git a/tests/python/meta_schedule/test_meta_schedule_runner.py b/tests/python/meta_schedule/test_meta_schedule_runner.py index 0d6a1e1e7fe2..5b4f6944df91 100644 --- a/tests/python/meta_schedule/test_meta_schedule_runner.py +++ b/tests/python/meta_schedule/test_meta_schedule_runner.py @@ -25,7 +25,7 @@ import pytest import tvm import tvm.testing -from tvm_ffi import register_func +from tvm_ffi import register_global_func from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput, LocalBuilder from tvm.meta_schedule.runner import ( @@ -454,7 +454,7 @@ def test_meta_schedule_local_runner_time_out(): ) def initializer(): - @register_func("meta_schedule.runner.test_time_out") + @register_global_func("meta_schedule.runner.test_time_out") def timeout_session_creator( # pylint: disable=unused-variable device: Device, # pylint: disable=unused-argument args_info: T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument @@ -492,7 +492,7 @@ def test_meta_schedule_rpc_runner_exception(): """Test meta schedule RPC Runner exception""" def initializer(): - @register_func("meta_schedule.runner.test_exception") + @register_global_func("meta_schedule.runner.test_exception") def exception_session_creator( # pylint: disable=unused-variable rpc_config: RPCConfig, # pylint: disable=unused-argument ) -> RPCSession: @@ -556,7 +556,7 @@ def test_meta_schedule_local_runner_exception(): ) def initializer(): - @register_func("meta_schedule.runner.test_exception") + @register_global_func("meta_schedule.runner.test_exception") def timeout_session_creator( # pylint: disable=unused-variable device: Device, # pylint: disable=unused-argument args_info: T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py index 7222c4d64972..332bebd79d31 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py @@ -42,7 +42,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -@tvm.register_func("meta_schedule.cpu.test_apply_custom_rule") +@tvm.register_global_func("meta_schedule.cpu.test_apply_custom_rule") def sch_fn(sch: tvm.tir.Schedule, block: tvm.tir.Block) -> List[tvm.tir.Schedule]: raise ValueError("Intended for meta_schedule.cpu.test_apply_custom_rule") diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py index be60524e8475..56372a63e576 100644 --- a/tests/python/relax/test_blockbuilder_core.py +++ b/tests/python/relax/test_blockbuilder_core.py @@ -31,7 +31,7 @@ @pytest.fixture(scope="module") def register_nop(): - @tvm.register_func("test.blockbuilder.nop") + @tvm.register_global_func("test.blockbuilder.nop") def nop(): pass diff --git a/tests/python/relax/test_frontend_nn_debug.py b/tests/python/relax/test_frontend_nn_debug.py index c1372adff10e..f3ead2e9c011 100644 --- a/tests/python/relax/test_frontend_nn_debug.py +++ b/tests/python/relax/test_frontend_nn_debug.py @@ -43,7 +43,7 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name def test_debug_func(): - @tvm.register_func("testing.relax.frontend.nn.test_debug_func") + @tvm.register_global_func("testing.relax.frontend.nn.test_debug_func") def _debug( # pylint: disable=too-many-arguments lineno: str, tensor: Tensor, diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 9e0369318841..e827f643b33c 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -899,7 +899,7 @@ def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 16, 8), d def test_empty(): - @tvm.register_func("test_empty_assert", override=True) + @tvm.register_global_func("test_empty_assert", override=True) def test_empty_assert(_lineo, x): assert x.shape == (10, 10) assert x.dtype == "float32" diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index d424ab69decc..9d05690f38b1 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -21,7 +21,7 @@ from tvm.script import tir as T -@tvm.register_func("test.op.identity", override=True) +@tvm.register_global_func("test.op.identity", override=True) def identity_packed(a): return tvm.runtime.tensor(a.numpy()) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 221d7d1270a5..8558f6e911b8 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -338,7 +338,7 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): ) return z - @tvm.register_func("test.inplace.add", override=True) + @tvm.register_global_func("test.inplace.add", override=True) def inplace_add(a, b): arr_a = a.numpy() arr_b = b.numpy() @@ -372,7 +372,7 @@ def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): assert result == tvm_arr_a assert (result.numpy() == sum).all() - @tvm.register_func("test.inplace.tuple_add", override=True) + @tvm.register_global_func("test.inplace.tuple_add", override=True) def inplace_tuple_add(a, b): arr_a = a.numpy() arr_b = b.numpy() diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py index a3003459f89d..e243770ed6e1 100644 --- a/tests/python/relax/test_runtime_builtin.py +++ b/tests/python/relax/test_runtime_builtin.py @@ -179,7 +179,7 @@ def test_tensor_cache(): temp = utils.tempdir() tvmjs.dump_tensor_cache(param_dict, temp.path, encode_format="f32-to-bf16") - fload(str(temp.path), tvm.cpu().device_type, 0) + fload(str(temp.path), tvm.cpu().dlpack_device_type(), 0) res = fget_params("x", -1) for i, v in enumerate(res): v_np = param_dict[f"x_{i}"] @@ -204,7 +204,7 @@ def test_tensor_cache_update(): tvmjs.dump_tensor_cache( param_dict, temp.path, encode_format="f32-to-bf16", update_if_exists=True ) - fload(str(temp.path), tvm.cpu().device_type, 0) + fload(str(temp.path), tvm.cpu().dlpack_device_type(), 0) res = fget_params("x", -1) for i, v in enumerate(res): v_np = param_dict[f"x_{i}"] diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py index 8253c379951a..e3de4944fef9 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py @@ -84,7 +84,7 @@ # Register a dumb function for testing purpose. -@tvm.register_func("test.dumb_function", override=True) +@tvm.register_global_func("test.dumb_function", override=True) def _dumb_function(): raise RuntimeError("Dumb function isn't supposed to be accessed.") diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py index cc4ffb1d525b..efc0a5694ca6 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py @@ -79,7 +79,7 @@ # Register a dumb function for testing purpose. -@tvm.register_func("test.dumb_function", override=True) +@tvm.register_global_func("test.dumb_function", override=True) def _dumb_function(): raise RuntimeError("Dumb function isn't supposed to be accessed.") diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 696499121072..ae0521a0e2f8 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -660,11 +660,11 @@ def transform_params( transformed = {} expected = [params[0].transpose(1, 0, 2, 3), params[1]] - @tvm.register_func("get_item", override=True) + @tvm.register_global_func("get_item", override=True) def get_item(i): return tvm.runtime.tensor(params[i], dev) - @tvm.register_func("set_item", override=True) + @tvm.register_global_func("set_item", override=True) def set_item(i, value): assert i not in transformed, f"Set item called multiple times for index {i}" transformed[i] = value.numpy() diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 044ba97cbfe4..9633244c67fb 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -79,8 +79,8 @@ def foo(x: R.Tensor((3, 4), "float32")): tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) # check the resulting tensor is on cpu:0 assert res.device == tvm.cpu(0) - assert res.device.device_type == 1 - assert res.device.device_id == 0 + assert res.device.dlpack_device_type() == 1 + assert res.device.index == 0 @pytest.mark.parametrize("exec_mode", EXEC_MODE) diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 728eb584ec24..d04fd6bdab1b 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -129,7 +129,7 @@ def test_capture_error_is_recoverable(): target = tvm.target.Target("cuda") dev = tvm.cuda() - @tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True) + @tvm.register_global_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True) def invalid_impl_for_cudagraph(arg_tensor): # Memory allocation/deallocation may not be performed while # capturing a cudaGraph. This passes the warm-up run diff --git a/tests/python/runtime/test_runtime_measure.py b/tests/python/runtime/test_runtime_measure.py index fe01e5d331a6..41271b1ba312 100644 --- a/tests/python/runtime/test_runtime_measure.py +++ b/tests/python/runtime/test_runtime_measure.py @@ -27,7 +27,7 @@ def test_min_repeat_ms(): tmp = tempdir() filename = tmp.relpath("log") - @tvm.register_func + @tvm.register_global_func def my_debug(filename): """one call lasts for 100 ms and writes one character to a file""" time.sleep(0.1) diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 796e886e7bce..627ebbb7d62c 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -53,7 +53,7 @@ # Windows does not support fork so we can enable Windows for testing sys.platform.startswith("win") == False and multiprocessing.get_start_method() != "fork", reason=( - "pytest + multiprocessing spawn method causes tvm.register_func to " + "pytest + multiprocessing spawn method causes tvm.register_global_func to " "not work on the rpc.Server." ), ) diff --git a/tests/python/runtime/test_runtime_trace.py b/tests/python/runtime/test_runtime_trace.py index 263652bb695c..146db5a06535 100644 --- a/tests/python/runtime/test_runtime_trace.py +++ b/tests/python/runtime/test_runtime_trace.py @@ -30,7 +30,7 @@ def test_trace_default_action(): def test_trace_expr_assign(): - @tvm.register_func("tvm.tir.trace_callback2") + @tvm.register_global_func("tvm.tir.trace_callback2") def trace_buffer(x): return @@ -59,7 +59,7 @@ def check_assign(dtype): def test_trace_expr_sum_generated(): - @tvm.register_func("tvm.tir.trace_callback3") + @tvm.register_global_func("tvm.tir.trace_callback3") def trace_buffer(x): return @@ -84,7 +84,7 @@ def check_expr_sum(dtype): def test_trace_expr_sum_args(): - @tvm.register_func("tvm.tir.trace_silent") + @tvm.register_global_func("tvm.tir.trace_silent") def silent(*args): return @@ -118,7 +118,7 @@ def check_expr_sum(dtype): def test_trace_expr_sum_custom(): - @tvm.register_func("tvm.tir.trace_callback4") + @tvm.register_global_func("tvm.tir.trace_callback4") def trace_buffer(x): return @@ -145,11 +145,11 @@ def check_expr_sum_custom(dtype): def test_trace_can_change_traced_value_int(): - @tvm.register_func("tvm.tir.trace_change_int_first") + @tvm.register_global_func("tvm.tir.trace_change_int_first") def trace_buffer(x): return 13 - @tvm.register_func("tvm.tir.trace_change_int_second") + @tvm.register_global_func("tvm.tir.trace_change_int_second") def trace_buffer(x): return 14 @@ -174,11 +174,11 @@ def check_assign(dtype): def test_trace_can_change_traced_value_float(): - @tvm.register_func("tvm.tir.trace_change_float_first") + @tvm.register_global_func("tvm.tir.trace_change_float_first") def trace_buffer(x): return 13.0 - @tvm.register_func("tvm.tir.trace_change_float_second") + @tvm.register_global_func("tvm.tir.trace_change_float_second") def trace_buffer(x): return 14.0 diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index 4906b219c359..8aa314bd6293 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -24,15 +24,19 @@ def test_all_targets_device_type_verify(): """Consistency verification for all targets' device type""" - all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()] + target_kind_set = set(tvm.target.Target.list_kinds()) + target_kind_set.remove("composite") + all_targets = [tvm.target.Target(t) for t in target_kind_set] for tgt in all_targets: - if tgt.kind.name not in tvm.runtime.Device.DEVICE_NAME_TO_TYPE: + if tgt.kind.name not in tvm.runtime.Device._DEVICE_NAME_TO_TYPE: raise KeyError( - "Cannot find target kind: %s in Device.DEVICE_NAME_TO_TYPE" % tgt.kind.name + "Cannot find target kind: %s in Device._DEVICE_NAME_TO_TYPE" % tgt.kind.name ) - assert tgt.get_target_device_type() == tvm.runtime.Device.DEVICE_NAME_TO_TYPE[tgt.kind.name] + assert ( + tgt.get_target_device_type() == tvm.runtime.Device._DEVICE_NAME_TO_TYPE[tgt.kind.name] + ) def test_target_string_parse(): @@ -347,7 +351,7 @@ def test_canon_multi_target_and_host_5(): def test_canon_multi_target_and_host_6(): """Test `canon_target_and_host` by using TVM Objects""" - cuda_device_type = tvm.device("cuda").device_type + cuda_device_type = tvm.device("cuda").dlpack_device_type() target = {cuda_device_type: Target(target="cuda", host="llvm")} host = None raw_targets_1 = Target.canon_multi_target_and_host(target, host) diff --git a/tests/python/target/test_virtual_device.py b/tests/python/target/test_virtual_device.py index a6434480fa83..4441bab128b8 100644 --- a/tests/python/target/test_virtual_device.py +++ b/tests/python/target/test_virtual_device.py @@ -21,7 +21,7 @@ def test_make_virtual_device_for_device(): virtual_device = tvm.target.VirtualDevice(tvm.device("cuda")) - assert virtual_device.device_type == 2 + assert virtual_device.dlpack_device_type() == 2 # ie kDLCUDA assert virtual_device.virtual_device_id == 0 assert virtual_device.target is None @@ -31,7 +31,7 @@ def test_make_virtual_device_for_device(): def test_make_virtual_device_for_device_and_target(): target = tvm.target.Target("cuda") virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target) - assert virtual_device.device_type == 2 # ie kDLCUDA + assert virtual_device.dlpack_device_type() == 2 # ie kDLCUDA assert virtual_device.target == target assert virtual_device.memory_scope == "" @@ -40,7 +40,7 @@ def test_make_virtual_device_for_device_target_and_memory_scope(): target = tvm.target.Target("cuda") scope = "local" virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target, scope) - assert virtual_device.device_type == 2 # ie kDLCUDA + assert virtual_device.dlpack_device_type() == 2 # ie kDLCUDA assert virtual_device.target == target assert virtual_device.memory_scope == scope diff --git a/tests/python/testing/test_tvm_testing_features.py b/tests/python/testing/test_tvm_testing_features.py index 6d394ebeb649..9618113ae3a9 100644 --- a/tests/python/testing/test_tvm_testing_features.py +++ b/tests/python/testing/test_tvm_testing_features.py @@ -49,7 +49,7 @@ def test_all_targets_used(self): assert sorted(self.targets_used) == sorted(self.enabled_targets) def test_all_devices_used(self): - sort_key = lambda dev: (dev.device_type, dev.device_id) + sort_key = lambda dev: (dev.dlpack_device_type(), dev.index) assert sorted(self.devices_used, key=sort_key) == sorted(self.enabled_devices, key=sort_key) targets_with_explicit_list = [] diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index 559d705b6267..01af60724cbb 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -182,7 +182,7 @@ def test_array(): def test_env_func(): - @tvm.register_func("test.sequal.env_func") + @tvm.register_global_func("test.sequal.env_func") def test(x): return x + 1 diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 840c83452ed5..67598b0ba04f 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -402,7 +402,7 @@ def get_original_code(): nonlocal original_code return original_code - @tvm.register_func(func_name, override=True) + @tvm.register_global_func(func_name, override=True) def tvm_callback_cuda_postproc(code, _): nonlocal original_code original_code = code @@ -424,7 +424,7 @@ def tvm_callback_cuda_postproc(code, _): if prev_postproc is None: tvm_ffi.registry.remove_global_func(func_name) else: - tvm.register_func(func_name, prev_postproc, override=True) + tvm.register_global_func(func_name, prev_postproc, override=True) @tvm.testing.requires_cuda diff --git a/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py b/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py index 5006efba50b2..e8fee40ec173 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py +++ b/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py @@ -19,7 +19,7 @@ from tvm.script import tir as T -@tvm.register_func("tvm.info.mem.global.test_with_head_address") +@tvm.register_global_func("tvm.info.mem.global.test_with_head_address") def mem_info_with_head_address(): return tvm.ir.make_node( "target.MemoryInfo", @@ -30,7 +30,7 @@ def mem_info_with_head_address(): ) -@tvm.register_func("tvm.info.mem.global.test_without_head_address") +@tvm.register_global_func("tvm.info.mem.global.test_without_head_address") def mem_info_without_head_address(): return tvm.ir.make_node( "target.MemoryInfo", diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index 0f71b78f0ca1..180f76a67ecd 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -26,7 +26,7 @@ from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol -@tvm.register_func("tvm.test_matmul") +@tvm.register_global_func("tvm.test_matmul") def my_matmul(a, b, c): c.copyfrom(np.dot(a.numpy(), b.numpy())) diff --git a/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py b/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py index 46fd4104544a..617d028c1332 100644 --- a/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py @@ -68,7 +68,7 @@ def test_device_setup(mod, target, dev): assert f.body.value == 0 assert f.body.body.node == "default" assert f.body.body.attr_key == "device_type" - assert f.body.body.value == dev.device_type + assert f.body.body.value == dev.dlpack_device_type() def test_no_buffers_no_device_setup(): diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index e8d21a8dc4f9..36500c4d9885 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -26,7 +26,7 @@ def register_mem(scope_tb, max_bits): # Register mem - @tvm.register_func("tvm.info.mem.%s" % scope_tb) + @tvm.register_global_func("tvm.info.mem.%s" % scope_tb) def mem_info_inp_buffer(): return tvm.ir.make_node( "target.MemoryInfo",