From db609d7b46920ab1fbc627e4094282b8b10e37e9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 13 Sep 2025 20:38:43 -0400 Subject: [PATCH] [REFACTOR][FFI] Split tvm-ffi into a separate repo This PR updates the code so we split tvm-ffi into a separate repo --- .github/actions/setup/action.yml | 2 +- .gitmodules | 3 + 3rdparty/tvm-ffi | 1 + CMakeLists.txt | 2 +- apps/android_rpc/app/src/main/jni/Android.mk | 4 +- .../app/src/main/jni/tvm_runtime.h | 24 +- apps/ios_rpc/tvmrpc/TVMRuntime.mm | 2 +- docs/install/from_source.rst | 2 +- ffi/.clang-format | 8 - ffi/CMakeLists.txt | 262 --- ffi/README.md | 18 - ffi/cmake/Utils/AddGoogleTest.cmake | 56 - ffi/cmake/Utils/AddLibbacktrace.cmake | 68 - ffi/cmake/Utils/CxxWarning.cmake | 30 - ffi/cmake/Utils/Library.cmake | 88 - ffi/cmake/Utils/Sanitizer.cmake | 35 - ffi/cmake/tvm_ffi-config.cmake | 58 - ffi/docs/.gitignore | 2 - ffi/docs/Makefile | 41 - ffi/docs/README.md | 46 - ffi/docs/concepts/abi_overview.md | 430 ---- ffi/docs/conf.py | 228 --- ffi/docs/get_started/install.md | 83 - ffi/docs/get_started/quick_start.md | 213 -- ffi/docs/guides/cpp_guide.md | 584 ------ ffi/docs/guides/packaging.md | 282 --- ffi/docs/guides/python_guide.md | 242 --- ffi/docs/index.rst | 53 - ffi/docs/reference/cpp/index.rst | 107 - ffi/docs/reference/python/index.rst | 69 - ffi/docs/requirements.txt | 21 - ffi/examples/inline_module/main.py | 87 - ffi/examples/packaging/CMakeLists.txt | 73 - ffi/examples/packaging/README.md | 61 - ffi/examples/packaging/pyproject.toml | 58 - .../python/my_ffi_extension/__init__.py | 48 - .../python/my_ffi_extension/_ffi_api.py | 24 - .../packaging/python/my_ffi_extension/base.py | 37 - ffi/examples/packaging/run_example.py | 40 - ffi/examples/packaging/src/extension.cc | 89 - ffi/examples/quick_start/CMakeLists.txt | 65 - ffi/examples/quick_start/README.md | 58 - ffi/examples/quick_start/run_example.py | 82 - ffi/examples/quick_start/run_example.sh | 27 - ffi/examples/quick_start/src/add_one_cpu.cc | 41 - ffi/examples/quick_start/src/add_one_cuda.cu | 58 - ffi/examples/quick_start/src/run_example.cc | 53 - ffi/include/tvm/ffi/any.h | 692 ------- ffi/include/tvm/ffi/base_details.h | 297 --- ffi/include/tvm/ffi/c_api.h | 1097 ---------- ffi/include/tvm/ffi/cast.h | 79 - ffi/include/tvm/ffi/container/array.h | 1147 ----------- .../tvm/ffi/container/container_details.h | 356 ---- ffi/include/tvm/ffi/container/map.h | 1762 ----------------- ffi/include/tvm/ffi/container/shape.h | 247 --- ffi/include/tvm/ffi/container/tensor.h | 468 ----- ffi/include/tvm/ffi/container/tuple.h | 317 --- ffi/include/tvm/ffi/container/variant.h | 302 --- ffi/include/tvm/ffi/dtype.h | 192 -- ffi/include/tvm/ffi/endian.h | 89 - ffi/include/tvm/ffi/error.h | 335 ---- ffi/include/tvm/ffi/extra/base.h | 48 - ffi/include/tvm/ffi/extra/base64.h | 142 -- ffi/include/tvm/ffi/extra/c_env_api.h | 142 -- ffi/include/tvm/ffi/extra/json.h | 84 - ffi/include/tvm/ffi/extra/module.h | 262 --- ffi/include/tvm/ffi/extra/serialization.h | 72 - ffi/include/tvm/ffi/extra/structural_equal.h | 78 - ffi/include/tvm/ffi/extra/structural_hash.h | 57 - ffi/include/tvm/ffi/function.h | 880 -------- ffi/include/tvm/ffi/function_details.h | 210 -- ffi/include/tvm/ffi/memory.h | 229 --- ffi/include/tvm/ffi/object.h | 1142 ----------- ffi/include/tvm/ffi/optional.h | 419 ---- ffi/include/tvm/ffi/reflection/access_path.h | 440 ---- ffi/include/tvm/ffi/reflection/accessor.h | 260 --- ffi/include/tvm/ffi/reflection/creator.h | 120 -- ffi/include/tvm/ffi/reflection/registry.h | 564 ------ ffi/include/tvm/ffi/rvalue_ref.h | 155 -- ffi/include/tvm/ffi/string.h | 1014 ---------- ffi/include/tvm/ffi/type_traits.h | 781 -------- ffi/licenses/LICENSE.dlpack.txt | 201 -- ffi/licenses/LICENSE.libbacktrace.txt | 29 - ffi/licenses/LICENSE.pytorch.txt | 84 - ffi/licenses/NOTICE.pytorch.txt | 456 ----- ffi/pyproject.toml | 159 -- ffi/python/tvm_ffi/.gitignore | 2 - ffi/python/tvm_ffi/__init__.py | 73 - ffi/python/tvm_ffi/_convert.py | 65 - ffi/python/tvm_ffi/_dtype.py | 141 -- ffi/python/tvm_ffi/_ffi_api.py | 20 - .../tvm_ffi/_optional_torch_c_dlpack.py | 404 ---- ffi/python/tvm_ffi/_tensor.py | 88 - ffi/python/tvm_ffi/access_path.py | 181 -- ffi/python/tvm_ffi/base.py | 53 - ffi/python/tvm_ffi/config.py | 92 - ffi/python/tvm_ffi/container.py | 252 --- ffi/python/tvm_ffi/cpp/__init__.py | 18 - ffi/python/tvm_ffi/cpp/load_inline.py | 437 ---- ffi/python/tvm_ffi/cython/base.pxi | 393 ---- ffi/python/tvm_ffi/cython/core.pyx | 26 - ffi/python/tvm_ffi/cython/device.pxi | 191 -- ffi/python/tvm_ffi/cython/dtype.pxi | 116 -- ffi/python/tvm_ffi/cython/error.pxi | 134 -- ffi/python/tvm_ffi/cython/function.pxi | 853 -------- ffi/python/tvm_ffi/cython/object.pxi | 295 --- ffi/python/tvm_ffi/cython/string.pxi | 80 - ffi/python/tvm_ffi/cython/tensor.pxi | 362 ---- .../tvm_ffi/cython/tvm_ffi_python_helpers.h | 580 ------ ffi/python/tvm_ffi/error.py | 193 -- ffi/python/tvm_ffi/libinfo.py | 167 -- ffi/python/tvm_ffi/module.py | 275 --- ffi/python/tvm_ffi/registry.py | 226 --- ffi/python/tvm_ffi/serialization.py | 67 - ffi/python/tvm_ffi/testing.py | 63 - ffi/python/tvm_ffi/utils/__init__.py | 18 - ffi/python/tvm_ffi/utils/lockfile.py | 113 -- ffi/scripts/benchmark_dlpack.py | 448 ----- ffi/scripts/run_tests.sh | 27 - ffi/src/ffi/container.cc | 88 - ffi/src/ffi/dtype.cc | 328 --- ffi/src/ffi/error.cc | 81 - ffi/src/ffi/extra/buffer_stream.h | 127 -- ffi/src/ffi/extra/env_c_api.cc | 148 -- ffi/src/ffi/extra/env_context.cc | 120 -- ffi/src/ffi/extra/json_parser.cc | 731 ------- ffi/src/ffi/extra/json_writer.cc | 307 --- ffi/src/ffi/extra/library_module.cc | 199 -- .../ffi/extra/library_module_dynamic_lib.cc | 118 -- .../ffi/extra/library_module_system_lib.cc | 143 -- ffi/src/ffi/extra/module.cc | 157 -- ffi/src/ffi/extra/module_internal.h | 114 -- ffi/src/ffi/extra/reflection_extra.cc | 144 -- ffi/src/ffi/extra/serialization.cc | 430 ---- ffi/src/ffi/extra/structural_equal.cc | 439 ---- ffi/src/ffi/extra/structural_hash.cc | 317 --- ffi/src/ffi/extra/testing.cc | 133 -- ffi/src/ffi/function.cc | 229 --- ffi/src/ffi/object.cc | 513 ----- ffi/src/ffi/tensor.cc | 82 - ffi/src/ffi/traceback.cc | 188 -- ffi/src/ffi/traceback.h | 182 -- ffi/src/ffi/traceback_win.cc | 142 -- ffi/tests/cpp/CMakeLists.txt | 33 - ffi/tests/cpp/extra/test_json_parser.cc | 394 ---- ffi/tests/cpp/extra/test_json_writer.cc | 241 --- ffi/tests/cpp/extra/test_serialization.cc | 372 ---- .../cpp/extra/test_structural_equal_hash.cc | 178 -- ffi/tests/cpp/test_any.cc | 415 ---- ffi/tests/cpp/test_array.cc | 286 --- ffi/tests/cpp/test_c_ffi_abi.cc | 31 - ffi/tests/cpp/test_dtype.cc | 130 -- ffi/tests/cpp/test_error.cc | 70 - ffi/tests/cpp/test_example.cc | 288 --- ffi/tests/cpp/test_function.cc | 239 --- ffi/tests/cpp/test_map.cc | 366 ---- ffi/tests/cpp/test_object.cc | 258 --- ffi/tests/cpp/test_optional.cc | 202 -- ffi/tests/cpp/test_reflection.cc | 269 --- ffi/tests/cpp/test_rvalue_ref.cc | 97 - ffi/tests/cpp/test_shape.cc | 72 - ffi/tests/cpp/test_string.cc | 430 ---- ffi/tests/cpp/test_tensor.cc | 164 -- ffi/tests/cpp/test_tuple.cc | 168 -- ffi/tests/cpp/test_variant.cc | 164 -- ffi/tests/cpp/testing_object.h | 296 --- ffi/tests/python/test_access_path.py | 133 -- ffi/tests/python/test_container.py | 124 -- ffi/tests/python/test_device.py | 94 - ffi/tests/python/test_dtype.py | 85 - ffi/tests/python/test_error.py | 113 -- ffi/tests/python/test_examples.py | 47 - ffi/tests/python/test_function.py | 221 --- ffi/tests/python/test_load_inline.py | 324 --- ffi/tests/python/test_object.py | 91 - ffi/tests/python/test_string.py | 54 - ffi/tests/python/test_tensor.py | 68 - jvm/native/linux-x86_64/pom.xml | 2 +- jvm/native/osx-x86_64/pom.xml | 2 +- pyproject.toml | 2 +- python/tvm/libinfo.py | 7 +- python/tvm/relax/frontend/nn/extern.py | 18 +- tests/lint/cpplint.sh | 1 - tests/scripts/task_python_adreno.sh | 2 +- .../task_python_arm_compute_library.sh | 2 +- tests/scripts/task_python_docs.sh | 4 +- tests/scripts/task_python_hexagon.sh | 2 +- tests/scripts/task_python_integration.sh | 2 +- tests/scripts/task_python_nightly.sh | 2 +- tests/scripts/task_python_unittest.sh | 2 +- tests/scripts/task_web_wasm.sh | 2 +- tests/scripts/unity/task_python_relax.sh | 2 +- web/Makefile | 4 +- web/emcc/wasm_runtime.cc | 22 +- 194 files changed, 63 insertions(+), 37818 deletions(-) create mode 160000 3rdparty/tvm-ffi delete mode 100644 ffi/.clang-format delete mode 100644 ffi/CMakeLists.txt delete mode 100644 ffi/README.md delete mode 100644 ffi/cmake/Utils/AddGoogleTest.cmake delete mode 100644 ffi/cmake/Utils/AddLibbacktrace.cmake delete mode 100644 ffi/cmake/Utils/CxxWarning.cmake delete mode 100644 ffi/cmake/Utils/Library.cmake delete mode 100644 ffi/cmake/Utils/Sanitizer.cmake delete mode 100644 ffi/cmake/tvm_ffi-config.cmake delete mode 100644 ffi/docs/.gitignore delete mode 100644 ffi/docs/Makefile delete mode 100644 ffi/docs/README.md delete mode 100644 ffi/docs/concepts/abi_overview.md delete mode 100644 ffi/docs/conf.py delete mode 100644 ffi/docs/get_started/install.md delete mode 100644 ffi/docs/get_started/quick_start.md delete mode 100644 ffi/docs/guides/cpp_guide.md delete mode 100644 ffi/docs/guides/packaging.md delete mode 100644 ffi/docs/guides/python_guide.md delete mode 100644 ffi/docs/index.rst delete mode 100644 ffi/docs/reference/cpp/index.rst delete mode 100644 ffi/docs/reference/python/index.rst delete mode 100644 ffi/docs/requirements.txt delete mode 100644 ffi/examples/inline_module/main.py delete mode 100644 ffi/examples/packaging/CMakeLists.txt delete mode 100644 ffi/examples/packaging/README.md delete mode 100644 ffi/examples/packaging/pyproject.toml delete mode 100644 ffi/examples/packaging/python/my_ffi_extension/__init__.py delete mode 100644 ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py delete mode 100644 ffi/examples/packaging/python/my_ffi_extension/base.py delete mode 100644 ffi/examples/packaging/run_example.py delete mode 100644 ffi/examples/packaging/src/extension.cc delete mode 100644 ffi/examples/quick_start/CMakeLists.txt delete mode 100644 ffi/examples/quick_start/README.md delete mode 100644 ffi/examples/quick_start/run_example.py delete mode 100755 ffi/examples/quick_start/run_example.sh delete mode 100644 ffi/examples/quick_start/src/add_one_cpu.cc delete mode 100644 ffi/examples/quick_start/src/add_one_cuda.cu delete mode 100644 ffi/examples/quick_start/src/run_example.cc delete mode 100644 ffi/include/tvm/ffi/any.h delete mode 100644 ffi/include/tvm/ffi/base_details.h delete mode 100644 ffi/include/tvm/ffi/c_api.h delete mode 100644 ffi/include/tvm/ffi/cast.h delete mode 100644 ffi/include/tvm/ffi/container/array.h delete mode 100644 ffi/include/tvm/ffi/container/container_details.h delete mode 100644 ffi/include/tvm/ffi/container/map.h delete mode 100644 ffi/include/tvm/ffi/container/shape.h delete mode 100644 ffi/include/tvm/ffi/container/tensor.h delete mode 100644 ffi/include/tvm/ffi/container/tuple.h delete mode 100644 ffi/include/tvm/ffi/container/variant.h delete mode 100644 ffi/include/tvm/ffi/dtype.h delete mode 100644 ffi/include/tvm/ffi/endian.h delete mode 100644 ffi/include/tvm/ffi/error.h delete mode 100644 ffi/include/tvm/ffi/extra/base.h delete mode 100644 ffi/include/tvm/ffi/extra/base64.h delete mode 100644 ffi/include/tvm/ffi/extra/c_env_api.h delete mode 100644 ffi/include/tvm/ffi/extra/json.h delete mode 100644 ffi/include/tvm/ffi/extra/module.h delete mode 100644 ffi/include/tvm/ffi/extra/serialization.h delete mode 100644 ffi/include/tvm/ffi/extra/structural_equal.h delete mode 100644 ffi/include/tvm/ffi/extra/structural_hash.h delete mode 100644 ffi/include/tvm/ffi/function.h delete mode 100644 ffi/include/tvm/ffi/function_details.h delete mode 100644 ffi/include/tvm/ffi/memory.h delete mode 100644 ffi/include/tvm/ffi/object.h delete mode 100644 ffi/include/tvm/ffi/optional.h delete mode 100644 ffi/include/tvm/ffi/reflection/access_path.h delete mode 100644 ffi/include/tvm/ffi/reflection/accessor.h delete mode 100644 ffi/include/tvm/ffi/reflection/creator.h delete mode 100644 ffi/include/tvm/ffi/reflection/registry.h delete mode 100644 ffi/include/tvm/ffi/rvalue_ref.h delete mode 100644 ffi/include/tvm/ffi/string.h delete mode 100644 ffi/include/tvm/ffi/type_traits.h delete mode 100644 ffi/licenses/LICENSE.dlpack.txt delete mode 100644 ffi/licenses/LICENSE.libbacktrace.txt delete mode 100644 ffi/licenses/LICENSE.pytorch.txt delete mode 100644 ffi/licenses/NOTICE.pytorch.txt delete mode 100644 ffi/pyproject.toml delete mode 100644 ffi/python/tvm_ffi/.gitignore delete mode 100644 ffi/python/tvm_ffi/__init__.py delete mode 100644 ffi/python/tvm_ffi/_convert.py delete mode 100644 ffi/python/tvm_ffi/_dtype.py delete mode 100644 ffi/python/tvm_ffi/_ffi_api.py delete mode 100644 ffi/python/tvm_ffi/_optional_torch_c_dlpack.py delete mode 100644 ffi/python/tvm_ffi/_tensor.py delete mode 100644 ffi/python/tvm_ffi/access_path.py delete mode 100644 ffi/python/tvm_ffi/base.py delete mode 100644 ffi/python/tvm_ffi/config.py delete mode 100644 ffi/python/tvm_ffi/container.py delete mode 100644 ffi/python/tvm_ffi/cpp/__init__.py delete mode 100644 ffi/python/tvm_ffi/cpp/load_inline.py delete mode 100644 ffi/python/tvm_ffi/cython/base.pxi delete mode 100644 ffi/python/tvm_ffi/cython/core.pyx delete mode 100644 ffi/python/tvm_ffi/cython/device.pxi delete mode 100644 ffi/python/tvm_ffi/cython/dtype.pxi delete mode 100644 ffi/python/tvm_ffi/cython/error.pxi delete mode 100644 ffi/python/tvm_ffi/cython/function.pxi delete mode 100644 ffi/python/tvm_ffi/cython/object.pxi delete mode 100644 ffi/python/tvm_ffi/cython/string.pxi delete mode 100644 ffi/python/tvm_ffi/cython/tensor.pxi delete mode 100644 ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h delete mode 100644 ffi/python/tvm_ffi/error.py delete mode 100644 ffi/python/tvm_ffi/libinfo.py delete mode 100644 ffi/python/tvm_ffi/module.py delete mode 100644 ffi/python/tvm_ffi/registry.py delete mode 100644 ffi/python/tvm_ffi/serialization.py delete mode 100644 ffi/python/tvm_ffi/testing.py delete mode 100644 ffi/python/tvm_ffi/utils/__init__.py delete mode 100644 ffi/python/tvm_ffi/utils/lockfile.py delete mode 100644 ffi/scripts/benchmark_dlpack.py delete mode 100755 ffi/scripts/run_tests.sh delete mode 100644 ffi/src/ffi/container.cc delete mode 100644 ffi/src/ffi/dtype.cc delete mode 100644 ffi/src/ffi/error.cc delete mode 100644 ffi/src/ffi/extra/buffer_stream.h delete mode 100644 ffi/src/ffi/extra/env_c_api.cc delete mode 100644 ffi/src/ffi/extra/env_context.cc delete mode 100644 ffi/src/ffi/extra/json_parser.cc delete mode 100644 ffi/src/ffi/extra/json_writer.cc delete mode 100644 ffi/src/ffi/extra/library_module.cc delete mode 100644 ffi/src/ffi/extra/library_module_dynamic_lib.cc delete mode 100644 ffi/src/ffi/extra/library_module_system_lib.cc delete mode 100644 ffi/src/ffi/extra/module.cc delete mode 100644 ffi/src/ffi/extra/module_internal.h delete mode 100644 ffi/src/ffi/extra/reflection_extra.cc delete mode 100644 ffi/src/ffi/extra/serialization.cc delete mode 100644 ffi/src/ffi/extra/structural_equal.cc delete mode 100644 ffi/src/ffi/extra/structural_hash.cc delete mode 100644 ffi/src/ffi/extra/testing.cc delete mode 100644 ffi/src/ffi/function.cc delete mode 100644 ffi/src/ffi/object.cc delete mode 100644 ffi/src/ffi/tensor.cc delete mode 100644 ffi/src/ffi/traceback.cc delete mode 100644 ffi/src/ffi/traceback.h delete mode 100644 ffi/src/ffi/traceback_win.cc delete mode 100644 ffi/tests/cpp/CMakeLists.txt delete mode 100644 ffi/tests/cpp/extra/test_json_parser.cc delete mode 100644 ffi/tests/cpp/extra/test_json_writer.cc delete mode 100644 ffi/tests/cpp/extra/test_serialization.cc delete mode 100644 ffi/tests/cpp/extra/test_structural_equal_hash.cc delete mode 100644 ffi/tests/cpp/test_any.cc delete mode 100644 ffi/tests/cpp/test_array.cc delete mode 100644 ffi/tests/cpp/test_c_ffi_abi.cc delete mode 100644 ffi/tests/cpp/test_dtype.cc delete mode 100644 ffi/tests/cpp/test_error.cc delete mode 100644 ffi/tests/cpp/test_example.cc delete mode 100644 ffi/tests/cpp/test_function.cc delete mode 100644 ffi/tests/cpp/test_map.cc delete mode 100644 ffi/tests/cpp/test_object.cc delete mode 100644 ffi/tests/cpp/test_optional.cc delete mode 100644 ffi/tests/cpp/test_reflection.cc delete mode 100644 ffi/tests/cpp/test_rvalue_ref.cc delete mode 100644 ffi/tests/cpp/test_shape.cc delete mode 100644 ffi/tests/cpp/test_string.cc delete mode 100644 ffi/tests/cpp/test_tensor.cc delete mode 100644 ffi/tests/cpp/test_tuple.cc delete mode 100644 ffi/tests/cpp/test_variant.cc delete mode 100644 ffi/tests/cpp/testing_object.h delete mode 100644 ffi/tests/python/test_access_path.py delete mode 100644 ffi/tests/python/test_container.py delete mode 100644 ffi/tests/python/test_device.py delete mode 100644 ffi/tests/python/test_dtype.py delete mode 100644 ffi/tests/python/test_error.py delete mode 100644 ffi/tests/python/test_examples.py delete mode 100644 ffi/tests/python/test_function.py delete mode 100644 ffi/tests/python/test_load_inline.py delete mode 100644 ffi/tests/python/test_object.py delete mode 100644 ffi/tests/python/test_string.py delete mode 100644 ffi/tests/python/test_tensor.py diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index 88b388817913..77271319b252 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -39,4 +39,4 @@ runs: - name: Install tvm-ffi pip package shell: bash -l {0} run: | - pip install -v ./ffi + pip install -v ./3rdparty/tvm-ffi diff --git a/.gitmodules b/.gitmodules index 32a70d37ae21..6b14c3524f7e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,3 +28,6 @@ [submodule "ffi/3rdparty/dlpack"] path = ffi/3rdparty/dlpack url = https://github.com/dmlc/dlpack.git +[submodule "3rdparty/tvm-ffi"] + path = 3rdparty/tvm-ffi + url = https://github.com/apache/tvm-ffi diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi new file mode 160000 index 000000000000..3e07df45afbc --- /dev/null +++ b/3rdparty/tvm-ffi @@ -0,0 +1 @@ +Subproject commit 3e07df45afbc8ea968ef03c34d84dc348ba6dfb0 diff --git a/CMakeLists.txt b/CMakeLists.txt index b05e5e165765..5e5a61490d8d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -567,7 +567,7 @@ if(USE_IOS_RPC) add_subdirectory("apps/ios_rpc") endif() -add_subdirectory(ffi) +add_subdirectory(3rdparty/tvm-ffi) if(TVM_DEBUG_WITH_ABI_CHANGE) message(STATUS "Building with debug code that may cause ABI changes...") diff --git a/apps/android_rpc/app/src/main/jni/Android.mk b/apps/android_rpc/app/src/main/jni/Android.mk index 692a3390131d..d482f9429559 100644 --- a/apps/android_rpc/app/src/main/jni/Android.mk +++ b/apps/android_rpc/app/src/main/jni/Android.mk @@ -37,8 +37,8 @@ LOCAL_SRC_FILES := org_apache_tvm_native_c_api.cc LOCAL_LDFLAGS := -L$(SYSROOT)/usr/lib/ -llog LOCAL_C_INCLUDES := $(ROOT_PATH)/include \ - $(ROOT_PATH)/ffi/include \ - $(ROOT_PATH)/ffi/3rdparty/dlpack/include \ + $(ROOT_PATH)/3rdparty/tvm-ffi/include \ + $(ROOT_PATH)/3rdparty/tvm-ffi/3rdparty/dlpack/include \ $(ROOT_PATH)/3rdparty/dmlc-core/include \ $(ROOT_PATH)/3rdparty/OpenCL-Headers diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index b0cb033e8812..6bda78cef0db 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -34,18 +34,18 @@ #define TVM_LOG_CUSTOMIZE 1 #define TVM_FFI_USE_LIBBACKTRACE 0 -#include "../ffi/src/ffi/container.cc" -#include "../ffi/src/ffi/dtype.cc" -#include "../ffi/src/ffi/error.cc" -#include "../ffi/src/ffi/extra/library_module.cc" -#include "../ffi/src/ffi/extra/library_module_dynamic_lib.cc" -#include "../ffi/src/ffi/extra/library_module_system_lib.cc" -#include "../ffi/src/ffi/extra/module.cc" -#include "../ffi/src/ffi/extra/testing.cc" -#include "../ffi/src/ffi/function.cc" -#include "../ffi/src/ffi/object.cc" -#include "../ffi/src/ffi/tensor.cc" -#include "../ffi/src/ffi/traceback.cc" +#include "../3rdparty/tvm-ffi/src/ffi/container.cc" +#include "../3rdparty/tvm-ffi/src/ffi/dtype.cc" +#include "../3rdparty/tvm-ffi/src/ffi/error.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/library_module_dynamic_lib.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/module.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/testing.cc" +#include "../3rdparty/tvm-ffi/src/ffi/function.cc" +#include "../3rdparty/tvm-ffi/src/ffi/object.cc" +#include "../3rdparty/tvm-ffi/src/ffi/tensor.cc" +#include "../3rdparty/tvm-ffi/src/ffi/traceback.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/device_api.cc" #include "../src/runtime/file_utils.cc" diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 8831210242bd..5dfff0cd86b4 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -33,7 +33,7 @@ #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 // internal TVM header to achieve Library class -#include <../../../ffi/src/ffi/extra/library_module.h> +#include <../../../3rdparty/tvm-ffi/src/ffi/extra/library_module.h> #include #endif diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 2fc3a9e88b05..ee81f8477835 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -135,7 +135,7 @@ Therefore, after we finish the build, we need to install the tvm-ffi package. .. code-block:: bash - cd ffi; pip install .; cd .. + cd 3rdparty/tvm-ffi; pip install .; cd .. Leaving the build environment ``tvm-build-venv``, there are two ways to install the successful build into your environment: diff --git a/ffi/.clang-format b/ffi/.clang-format deleted file mode 100644 index 9d622b98ba06..000000000000 --- a/ffi/.clang-format +++ /dev/null @@ -1,8 +0,0 @@ -# Run the following command to reformat a file: -# clang-format -i -style=Google -# Or use clang-format-diff to only reformat the changed lines: -# https://clang.llvm.org/docs/ClangFormat.html -BasedOnStyle: Google -DerivePointerAlignment: false -ColumnLimit: 100 -PointerAlignment: Left diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt deleted file mode 100644 index 2767669bce24..000000000000 --- a/ffi/CMakeLists.txt +++ /dev/null @@ -1,262 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -cmake_minimum_required(VERSION 3.18) - -project( - tvm_ffi - LANGUAGES CXX C -) - -option(TVM_FFI_USE_LIBBACKTRACE "Enable libbacktrace" ON) -option(TVM_FFI_USE_EXTRA_CXX_API "Enable extra CXX API in shared lib" ON) -option(TVM_FFI_BACKTRACE_ON_SEGFAULT "Set signal handler to print traceback on segfault" ON) - -if (TVM_FFI_USE_LIBBACKTRACE) - include(${CMAKE_CURRENT_LIST_DIR}/cmake/Utils/AddLibbacktrace.cmake) -endif() - -include(${CMAKE_CURRENT_LIST_DIR}/cmake/Utils/Library.cmake) - - -########## Target: `tvm_ffi_header` ########## - -# they can be used in cases where user do not want to link into the library -# in cases like deferred linking -add_library(tvm_ffi_header INTERFACE) -target_compile_features(tvm_ffi_header INTERFACE cxx_std_17) -target_include_directories( - tvm_ffi_header INTERFACE - $ - $ -) -target_include_directories( - tvm_ffi_header INTERFACE - $ - $ -) - -########## Target: `tvm_ffi_objs` ########## - -set(tvm_ffi_objs_sources - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback_win.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/error.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/function.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/tensor.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" -) - -set(tvm_ffi_extra_objs_sources - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_equal.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/module.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_context.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_c_api.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/testing.cc" -) -if (TVM_FFI_USE_EXTRA_CXX_API) - list(APPEND tvm_ffi_objs_sources ${tvm_ffi_extra_objs_sources}) -endif() - -add_library(tvm_ffi_objs OBJECT ${tvm_ffi_objs_sources}) -target_compile_features(tvm_ffi_objs PRIVATE cxx_std_17) - -set_target_properties( - tvm_ffi_objs PROPERTIES - POSITION_INDEPENDENT_CODE ON - CXX_EXTENSIONS OFF - CXX_STANDARD_REQUIRED ON - CXX_VISIBILITY_PRESET hidden - VISIBILITY_INLINES_HIDDEN ON - PREFIX "lib" -) - -# add the include path as public so they are visible to downstreams -target_link_libraries(tvm_ffi_objs PUBLIC tvm_ffi_header) - -if (TVM_FFI_USE_LIBBACKTRACE) - message(STATUS "Setting C++ macro TVM_FFI_USE_LIBBACKTRACE - 1") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_USE_LIBBACKTRACE=1) -else() - message(STATUS "Setting C++ macro TVM_FFI_USE_LIBBACKTRACE - 0") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_USE_LIBBACKTRACE=0) -endif() - -if (TVM_FFI_BACKTRACE_ON_SEGFAULT) - message(STATUS "Setting C++ macro TVM_FFI_BACKTRACE_ON_SEGFAULT - 1") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_BACKTRACE_ON_SEGFAULT=1) -else() - message(STATUS "Setting C++ macro TVM_FFI_BACKTRACE_ON_SEGFAULT - 0") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_BACKTRACE_ON_SEGFAULT=0) -endif() - -tvm_ffi_add_msvc_flags(tvm_ffi_objs) -tvm_ffi_add_target_from_obj(tvm_ffi tvm_ffi_objs) - -if (TARGET libbacktrace) - target_link_libraries(tvm_ffi_objs PRIVATE libbacktrace) - target_link_libraries(tvm_ffi_shared PRIVATE libbacktrace) - target_link_libraries(tvm_ffi_static PRIVATE libbacktrace) -endif () - -if (MSVC) - target_link_libraries(tvm_ffi_objs PRIVATE DbgHelp.lib) - target_link_libraries(tvm_ffi_shared PRIVATE DbgHelp.lib) - target_link_libraries(tvm_ffi_static PRIVATE DbgHelp.lib) - # produce pdb file - target_link_options(tvm_ffi_shared PRIVATE /DEBUG) -endif () - -# expose the headers as public dependencies -target_link_libraries(tvm_ffi_objs PUBLIC tvm_ffi_header) -target_link_libraries(tvm_ffi_shared PUBLIC tvm_ffi_header) -target_link_libraries(tvm_ffi_static PUBLIC tvm_ffi_header) - -#---------------------------------------------------------------------------- -# The following code section only is triggered when the project is the root -# and will be skipped when the project is a subproject. -#---------------------------------------------------------------------------- -if (NOT ${PROJECT_NAME} STREQUAL ${CMAKE_PROJECT_NAME}) - return() -endif() - -option(TVM_FFI_ATTACH_DEBUG_SYMBOLS "Attach debug symbols even in release mode" OFF) -option(TVM_FFI_BUILD_TESTS "Adding test targets." OFF) - -if (TVM_FFI_ATTACH_DEBUG_SYMBOLS) - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") - target_compile_options(tvm_ffi_objs PRIVATE -g1) - endif() -endif() - -include(cmake/Utils/CxxWarning.cmake) -include(cmake/Utils/Sanitizer.cmake) - -# remap the file name to the source directory so we can see the -# exact file name in traceback relative to the project source root -tvm_ffi_add_prefix_map(tvm_ffi_objs ${CMAKE_SOURCE_DIR}) - -########## Adding cpp tests ########## - -# logics below are only executed when the project is the root project. -# but not when the project is a subproject. -if (TVM_FFI_BUILD_TESTS) - enable_testing() - message(STATUS "Enable Testing") - include(cmake/Utils/AddGoogleTest.cmake) - add_subdirectory(tests/cpp/) - tvm_ffi_add_cxx_warning(tvm_ffi_objs) -endif() - -########## Adding python module ########## -option(TVM_FFI_BUILD_PYTHON_MODULE "Adding python module." OFF) - -if (TVM_FFI_BUILD_PYTHON_MODULE) - # Helper function to build the cython module - message(STATUS "Building cython module..") - find_package( - Python COMPONENTS Interpreter Development.Module Development.SABIModule - REQUIRED) - set(core_cpp ${CMAKE_CURRENT_BINARY_DIR}/core.cpp) - set(core_pyx ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/core.pyx) - set(cython_sources - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/core.pyx - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/base.pxi - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/device.pxi - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/dtype.pxi - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/error.pxi - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/function.pxi - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tensor.pxi - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/object.pxi - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/string.pxi - ) - # set working directory to source so we can see the exact file name in traceback - # relatived to the project source root - add_custom_command( - OUTPUT ${core_cpp} - COMMAND ${Python_EXECUTABLE} -m cython --cplus ${core_pyx} -o ${core_cpp} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - COMMENT "Transpiling ${core_pyx} to ${core_cpp}" - DEPENDS ${cython_sources} - VERBATIM - ) - if(Python_VERSION VERSION_GREATER_EQUAL "3.12") - # >= Python3.12, use Use_SABI version - Python_add_library(tvm_ffi_cython MODULE "${core_cpp}" USE_SABI 3.12) - set_target_properties(tvm_ffi_cython PROPERTIES OUTPUT_NAME "core") - if(NOT WIN32) - set_target_properties(tvm_ffi_cython PROPERTIES SUFFIX ".abi3.so") - endif() - else() - # before Python3.12, use WITH_SOABI version - Python_add_library(tvm_ffi_cython MODULE "${core_cpp}" WITH_SOABI) - set_target_properties(tvm_ffi_cython PROPERTIES OUTPUT_NAME "core") - endif() - target_include_directories(tvm_ffi_cython PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython) - target_compile_features(tvm_ffi_cython PRIVATE cxx_std_17) - target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_header) - target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_shared) - # Set RPATH for tvm_ffi_cython to find tvm_ffi_shared.so relatively - if(APPLE) - # macOS uses @loader_path - set_target_properties(tvm_ffi_cython PROPERTIES INSTALL_RPATH "@loader_path/lib") - elseif(LINUX) - # Linux uses $ORIGIN - set_target_properties(tvm_ffi_cython PROPERTIES INSTALL_RPATH "\$ORIGIN/lib") - endif() - install(TARGETS tvm_ffi_cython DESTINATION .) - - ########## Installing the source ########## - install( - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include DESTINATION 3rdparty/dlpack/include - ) - install( - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/libbacktrace DESTINATION 3rdparty/libbacktrace - PATTERN ".git" EXCLUDE - PATTERN ".git*" EXCLUDE - PATTERN "*.tmp" EXCLUDE - ) - install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/ DESTINATION src/ffi/) - install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/cmake/Utils/ DESTINATION cmake/Utils) - install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt DESTINATION .) - install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/cmake/tvm_ffi-config.cmake DESTINATION cmake) -endif() - -########## Install the related for normal cmake library ########## - -install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/tvm/ffi/ DESTINATION include/tvm/ffi/) -install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include/ DESTINATION include/) -install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h DESTINATION include/) -install(TARGETS tvm_ffi_shared DESTINATION lib) -# ship additional dSYM files for debugging symbols on if available -if (APPLE) - install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib/ DESTINATION lib FILES_MATCHING PATTERN "*.dSYM") -endif() - -if (NOT TVM_FFI_BUILD_PYTHON_MODULE) - # when building wheel, we do not ship static as we already ships source and dll - install(TARGETS tvm_ffi_static DESTINATION lib) -endif() diff --git a/ffi/README.md b/ffi/README.md deleted file mode 100644 index 3b1b1199c209..000000000000 --- a/ffi/README.md +++ /dev/null @@ -1,18 +0,0 @@ - - - - - - - - - - - - - - - - - -# tvm ffi diff --git a/ffi/cmake/Utils/AddGoogleTest.cmake b/ffi/cmake/Utils/AddGoogleTest.cmake deleted file mode 100644 index af841752c677..000000000000 --- a/ffi/cmake/Utils/AddGoogleTest.cmake +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -include(FetchContent) -set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE) -set(BUILD_GMOCK ON CACHE BOOL "" FORCE) -set(BUILD_GTEST ON CACHE BOOL "" FORCE) -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG v1.14.0 -) -FetchContent_GetProperties(googletest) -if (NOT googletest_POPULATED) - FetchContent_MakeAvailable(googletest) - include(GoogleTest) - set_target_properties(gtest PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - set_target_properties(gtest_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - set_target_properties(gmock PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - set_target_properties(gmock_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - mark_as_advanced( - BUILD_GMOCK BUILD_GTEST BUILD_SHARED_LIBS - gmock_build_tests gtest_build_samples gtest_build_tests - gtest_disable_pthreads gtest_force_shared_crt gtest_hide_internal_symbols - ) -endif() - -macro(tvm_ffi_add_googletest target_name) - add_test( - NAME ${target_name} - COMMAND ${target_name} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - ) - target_link_libraries(${target_name} PRIVATE gtest_main) - gtest_discover_tests(${target_name} - WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} - DISCOVERY_MODE PRE_TEST - PROPERTIES - VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" - ) - set_target_properties(${target_name} PROPERTIES FOLDER tests) -endmacro() diff --git a/ffi/cmake/Utils/AddLibbacktrace.cmake b/ffi/cmake/Utils/AddLibbacktrace.cmake deleted file mode 100644 index e920a1f1991a..000000000000 --- a/ffi/cmake/Utils/AddLibbacktrace.cmake +++ /dev/null @@ -1,68 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -include(ExternalProject) - -function(_libbacktrace_compile) - set(_libbacktrace_source ${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace) - set(_libbacktrace_prefix ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace) - if(CMAKE_SYSTEM_NAME MATCHES "Darwin" AND (CMAKE_C_COMPILER MATCHES "^/Library" OR CMAKE_C_COMPILER MATCHES "^/Applications")) - set(_cmake_c_compiler "/usr/bin/cc") - else() - set(_cmake_c_compiler "${CMAKE_C_COMPILER}") - endif() - - message(STATUS CMAKC_C_COMPILER="${CMAKE_C_COMPILER}") - - file(MAKE_DIRECTORY ${_libbacktrace_prefix}/include) - file(MAKE_DIRECTORY ${_libbacktrace_prefix}/lib) - - ExternalProject_Add(project_libbacktrace - PREFIX libbacktrace - SOURCE_DIR ${_libbacktrace_source} - BINARY_DIR ${_libbacktrace_prefix} - CONFIGURE_COMMAND - "sh" - "${_libbacktrace_source}/configure" - "--prefix=${_libbacktrace_prefix}" - --with-pic - "CC=${_cmake_c_compiler}" - "CPP=${_cmake_c_compiler} -E" - "CFLAGS=${CMAKE_C_FLAGS}" - "LDFLAGS=${CMAKE_EXE_LINKER_FLAGS}" - "NM=${CMAKE_NM}" - "STRIP=${CMAKE_STRIP}" - "--host=${MACHINE_NAME}" - INSTALL_DIR ${_libbacktrace_prefix} - BUILD_COMMAND make - INSTALL_COMMAND make install - BUILD_BYPRODUCTS "${_libbacktrace_prefix}/lib/libbacktrace.a" - "${_libbacktrace_prefix}/include/backtrace.h" - ) - ExternalProject_Add_Step(project_libbacktrace checkout DEPENDERS configure DEPENDEES download) - set_target_properties(project_libbacktrace PROPERTIES EXCLUDE_FROM_ALL TRUE) - add_library(libbacktrace STATIC IMPORTED) - add_dependencies(libbacktrace project_libbacktrace) - set_target_properties(libbacktrace PROPERTIES - IMPORTED_LOCATION ${_libbacktrace_prefix}/lib/libbacktrace.a - INTERFACE_INCLUDE_DIRECTORIES ${_libbacktrace_prefix}/include - ) -endfunction() - -if(NOT MSVC) - _libbacktrace_compile() -endif() diff --git a/ffi/cmake/Utils/CxxWarning.cmake b/ffi/cmake/Utils/CxxWarning.cmake deleted file mode 100644 index a85e58825b9e..000000000000 --- a/ffi/cmake/Utils/CxxWarning.cmake +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -function(tvm_ffi_add_cxx_warning target_name) - # GNU, Clang, or AppleClang - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") - target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic" "-Wno-unused-parameter") - return() - endif() - # MSVC - if(MSVC) - # target_compile_options(${target_name} PRIVATE "/W4" "/WX") - return() - endif() - message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") -endfunction() diff --git a/ffi/cmake/Utils/Library.cmake b/ffi/cmake/Utils/Library.cmake deleted file mode 100644 index 611f972dcecd..000000000000 --- a/ffi/cmake/Utils/Library.cmake +++ /dev/null @@ -1,88 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -function(tvm_ffi_add_prefix_map target_name prefix_path) - # Add prefix map so the path displayed becomes relative to prefix_path - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") - target_compile_options(${target_name} PRIVATE "-ffile-prefix-map=${prefix_path}/=") - endif() -endfunction() - -function(tvm_ffi_add_apple_dsymutil target_name) - # running dsymutil on macos to generate debugging symbols for backtraces - if(APPLE AND TVM_FFI_USE_LIBBACKTRACE) - find_program(DSYMUTIL dsymutil) - mark_as_advanced(DSYMUTIL) - add_custom_command(TARGET ${target_name} - POST_BUILD - COMMAND ${DSYMUTIL} ARGS $ - COMMENT "[COMMAND] dsymutil $" - VERBATIM - ) - endif() -endfunction() - -function(tvm_ffi_add_msvc_flags target_name) - # running if we are under msvc - if(MSVC) - target_compile_definitions(${target_name} PUBLIC -DWIN32_LEAN_AND_MEAN) - target_compile_definitions(${target_name} PUBLIC -D_CRT_SECURE_NO_WARNINGS) - target_compile_definitions(${target_name} PUBLIC -D_SCL_SECURE_NO_WARNINGS) - target_compile_definitions(${target_name} PUBLIC -D_ENABLE_EXTENDED_ALIGNED_STORAGE) - target_compile_definitions(${target_name} PUBLIC -DNOMINMAX) - target_compile_options(${target_name} PRIVATE "/Zi") - endif() -endfunction() - -function(tvm_ffi_add_target_from_obj target_name obj_target_name) - add_library(${target_name}_static STATIC $) - set_target_properties( - ${target_name}_static PROPERTIES - OUTPUT_NAME "${target_name}_static" - ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - ) - add_library(${target_name}_shared SHARED $) - set_target_properties( - ${target_name}_shared PROPERTIES - OUTPUT_NAME "${target_name}" - ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - ) - if (WIN32) - target_compile_definitions(${obj_target_name} PRIVATE TVM_FFI_EXPORTS) - # set the output directory for each config type so msbuild also get into lib - # without appending the config type to the output directory - # do both Release and RELEASE suffix, since while cmake docs suggest Release is ok. - # real runs on MSbuild suggest that we might need RELEASE instead - foreach(CONFIG_TYPE Release RELEASE) - set_target_properties(${target_name}_shared PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${CONFIG_TYPE} "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY_${CONFIG_TYPE} "${CMAKE_BINARY_DIR}/lib" - ARCHIVE_OUTPUT_DIRECTORY_${CONFIG_TYPE} "${CMAKE_BINARY_DIR}/lib" - ) - set_target_properties(${target_name}_static PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${CONFIG_TYPE} "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY_${CONFIG_TYPE} "${CMAKE_BINARY_DIR}/lib" - ARCHIVE_OUTPUT_DIRECTORY_${CONFIG_TYPE} "${CMAKE_BINARY_DIR}/lib" - ) - endforeach() - endif() - tvm_ffi_add_apple_dsymutil(${target_name}_shared) -endfunction() diff --git a/ffi/cmake/Utils/Sanitizer.cmake b/ffi/cmake/Utils/Sanitizer.cmake deleted file mode 100644 index a20eead0c869..000000000000 --- a/ffi/cmake/Utils/Sanitizer.cmake +++ /dev/null @@ -1,35 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -function(add_sanitizer_address target_name) - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") - include(CheckCXXCompilerFlag) - set (_saved_CRF ${CMAKE_REQUIRED_FLAGS}) - set(CMAKE_REQUIRED_FLAGS "-fsanitize=address") - check_cxx_source_compiles("int main() { return 0; }" COMPILER_SUPPORTS_ASAN) - set (CMAKE_REQUIRED_FLAGS ${_saved_CRF}) - get_target_property(_saved_type ${target_name} TYPE) - if (${_saved_type} STREQUAL "INTERFACE_LIBRARY") - set(_saved_type INTERFACE) - else() - set(_saved_type PRIVATE) - endif() - target_link_options(${target_name} ${_saved_type} "-fsanitize=address") - target_compile_options(${target_name} ${_saved_type} "-fsanitize=address" "-fno-omit-frame-pointer" "-g") - return() - endif() -endfunction() diff --git a/ffi/cmake/tvm_ffi-config.cmake b/ffi/cmake/tvm_ffi-config.cmake deleted file mode 100644 index 01f60ca10bff..000000000000 --- a/ffi/cmake/tvm_ffi-config.cmake +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -find_package(Python COMPONENTS Interpreter REQUIRED) - -# call tvm_ffi.config to get the cmake directory and set it to tvm_ffi_ROOT -execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --includedir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_INCLUDE_DIR) - -execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --dlpack-includedir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_DLPACK_INCLUDE_DIR) - -execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --libfiles - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_LIB_FILES) - -message(STATUS "Finding libfiles ${tvm_ffi_LIB_FILES}") - -add_library(tvm_ffi_header INTERFACE) -target_compile_features(tvm_ffi_header INTERFACE cxx_std_17) -target_include_directories(tvm_ffi_header INTERFACE "${tvm_ffi_INCLUDE_DIR}") -target_include_directories(tvm_ffi_header INTERFACE "${tvm_ffi_DLPACK_INCLUDE_DIR}") - -add_library(tvm_ffi_shared SHARED IMPORTED) -target_compile_features(tvm_ffi_shared INTERFACE cxx_std_17) - -if(WIN32) - set_target_properties( - tvm_ffi_shared PROPERTIES IMPORTED_IMPLIB "${tvm_ffi_LIB_FILES}" - ) -else() - set_target_properties( - tvm_ffi_shared PROPERTIES IMPORTED_LOCATION "${tvm_ffi_LIB_FILES}" - ) -endif() - -set_target_properties( - tvm_ffi_shared PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "${tvm_ffi_INCLUDE_DIR};${tvm_ffi_DLPACK_INCLUDE_DIR}" -) -# extra cmake functions -include(${CMAKE_CURRENT_LIST_DIR}/Utils/Library.cmake) diff --git a/ffi/docs/.gitignore b/ffi/docs/.gitignore deleted file mode 100644 index d7ab85b91f9e..000000000000 --- a/ffi/docs/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -_build -**/generated/* diff --git a/ffi/docs/Makefile b/ffi/docs/Makefile deleted file mode 100644 index 51e4de21d31d..000000000000 --- a/ffi/docs/Makefile +++ /dev/null @@ -1,41 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= python3 -m sphinx -SOURCEDIR = . -BUILDDIR = _build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile livehtml clean - -livehtml: - @sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) --ignore reference/cpp/generated - -clean: - rm -rf $(BUILDDIR) - rm -rf reference/python/generated - rm -rf reference/cpp/generated - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/ffi/docs/README.md b/ffi/docs/README.md deleted file mode 100644 index 39fff194df4f..000000000000 --- a/ffi/docs/README.md +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - - - - - - - - - -# TVM FFI Documentation - -To build locally - -First install the tvm-ffi package -```bash -pip install .. -``` - -Install all the requirements to build docs - -```bash -pip install -r requirements.txt -``` - -Then build the doc -```bash -make livehtml -``` - -## Build with C++ Docs - -To build with C++ docs, we need to first install Doxygen. Then -set the environment variable `BUILD_CPP_DOCS=1`, to turn on c++ docs. - -```bash -BUILD_CPP_DOCS=1 make livehtml -``` - -Building c++ docs can take longer, so it is not on by default. diff --git a/ffi/docs/concepts/abi_overview.md b/ffi/docs/concepts/abi_overview.md deleted file mode 100644 index 118257896424..000000000000 --- a/ffi/docs/concepts/abi_overview.md +++ /dev/null @@ -1,430 +0,0 @@ - - - - - - - - - - - - - - - - -# ABI Overview - -This section provides an overview of the ABI convention of TVM FFI. The ABI -is designed around the following key principles: - -- **Stable C ABI:** Core ABI is defined on top of a stable C ABI. -- **Minimal and efficient:** Keep things simple when possible and bring close-to-metal efficiency. -- **Focus on machine learning systems:** while also ensuring reasonable extensibility. - -To explain the concepts in the following sections, we will write in **low-level C/C++ code** when possible, -so the code itself illustrates the low-level semantics of how to work with the ABI convention. -These can serve as references for how to build language bindings and compiler codegen for the ABI. - -```{note} -The authoritative ABI specifications are defined in [tvm/ffi/c_api.h](https://github.com/apache/tvm/blob/main/ffi/include/tvm/ffi/c_api.h) for core ABI, -and [tvm/ffi/extra/c_env_api.h](https://github.com/apache/tvm/blob/main/ffi/include/tvm/ffi/extra/c_env_api.h) for extra support features -such as stream handling. This document provides explanations about design concepts and rationales. -``` - -## Simplified Example - -Before diving into details, it is helpful to review at a high level -what happens when a function is called in TVM FFI ABI. -One main design goal here is to represent all kinds of functions in a single -unified C signature. Please review the following -simplified code example that illustrates the key idea: - -```c++ -// simplified struct for TVMFFIAny -typedef struct TVMFFIAny { - int32_t type_index; - uint32_t zero_padding; - // union values - union { - int64_t v_int64; // integers - double v_float64; // floating-point numbers - const char* v_c_str; // raw C-string - }; -}; - -// This is the signature of TVM FFI function ABI -typedef int (*TVMFFISafeCallType)( - void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result -); - -// An example function signature -int MyFunc(const char* param0, int param1); - -// This is what MyFunc looks like when exposed through TVM FFI ABI -int MyFuncTVMFFISafeCall( - void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result -) { - assert(args[0].type_index == kTVMFFIRawStr); - assert(args[1].type_index == kTVMFFInt); - result->type_index = kTVMFFInt; - result->v_int64 = MyFunc(args[0].v_c_str, args[1].v_int64); - // return value indicates no error occurred - return 0; -} - -// This is how we call the MyFuncTVMFFISafeCall -// this can happen on the caller side in another language (e.g. python) -int CallTVMFFISafeCall(const char* param0, int param1) { - // arguments on stack - TVMFFIAny args[2], result; - args[0].type_index = kTVMFFIRawStr; - args[0].v_c_str = param0; - args[1].type_index = kTVMFFInt; - args[1].v_int64 = param1; - result.type_index = kTVMFFINone; - // In this case we do not need handle - // handle is used to hold closure pointers - void* handle = nullptr; - int num_args = 2; - MyFuncTVMFFISafeCall(handle, args, num_args, &result); - return result.v_int64; -} -``` - -At a high level, the `TVMFFISafeCallType` signature does the following things: -- Arguments and return values are stored in structured `TVMFFIAny` - - Each value comes with a `type_index` to indicate its type - - Values are stored in union fields, depending on the specific type. -- Caller can explicitly store the type index and value into - a stack of `TVMFFIAny`. -- Callee can load the parameters from args and check their type indices. - -In this way, the same `TVMFFISafeCallType` can be used to represent any function -that contains an arbitrary number of arguments and types that can be identified by `type_index`. -Of course, this is a simplified example and we did not touch on specific details -like Any value format and error handling. The following sections will provide a more systematic -treatment of each of these specific topics. -You can keep this example in mind as the overall picture and refine it as you read through -the following sections. - - -## TVMFFIAny Storage Format - -To start with, we need a mechanism to store the values that are passed across machine learning frameworks. -It achieves this using a core data structure called TVMFFIAny. - -```c++ -typedef struct TVMFFIAny { - int32_t type_index; - union { // 4 bytes - uint32_t zero_padding; - uint32_t small_str_len; - }; - // union values - union { - int64_t v_int64; // integers - double v_float64; // floating-point numbers - void* v_ptr; // typeless pointers - const char* v_c_str; // raw C-string - TVMFFIObject* v_obj; // ref counted objects - DLDataType v_dtype; // data type - DLDevice v_device; // device - char v_bytes[8]; // small string - ... - }; -} TVMFFIAny; -``` - -TVMFFIAny is a 16-byte C structure that follows the design principle of tagged-union: - -- `type_index` helps us identify the type being stored. -- The value union part is designed to store the value: - - Small POD values (like integers and floats) are stored directly as "on-stack" values. - - `v_obj` can also point to a managed heap-allocated object, which we will discuss next. -- The second field stores metadata for small strings. - - -### Storing a POD Value - -There are many values that are plain-old-data types. In such cases, we store them directly -on-stack in the value part of the TVMFFIAny. The following example shows how to store -an int. - -```c++ -void SetIntValue(TVMFFIAny* any, int value) { - // must zero the entire space first - any->type_index = kTVMFFIInt; - any->zero_padding = 0; - any->v_int64 = value; -} -``` - -:::{note} - -We **must zero the content that is not being used** by -the current value type. The following example shows a common place -where mistakes can be made when we forget to zero the value field -on 32-bit platforms (where pointers only fill the 32-bit part of the value). - -```c++ -void SetOpaquePtrValue(TVMFFIAny* any, void* opaque_ptr) { - any->type_index = kTVMFFIOpaquePtr; - // must zero the padding - any->zero_padding = 0; - // the zeroing is needed for 32-bit platforms! - any->v_uint64 = 0; - any->v_ptr = opaque_ptr; -} -``` - -**Rationale:** Such invariants allow us to directly compare -and hash TVMFFIAny in bytes for quick equality checks without going through -type index switching. -::: - - -## Object Storage Format - -When TVMFFIAny points to a heap-allocated object (such as n-dimensional arrays), -we adopt a unified object storage format, defined as follows: - -```c++ -typedef struct TVMFFIObject { - int32_t type_index; - uint32_t weak_ref_count; - uint64_t strong_ref_count; - union { - void (*deleter)(struct TVMFFIObject* self, int flags); - int64_t __ensure_align; - }; -} TVMFFIObject; -``` - -`TVMFFIObject` defines a common 24-byte intrusive header that all in-memory objects share: - -- `type_index` helps us identify the type being stored, which is consistent with `TVMFFIAny.type_index`. -- `weak_ref_count` stores the weak atomic reference counter of the object. -- `strong_ref_count` stores the strong atomic reference counter of the object. -- `deleter` should be called when either the strong or weak ref counter goes to zero. - - The flags are set to indicate the event of either weak or strong going to zero, or both. - - When `strong_ref_count` gets to zero, the deleter needs to call the destructor of the object. - - When `weak_ref_count` gets to zero, the deleter needs to free the memory allocated by self. - -**Rationales:** There are several considerations when designing the data structure: -- `type_index` enables runtime dynamic type checking and casting. -- We introduce weak/strong ref counters so we can be compatible with systems that need weak pointers. -- The weak ref counter is kept as 32-bit so we can pack the object header as 24 bytes. -- `deleter` ensures that objects allocated from one language/runtime can be safely deleted in another. - -The object format provides a unified way to manage object life-cycle and dynamic type casting -for heap-allocated objects, including Shape, Tensor, -Function, Array, Map and other custom objects. - - -### DLPack Compatible Tensor - -We provide first-class support for DLPack raw unmanaged pointer support as well as a managed Tensor object that -directly adopts the DLPack DLTensor layout. The overall layout of the Tensor object is as follows: - -```c++ -struct TensorObj: public ffi::Object, public DLTensor { -}; -``` - -That means we can read out the array buffer information from an `TVMFFIAny` -in the following way: - -```c++ -DLTensor* ReadDLTensorPtr(const TVMFFIAny *value) { - if (value->type_index == kTVMFFIDLTensorPtr) { - return static_cast(value->v_ptr); - } - assert(value->type_index == kTVMFFITensor); - return reinterpret_cast( - reinterpret_cast(value->v_obj) + sizeof(TVMFFIObject)); -} -``` -The above code can be used as a reference to implement compiler codegen for data. -Note that the C++ API automatically handles such conversion. - -### Advanced: Dynamic Type Index - -The `TVMFFITypeIndex` defines a set of type indices. Each built-in type has a corresponding statically -assigned type index that is defined in the enum. Static type indices should be sufficient for most -library use cases. -For advanced use cases we also support user-defined objects whose `type_index` are assigned at startup time -by calling `TVMFFITypeGetOrAllocIndex` with a unique -`type_key` string. This design allows us to enable decentralized extension of the objects as long as the `type_key` -values are unique by appending namespace prefix to the key. - -## AnyView and Managed Any - -An `TVMFFIAny` can either be treated as a strongly managed value (corresponding to `ffi::Any` in C++), -or an unmanaged value (corresponding to `ffi::AnyView` in C++). -- For POD types, there is no difference between the two -- For object types, copying of AnyView should not change reference counters, while copying and deletion - of managed Any should result in increase and decrease of strong reference counters. -- When we convert AnyView to Any, we will convert raw C string `const char*` and `const TVMFFIByteArray*` - into their managed counterparts (String and Bytes). -- C API function `TVMFFIAnyViewToOwnedAny` is provided to perform such conversion. - -Unless the user is writing a compiler backend that needs low-level C style access, we encourage use of the -C++ API to automatically manage conversion and casting between normal types and Any. The following code -shows some example usage of the C++ API. - -```c++ -#include - -void AnyExample() { - namespace ffi = tvm::ffi; - // Here is a managed any - ffi::Any value = "hello world"; - // explicit cast to a specific type - ffi::String str_value = value.cast(); - // copy int to value - value = 1; - // copy into a view - ffi::AnyView view = value; - // cast view back to int - std::cout << "Value is " << view.cast() << std::endl; -} -``` - -`ffi::Any` can serve as a container type to hold managed values that can be recognized by the TVM FFI system. -They can be composed with container structures such as `Map`, `Array` to represent various -broad patterns in APIs that may appear in ML systems. - -## Function Calling Convention - -As discussed in the overview, we need to consider foreign function calls as first-class citizens. We adopt a single standard C function as follows: - -```c++ -typedef int (*TVMFFISafeCallType)( - void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result -); -``` - -The handle contains the pointer to the function object itself, allowing us to support closures. args and num_args describe the input arguments and results store the return value. When args and results contain heap-managed objects, we expect the caller to own args and result. - -```{note} -Before calling the function, caller must set `result->type_index` to be kTVMFFINone, or any type index that do not corresponds -to an on-heap object. - -**Rationale:** Simplifies callee implementation as initial state of result can be viewed as managed Any. -``` - -We call this approach a packed function, as it provides a single signature to represent all functions in a "type-erased" way. It saves the need to declare and jit shim for each FFI function call while maintaining reasonable efficiency. This mechanism enables the following scenarios: -- Calling from Dynamic Languages (e.g., Python): we provide a tvm_ffi binding that prepares the args based on dynamically examining Python arguments passed in. -- Calling from Static Languages (e.g., C++): For static languages, we can leverage C++ templates to directly instantiate the arguments on the stack, saving the need for dynamic examination. -- Dynamic language Callbacks: the signature enables us to easily bring dynamic language (Python) callbacks as ffi::Function, as we can take each argument and convert to the dynamic values. -- Efficiency: In practice, we find this approach is sufficient for machine learning focused workloads. For example, we can get to microsecond level overhead for Python/C++ calls, which is generally similar to overhead for eager mode. When both sides of calls are static languages, the overhead will go down to tens of nanoseconds. As a side note, although we did not find it necessary, the signature still leaves room for link time optimization (LTO), when both sides are static languages with a known symbol and linked into a single binary when we inline the callee into caller side and the stack argument memory passing into register passing. - -We support first-class Function objects that allow us to also pass function/closures from different places around, enabling cool usages such as quick python callback for prototyping, and dynamic Functor creation for driver-based kernel launching. - - -## Error Handling - -Most TVM FFI C API calls, including `TVMFFISafeCallType` uses the return value to -indicate whether an error happens. When an error happens during a function call, -a non-zero value will be returned. The callee needs also to set the error through `TVMFFIErrorSetRaisedFromCStr` or `TVMFFIErrorSetRaised` API, which stores -the error on a thread-local storage. - -```c++ -// Example function that raises an error -int ErrorFunc(void* handle, const TVMFFIAny* args, int num_args, TVMFFIAny *result) { - const char* error_kind = "RuntimeError"; - const char* error_msg = "error message"; - // set the thread-local error state - TVMFFIErrorSetRaisedFromCStr(error_kind, error_msg); - return -1; -} -``` - -The caller can retrieve the error from thread-local error storage -using `TVMFFIErrorMoveFromRaised` function. -The ABI stores Error also as a specific Object, -the overall error object is stored as follows -```c++ -typedef struct { - /*! \brief The kind of the error. */ - TVMFFIByteArray kind; - /*! \brief The message of the error. */ - TVMFFIByteArray message; - /*! \brief The traceback of the error. */ - TVMFFIByteArray traceback; - /*! - * \brief Function handle to update the traceback of the error. - * \param self The self object handle. - * \param traceback The traceback to update. - */ - void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback); -} TVMFFIErrorCell; - -// error object -class ErrorObj : public ffi::Object, public TVMFFIErrorCell { -}; -``` - -The error object stores kind, message and traceback as string. When possible, -we store the traceback in the same format of python traceback (see an example as follows): -``` -File "src/extension.cc", line 45, in void my_ffi_extension::RaiseError(tvm::ffi::String) -``` - -We provide C++ object `ffi::Error` that can be throwed as exception in c++ environment. When we encounter -the C ABI boundary, we will catch the error and call `TVMFFIErrorSetRaised` to propagate the error -to the caller safely. -`TVMFFIErrorSetRaisedFromCStr` is a convenient method to set error directly from C string and can be useful in compiler backend construction to implement features such as assert. - -**Rationales:** The error object contains minimal but sufficient information to reconstruct structured -error in python side. We opt-for thread-local error state as it simplifies overall support. - -## String and Bytes - -The ABI supports strings and bytes as first-class citizens. A string can take multiple forms that are identified by -its `type_index`. - -- `kTVMFFIRawStr`: raw C string terminated by `\0`. -- `kTVMFFISmallStr`: small string, the length is stored in `small_str_len` and data is stored in `v_bytes`. -- `kTVMFFIStr`: on-heap string object for strings that are longer than 7 characters. - -The following code shows the layout of the on-heap string object. -```c++ -// span-like data structure to store header and length -typedef struct { - const char* data; - size_t size; -} TVMFFIByteArray; - -// showcase the layout of the on-heap string. -class StringObj : public ffi::Object, public TVMFFIByteArray { -}; -``` - -The following code shows how to read a string from `TVMFFIAny` -```c++ -TVMFFIByteArray ReadString(const TVMFFIAny *value) { - TVMFFIByteArray ret; - if (value->type_index == kTVMFFIRawStr) { - ret.data = value->v_c_str; - ret.size = strlen(ret.data); - } else if (value->type_index == kTVMFFISmallStr) { - ret.data = value->v_bytes; - ret.size = value->small_str_len; - } else { - assert(value->type_index == kTVMFFIStr); - ret = *reinterpret_cast( - reinterpret_cast(value->v_obj) + sizeof(TVMFFIObject)); - } - return ret; -} -``` - -Similarly, we have type indices to represent bytes. The C++ API provides classes -`ffi::String` and `ffi::Bytes` to enable the automatic conversion of these values with Any storage format. - -**Rationales:** Separate string and bytes enable clear mappings from the Python side. Small string allows us to -store short names on-stack. To favor 8-byte alignment (v_bytes) and keep things simple, we did not further -pack characters into the `small_len` field. diff --git a/ffi/docs/conf.py b/ffi/docs/conf.py deleted file mode 100644 index 139254fd97b4..000000000000 --- a/ffi/docs/conf.py +++ /dev/null @@ -1,228 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -*- coding: utf-8 -*- -import os -import sys - -import tomli - - -os.environ["TVM_FFI_BUILD_DOCS"] = "1" - -build_exhale = os.environ.get("BUILD_CPP_DOCS", "0") == "1" - - -# -- General configuration ------------------------------------------------ - -# Load version from pyproject.toml -with open("../pyproject.toml", "rb") as f: - pyproject_data = tomli.load(f) -__version__ = pyproject_data["project"]["version"] - -project = "tvm-ffi" - -version = __version__ -release = __version__ - -# -- Extensions and extension configurations -------------------------------- - -extensions = [ - "breathe", - "myst_parser", - "nbsphinx", - "autodocsumm", - "sphinx.ext.autodoc", - "sphinx.ext.autosectionlabel", - "sphinx.ext.autosummary", - "sphinx.ext.intersphinx", - "sphinx.ext.mathjax", - "sphinx.ext.napoleon", - "sphinx.ext.viewcode", - "sphinx.ext.ifconfig", - "sphinx_copybutton", - "sphinx_reredirects", - "sphinx_tabs.tabs", - "sphinx_toolbox.collapse", - "sphinxcontrib.httpdomain", - "sphinxcontrib.mermaid", -] - -if build_exhale: - extensions.append("exhale") - -breathe_default_project = "tvm-ffi" - -breathe_projects = {"tvm-ffi": "./_build/doxygen/xml"} - -exhaleDoxygenStdin = """ -INPUT = ../include -PREDEFINED += TVM_FFI_DLL= TVM_FFI_INLINE= TVM_FFI_EXTRA_CXX_API= __cplusplus=201703 - -EXCLUDE_SYMBOLS += *details* *TypeTraits* std \ - *use_default_type_traits_v* *is_optional_type_v* *operator* \ - -EXCLUDE_PATTERNS += *details.h -ENABLE_PREPROCESSING = YES -MACRO_EXPANSION = YES -""" - -exhaleAfterTitleDescription = """ -This page contains the full API index for the C++ API. -""" - -# Setup the exhale extension -exhale_args = { - "containmentFolder": "reference/cpp/generated", - "rootFileName": "index.rst", - "doxygenStripFromPath": "../include", - "rootFileTitle": "Full API Index", - "createTreeView": True, - "exhaleExecutesDoxygen": True, - "exhaleDoxygenStdin": exhaleDoxygenStdin, - "afterTitleDescription": exhaleAfterTitleDescription, -} -nbsphinx_allow_errors = True -nbsphinx_execute = "never" - -autosectionlabel_prefix_document = True -nbsphinx_allow_directives = True - -myst_enable_extensions = [ - "dollarmath", - "amsmath", - "deflist", - "colon_fence", - "html_image", - "linkify", - "attrs_block", - "substitution", -] - -myst_heading_anchors = 3 -myst_ref_domains = ["std", "py"] -myst_all_links_external = False - -intersphinx_mapping = { - "python": ("https://docs.python.org/3.12", None), - "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest", None), - "pillow": ("https://pillow.readthedocs.io/en/stable", None), - "numpy": ("https://numpy.org/doc/stable", None), - "torch": ("https://pytorch.org/docs/stable", None), -} - -autodoc_mock_imports = ["torch"] -autodoc_default_options = { - "members": True, - "undoc-members": True, - "show-inheritance": True, - "inherited-members": False, - "member-order": "bysource", -} - -# -- Other Options -------------------------------------------------------- - -templates_path = [] - -redirects = {} - -source_suffix = {".rst": "restructuredtext", ".md": "markdown"} - -language = "en" - -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md"] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# A list of ignored prefixes for module index sorting. -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - -# -- Options for HTML output ---------------------------------------------- - -html_theme = "sphinx_book_theme" -html_title = project -html_copy_source = True -html_last_updated_fmt = "" - -html_favicon = "https://tvm.apache.org/images/logo/tvm-logo-square.png" - - -footer_dropdown = { - "name": "ASF", - "items": [ - ("ASF Homepage", "https://apache.org/"), - ("License", "https://www.apache.org/licenses/"), - ("Sponsorship", "https://www.apache.org/foundation/sponsorship.html"), - ("Security", "https://tvm.apache.org/docs/reference/security.html"), - ("Thanks", "https://www.apache.org/foundation/thanks.html"), - ("Events", "https://www.apache.org/events/current-event"), - ], -} - - -footer_copyright = "Copyright © 2025, Apache Software Foundation" -footer_note = ( - "Apache TVM, Apache, the Apache feather, and the Apache TVM project " - + "logo are either trademarks or registered trademarks of the Apache Software Foundation." -) - - -def footer_html(): - # Create footer HTML with two-line layout - # Generate dropdown menu items - dropdown_items = "" - for item_name, item_url in footer_dropdown["items"]: - dropdown_items += f'
  • {item_name}
  • \n' - - footer_dropdown_html = f""" - - """ - return footer_dropdown_html - - -html_theme_options = { - "repository_url": "https://github.com/apache/tvm", - "use_repository_button": True, - "extra_footer": footer_html(), -} - -html_context = { - "display_github": True, - "github_user": "apache", - "github_version": "main", - "conf_py_path": "/ffi/docs/", -} diff --git a/ffi/docs/get_started/install.md b/ffi/docs/get_started/install.md deleted file mode 100644 index 87223d011497..000000000000 --- a/ffi/docs/get_started/install.md +++ /dev/null @@ -1,83 +0,0 @@ - - - - - - - - - - - - - - - - -# Installation - -TVM FFI is built and tested on Windows, macOS, and various -Linux distributions. You can install tvm-ffi using one of the -methods below - -## Quick Start - -The easiest way to try it out is to install from PyPI. - -```bash -pip install apache-tvm-ffi -``` - -After installation, you can run the following command to confirm that -the installation was successful - -```bash -tvm-ffi-config -h -``` - -This configuration tool is also useful in various ways to help you build -libraries with tvm-ffi. - - -## Install From Source - -You can also build and install tvm-ffi from source. - -### Dependencies - -- CMake (>= 3.24.0) -- Git -- A recent C++ compiler supporting C++17, at minimum: - - GCC 7.1 - - Clang 5.0 - - Apple Clang 9.3 - - Visual Studio 2019 (v16.7) -- Python (>= 3.9) - - -Developers can clone the source repository from GitHub. - -```bash -git clone --recursive https://github.com/apache/tvm tvm -``` - -```{note} -It's important to use the ``--recursive`` flag when cloning the repository, which will -automatically clone the submodules. If you forget to use this flag, you can manually clone the submodules -by running ``git submodule update --init --recursive`` in the root directory. -``` - -Then you can install directly in development mode - -```bash -cd tvm/ffi -pip install -ve . -``` - -The additional `-e` flag will install the Python files in `editable` mode, -which allows direct editing of the Python files to be immediately reflected in the package -and is useful for development. - -## What to Do Next - -Now that you have installed TVM FFI, we recommend reading the [Quick Start](./quick_start.md) tutorial. diff --git a/ffi/docs/get_started/quick_start.md b/ffi/docs/get_started/quick_start.md deleted file mode 100644 index 4861aa87b253..000000000000 --- a/ffi/docs/get_started/quick_start.md +++ /dev/null @@ -1,213 +0,0 @@ - - - - - - - - - - - - - - - - -# Quick Start - -This is a quick start guide explaining the basic features and usage of tvm-ffi. -The source code can be found at `examples/quick_start` in the project source. - -## Build and Run the Example - -Let us first get started by build and run the example. The example will show us: - -- How to expose c++ functions as tvm ffi ABI function -- How to load and run tvm-ffi based library from python -- How to load and run tvm-ffi based library from c++ - - -Before starting, ensure you have: - -- TVM FFI installed following [installation](./install.md) -- C++ compiler with C++17 support -- CMake 3.18 or later -- (Optional) CUDA toolkit for GPU examples -- (Optional) PyTorch for checking torch integrations - -Then obtain a copy of the tvm-ffi source code. - -```bash -git clone https://github.com/apache/tvm --recursive -cd tvm/ffi -``` - -The examples are now in the example folder, you can quickly build -the example using the following command. -```bash -cd examples/quick_start -cmake -B build -S . -cmake --build build -``` - -After the build finishes, you can run the python examples by -``` -python run_example.py -``` - -You can also run the c++ example - -``` -./build/example -``` - -## Walk through the Example - -Now we have quickly try things out. Let us now walk through the details of the example. -Specifically, in this example, we create a simple "add one" operation that adds 1 to each element of an input -tensor and expose that function as TVM FFI compatible function. The key file structures are as follows: - -``` -examples/quick_start/ -├── src/ -│ ├── add_one_cpu.cc # CPU implementation -│ ├── add_one_cuda.cu # CUDA implementation -│ └── run_example.cc # C++ usage example -├── run_example.py # Python usage example -├── run_example.sh # Build and run script -└── CMakeLists.txt # Build configuration -``` - -### CPU Implementation - -```cpp -#include -#include -#include - -namespace tvm_ffi_example { - -void AddOne(DLTensor* x, DLTensor* y) { - // Validate inputs - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - - // Perform the computation - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } -} - -// Expose the function through TVM FFI -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example::AddOne); -} -``` - -**Key Points:** -- Functions take `DLTensor*` parameters for cross-language compatibility -- The `TVM_FFI_DLL_EXPORT_TYPED_FUNC` macro exposes the function with a given name - -### CUDA Implementation - -```cpp -void AddOneCUDA(DLTensor* x, DLTensor* y) { - // Validation (same as CPU version) - // ... - - int64_t n = x->shape[0]; - int64_t nthread_per_block = 256; - int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; - - // Get current CUDA stream from environment - cudaStream_t stream = static_cast( - TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); - - // Launch kernel - AddOneKernel<<>>( - static_cast(x->data), static_cast(y->data), n); -} - -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA); -``` - -**Key Points:** -- We use `TVMFFIEnvGetStream` to obtain the current stream from the environement -- When invoking ffi Function from python end with PyTorch tensor as argument, - the stream will be populated with torch's current stream. - - -### Working with PyTorch - -Atfer build, we will create library such as `build/add_one_cuda.so`, that can be loaded by -with api {py:func}`tvm_ffi.load_module` that returns a {py:class}`tvm_ffi.Module` -Then the function will become available as property of the loaded module. -The tensor arguments in the ffi functions automatically consumes `torch.Tensor`. The following code shows how -to use the function in torch. - -```python -import torch -import tvm_ffi - -if torch.cuda.is_available(): - mod = tvm_ffi.load_module("build/add_one_cuda.so") - - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") - y = torch.empty_like(x) - - # TVM FFI automatically handles CUDA streams - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - mod.add_one_cuda(x, y) - stream.synchronize() -``` - -### Working with Python Data Arrays - -TVM FFI functions works automaticaly with python data arrays that are compatible with dlpack. -The following examples how to use the function with numpy. - -```python -import tvm_ffi -import numpy as np - -# Load the compiled module -mod = tvm_ffi.load_module("build/add_one_cpu.so") - -# Create input and output arrays -x = np.array([1, 2, 3, 4, 5], dtype=np.float32) -y = np.empty_like(x) - -# Call the function -mod.add_one_cpu(x, y) -print("Result:", y) # [2, 3, 4, 5, 6] -``` - -### Working with C++ - -One important design goal of tvm-ffi is to be universally portable. -As a result, the result libraries do not have explicit dependencies in python -and can be loaded in other language environments, such as c++. The following code -shows how to run the example exported function in C++. - -```cpp -#include -#include - -void CallAddOne(DLTensor* x, DLTensor *y) { - namespace ffi = tvm::ffi; - ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so"); - ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value(); - add_one_cpu(x, y); -} -``` - -## Summary Key Concepts - -- **TVM_FFI_DLL_EXPORT_TYPED_FUNC** exposes a c++ function into tvm-ffi C ABI -- **DLTensor** is a universal tensor structure that enables zero-copy exchange of array data -- **Module loading** is provided by tvm ffi APIs in multiple languages. diff --git a/ffi/docs/guides/cpp_guide.md b/ffi/docs/guides/cpp_guide.md deleted file mode 100644 index a27fe2dac1e6..000000000000 --- a/ffi/docs/guides/cpp_guide.md +++ /dev/null @@ -1,584 +0,0 @@ - - - - - - - - - - - - - - - - -{#cpp-guide} - -# C++ Guide - -This guide introduces the tvm-ffi C++ API. -We provide C++ API on top of the stable C ABI to provide a type-safe and efficient way to work with the tvm-ffi. -The C++ API is designed to abstract away the complexity of the C ABI while maintaining full compatibility. -The C++ API builds around the following key concepts: - -- **Any and AnyView**: Type-erased containers that can hold values of any supported type in tvm-ffi. -- **Function**: A type-erased "packed" function that can be invoked like normal functions. -- **Objects and ObjectRefs**: Reference-counted objects to manage on-heap data types. - -Code examples in this guide use `EXPECT_EQ` for demonstration purposes, which is a testing framework macro. In actual applications, you would use standard C++ assertions or error handling. -You can find runnable code of the examples under tests/cpp/test_example.cc. - -## Any and AnyView - -`Any` and `AnyView` are the foundation of tvm-ffi, providing -ways to store values that are compatible with the ffi system. -The following example shows how we can interact with Any and AnyView. - -```cpp - -#include - -void ExampleAny() { - namespace ffi = tvm::ffi; - // Create an Any from various types - // EXPECT_EQ is used here for demonstration purposes (testing framework) - ffi::Any int_value = 42; - ffi::Any float_value = 3.14; - ffi::Any string_value = "hello world"; - - // AnyView provides a lightweight view without ownership - ffi::AnyView view = int_value; - // we can cast Any/AnyView to a specific type - int extracted = view.cast(); - EXPECT_EQ(extracted, 42); - - // If we are not sure about the type - // we can use as to get an optional value - std::optional maybe_int = view.as(); - if (maybe_int.has_value()) { - EXPECT_EQ(maybe_int.value(), 42); - } - // Try cast is another version that will try to run the type - // conversion even if the type does not exactly match - std::optional maybe_int_try = view.try_cast(); - if (maybe_int_try.has_value()) { - EXPECT_EQ(maybe_int_try.value(), 42); - } -} -``` - -At a high level, we can perform the following operations: - -- We can store a value into Any, under the hood, Any will record the type of the value by its type_index. -- We can fetch a value from Any or AnyView using the `cast` function. -- If we are unsure about the type in Any, we can use `as` or `try_cast` function to get an optional value. - -Under the hood, Any and AnyView store the value via the ABI convention and also manage the reference -counting correctly when the stored value is an on-heap object. - -## Object and ObjectRef - -The tvm-ffi object system provides the foundation for all managed, reference-counted objects -in the system. It enables type safety, cross-language compatibility, and efficient memory management. - -The object system is built around three key classes: Object, ObjectPtr, and ObjectRef. -The `Object` class is the base class of all heap-allocated objects. It contains a common header -that includes the `type_index`, reference counter and deleter for the object. -Users do not need to explicitly manage these fields as part of the C++ API. Instead, -they are automatically managed through a smart pointer `ObjectPtr` which points -to a heap-allocated object instance. -The following code shows an example object and the creation of an `ObjectPtr`: - -```cpp -#include -#include - -class MyIntPairObj : public tvm::ffi::Object { - public: - int64_t a; - int64_t b; - - MyIntPairObj() = default; - MyIntPairObj(int64_t a, int64_t b) : a(a), b(b) {} - - // Required: declare type information - // to register a dynamic type index through the system -TVM_FFI_DECLARE_OBJECT_INFO_FINAL("example.MyIntPair", MyIntPairObj, tvm::ffi::Object); -}; - -void ExampleObjectPtr() { - namespace ffi = tvm::ffi; - // make_object automatically sets up the deleter correctly - // This function creates a new ObjectPtr with proper memory management - // It handles allocation, initialization, and sets up the reference counting system - ffi::ObjectPtr obj = ffi::make_object(100, 200); - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(obj->a, 100); - EXPECT_EQ(obj->b, 200); -} -``` - -We typically provide a reference class that wraps the ObjectPtr. -The `ObjectRef` base class provides the interface and reference counting -functionality for these wrapper classes. -```cpp -#include -#include - -class MyIntPair : public tvm::ffi::ObjectRef { - public: - // Constructor - explicit MyIntPair(int64_t a, int64_t b) { - data_ = tvm::ffi::make_object(a, b); - } - - // Required: define object reference methods - // This macro provides the necessary methods for ObjectRef functionality - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MyIntPair, tvm::ffi::ObjectRef, MyIntPairObj); -}; - -void ExampleObjectRef() { - namespace ffi = tvm::ffi; - MyIntPair pair(100, 200); - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(pair->a, 100); - EXPECT_EQ(pair->b, 200); -} -``` - -**Note:** The ObjectRef provides a user-friendly interface while ObjectPtr handles the low-level memory management. -The ObjectRef acts as a smart pointer wrapper that automatically manages the ObjectPtr lifecycle. - -The overall implementation pattern is as follows: -- **Object Class**: Inherits from `ffi::Object`, stores data and implements the core functionality. -- **ObjectPtr**: Smart pointer that manages the Object lifecycle and reference counting. -- **Ref Class**: Inherits from `ffi::ObjectRef`, provides a user-friendly interface and automatic memory management. - -This design ensures efficient memory management while providing a clean API for users. Once we define an ObjectRef class, -we can integrate it with the Any, AnyView and Functions. - -```cpp -#include -#include - -void ExampleObjectRefAny() { - namespace ffi = tvm::ffi; - MyIntPair pair(100, 200); - ffi::Any any = pair; - MyIntPair pair2 = any.cast(); - // Note: EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(pair2->a, 100); - EXPECT_EQ(pair2->b, 200); -} - -``` - -Under the hood, ObjectPtr manages the lifecycle of the object through the same mechanism as shared pointers. We designed -the object to be intrusive, which means the reference counter and type index metadata are embedded at the header of each object. -This design allows us to allocate the control block and object memory together. As we will see in future sections, -all of our heap-allocated classes such as Function, on-heap String, Array and Map are managed using subclasses of Object, -and the user-facing classes such as Function are ObjectRefs. - - -We provide a collection of built-in object and reference types, which are sufficient for common cases. -Developers can also bring new object types as shown in the example of this section. We provide mechanisms -to expose these objects to other language bindings such as Python. - - -## Function - -The `Function` class provides a type-safe way to create and invoke callable objects -through tvm-ffi ABI convention. We can create a `ffi::Function` from an existing typed lambda function. - -```cpp -#include - -void ExampleFunctionFromTyped() { - namespace ffi = tvm::ffi; - // Create a function from a typed lambda - ffi::Function fadd1 = ffi::Function::FromTyped( - [](const int a) -> int { return a + 1; } - ); - int b = fadd1(1).cast(); - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(b, 2); -} -``` - -Under the hood, tvm-ffi leverages Any and AnyView to create a unified ABI for -all functions. The following example demonstrates the low-level way of defining -a "packed" function for the same `fadd1`. - -```cpp -void ExampleFunctionFromPacked() { - namespace ffi = tvm::ffi; - // Create a function from a typed lambda - ffi::Function fadd1 = ffi::Function::FromPacked( - [](const ffi::AnyView* args, int32_t num_args, ffi::Any* rv) { - // Check that we have exactly one argument - TVM_FFI_ICHECK_EQ(num_args, 1); - int a = args[0].cast(); - *rv = a + 1; - } - ); - int b = fadd1(1).cast(); - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(b, 2); -} -``` - -At a high level, `ffi::Function` implements function calling by the following convention: -- The arguments are passed through an on-stack array of `ffi::AnyView` -- Return values are passed through `ffi::Any` - -Because the return value is `ffi::Any`, we need to explicitly call `cast` to convert the return -value to the desirable type. Importantly, `ffi::Function` itself is a value type that is compatible -with tvm-ffi, which means we can pass it as an argument and return values. The following code shows -an example of passing a function as an argument and applying it inside. - -```cpp -void ExampleFunctionPassFunction() { - namespace ffi = tvm::ffi; - // Create a function from a typed lambda - ffi::Function fapply = ffi::Function::FromTyped( - [](const ffi::Function f, ffi::Any param) { return f(param.cast()); }); - ffi::Function fadd1 = ffi::Function::FromTyped( // - [](const int a) -> int { return a + 1; }); - int b = fapply(fadd1, 2).cast(); - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(b, 3); -} -``` - -This pattern is very powerful because we can construct `ffi::Function` not only from C++, -but from any languages that expose to the tvm-ffi ABI. For example, this means we can easily call functions -passed in or registered from Python for quick debugging or other purposes. - - -### Global Function Registry - -Besides creating functions locally, tvm-ffi provides a global function registry that allows -functions to be registered and called across different modules and languages. -The following code shows an example - -```cpp -#include -#include - -void ExampleGlobalFunctionRegistry() { - namespace ffi = tvm::ffi; - ffi::reflection::GlobalDef().def("xyz.add1", [](const int a) -> int { return a + 1; }); - ffi::Function fadd1 = ffi::Function::GetGlobalRequired("xyz.add1"); - int b = fadd1(1).cast(); - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(b, 2); -} -``` - -You can also access and register global functions from the Python API. - -### Exporting as Library Symbol - -Besides the API that allows registration of functions into the global table, -we also provide a macro to export static functions as `TVMFFISafeCallType` symbols in a dynamic library. - -```c++ -void AddOne(DLTensor* x, DLTensor* y) { - // ... implementation omitted ... -} - -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, my_ffi_extension::AddOne); -``` - -The new `add_one` takes the signature of `TVMFFISafeCallType` and can be wrapped as `ffi::Function` -through the C++ `ffi::Module` API. - -```cpp -ffi::Module mod = ffi::Module::LoadFromFile("path/to/export_lib.so"); -ffi::Function func = mod->GetFunction("add_one").value(); -``` - -## Error Handling - -We provide a specific `ffi::Error` type that is also made compatible with the ffi ABI. -We also provide a macro `TVM_FFI_THROW` to simplify the error throwing step. - -```cpp -// file: cpp/test_example.cc -#include - -void FuncThrowError() { - namespace ffi = tvm::ffi; - TVM_FFI_THROW(TypeError) << "test0"; -} - -void ExampleErrorHandling() { - namespace ffi = tvm::ffi; - try { - FuncThrowError(); - } catch (const ffi::Error& e) { - EXPECT_EQ(e.kind(), "TypeError"); - EXPECT_EQ(e.message(), "test0"); - std::cout << e.traceback() << std::endl; - } -} -``` -The structured error class records kind, message and traceback that can be mapped to -Pythonic style error types and tracebacks. The traceback follows the Python style, -tvm-ffi will try to preserve the traceback when possible. In the above example, -you can see the traceback output as -``` -... more lines omitted -File "cpp/test_example.cc", line 106, in ExampleErrorHandling -File "cpp/test_example.cc", line 100, in void FuncThrowError() -``` - -The ffi ABI provides minimal but sufficient mechanisms to propagate these errors across -language boundaries. -So when we call the function from Python, the Error will be translated into a corresponding -Error type. Similarly, when we call a Python callback from C++, the error will be translated -into the right error kind and message. - - -## Tensor - -For many use cases, we do not need to manage the nd-array/Tensor memory. -In such cases, `DLTensor*` can be used as the function arguments. -There can be cases for a managed container for multi-dimensional arrays. -`ffi::Tensor` is a minimal container to provide such support. -Notably, specific logic of device allocations and array operations are non-goals -of the FFI. Instead, we provide minimal generic API `ffi::Tensor::FromNDAlloc` -to enable flexible customization of Tensor allocation. - -```cpp -#include -#include - -struct CPUNDAlloc { - void AllocData(DLTensor* tensor) { - tensor->data = malloc(tvm::ffi::GetDataSize(*tensor)); - } - void FreeData(DLTensor* tensor) { free(tensor->data); } -}; - -void ExampleTensor() { - namespace ffi = tvm::ffi; - ffi::Shape shape = {1, 2, 3}; - DLDataType dtype = {kDLFloat, 32, 1}; - DLDevice device = {kDLCPU, 0}; - ffi::Tensor tensor = ffi::Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); - // now tensor is a managed tensor -} -``` - -The above example shows how we define `CPUNDAlloc` that customizes `AllocData` -and `FreeData` behavior. The CPUNDAlloc struct will be kept alive with the Tensor object. -This pattern allows us to implement various Tensor allocations using the same API: - -- For CUDA allocation, we can change malloc to cudaMalloc -- For memory-pool based allocation, we can update `CPUNDAlloc` to keep a strong reference to the pool, - so we can keep memory-pool alive when the array is alive. - -**Working with Shapes** As you may have noticed in the example, we have a `ffi::Shape` container that is used -to represent the shapes in nd-array. This container allows us to have compact and efficient representation -of managed shapes and we provide quick conversions from standard vector types. - -### DLPack Conversion - -We provide first-class DLPack support to the `ffi::Tensor` that enables efficient exchange -through the DLPack Protocol. - -```cpp -#include - -void ExampleTensorDLPack() { - namespace ffi = tvm::ffi; - ffi::Shape shape = {1, 2, 3}; - DLDataType dtype = {kDLFloat, 32, 1}; - DLDevice device = {kDLCPU, 0}; - ffi::Tensor tensor = ffi::Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); - // convert to DLManagedTensorVersioned - DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned(); - // load back from DLManagedTensorVersioned - ffi::Tensor tensor2 = ffi::Tensor::FromDLPackVersioned(dlpack); -} -``` - -These APIs are also available through the C APIs -`TVMFFITensorFromDLPackVersioned` and `TVMFFITensorToDLPackVersioned`. - -## String and Bytes - -The tvm-ffi provides first-class support for `String` and `Bytes` types that are efficient, -FFI-compatible, and interoperable with standard C++ string types. - -```cpp -#include - -void ExampleString() { - namespace ffi = tvm::ffi; - ffi::String str = "hello world"; - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(str.size(), 11); - std::string std_str = str; - EXPECT_EQ(std_str, "hello world"); -} -``` - -Alternatively, users can always directly use `std::string` in function arguments, conversion -will happen automatically. - -**Rationale:** We need to have separate Bytes and String so they map well to corresponding Python types. -`ffi::String` is backed by a possibly managed object that makes it more compatible with the Object system. - -## Container Types - -To enable effective passing and storing of collections of values that are compatible with tvm-ffi, -we provide several built-in container types. - -### Array - -`Array` provides an array data type that can be used as function arguments. -When we use `Array` as an argument of a Function, it will -perform runtime checks of the elements to ensure the values match the expected type. - -```cpp -#include - - -void ExampleArray() { - namespace ffi = tvm::ffi; - ffi::Array numbers = {1, 2, 3}; - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(numbers.size(), 3); - EXPECT_EQ(numbers[0], 1); - - ffi::Function head = ffi::Function::FromTyped([](const ffi::Array a) { - return a[0]; - }); - EXPECT_EQ(head(numbers).cast(), 1); - - try { - // throw an error because 2.2 is not int - head(ffi::Array({1, 2.2})); - } catch (const ffi::Error& e) { - EXPECT_EQ(e.kind(), "TypeError"); - } -} -``` - -Under the hood, Array is backed by a reference-counted Object `ArrayObj` that stores -a collection of Any values. Note that conversion from Any to `Array` will result in -runtime checks of elements because the type index only indicates `ArrayObj` as the backing storage. -If you want to defer such checks at the FFI function boundary, consider using `Array` instead. -When passing lists and tuples from Python, the values will be converted to `Array` before -being passed into the Function. - -**Performance note:** Repeatedly converting Any to `Array` can incur repeated -checking overhead at each element. Consider using `Array` to defer checking or only run conversion once. - -### Tuple - -`Tuple` provides type-safe fixed-size collections. - -```cpp -#include - -void ExampleTuple() { - namespace ffi = tvm::ffi; - ffi::Tuple tup(42, "hello", true); - - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(tup.get<0>(), 42); - EXPECT_EQ(tup.get<1>(), "hello"); - EXPECT_EQ(tup.get<2>(), true); -} -``` - -Under the hood, Tuple is backed by the same `ArrayObj` as the Array container. -This enables zero-cost exchange with input arguments. - -**Rationale:** This design unifies the conversion rules from Python list/tuple to -Array/Tuple. We always need a container representation for tuples -to be stored in Any. - -### Map - -`Map` provides a key-value based hashmap container that can accept dict-style parameters. - -```cpp -#include - -void ExampleMap() { - namespace ffi = tvm::ffi; - - ffi::Map map0 = {{"Alice", 100}, {"Bob", 95}}; - - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(map0.size(), 2); - EXPECT_EQ(map0.at("Alice"), 100); - EXPECT_EQ(map0.count("Alice"), 1); -} -``` - - -Under the hood, Map is backed by a reference-counted Object `MapObj` that stores -a collection of Any values. The implementation provides a SmallMap variant that stores -values as an array and another variant that is based on a hashmap. The Map preserves insertion -order like Python dictionaries. Conversion from Any to `Map` will result in -runtime checks of its elements because the type index only indicates `MapObj` as the backing storage. -If you want to defer such checks at the FFI function boundary, consider using `Map` instead. -When passing dictionaries from Python, the values will be converted to `Map` before -being passed into the Function. - -**Performance note:** Repeatedly converting Any to `Map` can incur repeated -checking overhead at each element. Consider using `Map` to defer checking or only run conversion once. - -### Optional - -`Optional` provides a safe way to handle values that may or may not exist. -We specialize Optional for `ffi::String` and Object types to be more compact, -using nullptr to indicate non-existence. - -```cpp -#include - -void ExampleOptional() { - namespace ffi = tvm::ffi; - ffi::Optional opt0 = 100; - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(opt0.has_value(), true); - EXPECT_EQ(opt0.value(), 100); - - ffi::Optional opt1; - EXPECT_EQ(opt1.has_value(), false); - EXPECT_EQ(opt1.value_or("default"), "default"); -} -``` - - -### Variant - -`Variant` provides a type-safe union of different types. - -```cpp -#include - -void ExampleVariant() { - namespace ffi = tvm::ffi; - ffi::Variant var0 = 100; - // EXPECT_EQ is used here for demonstration purposes (testing framework) - EXPECT_EQ(var0.get(), 100); - - var0 = ffi::String("hello"); - std::optional maybe_str = var0.as(); - EXPECT_EQ(maybe_str.value(), "hello"); - - std::optional maybe_int2 = var0.as(); - EXPECT_EQ(maybe_int2.has_value(), false); -} -``` - -Under the hood, Variant is a wrapper around Any that restricts the type to the specific types in the list. diff --git a/ffi/docs/guides/packaging.md b/ffi/docs/guides/packaging.md deleted file mode 100644 index c12fe4e30719..000000000000 --- a/ffi/docs/guides/packaging.md +++ /dev/null @@ -1,282 +0,0 @@ - - - - - - - - - - - - - - - - -# Packaging - -This guide explains how to package a tvm-ffi-based library into a Python ABI-agnostic wheel. -It demonstrates both source-level builds (for cross-compilation) and builds based on pre-shipped shared libraries. -At a high level, packaging with tvm-ffi offers several benefits: - -- **ABI-agnostic wheels**: Works across different Python versions with minimal dependency. -- **Universally deployable**: Build once with tvm-ffi and ship to different environments, including Python and non-Python environments. - -While this guide shows how to build a wheel package, the resulting `my_ffi_extension.so` is agnostic -to Python, comes with minimal dependencies, and can be used in other deployment scenarios. - -## Build and Run the Example - -Let's start by building and running the example. -First, obtain a copy of the tvm-ffi source code. - -```bash -git clone https://github.com/apache/tvm --recursive -cd tvm/ffi -``` - -The examples are now in the examples folder. You can quickly build -and install the example using the following command. -```bash -cd examples/packaging -pip install -v . -``` - -Then you can run examples that leverage the built wheel package. - -```bash -python run_example.py add_one -``` - -## Setup pyproject.toml - -A typical tvm-ffi-based project has the following structure: - -``` -├── CMakeLists.txt # CMake build configuration -├── pyproject.toml # Python packaging configuration -├── src/ -│ └── extension.cc # C++ source code -├── python/ -│ └── my_ffi_extension/ -│ ├── __init__.py # Python package initialization -│ ├── base.py # Library loading logic -│ └── _ffi_api.py # FFI API registration -└── README.md # Project documentation -``` - -The `pyproject.toml` file configures the build system and project metadata. - -```toml -[project] -name = "my-ffi-extension" -version = "0.1.0" -# ... more project metadata omitted ... - -[build-system] -requires = ["scikit-build-core>=0.10.0", "apache-tvm-ffi"] -build-backend = "scikit_build_core.build" - -[tool.scikit-build] -# ABI-agnostic wheel -wheel.py-api = "py3" -# ... more build configuration omitted ... -``` - -We use scikit-build-core for building the wheel. Make sure you add tvm-ffi as a build-system requirement. -Importantly, we should set `wheel.py-api` to `py3` to indicate it is ABI-generic. - -## Setup CMakeLists.txt - -The CMakeLists.txt handles the build and linking of the project. -There are two ways you can build with tvm-ffi: - -- Link the pre-built `libtvm_ffi` shipped from the pip package -- Build tvm-ffi from source - -For common cases, using the pre-built library and linking tvm_ffi_shared is sufficient. -To build with the pre-built library, you can do: - -```cmake -cmake_minimum_required(VERSION 3.18) -project(my_ffi_extension) - -find_package(Python COMPONENTS Interpreter REQUIRED) -execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT) -# find the prebuilt package -find_package(tvm_ffi CONFIG REQUIRED) - -# ... more cmake configuration omitted ... - -# linking the library -target_link_libraries(my_ffi_extension tvm_ffi_shared) -``` - -There are cases where one may want to cross-compile or bundle part of tvm_ffi objects directly -into the project. In such cases, you should build from source. - -```cmake -execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --sourcedir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT) -# add the shipped source code as a cmake subdirectory -add_subdirectory(${tvm_ffi_ROOT} tvm_ffi) - -# ... more cmake configuration omitted ... - -# linking the library -target_link_libraries(my_ffi_extension tvm_ffi_shared) -``` -Note that it is always safe to build from source, and the extra cost of building tvm-ffi is small -because tvm-ffi is a lightweight library. If you are in doubt, -you can always choose to build tvm-ffi from source. -In Python or other cases when we dynamically load libtvm_ffi shipped with the dedicated pip package, -you do not need to ship libtvm_ffi.so in your package even if you build tvm-ffi from source. -The built objects are only used to supply the linking information. - -## Exposing C++ Functions - -The C++ implementation is defined in `src/extension.cc`. -There are two ways one can expose a function in C++ to the FFI library. -First, `TVM_FFI_DLL_EXPORT_TYPED_FUNC` can be used to expose the function directly as a C symbol that follows the tvm-ffi ABI, -which can later be accessed via `tvm_ffi.load_module`. - -Here's a basic example of the function implementation: - -```c++ -void AddOne(DLTensor* x, DLTensor* y) { - // ... implementation omitted ... -} - -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, my_ffi_extension::AddOne); -``` - -We can also register a function into the global function table with a given name: - -```c++ -void RaiseError(ffi::String msg) { - TVM_FFI_THROW(RuntimeError) << msg; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("my_ffi_extension.raise_error", RaiseError); -} -``` - -Make sure to have a unique name across all registered functions when registering a global function. -Always prefix with a package namespace name to avoid name collisions. -The function can then be found via `tvm_ffi.get_global_func(name)` -and is expected to stay throughout the lifetime of the program. - -We recommend using `TVM_FFI_DLL_EXPORT_TYPED_FUNC` for functions that are supposed to be dynamically -loaded (such as JIT scenarios) so they won't be exposed to the global function table. - -## Library Loading in Python - -The base module handles loading the compiled extension: - -```python -import tvm_ffi -import os -import sys - -def _load_lib(): - file_dir = os.path.dirname(os.path.realpath(__file__)) - - # Platform-specific library names - if sys.platform.startswith("win32"): - lib_name = "my_ffi_extension.dll" - elif sys.platform.startswith("darwin"): - lib_name = "my_ffi_extension.dylib" - else: - lib_name = "my_ffi_extension.so" - - lib_path = os.path.join(file_dir, lib_name) - return tvm_ffi.load_module(lib_path) - -_LIB = _load_lib() -``` - -Effectively, it leverages the `tvm_ffi.load_module` call to load the library -extension DLL shipped along with the package. The `_ffi_api.py` contains a function -call to `tvm_ffi.init_ffi_api` that registers all global functions prefixed -with `my_ffi_extension` into the module. - -```python -# _ffi_api.py -import tvm_ffi -from .base import _LIB - -# Register all global functions prefixed with 'my_ffi_extension.' -# This makes functions registered via TVM_FFI_STATIC_INIT_BLOCK available -tvm_ffi.init_ffi_api("my_ffi_extension", __name__) -``` - -Then we can redirect the calls to the related functions. - -```python -from .base import _LIB -from . import _ffi_api - -def add_one(x, y): - # ... docstring omitted ... - return _LIB.add_one(x, y) - -def raise_error(msg): - # ... docstring omitted ... - return _ffi_api.raise_error(msg) -``` - -## Build and Use the Package - -First, build the wheel: -```bash -pip wheel -v -w dist . -``` - -Then install the built wheel: -```bash -pip install dist/*.whl -``` - -Then you can try it out: - -```python -import torch -import my_ffi_extension - -# Create input and output tensors -x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) -y = torch.empty_like(x) - -# Call the function -my_ffi_extension.add_one(x, y) -print(y) # Output: tensor([2., 3., 4., 5., 6.]) -``` - -You can also run the following command to see how errors are raised and propagated -across language boundaries: - -```python -python run_example.py raise_error -``` - -When possible, tvm-ffi will try to preserve tracebacks across language boundaries. You will see tracebacks like: -``` -File "src/extension.cc", line 45, in void my_ffi_extension::RaiseError(tvm::ffi::String) -``` - -## Wheel Auditing - -When using `auditwheel`, exclude `libtvm_ffi` as it will be shipped with the `tvm_ffi` package. - -```bash -auditwheel repair --exclude libtvm_ffi.so dist/*.whl -``` - -As long as you import `tvm_ffi` first before loading the library, the symbols will be available. diff --git a/ffi/docs/guides/python_guide.md b/ffi/docs/guides/python_guide.md deleted file mode 100644 index 0ab56eb9c461..000000000000 --- a/ffi/docs/guides/python_guide.md +++ /dev/null @@ -1,242 +0,0 @@ - - - - - - - - - - - - - - - - -# Python Guide - -This guide introduces the `tvm_ffi` Python package. -At a high level, the `tvm_ffi` Python package provides first-class Python support for - -- Pythonic classes to represent values in TVM FFI Any ABI. -- Mechanisms to call into TVM FFI ABI compatible functions. -- Conversion between Python values and `tvm_ffi` values. - -In this guide, we will run examples that make use of pre-registered testing functions in `tvm_ffi`. -If so, we will also briefly copy snippets that show the corresponding C++ behavior. - -## Load and Run Module - -The most common use case of TVM FFI is to load a runnable module and run the corresponding function. -You can follow the [quick start guide](../get_started/quick_start.md) for details on building the -library `build/add_one_cpu.so`. Let's walk through the load and run example again for NumPy - -```python -import tvm_ffi -import numpy as np - -# Load the compiled module -mod = tvm_ffi.load_module("build/add_one_cpu.so") - -# Create input and output arrays -x = np.array([1, 2, 3, 4, 5], dtype=np.float32) -y = np.empty_like(x) - -# Call the function -mod.add_one_cpu(x, y) -``` - -In this case, {py:func}`tvm_ffi.load_module` will return a {py:class}`tvm_ffi.Module` class that contains -the exported functions. You can access the functions by their names. - -## Tensor - -`tvm_ffi` provides a managed DLPack-compatible Tensor. - -```python -import numpy as np -import tvm_ffi - -# Demonstrate DLPack conversion between NumPy and TVM FFI -np_data = np.array([1, 2, 3, 4], dtype=np.float32) -tvm_array = tvm_ffi.from_dlpack(np_data) -# Convert back to NumPy -np_result = np.from_dlpack(tvm_array) -``` - -In most cases, however, you do not have to explicitly create Tensors. -The Python interface can take in `torch.Tensor` and `numpy.ndarray` objects -and automatically convert them to {py:class}`tvm_ffi.Tensor`. - -## Functions and Callbacks - -{py:class}`tvm_ffi.Function` provides the Python interface for `ffi::Function` in the C++. -You can retrieve globally registered functions via {py:func}`tvm_ffi.get_global_func`. - -```python -import tvm_ffi - -# testing.echo is defined and registered in C++ -# [](ffi::Any x) { return x; } -fecho = tvm_ffi.get_global_func("testing.echo") -assert fecho(1) == 1 -``` - -You can pass a Python function as an argument to another FFI function as callbacks. -Under the hood, {py:func}`tvm_ffi.convert` is called to convert the Python function into a -{py:class}`tvm_ffi.Function`. - -```python -import tvm_ffi - -# testing.apply is registered in C++ -# [](ffi::Function f, ffi::Any val) { return f(x); } -fapply = tvm_ffi.get_global_func("testing.apply") -# invoke fapply with lambda callback as f -assert fapply(lambda x: x + 1, 1) == 2 -``` - -This is a very powerful pattern that allows us to inject Python callbacks into the C++ code. -You can also register a Python callback as a global function. - -```python -import tvm_ffi - -@tvm_ffi.register_global_func("example.add_one") -def add_one(a): - return a + 1 - -assert tvm_ffi.get_global_func("example.add_one")(1) == 2 -``` - -## Container Types - -When an FFI function takes arguments from lists/tuples, they will be converted into {py:class}`tvm_ffi.Array`. - -```python -import tvm_ffi - -# Lists become Arrays -arr = tvm_ffi.convert([1, 2, 3, 4]) -assert isinstance(arr, tvm_ffi.Array) -assert len(arr) == 4 -assert arr[0] == 1 -``` - -Dictionaries will be converted to {py:class}`tvm_ffi.Map` - -```python -import tvm_ffi - -map_obj = tvm_ffi.convert({"a": 1, "b": 2}) -assert isinstance(map_obj, tvm_ffi.Map) -assert len(map_obj) == 2 -assert map_obj["a"] == 1 -assert map_obj["b"] == 2 -``` - -When container values are returned from FFI functions, they are also stored in these -types respectively. - - -## Error Handling - -An FFI function may raise an error. In such cases, the Python package will automatically -translate the error to the corresponding error kind in Python - -```python -import tvm_ffi - -# defined in C++ -# [](String kind, String msg) { throw Error(kind, msg, traceback); } -test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error") - -test_raise_error("ValueError", "message") -``` -The above code shows an example where an error is raised in C++, resulting in the following error trace -``` -Traceback (most recent call last): -File "example.py", line 7, in - test_raise_error("ValueError", "message") - ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^ -File "python/tvm_ffi/cython/function.pxi", line 325, in core.Function.__call__ - raise move_from_last_error().py_error() - ^^^ -File "src/ffi/extra/testing.cc", line 60, in void tvm::ffi::TestRaiseError(tvm::ffi::String, tvm::ffi::String) - throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0)); -``` - -We register common error kinds. You can also register extra error dispatch via the {py:func}`tvm_ffi.register_error` function. - -## Advanced: Register Your Own Object - -For advanced use cases, you may want to register your own objects. This can be achieved through the -reflection registry in the TVM-FFI API. First, let's review the C++ side of the code. For this -example, you do not need to change the C++ side as this code is pre-shipped with the testing module of the `tvm_ffi` package. - -```cpp -#include - -// Step 1: Define the object class (stores the actual data) -class TestIntPairObj : public tvm::ffi::Object { -public: - int64_t a; - int64_t b; - - TestIntPairObj() = default; - TestIntPairObj(int64_t a, int64_t b) : a(a), b(b) {} - - // Required: declare type information -TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestIntPair", TestIntPairObj, tvm::ffi::Object); -}; - -// Step 2: Define the reference wrapper (user-facing interface) -class TestIntPair : public tvm::ffi::ObjectRef { -public: - // Constructor - explicit TestIntPair(int64_t a, int64_t b) { - data_ = tvm::ffi::make_object(a, b); - } - - // Required: define object reference methods - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TestIntPair, tvm::ffi::ObjectRef, TestIntPairObj); -}; - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - // register the object into the system - // register field accessors and a global static function `__create__` as ffi::Function - refl::ObjectDef() - .def_ro("a", &TestIntPairObj::a) - .def_ro("b", &TestIntPairObj::b) - .def_static("__create__", [](int64_t a, int64_t b) -> TestIntPair { - return TestIntPair(a, b); - }); -} -``` - -You can then create wrapper classes for objects that are in the library as follows: - -```python -import tvm_ffi - -# Register the class -@tvm_ffi.register_object("testing.TestIntPair") -class TestIntPair(tvm_ffi.Object): - def __init__(self, a, b): - # This is a special method to call an FFI function whose return - # value exactly initializes the object handle of the object - self.__init_handle_by_constructor__(TestIntPair.__create__, a, b) - -test_int_pair = TestIntPair(1, 2) -# We can access the fields by name -# The properties are populated by the reflection mechanism -assert test_int_pair.a == 1 -assert test_int_pair.b == 2 -``` -Under the hood, we leverage the information registered through the reflection registry to -generate efficient field accessors and methods for each class. - -Importantly, when you have multiple inheritance, you need to call {py:func}`tvm_ffi.register_object` -on both the base class and the child class. diff --git a/ffi/docs/index.rst b/ffi/docs/index.rst deleted file mode 100644 index 643ee417913d..000000000000 --- a/ffi/docs/index.rst +++ /dev/null @@ -1,53 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -Apache TVM FFI Documentation -============================ - -Welcome to the documentation for TVM FFI. You can get started by reading the get started section, -or reading through the guides and concepts sections. - - -.. toctree:: - :maxdepth: 1 - :caption: Get Started - - get_started/install.md - get_started/quick_start.md - -.. toctree:: - :maxdepth: 1 - :caption: Guides - - guides/packaging.md - guides/cpp_guide.md - guides/python_guide.md - - -.. toctree:: - :maxdepth: 1 - :caption: Concepts - - concepts/abi_overview.md - - -.. toctree:: - :maxdepth: 1 - :caption: Reference - - reference/python/index.rst - reference/cpp/index.rst diff --git a/ffi/docs/reference/cpp/index.rst b/ffi/docs/reference/cpp/index.rst deleted file mode 100644 index ac9b1d73f9d3..000000000000 --- a/ffi/docs/reference/cpp/index.rst +++ /dev/null @@ -1,107 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -C++ API -======= - -This page contains the API reference for the C++ API. The full API index below -can be a bit dense, so we recommend the following tips first: - -- Please read the :ref:`C++ Guide` for a high-level overview of the C++ API. - - - The C++ Guide and examples will likely be sufficient to get started with most use cases. - -- The :ref:`cpp-key-classes` lists the key classes that are most commonly used. -- You can go to the Full API Index at the bottom of this page to access the full list of APIs. - - - We usually group the APIs by files. You can look at the file hierarchy in the - full API index and navigate to the specific file to find the APIs in that file. - -Header Organization -------------------- - -The C++ APIs are organized into the following folders: - -.. list-table:: - :header-rows: 1 - :widths: 30 70 - - * - Folder - - Description - * - ``tvm/ffi/`` - - Core functionalities that support Function, Any, Object, etc. - * - ``tvm/ffi/container/`` - - Additional container types such as Array, Map, Shape, Tensor, Variant ... - * - ``tvm/ffi/reflection/`` - - Reflection support for function and type information registration. - * - ``tvm/ffi/extra/`` - - Extra APIs that are built on top. - - -.. _cpp-key-classes: - -Key Classes ------------ - -.. list-table:: - :header-rows: 1 - :widths: 30 70 - - * - Class - - Description - * - :cpp:class:`tvm::ffi::Function` - - Type-erased function that implements the ABI. - * - :cpp:class:`tvm::ffi::Any` - - Type-erased container for any supported value. - * - :cpp:class:`tvm::ffi::AnyView` - - Lightweight view of Any without ownership. - * - :cpp:class:`tvm::ffi::Object` - - Base class for all heap-allocated FFI objects. - * - :cpp:class:`tvm::ffi::ObjectRef` - - Reference class for objects. - * - :cpp:class:`tvm::ffi::Tensor` - - Multi-dimensional tensor with DLPack support. - * - :cpp:class:`tvm::ffi::Shape` - - Tensor shape container. - * - :cpp:class:`tvm::ffi::Module` - - Dynamic library module that can load exported functions. - * - :cpp:class:`tvm::ffi::String` - - String type for FFI. - * - :cpp:class:`tvm::ffi::Bytes` - - Byte array type. - * - :cpp:class:`tvm::ffi::Array` - - Dynamic array container. - * - :cpp:class:`tvm::ffi::Tuple` - - Heterogeneous tuple container. - * - :cpp:class:`tvm::ffi::Map` - - Key-value map container. - * - :cpp:class:`tvm::ffi::Optional` - - Optional value wrapper. - * - :cpp:class:`tvm::ffi::Variant` - - Type-safe union container. - - - -.. _cpp-full-api-index: - -Full API Index --------------- - -.. toctree:: - :maxdepth: 2 - - generated/index.rst diff --git a/ffi/docs/reference/python/index.rst b/ffi/docs/reference/python/index.rst deleted file mode 100644 index 13008089f3a9..000000000000 --- a/ffi/docs/reference/python/index.rst +++ /dev/null @@ -1,69 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -Python API -========== - -.. automodule:: tvm_ffi - :no-members: - -.. currentmodule:: tvm_ffi - -Object ------- -.. autosummary:: - :toctree: generated/ - - Object - register_object - - -Function and Module -------------------- -.. autosummary:: - :toctree: generated/ - - - Function - Module - register_global_func - get_global_func - system_lib - load_module - init_ffi_api - register_error - convert - - -Tensor ------- -.. autosummary:: - :toctree: generated/ - - Shape - Tensor - Device - from_dlpack - - -Containers ----------- -.. autosummary:: - :toctree: generated/ - - Array - Map diff --git a/ffi/docs/requirements.txt b/ffi/docs/requirements.txt deleted file mode 100644 index 74784b5153a6..000000000000 --- a/ffi/docs/requirements.txt +++ /dev/null @@ -1,21 +0,0 @@ -autodocsumm -exhale -breathe -linkify-it-py -matplotlib -myst-parser -nbconvert -nbsphinx -nbstripout -sphinx -sphinx-autobuild -sphinx-book-theme -sphinx-copybutton -sphinx-reredirects==0.1.2 -sphinx-tabs == 3.4.1 -sphinx-toolbox == 3.4.0 -sphinxcontrib-mermaid -sphinxcontrib-napoleon==0.7 -sphinxcontrib_httpdomain==1.8.1 -tomli -urllib3>=2.5.0 diff --git a/ffi/examples/inline_module/main.py b/ffi/examples/inline_module/main.py deleted file mode 100644 index 5cfcd41bec12..000000000000 --- a/ffi/examples/inline_module/main.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import torch -import tvm_ffi.cpp -from tvm_ffi.module import Module - - -def main(): - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cpp_sources=r""" - void add_one_cpu(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } - } - - void add_one_cuda(DLTensor* x, DLTensor* y); - """, - cuda_sources=r""" - __global__ void AddOneKernel(float* x, float* y, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - y[idx] = x[idx] + 1; - } - } - - void add_one_cuda(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - - int64_t n = x->shape[0]; - int64_t nthread_per_block = 256; - int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; - // Obtain the current stream from the environment - // it will be set to torch.cuda.current_stream() when calling the function - // with torch.Tensors - cudaStream_t stream = static_cast( - TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); - // launch the kernel - AddOneKernel<<>>(static_cast(x->data), - static_cast(y->data), n); - } - """, - functions=["add_one_cpu", "add_one_cuda"], - ) - - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) - y = torch.empty_like(x) - mod.add_one_cpu(x, y) - torch.testing.assert_close(x + 1, y) - - x_cuda = x.cuda() - y_cuda = torch.empty_like(x_cuda) - mod.add_one_cuda(x_cuda, y_cuda) - torch.testing.assert_close(x_cuda + 1, y_cuda) - - -if __name__ == "__main__": - main() diff --git a/ffi/examples/packaging/CMakeLists.txt b/ffi/examples/packaging/CMakeLists.txt deleted file mode 100644 index ed55f7ca33df..000000000000 --- a/ffi/examples/packaging/CMakeLists.txt +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.18) -project(my_ffi_extension) - -option(TVM_FFI_EXT_FROM_SOURCE "Build tvm_ffi from source, useful for cross compilation." ON) -option(TVM_FFI_EXT_SHIP_DEBUG_SYMBOLS "Ship debug symbols" ON) - -# There are two ways to include tvm_ffi -# -# 1. Build tvm_ffi from source, which is reasonably cheap since tvm ffi is small -# 2. Use the pre-built tvm_ffi shipped from the pip -# -# This example shows both options, you only need to pick a specific one. -# -# - For common build cases, using pre-built and link tvm_ffi_shared is sufficient. -# - For cases where you may want to cross-compile or bundle part of tvm_ffi_objects directly -# into your project, opt for building tvm_ffi from source path. -# Note that it is always safe to build from source and extra cost of building tvm_ffi is small. -# So when in doubt, you can always choose to the building tvm_ffi from source route. -# -# In python or other cases when we dynamically load libtvm_ffi_shared. Even when you build -# from source, you do not need to ship libtvm_ffi.so built here as they are only -# used to supply the linking information. -# first find python related components -find_package(Python COMPONENTS Interpreter REQUIRED) -if (TVM_FFI_BUILD_FROM_SOURCE) - execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --sourcedir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT) - message(STATUS "Building tvm_ffi from source: ${tvm_ffi_ROOT}") - add_subdirectory(${tvm_ffi_ROOT} tvm_ffi) -else() - # call tvm_ffi.config to get the cmake directory and set it to tvm_ffi_ROOT - execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT) - find_package(tvm_ffi CONFIG REQUIRED) -endif() - -# use the projects as usual -add_library(my_ffi_extension SHARED src/extension.cc) -target_link_libraries(my_ffi_extension tvm_ffi_header) -target_link_libraries(my_ffi_extension tvm_ffi_shared) - -# show as my_ffi_extension.so -set_target_properties( - my_ffi_extension PROPERTIES PREFIX "" -) - -if (TVM_FFI_EXT_SHIP_DEBUG_SYMBOLS) - # ship debugging symbols for backtrace on macos - tvm_ffi_add_prefix_map(my_ffi_extension ${CMAKE_CURRENT_SOURCE_DIR}) - tvm_ffi_add_apple_dsymutil(my_ffi_extension) - install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ DESTINATION . FILES_MATCHING PATTERN "*.dSYM") -endif() - -install(TARGETS my_ffi_extension DESTINATION .) diff --git a/ffi/examples/packaging/README.md b/ffi/examples/packaging/README.md deleted file mode 100644 index 25bcc1ca3c0b..000000000000 --- a/ffi/examples/packaging/README.md +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - - - - - - - - - - -# TVM FFI Packaging Example - -This is an example project that packages a tvm-ffi based library -into a Python ABI-agnostic wheel. - -This example can also serve as a guideline for general -packaging as well. - -- Source-level build for cross-compilation support in CMake -- Registration via global function table - -## Install the wheel - -```bash -pip install . -``` - -### Note on build and auditwheel - -Note: When running the auditwheel process, make sure to skip -`libtvm_ffi.so` as they are shipped via the tvm_ffi package. - -## Run the example - -After installing the `my_ffi_extension` example package, you can run the following example -that invokes the `add_one` function exposed. - -```bash -python run_example.py add_one -``` - -You can also run the following command to see how error is raised and propagated -across the language boundaries. - -```python -python run_example.py raise_error -``` - -When possible, tvm_ffi will try to preserve traceback across language boundary. You will see traceback like -``` -File "src/extension.cc", line 45, in void my_ffi_extension::RaiseError(tvm::ffi::String) -``` -If you are in an IDE like VSCode, you can click and jump to the C++ lines of error when -the debug symbols are preserved. diff --git a/ffi/examples/packaging/pyproject.toml b/ffi/examples/packaging/pyproject.toml deleted file mode 100644 index 7825ca81ce98..000000000000 --- a/ffi/examples/packaging/pyproject.toml +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[project] -name = "my-ffi-extension" -version = "0.1.0" - -readme = "README.md" -license = { text = "Apache 2.0" } -classifiers = [ - "License :: OSI Approved :: Apache Software License", - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", -] -keywords = ["machine learning", "inference"] -requires-python = ">=3.9" - -dependencies = ["apache-tvm-ffi"] - -[build-system] -requires = ["scikit-build-core>=0.10.0", "apache-tvm-ffi"] -build-backend = "scikit_build_core.build" - -[tool.scikit-build] -# the wheel is abi agnostic -wheel.py-api = "py3" -minimum-version = "build-system.requires" - -# Build configuration -build-dir = "build" -build.verbose = true - -# CMake configuration -cmake.version = "CMakeLists.txt" -cmake.build-type = "RelWithDebugInfo" - -# Logging -logging.level = "INFO" - -# Wheel configuration -wheel.packages = ["python/my_ffi_extension"] -wheel.install-dir = "my_ffi_extension" diff --git a/ffi/examples/packaging/python/my_ffi_extension/__init__.py b/ffi/examples/packaging/python/my_ffi_extension/__init__.py deleted file mode 100644 index 4cd4207df136..000000000000 --- a/ffi/examples/packaging/python/my_ffi_extension/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations. -from .base import _LIB -from . import _ffi_api - - -def add_one(x, y): - """ - Adds one to the input tensor. - - Parameters - ---------- - x : Tensor - The input tensor. - y : Tensor - The output tensor. - """ - return _LIB.add_one(x, y) - - -def raise_error(msg): - """ - Raises an error with the given message. - - Parameters - ---------- - msg : str - The message to raise the error with. - - Raises - ------ - RuntimeError - The error raised by the function. - """ - return _ffi_api.raise_error(msg) diff --git a/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py b/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py deleted file mode 100644 index 616b1ee8e80c..000000000000 --- a/ffi/examples/packaging/python/my_ffi_extension/_ffi_api.py +++ /dev/null @@ -1,24 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations. - -import tvm_ffi - -# make sure lib is loaded first -from .base import _LIB - -# this is a short cut to register all the global functions -# prefixed by `my_ffi_extension.` to this module -tvm_ffi.init_ffi_api("my_ffi_extension", __name__) diff --git a/ffi/examples/packaging/python/my_ffi_extension/base.py b/ffi/examples/packaging/python/my_ffi_extension/base.py deleted file mode 100644 index d65264eb7124..000000000000 --- a/ffi/examples/packaging/python/my_ffi_extension/base.py +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations. -# Base logic to load library for extension package -import tvm_ffi -import os -import sys - - -def _load_lib(): - # first look at the directory of the current file - file_dir = os.path.dirname(os.path.realpath(__file__)) - - if sys.platform.startswith("win32"): - lib_dll_name = "my_ffi_extension.dll" - elif sys.platform.startswith("darwin"): - lib_dll_name = "my_ffi_extension.dylib" - else: - lib_dll_name = "my_ffi_extension.so" - - lib_path = os.path.join(file_dir, lib_dll_name) - return tvm_ffi.load_module(lib_path) - - -_LIB = _load_lib() diff --git a/ffi/examples/packaging/run_example.py b/ffi/examples/packaging/run_example.py deleted file mode 100644 index 11642257e8bc..000000000000 --- a/ffi/examples/packaging/run_example.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations. -# Base logic to load library for extension package -import torch -import sys -import my_ffi_extension - - -def run_add_one(): - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) - y = torch.empty_like(x) - my_ffi_extension.add_one(x, y) - print(y) - - -def run_raise_error(): - my_ffi_extension.raise_error("This is an error") - - -if __name__ == "__main__": - if len(sys.argv) > 1: - if sys.argv[1] == "add_one": - run_add_one() - elif sys.argv[1] == "raise_error": - run_raise_error() - else: - print("Usage: python run_example.py ") diff --git a/ffi/examples/packaging/src/extension.cc b/ffi/examples/packaging/src/extension.cc deleted file mode 100644 index 6a7324f4108e..000000000000 --- a/ffi/examples/packaging/src/extension.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file example.cc - * \brief Example of a tvm-ffi based library that registers various functions. - * - * It is a simple example that demonstrates how to package a tvm-ffi library into a python wheel. - * The library is written in C++ and can be compiled into a shared library. - * The shared library can then be loaded into python and used to call the functions. - */ -#include -#include -#include -#include -#include - -namespace my_ffi_extension { - -namespace ffi = tvm::ffi; - -/*! - * \brief Raises a runtime error - * - * This is an example function to show how to raise and propagate - * an error across the language boundary. - * - * \param msg The message to raise the error with - */ -void RaiseError(ffi::String msg) { TVM_FFI_THROW(RuntimeError) << msg; } - -void AddOne(ffi::Tensor x, ffi::Tensor y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } -} - -// expose global symbol add_one -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one, my_ffi_extension::AddOne); - -// The static initialization block is -// called once when the library is loaded. -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - // In this particular example, we use the reflection mechanisms to - // register the functions directly into the global function table. - // - // This is an alternative approach to TVM_FFI_DLL_EXPORT_TYPED_FUNC - // that exports the function directly as C symbol that follows tvm-ffi abi. - // - // - For functions that are expected to be static part of tvm_ffi_example project, - // one can use reflection mechanisms to register the globa function. - // - For functions that are compiled and dynamically loaded at runtime, consider - // using the normal export mechanism so they won't be exposed to the global function table. - // - // Make sure to have a unique name across all registered functions, - // always prefix with a package namespace name to avoid name collision. - // - // The function can then be found via tvm_ffi.get_global_func(name) - // If the function is expected to stay throughout the lifetime of the program/ - // - // When registering via reflection mechanisms, the library do not need to be loaded via - // tvm::ffi::Module::LoadFromFile, instead, just load the dll or simply bundle into the - // final project - refl::GlobalDef().def("my_ffi_extension.raise_error", RaiseError); -} -} // namespace my_ffi_extension diff --git a/ffi/examples/quick_start/CMakeLists.txt b/ffi/examples/quick_start/CMakeLists.txt deleted file mode 100644 index 05530988000e..000000000000 --- a/ffi/examples/quick_start/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.18) -project(tvm_ffi_example) - - -# first find python related components -find_package(Python COMPONENTS Interpreter REQUIRED) - -# call tvm_ffi.config to get the cmake directory and set it to tvm_ffi_ROOT -execute_process( - COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT) -# find package will automatically include the related projects -find_package(tvm_ffi CONFIG REQUIRED) - -# use the projects as usual -add_library(add_one_cpu SHARED src/add_one_cpu.cc) -target_link_libraries(add_one_cpu tvm_ffi_header) -target_link_libraries(add_one_cpu tvm_ffi_shared) -# show as add_one_cpu.so -set_target_properties( - add_one_cpu PROPERTIES - PREFIX "" - SUFFIX ".so" -) - -# Check if CUDA is available -if(NOT WIN32) - find_package(CUDA QUIET) - if(CUDA_FOUND) - enable_language(CUDA) - add_library(add_one_cuda SHARED src/add_one_cuda.cu) - target_link_libraries(add_one_cuda tvm_ffi_shared) - - # show as add_one_cuda.so - set_target_properties( - add_one_cuda PROPERTIES - PREFIX "" - SUFFIX ".so" - ) - endif() -endif() - -add_executable(run_example src/run_example.cc) -set_target_properties( - run_example PROPERTIES - CXX_STANDARD 17 -) -target_link_libraries(run_example tvm_ffi_shared) diff --git a/ffi/examples/quick_start/README.md b/ffi/examples/quick_start/README.md deleted file mode 100644 index 002d4375a6dc..000000000000 --- a/ffi/examples/quick_start/README.md +++ /dev/null @@ -1,58 +0,0 @@ - - - - - - - - - - - - - - - - - -# Getting Started with TVM FFI - -This example demonstrates how to use tvm-ffi to expose a universal function -that can be loaded in different environments. - -The example implements a simple "add one" operation that adds 1 to each element -of an input tensor, showing how to create C++ functions callable from Python. - -You can run this quick start example by: - -```bash -# ensure you installed tvm-ffi first -pip install -e ../.. - -# Build and run the complete example -./run_example.sh -``` - -At a high level, the `TVM_FFI_DLL_EXPORT_TYPED_FUNC` macro helps to expose -a C++ function into the TVM FFI C ABI convention for functions. -Then the function can be accessed by different environments and languages -that interface with the TVM FFI. The current example shows how to do so -in Python and C++. - -## Key Files - -- `src/add_one_cpu.cc` - CPU implementation of the add_one function -- `src/add_one_cuda.cu` - CUDA implementation for GPU operations -- `run_example.py` - Python example showing how to call the functions -- `run_example.cc` - C++ example demonstrating the same functionality - -## Compile without CMake - -You can also compile the modules directly using -flags provided by the `tvm-ffi-config` tool. - -```bash -g++ -shared -fPIC `tvm-ffi-config --cxxflags` \ - src/add_one_cpu.cc -o build/add_one_cpu.so \ - `tvm-ffi-config --ldflags` `tvm-ffi-config --libs` -``` diff --git a/ffi/examples/quick_start/run_example.py b/ffi/examples/quick_start/run_example.py deleted file mode 100644 index a8f4fc00a600..000000000000 --- a/ffi/examples/quick_start/run_example.py +++ /dev/null @@ -1,82 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm_ffi - -try: - import torch -except ImportError: - torch = None - -import numpy -import ctypes - - -def run_add_one_cpu(): - """Load the add_one_cpu module and call the add_one_cpu function.""" - mod = tvm_ffi.load_module("build/add_one_cpu.so") - - x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) - y = numpy.empty_like(x) - # tvm-ffi automatically handles DLPack compatible tensors - # torch tensors can be viewed as ffi::Tensor or DLTensor* - # in the background - mod.add_one_cpu(x, y) - print("numpy.result after add_one(x, y)") - print(x) - - if torch is None: - return - - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32) - y = torch.empty_like(x) - # tvm-ffi automatically handles DLPack compatible tensors - # torch tensors can be viewed as ffi::Tensor or DLTensor* - # in the background - mod.add_one_cpu(x, y) - print("torch.result after add_one(x, y)") - print(y) - - -def run_add_one_cuda(): - """Load the add_one_cuda module and call the add_one_cuda function.""" - if torch is None or not torch.cuda.is_available(): - return - - mod = tvm_ffi.load_module("build/add_one_cuda.so") - x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") - y = torch.empty_like(x) - - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # tvm-ffi automatically handles DLPack compatible tensors - # it also handles interactions with torch runtime - # torch.cuda.current_stream() will be set and available via TVMFFIEnvGetStream - # when calling the function - mod.add_one_cuda(x, y) - stream.synchronize() - print("torch.result after mod.add_one_cuda(x, y)") - print(y) - - -def main(): - """Main function to run the example.""" - run_add_one_cpu() - run_add_one_cuda() - - -if __name__ == "__main__": - main() diff --git a/ffi/examples/quick_start/run_example.sh b/ffi/examples/quick_start/run_example.sh deleted file mode 100755 index 0602b85f3718..000000000000 --- a/ffi/examples/quick_start/run_example.sh +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -#!/bin/bash -set -ex - -cmake -B build -S . -cmake --build build - -# running python example -python run_example.py - -# running c++ example -./build/run_example diff --git a/ffi/examples/quick_start/src/add_one_cpu.cc b/ffi/examples/quick_start/src/add_one_cpu.cc deleted file mode 100644 index 76b9b3752c88..000000000000 --- a/ffi/examples/quick_start/src/add_one_cpu.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -namespace tvm_ffi_example { - -void AddOne(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } -} - -// Expose global symbol `add_one_cpu` that follows tvm-ffi abi -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example::AddOne); -} // namespace tvm_ffi_example diff --git a/ffi/examples/quick_start/src/add_one_cuda.cu b/ffi/examples/quick_start/src/add_one_cuda.cu deleted file mode 100644 index 52f1e7482505..000000000000 --- a/ffi/examples/quick_start/src/add_one_cuda.cu +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -namespace tvm_ffi_example { - -__global__ void AddOneKernel(float* x, float* y, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - y[idx] = x[idx] + 1; - } -} - -void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - - int64_t n = x->shape[0]; - int64_t nthread_per_block = 256; - int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; - // Obtain the current stream from the environment - // it will be set to torch.cuda.current_stream() when calling the function - // with torch.Tensors - cudaStream_t stream = - static_cast(TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); - // launch the kernel - AddOneKernel<<>>(static_cast(x->data), - static_cast(y->data), n); -} - -// Expose global symbol `add_one_cpu` that follows tvm-ffi abi -TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA); -} // namespace tvm_ffi_example diff --git a/ffi/examples/quick_start/src/run_example.cc b/ffi/examples/quick_start/src/run_example.cc deleted file mode 100644 index 90e61d170baa..000000000000 --- a/ffi/examples/quick_start/src/run_example.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -// This file shows how to load the same compiled module and interact with it in C++ -namespace ffi = tvm::ffi; - -struct CPUNDAlloc { - void AllocData(DLTensor* tensor) { tensor->data = malloc(ffi::GetDataSize(*tensor)); } - void FreeData(DLTensor* tensor) { free(tensor->data); } -}; - -inline ffi::Tensor Empty(ffi::Shape shape, DLDataType dtype, DLDevice device) { - return ffi::Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); -} - -int main() { - // load the module - ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so"); - - // create an Tensor, alternatively, one can directly pass in a DLTensor* - ffi::Tensor x = Empty({5}, DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); - for (int i = 0; i < 5; ++i) { - reinterpret_cast(x->data)[i] = static_cast(i); - } - - ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value(); - add_one_cpu(x, x); - - std::cout << "x after add_one_cpu(x, x)" << std::endl; - for (int i = 0; i < 5; ++i) { - std::cout << reinterpret_cast(x->data)[i] << " "; - } - std::cout << std::endl; - return 0; -} diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h deleted file mode 100644 index 738adc4f86ea..000000000000 --- a/ffi/include/tvm/ffi/any.h +++ /dev/null @@ -1,692 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/any.h - * \brief Any value support. - */ -#ifndef TVM_FFI_ANY_H_ -#define TVM_FFI_ANY_H_ - -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -class Any; - -namespace details { -// Helper to perform -// unsafe operations related to object -struct AnyUnsafe; -} // namespace details - -/*! - * \brief AnyView allows us to take un-managed reference view of any value. - */ -class AnyView { - protected: - /*! \brief The underlying backing data of the any object */ - TVMFFIAny data_; - // Any can see AnyView - friend class Any; - - public: - // NOTE: the following functions use style - // since they are common functions appearing in FFI. - /*! - * \brief Reset any view to None - */ - void reset() { - data_.type_index = TypeIndex::kTVMFFINone; - // invariance: always set the union padding part to 0 - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Swap this AnyView with another AnyView - * \param other The other AnyView - */ - TVM_FFI_INLINE void swap(AnyView& other) noexcept { std::swap(data_, other.data_); } - /*! \return the internal type index */ - TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } - /*! \brief Default constructor */ - AnyView() { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - ~AnyView() = default; - // constructors from any view - /*! \brief Copy constructor */ - AnyView(const AnyView&) = default; - /*! \brief Copy assignment operator */ - AnyView& operator=(const AnyView&) = default; - /*! \brief Move constructor */ - AnyView(AnyView&& other) : data_(other.data_) { - other.data_.type_index = TypeIndex::kTVMFFINone; - other.data_.zero_padding = 0; - other.data_.v_int64 = 0; - } - TVM_FFI_INLINE AnyView& operator=(AnyView&& other) { - // copy-and-swap idiom - AnyView(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Constructor from a general type. - * \tparam T The type to convert from. - * \param other The value to convert from. - */ - template ::convert_enabled>> - AnyView(const T& other) { // NOLINT(*) - TypeTraits::CopyToAnyView(other, &data_); - } - /*! - * \brief Assign from a general type. - * \tparam T The type to convert from. - * \param other The value to convert from. - */ - template ::convert_enabled>> - TVM_FFI_INLINE AnyView& operator=(const T& other) { // NOLINT(*) - // copy-and-swap idiom - AnyView(other).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief Try to see if we can reinterpret the AnyView to as T object. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try run type conversion (use try_cast for that purpose). - */ - template ::convert_enabled>> - TVM_FFI_INLINE std::optional as() const { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::CopyFromAnyViewAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - /*! - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T* as() const { - return this->as().value_or(nullptr); - } - - /*! - * \brief Cast to a type T. - * - * \tparam T The type to cast to. - * \return The casted value, or throws an exception if the cast is not possible. - */ - template ::convert_enabled>> - TVM_FFI_INLINE T cast() const { - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /*! - * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - */ - template ::convert_enabled>> - TVM_FFI_INLINE std::optional try_cast() const { - return TypeTraits::TryCastFromAnyView(&data_); - } - - // comparison with nullptr - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { - return data_.type_index == TypeIndex::kTVMFFINone; - } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { - return data_.type_index != TypeIndex::kTVMFFINone; - } - /*! - * \brief Get the type key of the Any - * \return The type key of the Any - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } - // The following functions are only used for testing purposes - /*! - * \return The underlying supporting data of any view - * \note This function is used only for testing purposes. - */ - TVM_FFI_INLINE TVMFFIAny CopyToTVMFFIAny() const { return data_; } - /*! - * \return Create an AnyView from TVMFFIAny - * \param data the underlying ffi data. - */ - TVM_FFI_INLINE static AnyView CopyFromTVMFFIAny(TVMFFIAny data) { - AnyView view; - view.data_ = data; - return view; - } -}; - -namespace details { -/*! - * \brief Helper function to inplace convert any view to any. - * \param data The pointer that represents the format as any view. - * \param extra_any_bytes Indicate that the data may contain extra bytes following - * the TVMFFIAny data structure. This is reserved for future possible optimizations - * of small-string and extended any object. - */ -TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, - [[maybe_unused]] size_t extra_any_bytes = 0) { - if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data->v_obj); - } else if (data->type_index >= TypeIndex::kTVMFFIRawStr) { - if (data->type_index == TypeIndex::kTVMFFIRawStr) { - // convert raw string to owned string object - String temp(data->v_c_str); - TypeTraits::MoveToAny(std::move(temp), data); - } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - // convert byte array to owned bytes object - Bytes temp(*static_cast(data->v_ptr)); - TypeTraits::MoveToAny(std::move(temp), data); - } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - // convert rvalue ref to owned object - Object** obj_addr = static_cast(data->v_ptr); - TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved"; - ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned(obj_addr[0])); - // set the rvalue ref to nullptr to avoid double move - obj_addr[0] = nullptr; - TypeTraits::MoveToAny(std::move(temp), data); - } - } -} -} // namespace details - -/*! - * \brief Managed Any that takes strong reference to a value. - * - * \note Develooper invariance: the TVMFFIAny data_ - * in the Any can be safely used in AnyView. - */ -class Any { - protected: - /*! \brief The underlying backing data of the any object */ - TVMFFIAny data_; - - public: - /*! - * \brief Reset any to None - */ - TVM_FFI_INLINE void reset() { - if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); - } - data_.type_index = TVMFFITypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Swap this Any with another Any - * \param other The other Any - */ - TVM_FFI_INLINE void swap(Any& other) noexcept { std::swap(data_, other.data_); } - /*! \return the internal type index */ - TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } - /*! - * \brief Default constructor - */ - Any() { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Destructor - */ - ~Any() { this->reset(); } - /*! - * \brief Constructor from another Any - * \param other The other Any - */ - Any(const Any& other) : data_(other.data_) { - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); - } - } - /*! - * \brief Move constructor from another Any - * \param other The other Any - */ - Any(Any&& other) : data_(other.data_) { - other.data_.type_index = TypeIndex::kTVMFFINone; - other.data_.zero_padding = 0; - other.data_.v_int64 = 0; - } - /*! - * \brief Assign from another Any - * \param other The other Any - */ - TVM_FFI_INLINE Any& operator=(const Any& other) { - // copy-and-swap idiom - Any(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Move assign from another Any - * \param other The other Any - */ - TVM_FFI_INLINE Any& operator=(Any&& other) { - // copy-and-swap idiom - Any(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Constructor from another AnyView - * \param other The other AnyView - */ - Any(const AnyView& other) : data_(other.data_) { // NOLINT(*) - details::InplaceConvertAnyViewToAny(&data_); - } - /*! - * \brief Assign from another AnyView - * \param other The other AnyView - */ - TVM_FFI_INLINE Any& operator=(const AnyView& other) { - // copy-and-swap idiom - Any(other).swap(*this); // NOLINT(*) - return *this; - } - /*! \brief Any can be converted to AnyView in zero cost. */ - operator AnyView() const { return AnyView::CopyFromTVMFFIAny(data_); } - /*! - * \brief Constructor from a general type - * \tparam T The value type of the other - */ - template ::convert_enabled>> - Any(T other) { // NOLINT(*) - TypeTraits::MoveToAny(std::move(other), &data_); - } - /*! - * \brief Assignment from a general type - * \tparam T The value type of the other - */ - template ::convert_enabled>> - TVM_FFI_INLINE Any& operator=(T other) { // NOLINT(*) - // copy-and-swap idiom - Any(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - /** - * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try to run type conversion (use try_cast for that purpose). - */ - template ::storage_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional as() && { - if constexpr (std::is_same_v) { - return std::move(*this); - } else { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::MoveFromAnyAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - } - - /** - * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try to run type conversion (use try_cast for that purpose). - */ - template ::convert_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional as() const& { - if constexpr (std::is_same_v) { - return *this; - } else { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::CopyFromAnyViewAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - } - - /*! - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T* as() const& { - return this->as().value_or(nullptr); - } - - /** - * \brief Cast to a type T, throw an exception if the cast is not possible. - * - * \tparam T The type to cast to. - */ - template ::convert_enabled>> - TVM_FFI_INLINE T cast() const& { - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /** - * \brief Cast to a type T, throw an exception if the cast is not possible. - * - * \tparam T The type to cast to. - */ - template ::storage_enabled>> - TVM_FFI_INLINE T cast() && { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::MoveFromAnyAfterCheck(&data_); - } - // slow path, try to do fallback convert - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /** - * \brief Try to cast to a type T. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note use STL name since it to be more consistent with cast API. - */ - template ::convert_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional try_cast() const { - if constexpr (std::is_same_v) { - return *this; - } else { - return TypeTraits::TryCastFromAnyView(&data_); - } - } - /*! - * \brief Check if the two Any are same type and value in shallow comparison. - * \param other The other Any - * \return True if the two Any are same type and value, false otherwise. - */ - TVM_FFI_INLINE bool same_as(const Any& other) const noexcept { - return data_.type_index == other.data_.type_index && - data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64; - } - - /*! - * \brief Check if any and ObjectRef are same type and value in shallow comparison. - * \param other The other ObjectRef - * \return True if the two Any are same type and value, false otherwise. - */ - TVM_FFI_INLINE bool same_as(const ObjectRef& other) const noexcept { - if (other.get() != nullptr) { - return (data_.type_index == other->type_index() && - reinterpret_cast(data_.v_obj) == other.get()); - } else { - return data_.type_index == TypeIndex::kTVMFFINone; - } - } - - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { - return data_.type_index == TypeIndex::kTVMFFINone; - } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { - return data_.type_index != TypeIndex::kTVMFFINone; - } - - /*! - * \brief Get the type key of the Any - * \return The type key of the Any - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } - - friend struct details::AnyUnsafe; - friend struct AnyHash; - friend struct AnyEqual; -}; - -// layout assert to ensure we can freely cast between the two types -static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); -static_assert(sizeof(Any) == sizeof(TVMFFIAny)); - -namespace details { - -template -struct Type2Str { - static std::string v() { return TypeTraitsNoCR::TypeStr(); } -}; - -template <> -struct Type2Str { - static std::string v() { return "Any"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "Any"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "AnyView"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "AnyView"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "void"; } -}; - -// Extra unsafe method to help any manipulation -struct AnyUnsafe : public ObjectUnsafe { - // FFI related operations - TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) { - TVMFFIAny result = any.data_; - any.data_.type_index = TypeIndex::kTVMFFINone; - any.data_.zero_padding = 0; - any.data_.v_int64 = 0; - return result; - } - - TVM_FFI_INLINE static Any MoveTVMFFIAnyToAny(TVMFFIAny&& data) { - Any any; - any.data_ = data; - data.type_index = TypeIndex::kTVMFFINone; - data.zero_padding = 0; - data.v_int64 = 0; - return any; - } - - template - TVM_FFI_INLINE static bool CheckAnyStrict(const Any& ref) { - return TypeTraits::CheckAnyStrict(&(ref.data_)); - } - - template - TVM_FFI_INLINE static T CopyFromAnyViewAfterCheck(const Any& ref) { - if constexpr (!std::is_same_v) { - return TypeTraits::CopyFromAnyViewAfterCheck(&(ref.data_)); - } else { - return ref; - } - } - - template - TVM_FFI_INLINE static T MoveFromAnyAfterCheck(Any&& ref) { - if constexpr (!std::is_same_v) { - return TypeTraits::MoveFromAnyAfterCheck(&(ref.data_)); - } else { - return std::move(ref); - } - } - - TVM_FFI_INLINE static Object* ObjectPtrFromAnyAfterCheck(const Any& ref) { - return reinterpret_cast(ref.data_.v_obj); - } - - TVM_FFI_INLINE static const TVMFFIAny* TVMFFIAnyPtrFromAny(const Any& ref) { - return &(ref.data_); - } - - template - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const Any& ref) { - return TypeTraits::GetMismatchTypeInfo(&(ref.data_)); - } -}; -} // namespace details - -/*! \brief String-aware Any equal functor */ -struct AnyHash { - /*! - * \brief Calculate the hash code of an Any - * \param a The given Any - * \return Hash code of a, string hash for strings and pointer address otherwise. - */ - uint64_t operator()(const Any& src) const { - if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) { - // for small string, we use the same type key hash as normal string - // so heap allocated string and on stack string will have the same hash - return details::StableHashCombine(TypeIndex::kTVMFFIStr, - details::StableHashSmallStrBytes(&src.data_)); - } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) { - // use byte the same type key as bytes - return details::StableHashCombine(TypeIndex::kTVMFFIBytes, - details::StableHashSmallStrBytes(&src.data_)); - } else if (src.data_.type_index == TypeIndex::kTVMFFIStr || - src.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* src_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(src); - return details::StableHashCombine(src.data_.type_index, - details::StableHashBytes(src_str->data, src_str->size)); - } else { - return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64); - } - } -}; - -/*! \brief String-aware Any hash functor */ -struct AnyEqual { - /*! - * \brief Check if the two Any are equal - * \param lhs left operand. - * \param rhs right operand - * \return String equality if both are strings, pointer address equality otherwise. - */ - bool operator()(const Any& lhs, const Any& rhs) const { - // header with type index - const int64_t* lhs_as_int64 = reinterpret_cast(&lhs.data_); - const int64_t* rhs_as_int64 = reinterpret_cast(&rhs.data_); - static_assert(sizeof(TVMFFIAny) == 16); - // fast path, check byte equality - if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) { - return true; - } - // common false case type index match, in this case we only need to pay attention to string - // equality - if (lhs.data_.type_index == rhs.data_.type_index) { - // specialy handle string hash - if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || - lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - const details::BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); - } - return false; - } else { - // type_index mismatch, if index is not string, return false - if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr && - lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) { - return false; - } - // small string and normal string comparison - if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) { - const details::BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size, - rhs.data_.small_str_len); - } - if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) { - const details::BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len, - rhs_str->size); - } - if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) { - const details::BytesObjBase* lhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size, - rhs.data_.small_str_len); - } - if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) { - const details::BytesObjBase* rhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len, - rhs_bytes->size); - } - return false; - } - } -}; -} // namespace ffi - -// Expose to the tvm namespace for usability -// Rationale: no ambiguity even in root -using tvm::ffi::Any; -using tvm::ffi::AnyView; - -} // namespace tvm -#endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h deleted file mode 100644 index c20f0e5c05cf..000000000000 --- a/ffi/include/tvm/ffi/base_details.h +++ /dev/null @@ -1,297 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/base_details.h - * \brief Internal detail utils that can be used by files in tvm/ffi. - * \note details headers are for internal use only - * and not to be directly used by user. - */ -#ifndef TVM_FFI_BASE_DETAILS_H_ -#define TVM_FFI_BASE_DETAILS_H_ - -#include -#include - -#include -#include - -#if defined(_MSC_VER) -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif - -#ifndef NOMINMAX -#define NOMINMAX -#endif - -#include - -#ifdef ERROR -#undef ERROR -#endif - -#endif -/// \cond Doxygen_Suppress - -#if defined(_MSC_VER) -#define TVM_FFI_INLINE [[msvc::forceinline]] inline -#else -#define TVM_FFI_INLINE [[gnu::always_inline]] inline -#endif - -/*! - * \brief Macro helper to force a function not to be inlined. - * It is only used in places that we know not inlining is good, - * e.g. some logging functions. - */ -#if defined(_MSC_VER) -#define TVM_FFI_NO_INLINE [[msvc::noinline]] -#else -#define TVM_FFI_NO_INLINE [[gnu::noinline]] -#endif - -#if defined(_MSC_VER) -#define TVM_FFI_UNREACHABLE() __assume(false) -#else -#define TVM_FFI_UNREACHABLE() __builtin_unreachable() -#endif - -#define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y -#define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y) - -#if defined(__GNUC__) || defined(__clang__) -#define TVM_FFI_FUNC_SIG __PRETTY_FUNCTION__ -#elif defined(_MSC_VER) -#define TVM_FFI_FUNC_SIG __FUNCSIG__ -#else -#define TVM_FFI_FUNC_SIG __func__ -#endif - -#if defined(__GNUC__) -// gcc and clang and attribute constructor -/// \cond Doxygen_Suppress -#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName) __attribute__((constructor)) static void FnName() -/// \endcond -/* - * \brief Macro that defines a block that will be called during static initialization. - * - * \code - * TVM_FFI_STATIC_INIT_BLOCK() { - * RegisterFunctions(); - * } - * \endcode - */ -#define TVM_FFI_STATIC_INIT_BLOCK() \ - TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__)) - -#else -/// \cond Doxygen_Suppress -// for other compilers, use the variable trick -#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName, RegVar) \ - static void FnName(); \ - [[maybe_unused]] static inline int RegVar = []() { \ - FnName(); \ - return 0; \ - }(); \ - static void FnName() - -#define TVM_FFI_STATIC_INIT_BLOCK() \ - TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__), \ - TVM_FFI_STR_CONCAT(__TVMFFIStaticInitReg, __COUNTER__)) -/// \endcond -#endif - -/* - * \brief Define the default copy/move constructor and assign operator - * \param TypeName The class typename. - */ -#define TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - TypeName(const TypeName& other) = default; \ - TypeName(TypeName&& other) = default; \ - TypeName& operator=(const TypeName& other) = default; \ - TypeName& operator=(TypeName&& other) = default; - -/** - * \brief marks the begining of a C call that logs exception - */ -#define TVM_FFI_LOG_EXCEPTION_CALL_BEGIN() \ - try { \ - (void)0 - -/*! - * \brief Marks the end of a C call that logs exception - */ -#define TVM_FFI_LOG_EXCEPTION_CALL_END(Name) \ - } \ - catch (const std::exception& err) { \ - std::cerr << "Exception caught during " << #Name << ":\n" << err.what() << std::endl; \ - exit(-1); \ - } - -/*! - * \brief Clear the padding parts so we can safely use v_int64 for hash - * and equality check even when the value stored is a pointer. - * - * This macro is used to clear the padding parts for hash and equality check - * in 32bit platform. - */ -#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ - if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \ - (result)->v_int64 = 0; \ - } - -namespace tvm { -namespace ffi { -namespace details { - -// for each iterator -struct for_each_dispatcher { - template - static void run(std::index_sequence, const F& f, Args&&... args) { // NOLINT(*) - (f(I, std::forward(args)), ...); - } -}; - -template -void for_each(const F& f, Args&&... args) { // NOLINT(*) - for_each_dispatcher::run(std::index_sequence_for{}, f, std::forward(args)...); -} - -/*! - * \brief hash an object and combines uint64_t key with previous keys - * - * This hash function is stable across platforms. - * - * \param key The left operand. - * \param value The right operand. - * \return the combined result. - */ -template ::value, bool> = true> -TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T& value) { - // XXX: do not use std::hash in this function. This hash must be stable - // across different platforms and std::hash is implementation dependent. - return key ^ (uint64_t(value) + 0x9e3779b9 + (key << 6) + (key >> 2)); -} - -/*! - * \brief Hash the binary bytes - * \param data The data pointer - * \param size The size of the bytes. - * \return the hash value. - */ -TVM_FFI_INLINE uint64_t StableHashBytes(const void* data_ptr, size_t size) { - const char* data = reinterpret_cast(data_ptr); - const constexpr uint64_t kMultiplier = 1099511628211ULL; - const constexpr uint64_t kMod = 2147483647ULL; - union Union { - uint8_t a[8]; - uint64_t b; - } u; - static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); - const char* it = data; - const char* end = it + size; - uint64_t result = 0; - if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { - // if alignment requirement is met, directly use load - if (reinterpret_cast(it) % 8 == 0) { - for (; it + 8 <= end; it += 8) { - u.b = *reinterpret_cast(it); - result = (result * kMultiplier + u.b) % kMod; - } - } else { - // unaligned version - for (; it + 8 <= end; it += 8) { - u.a[0] = it[0]; - u.a[1] = it[1]; - u.a[2] = it[2]; - u.a[3] = it[3]; - u.a[4] = it[4]; - u.a[5] = it[5]; - u.a[6] = it[6]; - u.a[7] = it[7]; - result = (result * kMultiplier + u.b) % kMod; - } - } - } else { - // need endian swap - for (; it + 8 <= end; it += 8) { - u.a[0] = it[7]; - u.a[1] = it[6]; - u.a[2] = it[5]; - u.a[3] = it[4]; - u.a[4] = it[3]; - u.a[5] = it[2]; - u.a[6] = it[1]; - u.a[7] = it[0]; - result = (result * kMultiplier + u.b) % kMod; - } - } - - if (it < end) { - u.b = 0; - uint8_t* a = u.a; - if (it + 4 <= end) { - a[0] = it[0]; - a[1] = it[1]; - a[2] = it[2]; - a[3] = it[3]; - it += 4; - a += 4; - } - if (it + 2 <= end) { - a[0] = it[0]; - a[1] = it[1]; - it += 2; - a += 2; - } - if (it + 1 <= end) { - a[0] = it[0]; - it += 1; - a += 1; - } - if constexpr (!TVM_FFI_IO_NO_ENDIAN_SWAP) { - std::swap(u.a[0], u.a[7]); - std::swap(u.a[1], u.a[6]); - std::swap(u.a[2], u.a[5]); - std::swap(u.a[3], u.a[4]); - } - result = (result * kMultiplier + u.b) % kMod; - } - return result; -} - -/*! - * \brief Same as StableHashBytes, but for small string data. - * \param data The data pointer - * \return the hash value. - */ -TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const TVMFFIAny* data) { - if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { - // fast path, no endian swap, simply hash as uint64_t - const constexpr uint64_t kMod = 2147483647ULL; - return data->v_uint64 % kMod; - } - return StableHashBytes(reinterpret_cast(data), sizeof(data->v_uint64)); -} - -} // namespace details -} // namespace ffi -} // namespace tvm -/// \endcond -#endif // TVM_FFI_BASE_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h deleted file mode 100644 index f13f820b7fc9..000000000000 --- a/ffi/include/tvm/ffi/c_api.h +++ /dev/null @@ -1,1097 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/c_api.h - * \brief This file defines the C convention of the FFI convention - */ -#ifndef TVM_FFI_C_API_H_ -#define TVM_FFI_C_API_H_ - -#include -#include - -/* - * \brief C-style Allocator that allocates memory for a DLPack tensor. - * \param prototype The prototype DLTensor to offer details about device and shape. - * \param out The output DLManagedTensorVersioned. - * \param error_ctx The context to set the error. - * \param SetError The function to set the error. - * \return 0 on success, -1 on failure. - * call SetError(error_ctx, kind, message) to set the error kind and message. - * \note Error propagation via SetError. - */ -typedef int (*DLPackTensorAllocator)( // - DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // - void (*SetError)(void* error_ctx, const char* kind, const char* message) // -); - -// Macros to do weak linking -#ifdef _MSC_VER -#define TVM_FFI_WEAK __declspec(selectany) -#else -#define TVM_FFI_WEAK __attribute__((weak)) -#endif - -// Defines two macros -// TVM_FFI_DLL: marks the function as a DLL export/import -// depending on whether TVM_FFI_EXPORTS is defined -// TVM_FFI_DLL_EXPORT: always marks the function as a DLL export -#if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) -#include -#define TVM_FFI_DLL EMSCRIPTEN_KEEPALIVE -#define TVM_FFI_DLL_EXPORT EMSCRIPTEN_KEEPALIVE -#endif -#if !defined(TVM_FFI_DLL) && defined(_MSC_VER) -#ifdef TVM_FFI_EXPORTS -#define TVM_FFI_DLL __declspec(dllexport) -#else -#define TVM_FFI_DLL __declspec(dllimport) -#endif -#define TVM_FFI_DLL_EXPORT __declspec(dllexport) -#endif -#ifndef TVM_FFI_DLL -#define TVM_FFI_DLL __attribute__((visibility("default"))) -#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default"))) -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __cplusplus -enum TVMFFITypeIndex : int32_t { -#else -typedef enum { -#endif - - /* - * \brief The root type of all FFI objects. - * - * We include it so TypeIndex captures all possible runtime values. - * `kTVMFFIAny` code will never appear in Any::type_index. - * However, it may appear in field annotations during reflection. - */ - kTVMFFIAny = -1, - // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) - // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, - // which is not owned by TVMFFIAny. It is required that the following - // invariant holds: - // - `Any::type_index` is never `kTVMFFIRawStr` - // - `AnyView::type_index` can be `kTVMFFIRawStr` - // - /*! \brief None/nullptr value */ - kTVMFFINone = 0, - /*! \brief POD int value */ - kTVMFFIInt = 1, - /*! \brief POD bool value */ - kTVMFFIBool = 2, - /*! \brief POD float value */ - kTVMFFIFloat = 3, - /*! \brief Opaque pointer object */ - kTVMFFIOpaquePtr = 4, - /*! \brief DLDataType */ - kTVMFFIDataType = 5, - /*! \brief DLDevice */ - kTVMFFIDevice = 6, - /*! \brief DLTensor* */ - kTVMFFIDLTensorPtr = 7, - /*! \brief const char* */ - kTVMFFIRawStr = 8, - /*! \brief TVMFFIByteArray* */ - kTVMFFIByteArrayPtr = 9, - /*! \brief R-value reference to ObjectRef */ - kTVMFFIObjectRValueRef = 10, - /*! \brief Small string on stack */ - kTVMFFISmallStr = 11, - /*! \brief Small bytes on stack */ - kTVMFFISmallBytes = 12, - /*! \brief Start of statically defined objects. */ - kTVMFFIStaticObjectBegin = 64, - /*! - * \brief Object, all objects starts with TVMFFIObject as its header. - * \note We will also add other fields - */ - kTVMFFIObject = 64, - /*! - * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } - */ - kTVMFFIStr = 65, - /*! - * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } - */ - kTVMFFIBytes = 66, - /*! \brief Error object. */ - kTVMFFIError = 67, - /*! \brief Function object. */ - kTVMFFIFunction = 68, - /*! - * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } - */ - kTVMFFIShape = 69, - /*! - * \brief Tensor object, layout = { TVMFFIObject, DLTensor, ... } - */ - kTVMFFITensor = 70, - /*! \brief Array object. */ - kTVMFFIArray = 71, - //---------------------------------------------------------------- - // more complex objects - //---------------------------------------------------------------- - /*! \brief Map object. */ - kTVMFFIMap = 72, - /*! \brief Runtime dynamic loaded module object. */ - kTVMFFIModule = 73, - /*! - * \brief Opaque python object. - * - * This is a special type index to indicate we are storing an opaque PyObject. - * Such object may interact with callback functions that are registered to support - * python-related operations. - * - * We only translate the objects that we do not recognize into this type index. - * - * \sa TVMFFIObjectCreateOpaque - */ - kTVMFFIOpaquePyObject = 74, - kTVMFFIStaticObjectEnd, - // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) - /*! \brief Start of type indices that are allocated at runtime. */ - kTVMFFIDynObjectBegin = 128 -#ifdef __cplusplus -}; -#else -} TVMFFITypeIndex; -#endif - -/*! \brief Handle to Object from C API's pov */ -typedef void* TVMFFIObjectHandle; - -/*! - * \brief bitmask of the object deleter flag. - */ -#ifdef __cplusplus -enum TVMFFIObjectDeleterFlagBitMask : int32_t { -#else -typedef enum { -#endif - /*! - * \brief deleter action when strong reference count becomes zero. - * Need to call destructor of the object but not free the memory block. - */ - kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0, - /*! - * \brief deleter action when weak reference count becomes zero. - * Need to free the memory block. - */ - kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1, - /*! - * \brief deleter action when both strong and weak reference counts become zero. - * \note This is the most common case. - */ - kTVMFFIObjectDeleterFlagBitMaskBoth = - (kTVMFFIObjectDeleterFlagBitMaskStrong | kTVMFFIObjectDeleterFlagBitMaskWeak), -#ifdef __cplusplus -}; -#else -} TVMFFIObjectDeleterFlagBitMask; -#endif - -/*! - * \brief C-based type of all FFI object header that allocates on heap. - * \note TVMFFIObject and TVMFFIAny share the common type_index header - */ -typedef struct { - /*! - * \brief type index of the object. - * \note The type index of Object and Any are shared in FFI. - */ - int32_t type_index; - /*! - * \brief Weak reference counter of the object, for compatiblity with weak_ptr design. - * \note Use u32 to ensure that overall object stays within 24-byte boundary, usually - * manipulation of weak counter is less common than strong counter. - */ - uint32_t weak_ref_count; - /*! \brief Strong reference counter of the object. */ - uint64_t strong_ref_count; - union { - /*! - * \brief Deleter to be invoked when strong reference counter goes to zero. - * \param self The self object handle. - * \param flags The flags to indicate deletion behavior. - * \sa TVMFFIObjectDeleterFlagBitMask - */ - void (*deleter)(void* self, int flags); - /*! - * \brief auxilary field to TVMFFIObject is always 8 bytes aligned. - * \note This helps us to ensure cross platform compatibility. - */ - int64_t __ensure_align; - }; -} TVMFFIObject; - -/*! - * \brief C-based type of all on stack Any value. - * - * Any value can hold on stack values like int, - * as well as reference counted pointers to object. - */ -typedef struct { - /*! - * \brief type index of the object. - * \note The type index of Object and Any are shared in FFI. - */ - int32_t type_index; - union { // 4 bytes - /*! \brief padding, must set to zero for values other than small string. */ - uint32_t zero_padding; - /*! - * \brief Length of small string, with a max value of 7. - * - * We keep small str to start at next 4 bytes to ensure alignment - * when accessing the small str content. - */ - uint32_t small_str_len; - }; - union { // 8 bytes - int64_t v_int64; // integers - double v_float64; // floating-point numbers - void* v_ptr; // typeless pointers - const char* v_c_str; // raw C-string - TVMFFIObject* v_obj; // ref counted objects - DLDataType v_dtype; // data type - DLDevice v_device; // device - char v_bytes[8]; // small string - char32_t v_char32[2]; // small UCS4 string and Unicode - uint64_t v_uint64; // uint64 repr mainly used for hashing - }; -} TVMFFIAny; - -/*! - * \brief Byte array data structure used by String and Bytes. - * - * String and Bytes object layout = { TVMFFIObject, TVMFFIByteArray, ... } - * - * \note This byte array data structure layout differs in 32/64 bit platforms. - * as size_t equals to the size of the pointer, use this convetion to - * be consistent with std::string and also avoid need to calculate padding - * for the size field on 32-bit platforms. - * The FFI binding should be careful when treating this ABI. - */ -typedef struct { - /*! \brief The data pointer. */ - const char* data; - /*! \brief The size of the data. */ - size_t size; -} TVMFFIByteArray; - -/*! - * \brief Shape cell used in shape object following header. - */ -typedef struct { - /*! \brief The data pointer. */ - const int64_t* data; - /*! \brief The size of the data. */ - size_t size; -} TVMFFIShapeCell; - -/*! - * \brief Error cell used in error object following header. - */ -typedef struct { - /*! \brief The kind of the error. */ - TVMFFIByteArray kind; - /*! \brief The message of the error. */ - TVMFFIByteArray message; - /*! - * \brief The traceback of the error. - */ - TVMFFIByteArray traceback; - /*! - * \brief Function handle to update the traceback of the error. - * \param self The self object handle. - * \param traceback The traceback to update. - */ - void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback); -} TVMFFIErrorCell; - -/*! - * \brief Type that defines C-style safe call convention - * - * Safe call explicitly catches exception on function boundary. - * - * \param handle The function handle - * \param num_args Number of input arguments - * \param args The input arguments to the call. - * \param result Store output result. - * - * IMPORTANT: caller must initialize result->type_index to be kTVMFFINone, - * or any other value smaller than kTVMFFIStaticObjectBegin. - * - * \return The call returns 0 if call is successful. - * It returns non-zero value if there is an error. - * - * Possible return error of the API functions: - * * 0: success - * * -1: error happens, can be retrieved by TVMFFIErrorMoveFromRaised - * * -2: a frontend error occurred and recorded in the frontend. - * - * \note We decided to leverage TVMFFIErrorMoveFromRaised and TVMFFIErrorSetRaised - * for C function error propagation. This design choice, while - * introducing a dependency for TLS runtime, simplifies error - * propgation in chains of calls in compiler codegen. - * As we do not need to propagate error through argument but simply - * set them in the runtime environment. - * - * \sa TVMFFIErrorMoveFromRaised - * \sa TVMFFIErrorSetRaised - * \sa TVMFFIErrorSetRaisedFromCStr - */ -typedef int (*TVMFFISafeCallType)(void* handle, const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result); - -/*! - * \brief Object cell for function object following header. - */ -typedef struct { - /*! \brief A C API compatible call with exception catching. */ - TVMFFISafeCallType safe_call; -} TVMFFIFunctionCell; - -/*! - * \brief Object cell for opaque object following header. - */ -typedef struct { - /*! \brief The handle of the opaque object, for python it is PyObject* */ - void* handle; -} TVMFFIOpaqueObjectCell; - -//------------------------------------------------------------ -// Section: Basic object API -//------------------------------------------------------------ -/*! - * \brief Increase the strong reference count of an object handle - * \param obj The object handle. - * \note Internally we increase the reference counter of the object. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj); - -/*! - * \brief Free an object handle by decreasing strong reference - * \param obj The object handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); - -/*! - * \brief Create an Opaque object by passing in handle, type_index and deleter. - * - * The opaque object's lifetime is managed as an Object, so it can be retained - * and released like other objects. - * When the opaque object is kTVMFFIOpaquePyObject, it can be converted back to - * the python type when returned or passed as arguments to a python function. - * - * We can support ffi::Function that interacts with these objects, - * most likely callback registered from python. - * - * For language bindings, we only convert types that we do not recognize into this type. - * On the C++ side, the most common way to represent such OpaqueObject is to simply - * use ffi::ObjectRef or ffi::Any. - * - * \param handle The resource handle of the opaque object. - * \param type_index The type index of the object. - * \param deleter deleter to recycle - * \param out The output of the opaque object. - * \return 0 when success, nonzero when failure happens - * - * \note The caller must ensure the type_index is a valid opaque object type index. - * \sa kTVMFFIOpaquePyObject - */ -TVM_FFI_DLL int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, - void (*deleter)(void* handle), TVMFFIObjectHandle* out); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_tindex the corresponding type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); - -//----------------------------------------------------------------------- -// Section: Basic function calling API for function implementation -//----------------------------------------------------------------------- -/*! - * \brief Create a FFIFunc by passing in callbacks from a C callback. - * The registered function can then be retrieved by the backend using its name. - * \param self The resource handle of the C callback. - * \param safe_call The C callback implementation. - * \param deleter The deleter to recycle. - * \param out The output of the function. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void* self), TVMFFIObjectHandle* out); - -/*! - * \brief Get a global function registered in the system. - * \param name The name of the function. - * \param out The result function pointer, NULL if it does not exist. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); - -/*! - * \brief Convert an AnyView to an owned Any. - * \param any_view The AnyView to convert. - * \param out The output Any, must be an empty object. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); - -/*! - * \brief Call a FFIFunc by passing in arguments. - * \param func The resource handle of the C callback. - * \param args The input arguments to the call. - * \param num_args The number of input arguments. - * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result); - -/*! - * \brief Move the last error from the environment to the result. - * \param result The result error. - * \note This function clears the error stored in the TLS. - */ -TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result); - -/*! - * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. - * \param error The error object handle - */ -TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); - -/*! - * \brief Set a raised error in TLS, which can be fetched by TVMFFIMoveFromRaised. - * \param kind The kind of the error. - * \param message The error message. - * \note This is a convenient method for the C API side to set an error directly from a string. - */ -TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message); - -/*! - * \brief Create an initial error object. - * \param kind The kind of the error. - * \param message The error message. - * \param traceback The traceback of the error. - * \return The created error object handle. - * \note This function is different from other functions as it is used in the error handling loop. - * So we do not follow normal error handling patterns via returning an error code. - */ -TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, - const TVMFFIByteArray* message, - const TVMFFIByteArray* traceback); - -//------------------------------------------------------------ -// Section: DLPack support APIs -//------------------------------------------------------------ -/*! - * \brief Produce a managed Tensor from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment required of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output Tensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorFromDLPack(DLManagedTensor* from, int32_t require_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out); - -/*! - * \brief Produce a DLManagedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out); - -/*! - * \brief Produce a managed Tensor from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment required of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output Tensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* from, - int32_t require_alignment, - int32_t require_contiguous, - TVMFFIObjectHandle* out); - -/*! - * \brief Produce a DLManagedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle from, - DLManagedTensorVersioned** out); -//--------------------------------------------------------------- -// Section: string/bytes support APIs. -// These APIs are used to simplify the string/bytes construction -//--------------------------------------------------------------- -/*! - * \brief Reinterpret the content of TVMFFIByteArray to String. - * \param input The TVMFFIByteArray to convert. - * \param out The output String owned by the caller, maybe a SmallStr or a Str object. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIStringFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out); - -/*! - * \brief Reinterpret the content of TVMFFIByteArray to Bytes. - * \param input The TVMFFIByteArray to convert. - * \param out The output Bytes owned by the caller, maybe a SmallBytes or a Bytes object. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out); - -//--------------------------------------------------------------- -// Section: dtype string support APIs. -// These APIs are used to simplify the dtype printings during FFI -//--------------------------------------------------------------- - -/*! - * \brief Convert a string to a DLDataType. - * \param str The string to convert. - * \param out The output DLDataType. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); - -/*! -* \brief Convert a DLDataType to a string. -* \param dtype The DLDataType to convert. -* \param out The output string. -* \return 0 on success, nonzero on failure. -* \note out is a String object that needs to be freed by the caller via TVMFFIObjectDecRef. -The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. - -* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. -*/ -TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out); - -//------------------------------------------------------------ -// Section: Type reflection support APIs -// -// The reflec -//------------------------------------------------------------ -/*! - * \brief Getter that can take the address of a field and set the result. - * \param field The raw address of the field. - * \param result Stores the result. - * \return 0 on success, nonzero on failure. - */ -typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result); - -/*! - * \brief Getter that can take the address of a field and set it to a value. - * \param field The raw address of the field. - * \param value The value to set. - * \return 0 on success, nonzero on failure. - */ -typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value); - -/*! - * \brief Function that creates a new instance of the type. - * \param result The new object handle - * \return 0 on success, nonzero on failure. - */ -typedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result); - -/*! - * \brief bitmask of the field. - */ -#ifdef __cplusplus -enum TVMFFIFieldFlagBitMask : int32_t { -#else -typedef enum { -#endif - /*! \brief The field is writable. */ - kTVMFFIFieldFlagBitMaskWritable = 1 << 0, - /*! \brief The field has default value. */ - kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1, - /*! \brief The field is a static method. */ - kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2, - /*! - * \brief The field should be ignored when performing structural eq/hash - * - * This is an optional meta-data for structural eq/hash. - */ - kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3, - /*! - * \brief The field enters a def region where var can be defined/matched. - * - * This is an optional meta-data for structural eq/hash. - */ - kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4, -#ifdef __cplusplus -}; -#else -} TVMFFIFieldFlagBitMask; -#endif - -/*! - * \brief Optional meta-data for structural eq/hash. - * - * This meta-data is only useful when we want to leverage the information - * to perform richer semantics aware structural comparison and hash. - * It can be safely ignored if such information is not needed. - * - * The meta-data record comparison method in tree node and DAG node. - * - * \code - * x = VarNode() - * v0 = AddNode(x, 1) - * v1 = AddNode(x, 1) - * v2 = AddNode(v0, v0) - * v3 = AddNode(v1, v0) - * \endcode - * - * Consider the construct sequence of AddNode below, - * if AddNode is treated as a tree node, then v2 and v3 - * structural equals to each other, but if AddNode is - * treated as a DAG node, then v2 and v3 does not - * structural equals to each other. - */ -#ifdef __cplusplus -enum TVMFFISEqHashKind : int32_t { -#else -typedef enum { -#endif - /*! \brief Do not support structural eq/hash. */ - kTVMFFISEqHashKindUnsupported = 0, - /*! - * \brief The object be compared as a tree node. - */ - kTVMFFISEqHashKindTreeNode = 1, - /*! - * \brief The object is treated as a free variable that can be mapped - * to another free variable in the definition region. - */ - kTVMFFISEqHashKindFreeVar = 2, - /*! - * \brief The field should be compared as a DAG node. - */ - kTVMFFISEqHashKindDAGNode = 3, - /*! - * \brief The object is treated as a constant tree node. - * - * Same as tree node, but the object does not contain free var - * as any of its nested children. - * - * That means we can use pointer equality for equality. - */ - kTVMFFISEqHashKindConstTreeNode = 4, - /*! - * \brief One can simply use pointer equality for equality. - * - * This is useful for "singleton"-style object that can - * is only an unique copy of each value. - */ - kTVMFFISEqHashKindUniqueInstance = 5, -#ifdef __cplusplus -}; -#else -} TVMFFISEqHashKind; -#endif - -/*! - * \brief Information support for optional object reflection. - */ -typedef struct { - /*! \brief The name of the field. */ - TVMFFIByteArray name; - /*! \brief The docstring about the field. */ - TVMFFIByteArray doc; - /*! \brief The type schema of the field in JSON string. */ - TVMFFIByteArray type_schema; - /*! - * \brief bitmask flags of the field. - */ - int64_t flags; - /*! \brief The size of the field. */ - int64_t size; - /*! \brief The alignment of the field. */ - int64_t alignment; - /*! \brief The offset of the field. */ - int64_t offset; - /*! \brief The getter to access the field. */ - TVMFFIFieldGetter getter; - /*! - * \brief The setter to access the field. - * \note The setter is set even if the field is readonly for serialization. - */ - TVMFFIFieldSetter setter; - /*! - * \brief The default value of the field, this field hold AnyView, - * valid when flags set kTVMFFIFieldFlagBitMaskHasDefault - */ - TVMFFIAny default_value; - /*! - * \brief Records the static type kind of the field. - * - * Possible values: - * - * - TVMFFITypeIndex::kTVMFFIObject for general objects. - * The value is nullable when kTVMFFIObject is chosen. - * - Static object type kinds such as Map, Dict, String - * - POD type index, note it does not give information about storage size of the field. - * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info - * about the field. - * - * When the value is a type index of Object type, the field is storaged as an ObjectRef. - * - * \note This information maybe helpful in designing serializer. - * As it helps to narrow down the field type so we don't have to - * print type_key for cases like POD types. - * It also helps to provide opportunities to enable short-cut getter to ObjectRef fields. - */ - int32_t field_static_type_index; -} TVMFFIFieldInfo; - -/*! - * \brief Method information that can appear in reflection table. - */ -typedef struct { - /*! \brief The name of the field. */ - TVMFFIByteArray name; - /*! \brief The docstring about the method. */ - TVMFFIByteArray doc; - /*! \brief Optional type schema of the method in JSON string. */ - TVMFFIByteArray type_schema; - /*! \brief bitmask flags of the method. */ - int64_t flags; - /*! - * \brief The method wrapped as ffi::Function, stored as AnyView. - * \note The first argument to the method is always the self for instance methods. - */ - TVMFFIAny method; -} TVMFFIMethodInfo; - -/*! - * \brief Extra information of object type that can be used for reflection. - * - * \note This information is optional and can be used to enable reflection based - * creation of the object. - */ -typedef struct { - /*! \brief The docstring about the object. */ - TVMFFIByteArray doc; - /*! - * \brief An optional function that can create a new empty instance of the type. - * - * When known_fixed_size is non-zero, creator can be called - * with nullptr passed to optional_bytes. - * - * \note Caller must call setter for each field to initialize the object for - * the final object to be in valid state. - * - * \note This field is optional to enable reflection based creation. - */ - TVMFFIObjectCreator creator; - /*! - * \brief Total size of the object struct, if it is fixed and known. - * - * This field is set optional and set to 0 if not registered. - */ - int32_t total_size; - /*! - * \brief Optional meta-data for structural eq/hash. - */ - TVMFFISEqHashKind structural_eq_hash_kind; -} TVMFFITypeMetadata; - -/*! - * \brief Column array that stores extra attributes about types - * - * The attributes stored in a column array that can be looked up by type index. - * Note that the TypeAttr behaves like type_traits so column[T] so not contain - * attributes from base classes. - * - * \note - * \sa TVMFFIRegisterTypeAttr - */ -typedef struct { - /*! \brief The data of the column. */ - const TVMFFIAny* data; - /*! \brief The size of the column. */ - size_t size; -} TVMFFITypeAttrColumn; - -/*! - * \brief Runtime type information for object type checking. - */ -#ifdef __cplusplus -struct TVMFFITypeInfo { -#else -typedef struct TVMFFITypeInfo { -#endif - /*! - *\brief The runtime type index, - * It can be allocated during runtime if the type is dynamic. - */ - int32_t type_index; - /*! \brief number of parent types in the type hierachy. */ - int32_t type_depth; - /*! \brief the unique type key to identify the type. */ - TVMFFIByteArray type_key; - /*! - * \brief type_acenstors[depth] stores the type_index of the acenstors at depth level - * \note To keep things simple, we do not allow multiple inheritance so the - * hieracy stays as a tree - */ - const struct TVMFFITypeInfo** type_acenstors; - // The following fields are used for reflection - /*! \brief Cached hash value of the type key, used for consistent structural hashing. */ - uint64_t type_key_hash; - /*! \brief number of reflection accessible fields. */ - int32_t num_fields; - /*! \brief number of reflection acccesible methods. */ - int32_t num_methods; - /*! \brief The reflection field information. */ - const TVMFFIFieldInfo* fields; - /*! \brief The reflection method. */ - const TVMFFIMethodInfo* methods; - /*! \brief The extra information of the type. */ - const TVMFFITypeMetadata* metadata; -#ifdef __cplusplus -}; -#else -} TVMFFITypeInfo; -#endif - -/*! - * \brief Register the function to runtime's global table. - * The registered function can then be retrieved by the backend using its name. - * \param name The name of the function. - * \param f The function to be registered. - * \param allow_override Whether to allow overriding an already registered function. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, - int allow_override); - -/*! - * \brief Register the function to runtime's global table with method info. - * This is the same as TVMFFIFunctionSetGlobal but with method info that can provide extra - * metadata used in the runtime. - * \param method_info The method info to be registered. - * \param allow_override Whether to allow overriding an already registered function. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo* method_info, - int allow_override); - -/*! - * \brief Register type field information for runtime reflection. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info); - -/*! - * \brief Register type method information for runtime reflection. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info); - -/*! - * \brief Register type creator information for runtime reflection. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata); - -/*! - * \brief Register extra type attributes that can be looked up during runtime. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* attr_name, - const TVMFFIAny* attr_value); - -/*! - * \brief Get the type attribute column by name. - * \return The pointer to the type attribute column. - * \return NULL if the attribute was not registered in the system. - */ -TVM_FFI_DLL const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* attr_name); - -//------------------------------------------------------------ -// Section: Backend noexcept functions for internal use -// -// These functions are used internally and do not throw error -// instead the error will be logged and abort the process -// These are function are being called in startup or exit time -// so exception handling do not apply -//------------------------------------------------------------ -/*! - * \brief Get stack traceback in a string. - * \param filename The current file name. - * \param lineno The current line number - * \param func The current function - * \param cross_ffi_boundary Whether the traceback is crossing the ffi boundary - * or we should stop at the ffi boundary when detected - * \return The traceback string - * - * \note filename/func can be nullptr, then this info is skipped, they are useful - * for cases when debug symbols are not available. - */ -TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, - const char* func, int cross_ffi_boundary); - -/*! - * \brief Initialize the type info during runtime. - * - * When the function is first called for a type, - * it will register the type to the type table in the runtime. - * If the static_tindex is non-negative, the function will - * allocate a runtime type index. - * Otherwise, we will populate the type table and return the static index. - * - * \param type_key The type key. - * \param type_depth The type depth. - * \param static_type_index Static type index if any, can be -1, which means this is a dynamic index - * \param num_child_slots Number of slots reserved for its children. - * \param child_slots_can_overflow Whether to allow child to overflow the slots. - * \param parent_type_index Parent type index, pass in -1 if it is root. - * - * \return The allocated type index. - */ -TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, - int32_t static_type_index, int32_t type_depth, - int32_t num_child_slots, - int32_t child_slots_can_overflow, - int32_t parent_type_index); - -/*! - * \brief Get dynamic type info by type index. - * \return The type info. - */ -TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); - -#ifdef __cplusplus -} // TVM_FFI_EXTERN_C -#endif - -//--------------------------------------------------------------- -// The following API defines static object attribute accessors -// for language bindings. -// -// They are defined in C++ inline functions for cleaner code. -// Note that they only have to do with address offset computation. -// So they can always be reimplemented in bindings when c++ is -// not available or when binding only wants to refer to the dll. -//---------------------------------------------------------------- -#ifdef __cplusplus -/*! - * \brief Get the type index of an object. - * \param obj The object handle. - * \return The type index. - */ -inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) { - return static_cast(obj)->type_index; -} - -/*! - * \brief Get the content of a small string in bytearray format. - * \param value The value to get the content of the small string in bytearray format. - * \return The content of the small string in bytearray format. - */ -inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) { - return TVMFFIByteArray{value->v_bytes, static_cast(value->small_str_len)}; -} - -/*! - * \brief Get the data pointer of a bytearray from a string or bytes object. - * \param obj The object handle. - * \return The data pointer. - */ -inline TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a ErrorInfo from an Error object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a function cell from a function object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIFunctionCell* TVMFFIFunctionGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a opaque object cell from a opaque object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIOpaqueObjectCell* TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + - sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a shape array from a shape object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the DLTensor pointer from an Tensor object. - * \param obj The object handle. - * \return The DLTensor pointer. - */ -inline DLTensor* TVMFFITensorGetDLTensorPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Create a DLDevice from a device type and device id. - * \param device_type The device type. - * \param device_id The device id. - * \return The DLDevice. - */ -inline DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) { - return DLDevice{static_cast(device_type), device_id}; -} -#endif // __cplusplus -#endif // TVM_FFI_C_API_H_ diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h deleted file mode 100644 index 398953ad6508..000000000000 --- a/ffi/include/tvm/ffi/cast.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/cast.h - * \brief Extra value casting helpers - */ -#ifndef TVM_FFI_CAST_H_ -#define TVM_FFI_CAST_H_ - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Get a reference type from a raw object ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the object alive beyond the scope of the function. - * - * \param ptr The object pointer - * \tparam RefType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline RefType GetRef(const ObjectType* ptr) { - using ContainerType = typename RefType::ContainerType; - static_assert(std::is_base_of_v, - "Can only cast to the ref of same container type"); - - if constexpr (is_optional_type_v || RefType::_type_is_nullable) { - if (ptr == nullptr) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } else { - TVM_FFI_ICHECK_NOTNULL(ptr); - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromUnowned( - const_cast(static_cast(ptr)))); -} - -/*! - * \brief Get an object ptr type from a raw object ptr. - * - * \param ptr The object pointer - * \tparam BaseType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline ObjectPtr GetObjectPtr(ObjectType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); -} -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CAST_H_ diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h deleted file mode 100644 index db025c02d863..000000000000 --- a/ffi/include/tvm/ffi/container/array.h +++ /dev/null @@ -1,1147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/array.h - * \brief Array type. - * - * tvm::ffi::Array is an erased type that contains a list of content - */ -#ifndef TVM_FFI_CONTAINER_ARRAY_H_ -#define TVM_FFI_CONTAINER_ARRAY_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! \brief Array node content in array */ -class ArrayObj : public Object, public details::InplaceArrayBase { - public: - ~ArrayObj() { - Any* begin = MutableBegin(); - for (int64_t i = 0; i < size_; ++i) { - (begin + i)->Any::~Any(); - } - if (data_deleter_ != nullptr) { - data_deleter_(data_); - } - } - - /*! \return The size of the array */ - size_t size() const { return this->size_; } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const Any& at(int64_t i) const { return this->operator[](i); } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const Any& operator[](int64_t i) const { - if (i >= size_) { - TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; - } - return static_cast(data_)[i]; - } - - /*! \return begin constant iterator */ - const Any* begin() const { return static_cast(data_); } - - /*! \return end constant iterator */ - const Any* end() const { return begin() + size_; } - - /*! \brief Release reference to all the elements */ - void clear() { ShrinkBy(size_); } - - /*! - * \brief Set i-th element of the array in-place - * \param i The index - * \param item The value to be set - */ - void SetItem(int64_t i, Any item) { - if (i >= size_) { - TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; - } - static_cast(data_)[i] = std::move(item); - } - - /*! - * \brief Constructs a container and copy from another - * \param cap The capacity of the container - * \param from Source of the copy - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr CopyFrom(int64_t cap, ArrayObj* from) { - int64_t size = from->size_; - if (size > cap) { - TVM_FFI_THROW(ValueError) << "Not enough capacity"; - } - ObjectPtr p = ArrayObj::Empty(cap); - Any* write = p->MutableBegin(); - Any* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) Any(*read++); - } - return p; - } - - /*! - * \brief Constructs a container and move from another - * \param cap The capacity of the container - * \param from Source of the move - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr MoveFrom(int64_t cap, ArrayObj* from) { - int64_t size = from->size_; - if (size > cap) { - TVM_FFI_THROW(RuntimeError) << "Not enough capacity"; - } - ObjectPtr p = ArrayObj::Empty(cap); - Any* write = p->MutableBegin(); - Any* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) Any(std::move(*read++)); - } - from->size_ = 0; - return p; - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr CreateRepeated(int64_t n, const Any& val) { - ObjectPtr p = ArrayObj::Empty(n); - Any* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < n; ++i) { - new (itr++) Any(val); - } - return p; - } - - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIArray, ArrayObj, Object); - /// \endcond - - private: - /*! \return Size of initialized memory, used by InplaceArrayBase. */ - size_t GetSize() const { return this->size_; } - - /*! \return begin mutable iterator */ - Any* MutableBegin() const { return static_cast(this->data_); } - - /*! \return end mutable iterator */ - Any* MutableEnd() const { return MutableBegin() + size_; } - - /*! - * \brief Emplace a new element at the back of the array - * \param idx The index of the element. - * \param args The arguments to construct the new element - */ - template - void EmplaceInit(size_t idx, Args&&... args) { - Any* itr = MutableBegin() + idx; - new (itr) Any(std::forward(args)...); - } - - /*! - * \brief Create an ArrayObj with the given capacity. - * \param n Required capacity - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr Empty(int64_t n = kInitSize) { - ObjectPtr p = make_inplace_array_object(n); - p->capacity_ = n; - p->size_ = 0; - p->data_ = p->AddressOf(0); - return p; - } - - /*! - * \brief Inplace-initialize the elements starting idx from [first, last) - * \param idx The starting point - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return Self - */ - template - ArrayObj* InitRange(int64_t idx, IterType first, IterType last) { - Any* itr = MutableBegin() + idx; - for (; first != last; ++first) { - Any ref = *first; - new (itr++) Any(std::move(ref)); - } - return this; - } - - /*! - * \brief Move elements from right to left, requires src_begin > dst - * \param dst Destination - * \param src_begin The start point of copy (inclusive) - * \param src_end The end point of copy (exclusive) - * \return Self - */ - ArrayObj* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { - Any* from = MutableBegin() + src_begin; - Any* to = MutableBegin() + dst; - while (src_begin++ != src_end) { - *to++ = std::move(*from++); - } - return this; - } - - /*! - * \brief Move elements from left to right, requires src_begin < dst - * \param dst Destination - * \param src_begin The start point of move (inclusive) - * \param src_end The end point of move (exclusive) - * \return Self - */ - ArrayObj* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { - Any* from = MutableBegin() + src_end; - Any* to = MutableBegin() + (src_end - src_begin + dst); - while (src_begin++ != src_end) { - *--to = std::move(*--from); - } - return this; - } - - /*! - * \brief Enlarges the size of the array - * \param delta Size enlarged, should be positive - * \param val Default value - * \return Self - */ - ArrayObj* EnlargeBy(int64_t delta, const Any& val = Any()) { - Any* itr = MutableEnd(); - while (delta-- > 0) { - new (itr++) Any(val); - ++size_; - } - return this; - } - - /*! - * \brief Shrinks the size of the array - * \param delta Size shrinked, should be positive - * \return Self - */ - ArrayObj* ShrinkBy(int64_t delta) { - Any* itr = MutableEnd(); - while (delta-- > 0) { - (--itr)->Any::~Any(); - --size_; - } - return this; - } - - /*! \brief Data pointer to the first element of the array */ - void* data_; - /*! \brief Number of elements used */ - int64_t size_; - /*! \brief Number of elements allocated */ - int64_t capacity_; - /*! - * \brief Optional data deleter when data is allocated separately - * and its deletion is not managed by ArrayObj::deleter_. - */ - void (*data_deleter_)(void*) = nullptr; - - /*! \brief Initial size of ArrayObj */ - static constexpr int64_t kInitSize = 4; - - /*! \brief Expansion factor of the Array */ - static constexpr int64_t kIncFactor = 2; - - // CRTP parent class - friend InplaceArrayBase; - - // Reference class - template - friend class Array; - - template - friend class Tuple; - - template - friend struct TypeTraits; - - // To specialize make_object - friend ObjectPtr make_object<>(); -}; - -/*! \brief Helper struct for type-checking - * - * is_valid_iterator::value will be true if IterType can - * be dereferenced into a type that can be stored in an Array, and - * false otherwise. - */ -template -struct is_valid_iterator - : std::bool_constant< - std::is_same_v< - T, std::remove_cv_t())>>> || - std::is_base_of_v< - T, std::remove_cv_t())>>>> { -}; - -template -struct is_valid_iterator, IterType> : is_valid_iterator {}; - -template -struct is_valid_iterator : std::true_type {}; - -/*! - * \brief Check whether IterType is valid iterator for T. - * \tparam T The type. - * \tparam IterType The type of iterator. - */ -template -inline constexpr bool is_valid_iterator_v = is_valid_iterator::value; - -/*! - * \brief Array, container representing a contiguous sequence of ObjectRefs. - * - * Array implements in-place copy-on-write semantics. - * - * As in typical copy-on-write, a method which would typically mutate the array - * instead opaquely copies the underlying container, and then acts on its copy. - * - * If the array has reference count equal to one, we directly update the - * container in place without copying. This is optimization is sound because - * when the reference count is equal to one this reference is guranteed to be - * the sole pointer to the container. - * - * - * operator[] only provides const access, use Set to mutate the content. - * \tparam T The content Value type, must be compatible with tvm::ffi::Any - */ -template >> -class Array : public ObjectRef { - public: - /*! \brief The value type of the array */ - using value_type = T; - // constructors - /*! - * \brief Construct an Array with UnsafeInit - */ - explicit Array(UnsafeInit tag) : ObjectRef(tag) {} - /*! - * \brief default constructor - */ - Array() { data_ = ArrayObj::Empty(); } - /*! - * \brief Move constructor - * \param other The other array - */ - Array(Array&& other) : ObjectRef(std::move(other.data_)) {} - /*! - * \brief Copy constructor - * \param other The other array - */ - Array(const Array& other) : ObjectRef(other.data_) {} - /*! - * \brief Constructor from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - Array(Array&& other) : ObjectRef(std::move(other.data_)) {} - /*! - * \brief Constructor from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - Array(const Array& other) : ObjectRef(other.data_) {} - - /*! - * \brief Move assignment from another array - * \param other The other array - */ - TVM_FFI_INLINE Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief Assignment from another array - * \param other The other array - */ - TVM_FFI_INLINE Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief Move assignment from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - TVM_FFI_INLINE Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief Assignment from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - TVM_FFI_INLINE Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief Constructor from iterator - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType first, IterType last) { - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - Assign(first, last); - } - - /*! - * \brief constructor from initializer list - * \param init The initializer list - */ - Array(std::initializer_list init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(const size_t n, const T& val) { data_ = ArrayObj::CreateRepeated(n, val); } - - public: - // iterators - /// \cond Doxygen_Suppress - struct ValueConverter { - using ResultType = T; - /*! - * \brief Convert any to T - * \param n The any value to convert - * \return The converted value - */ - static T convert(const Any& n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck(n); } - }; - /// \endcond - - /*! \brief The iterator type of the array */ - using iterator = details::IterAdapter; - /*! \brief The reverse iterator type of the array */ - using reverse_iterator = details::ReverseIterAdapter; - - /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayObj()->begin()); } - - /*! \return end iterator */ - iterator end() const { return iterator(GetArrayObj()->end()); } - - /*! \return rbegin iterator */ - reverse_iterator rbegin() const { - // ArrayObj::end() is never nullptr - return reverse_iterator(GetArrayObj()->end() - 1); - } - - /*! \return rend iterator */ - reverse_iterator rend() const { - // ArrayObj::begin() is never nullptr - return reverse_iterator(GetArrayObj()->begin() - 1); - } - - public: - // const methods in std::vector - /*! - * \brief Immutably read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const T operator[](int64_t i) const { - ArrayObj* p = GetArrayObj(); - if (p == nullptr) { - TVM_FFI_THROW(IndexError) << "cannot index a null array"; - } - if (i < 0 || i >= p->size_) { - TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin() + i)); - } - - /*! \return The size of the array */ - size_t size() const { - ArrayObj* p = GetArrayObj(); - return p == nullptr ? 0 : GetArrayObj()->size_; - } - - /*! \return The capacity of the array */ - size_t capacity() const { - ArrayObj* p = GetArrayObj(); - return p == nullptr ? 0 : GetArrayObj()->capacity_; - } - - /*! \return Whether array is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the array */ - const T front() const { - ArrayObj* p = GetArrayObj(); - if (p == nullptr || p->size_ == 0) { - TVM_FFI_THROW(IndexError) << "cannot index a empty array"; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin())); - } - - /*! \return The last element of the array */ - const T back() const { - ArrayObj* p = GetArrayObj(); - if (p == nullptr || p->size_ == 0) { - TVM_FFI_THROW(IndexError) << "cannot index a empty array"; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->end() - 1)); - } - - public: - // mutation in std::vector, implements copy-on-write - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - void push_back(const T& item) { - ArrayObj* p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, item); - } - - /*! - * \brief Emplace a new element at the back of the array - * \param args The arguments to construct the new element - */ - template - void emplace_back(Args&&... args) { - ArrayObj* p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, std::forward(args)...); - } - - /*! - * \brief Insert an element into the given position - * \param position An iterator pointing to the insertion point - * \param val The element to insert - */ - void insert(iterator position, const T& val) { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; - } - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - auto addr = CopyOnWrite(1) // - ->EnlargeBy(1) // - ->MoveElementsRight(idx + 1, idx, size) // - ->MutableBegin(); - new (addr + idx) Any(val); - } - - /*! - * \brief Insert a range of elements into the given position - * \param position An iterator pointing to the insertion point - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - template - void insert(iterator position, IterType first, IterType last) { - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - - if (first == last) { - return; - } - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; - } - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - int64_t numel = std::distance(first, last); - CopyOnWrite(numel) - ->EnlargeBy(numel) - ->MoveElementsRight(idx + numel, idx, size) - ->InitRange(idx, first, last); - } - - /*! \brief Remove the last item of the list */ - void pop_back() { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array"; - } - int64_t size = GetArrayObj()->size_; - if (size == 0) { - TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty array"; - } - CopyOnWrite()->ShrinkBy(1); - } - - /*! - * \brief Erase an element on the given position - * \param position An iterator pointing to the element to be erased - */ - void erase(iterator position) { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; - } - int64_t st = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - if (st < 0 || st >= size) { - TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ", because Array size is " - << size; - } - CopyOnWrite() // - ->MoveElementsLeft(st, st + 1, size) // - ->ShrinkBy(1); - } - - /*! - * \brief Erase a given range of elements - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - void erase(iterator first, iterator last) { - if (first == last) { - return; - } - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; - } - int64_t size = GetArrayObj()->size_; - int64_t st = std::distance(begin(), first); - int64_t ed = std::distance(begin(), last); - if (st >= ed) { - TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")"; - } - if (st < 0 || st > size || ed < 0 || ed > size) { - TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")" - << ", because array size is " << size; - } - CopyOnWrite() // - ->MoveElementsLeft(st, ed, size) // - ->ShrinkBy(ed - st); - } - - /*! - * \brief Resize the array. - * \param n The new size. - */ - void resize(int64_t n) { - if (n < 0) { - TVM_FFI_THROW(ValueError) << "cannot resize an Array to negative size"; - } - if (data_ == nullptr) { - SwitchContainer(n); - return; - } - int64_t size = GetArrayObj()->size_; - if (size < n) { - CopyOnWrite(n - size)->EnlargeBy(n - size); - } else if (size > n) { - CopyOnWrite()->ShrinkBy(size - n); - } - } - - /*! - * \brief Make sure the list has the capacity of at least n - * \param n lower bound of the capacity - */ - void reserve(int64_t n) { - if (data_ == nullptr || n > GetArrayObj()->capacity_) { - SwitchContainer(n); - } - } - - /*! \brief Release reference to all the elements */ - void clear() { - if (data_ != nullptr) { - ArrayObj* p = CopyOnWrite(); - p->clear(); - } - } - /// \cond Doxygen_Suppress - template - static size_t CalcCapacityImpl() { - return 0; - } - - template - static size_t CalcCapacityImpl(Array value, Args... args) { - return value.size() + CalcCapacityImpl(args...); - } - - template - static size_t CalcCapacityImpl(T value, Args... args) { - return 1 + CalcCapacityImpl(args...); - } - - template - static void AgregateImpl(Array& dest) {} // NOLINT(*) - - template - static void AgregateImpl(Array& dest, Array value, Args... args) { // NOLINT(*) - dest.insert(dest.end(), value.begin(), value.end()); - AgregateImpl(dest, args...); - } - - template - static void AgregateImpl(Array& dest, T value, Args... args) { // NOLINT(*) - dest.push_back(value); - AgregateImpl(dest, args...); - } - /// \endcond - - public: - // Array's own methods - - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - void Set(int64_t i, T value) { - ArrayObj* p = this->CopyOnWrite(); - if (i < 0 || i >= p->size_) { - TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; - } - *(p->MutableBegin() + i) = std::move(value); - } - - /*! \return The underlying ArrayObj */ - ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } - - /*! - * \brief Helper function to apply a map function onto the array. - * - * \param fmap The transformation function T -> U. - * - * \tparam F The type of the mutation function. - * - * \tparam U The type of the returned array, inferred from the - * return type of F. If overridden by the user, must be something - * that is convertible from the return type of F. - * - * \note This function performs copy on write optimization. If - * `fmap` returns an object of type `T`, and all elements of the - * array are mapped to themselves, then the returned array will be - * the same as the original, and reference counts of the elements in - * the array will not be incremented. - * - * \return The transformed array. - */ - template > - Array Map(F fmap) const { - return Array(MapHelper(data_, fmap)); - } - - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template >>> - void MutateByApply(F fmutate) { - data_ = MapHelper(std::move(data_), fmutate); - } - - /*! - * \brief reset the array to content from iterator. - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - void Assign(IterType first, IterType last) { - int64_t cap = std::distance(first, last); - if (cap < 0) { - TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative size"; - } - ArrayObj* p = GetArrayObj(); - if (p != nullptr && data_.unique() && p->capacity_ >= cap) { - // do not have to make new space - p->clear(); - } else { - // create new space - data_ = ArrayObj::Empty(cap); - p = GetArrayObj(); - } - // To ensure exception safety, size is only incremented after the initialization succeeds - Any* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { - new (itr) Any(*first); - } - } - - /*! - * \brief Copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - ArrayObj* CopyOnWrite() { - if (data_ == nullptr) { - return SwitchContainer(ArrayObj::kInitSize); - } - if (!data_.unique()) { - return SwitchContainer(capacity()); - } - return static_cast(data_.get()); - } - - /*! \brief specify container node */ - using ContainerType = ArrayObj; - - /*! - * \brief Agregate arguments into a single Array - * \param args sequence of T or Array elements - * \return Agregated Array - */ - template - static Array Agregate(Args... args) { - Array result; - result.reserve(CalcCapacityImpl(args...)); - AgregateImpl(result, args...); - return result; - } - - private: - /*! - * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. - * \param reserve_extra Number of extra slots needed - * \return ArrayObj pointer to the unique copy - */ - ArrayObj* CopyOnWrite(int64_t reserve_extra) { - ArrayObj* p = GetArrayObj(); - if (p == nullptr) { - // necessary to get around the constexpr address issue before c++17 - const int64_t kInitSize = ArrayObj::kInitSize; - return SwitchContainer(std::max(kInitSize, reserve_extra)); - } - if (p->capacity_ >= p->size_ + reserve_extra) { - return CopyOnWrite(); - } - int64_t cap = p->capacity_ * ArrayObj::kIncFactor; - cap = std::max(cap, p->size_ + reserve_extra); - return SwitchContainer(cap); - } - - /*! - * \brief Move or copy the ArrayObj to new address with the given capacity - * \param capacity The capacity requirement of the new address - */ - ArrayObj* SwitchContainer(int64_t capacity) { - if (data_ == nullptr) { - data_ = ArrayObj::Empty(capacity); - } else if (data_.unique()) { - data_ = ArrayObj::MoveFrom(capacity, GetArrayObj()); - } else { - data_ = ArrayObj::CopyFrom(capacity, GetArrayObj()); - } - return static_cast(data_.get()); - } - - /*! \brief Helper method for mutate/map - * - * A helper function used internally by both `Array::Map` and - * `Array::MutateInPlace`. Given an array of data, apply the - * mapping function to each element, returning the collected array. - * Applies both mutate-in-place and copy-on-write optimizations, if - * possible. - * - * \param data A pointer to the ArrayObj containing input data. - * Passed by value to allow for mutate-in-place optimizations. - * - * \param fmap The mapping function - * - * \tparam F The type of the mutation function. - * - * \tparam U The output type of the mutation function. Inferred - * from the callable type given. Must inherit from ObjectRef. - * - * \return The mapped array. Depending on whether mutate-in-place - * or copy-on-write optimizations were applicable, may be the same - * underlying array as the `data` parameter. - */ - template > - static ObjectPtr MapHelper(ObjectPtr data, F fmap) { - if (data == nullptr) { - return nullptr; - } - - TVM_FFI_ICHECK(data->IsInstance()); - - constexpr bool is_same_output_type = std::is_same_v; - - if constexpr (is_same_output_type) { - if (data.unique()) { - // Mutate-in-place path. Only allowed if the output type U is - // the same as type T, we have a mutable this*, and there are - // no other shared copies of the array. - auto arr = static_cast(data.get()); - for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { - T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it); - // reset the original value to nullptr, to ensure unique ownership - it->reset(); - T mapped = fmap(std::move(value)); - *it = std::move(mapped); - } - return data; - } - } - - constexpr bool compatible_types = is_valid_iterator_v || is_valid_iterator_v; - - ObjectPtr output = nullptr; - auto arr = static_cast(data.get()); - - auto it = arr->begin(); - if constexpr (compatible_types) { - // Copy-on-write path, if the output Array might be - // represented by the same underlying array as the existing - // Array. Typically, this is for functions that map `T` to - // `T`, but can also apply to functions that map `T` to - // `Optional`, or that map `T` to a subclass or superclass of - // `T`. - bool all_identical = true; - for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); - if (!(*it).same_as(mapped)) { - // At least one mapped element is different than the - // original. Therefore, prepare the output array, - // consisting of any previous elements that had mapped to - // themselves (if any), and the element that didn't map to - // itself. - // - // We cannot use `U()` as the default object, as `U` may be - // a non-nullable type. Since the default `Any()` - // will be overwritten before returning, all objects will be - // of type `U` for the calling scope. - all_identical = false; - output = ArrayObj::CreateRepeated(arr->size(), Any()); - output->InitRange(0, arr->begin(), it); - output->SetItem(it - arr->begin(), std::move(mapped)); - it++; - break; - } - } - if (all_identical) { - return data; - } - } else { - // Path for incompatible types. The constexpr check for - // compatible types isn't strictly necessary, as the first - // (*it).same_as(mapped) would return false, but we might as well - // avoid it altogether. - // - // We cannot use `U()` as the default object, as `U` may be a - // non-nullable type. Since the default `Any()` will be - // overwritten before returning, all objects will be of type `U` - // for the calling scope. - output = ArrayObj::CreateRepeated(arr->size(), Any()); - } - - // Normal path for incompatible types, or post-copy path for - // copy-on-write instances. - // - // If the types are incompatible, then at this point `output` is - // empty, and `it` points to the first element of the input. - // - // If the types were compatible, then at this point `output` - // contains zero or more elements that mapped to themselves - // followed by the first element that does not map to itself, and - // `it` points to the element just after the first element that - // does not map to itself. Because at least one element has been - // changed, we no longer have the opportunity to avoid a copy, so - // we don't need to check the result. - // - // In both cases, `it` points to the next element to be processed, - // so we can either start or resume the iteration from that point, - // with no further checks on the result. - for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); - output->SetItem(it - arr->begin(), std::move(mapped)); - } - - return output; - } - template - friend class Array; -}; - -/*! - * \brief Concat two Arrays. - * \param lhs first Array to be concatenated. - * \param rhs second Array to be concatenated. - * \return The concatenated Array. Original Arrays are kept unchanged. - */ -template || - TypeTraits::convert_enabled>> -inline Array Concat(Array lhs, const Array& rhs) { - for (const auto& x : rhs) { - lhs.push_back(x); - } - return std::move(lhs); -} - -/*! - * \brief Specialize make_object - * \return The empty array object. - */ -template <> -inline ObjectPtr make_object() { - return ArrayObj::Empty(); -} - -// Traits for Array -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray; - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - if constexpr (!std::is_same_v) { - const ArrayObj* n = reinterpret_cast(src->v_obj); - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - // CheckAnyStrict is cheaper than try_cast - if (details::AnyUnsafe::CheckAnyStrict(any_v)) continue; - // try see if p is convertible to T - if (any_v.try_cast()) continue; - // now report the accurate mismatch information - return "Array[index " + std::to_string(i) + ": " + - details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; - } - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) return false; - if constexpr (std::is_same_v) { - return true; - } else { - const ArrayObj* n = reinterpret_cast(src->v_obj); - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v)) return false; - } - return true; - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // try to run conversion. - if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; - if constexpr (!std::is_same_v) { - const ArrayObj* n = reinterpret_cast(src->v_obj); - bool storage_check = [&]() { - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v)) return false; - } - return true; - }(); - // fast path, if storage check passes, we can return the array directly. - if (storage_check) { - return CopyFromAnyViewAfterCheck(src); - } - // slow path, try to run a conversion to Array - Array result; - result.reserve(n->size()); - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - if (auto opt_v = any_v.try_cast()) { - result.push_back(*std::move(opt_v)); - } else { - return std::nullopt; - } - } - return result; - } else { - return CopyFromAnyViewAfterCheck(src); - } - } - - TVM_FFI_INLINE static std::string TypeStr() { return "Array<" + details::Type2Str::v() + ">"; } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Array> = type_contains_v; -} // namespace details - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_ARRAY_H_ diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h deleted file mode 100644 index bb29a14f7cb8..000000000000 --- a/ffi/include/tvm/ffi/container/container_details.h +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/container_details.h - * \brief Common utilities for typed container types. - */ -#ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ -#define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ - -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base template for classes with array like memory layout. - * - * It provides general methods to access the memory. The memory - * layout is ArrayType + [ElemType]. The alignment of ArrayType - * and ElemType is handled by the memory allocator. - * - * \tparam ArrayType The array header type, contains object specific metadata. - * \tparam ElemType The type of objects stored in the array right after - * ArrayType. - * - * \code - * // Example usage of the template to define a simple array wrapper - * class ArrayObj : public tvm::ffi::details::InplaceArrayBase { - * public: - * // Wrap EmplaceInit to initialize the elements - * template - * void Init(Iterator begin, Iterator end) { - * size_t num_elems = std::distance(begin, end); - * auto it = begin; - * this->size = 0; - * for (size_t i = 0; i < num_elems; ++i) { - * InplaceArrayBase::EmplaceInit(i, *it++); - * this->size++; - * } - * } - * } - * - * void test_function() { - * vector fields; - * auto ptr = make_inplace_array_object(fields.size()); - * ptr->Init(fields.begin(), fields.end()); - * - * // Access the 0th element in the array. - * assert(ptr->operator[](0) == fields[0]); - * } - * - * \endcode - */ -template -class InplaceArrayBase { - public: - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Const reference to ElemType at the index. - */ - const ElemType& operator[](size_t idx) const { - size_t size = Self()->GetSize(); - if (idx > size) { - TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; - } - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Reference to ElemType at the index. - */ - ElemType& operator[](size_t idx) { - size_t size = Self()->GetSize(); - if (idx > size) { - TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; - } - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Destroy the Inplace Array Base object - */ - ~InplaceArrayBase() { - if constexpr (!(std::is_standard_layout::value && std::is_trivial::value)) { - size_t size = Self()->GetSize(); - for (size_t i = 0; i < size; ++i) { - ElemType* fp = reinterpret_cast(AddressOf(i)); - fp->ElemType::~ElemType(); - } - } - } - - protected: - /*! - * \brief Construct a value in place with the arguments. - * - * \tparam Args Type parameters of the arguments. - * \param idx Index of the element. - * \param args Arguments to construct the new value. - * - * \note Please make sure ArrayType::GetSize returns 0 before first call of - * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. - */ - template - void EmplaceInit(size_t idx, Args&&... args) { - void* field_ptr = AddressOf(idx); - new (field_ptr) ElemType(std::forward(args)...); - } - - /*! - * \brief Return the self object for the array. - * - * \return Pointer to ArrayType. - */ - inline ArrayType* Self() const { - return static_cast(const_cast(this)); - } - - /*! - * \brief Return the raw pointer to the element at idx. - * - * \param idx The index of the element. - * \return Raw pointer to the element. - */ - void* AddressOf(size_t idx) const { - static_assert( - alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); - - size_t kDataStart = sizeof(ArrayType); - ArrayType* self = Self(); - char* data_start = reinterpret_cast(self) + kDataStart; - return data_start + idx * sizeof(ElemType); - } -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - IterAdapter& operator++() { - ++iter_; - return *this; - } - IterAdapter& operator--() { - --iter_; - return *this; - } - IterAdapter operator++(int) { - IterAdapter copy = *this; - ++iter_; - return copy; - } - IterAdapter operator--(int) { - IterAdapter copy = *this; - --iter_; - return copy; - } - - IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - - IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } - - IterAdapter& operator+=(difference_type offset) { - iter_ += offset; - return *this; - } - - IterAdapter& operator-=(difference_type offset) { - iter_ -= offset; - return *this; - } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const IterAdapter& rhs) const { - return iter_ - rhs.iter_; - } - - bool operator==(IterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(IterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class ReverseIterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} - ReverseIterAdapter& operator++() { - --iter_; - return *this; - } - ReverseIterAdapter& operator--() { - ++iter_; - return *this; - } - ReverseIterAdapter operator++(int) { - ReverseIterAdapter copy = *this; - --iter_; - return copy; - } - ReverseIterAdapter operator--(int) { - ReverseIterAdapter copy = *this; - ++iter_; - return copy; - } - ReverseIterAdapter operator+(difference_type offset) const { - return ReverseIterAdapter(iter_ - offset); - } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const ReverseIterAdapter& rhs) const { - return rhs.iter_ - iter_; - } - - bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief Check if T is compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool storage_enabled_v = std::is_same_v || TypeTraits::storage_enabled; - -/*! - * \brief Check if all T are compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool all_storage_enabled_v = (storage_enabled_v && ...); - -/*! - * \brief Check if all T are compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool all_object_ref_v = (std::is_base_of_v && ...); -/** - * \brief Check if Any storage of Derived can always be directly used as Base. - * - * \tparam Base The base type. - * \tparam Derived The derived type. - * \return True if Derived's storage can be used as Base's storage, false otherwise. - */ -template -inline constexpr bool type_contains_v = - std::is_base_of_v || std::is_same_v; -// special case for Any -template -inline constexpr bool type_contains_v = true; - -/*! - * \brief Create a string of the container type. - * \tparam V The types of the elements in the container. - * \param name The name of the container type. - * \return A string of the container type. - */ -template -std::string ContainerTypeStr(const char* name) { - std::stringstream ss; - // helper to construct concated string of TypeStr - class TypeStrHelper { - public: - TypeStrHelper(std::stringstream& stream) : stream_(stream) {} // NOLINT(*) - - TypeStrHelper& operator<<(const std::string& str) { - if (counter_ > 0) { - stream_ << ", "; - } - stream_ << str; - counter_++; - return *this; - } - - private: - std::stringstream& stream_; // NOLINT(*) - int counter_ = 0; - }; - TypeStrHelper helper(ss); - ss << name << '<'; - (helper << ... << Type2Str::v()); - ss << '>'; - return ss.str(); -} - -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h deleted file mode 100644 index 471904502cfb..000000000000 --- a/ffi/include/tvm/ffi/container/map.h +++ /dev/null @@ -1,1762 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/map.h - * \brief Runtime Map container types. - */ -#ifndef TVM_FFI_CONTAINER_MAP_H_ -#define TVM_FFI_CONTAINER_MAP_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/// \cond Doxygen_Suppress -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE -#define TVM_FFI_MAP_FAIL_IF_CHANGED() \ - TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; -#else -#define TVM_FFI_MAP_FAIL_IF_CHANGED() -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE -/// \endcond - -/*! \brief Shared content of all specializations of hash map */ -class MapObj : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = Any; - /*! \brief Type of the values in the hash map */ - using mapped_type = Any; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /// \cond Doxygen_Suppress - /*! \brief Type of raw storage of the key-value pair in the hash map */ - struct KVRawStorageType { - TVMFFIAny first; - TVMFFIAny second; - }; - /// \endcond - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect"); - - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object); - /// \endcond - - /*! - * \brief Number of elements in the MapObj - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { erase(find(key)); } - - /// \cond Doxygen_Suppress - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType*; - using reference = KVType&; -/*! \brief Default constructor */ -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - iterator() : state_marker(0), index(0), self(nullptr) {} -#else - iterator() : index(0), self(nullptr) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return *((*this).operator->()); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - --(*this); - return copy; - } - - protected: -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; - /*! \brief Construct by value */ - iterator(uint64_t index, const MapObj* self) - : state_marker(self->state_marker), index(index), self(self) {} - -#else - iterator(uint64_t index, const MapObj* self) : index(index), self(self) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapObj* self; - - friend class DenseMapObj; - friend class SmallMapObj; - }; - /// \endcond - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - - protected: -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(KVType&& kv, ObjectPtr* map); - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapObj* from); - /*! - * \brief data pointer to the data region of the map. - * \note For immutable inplace small map we do not need data_, - * but we keep it here for future compact with mutable container. - */ - void* data_; - /*! \brief number of entries in the container */ - uint64_t size_; - /*! \brief number of slots */ - uint64_t slots_; - /*! - * \brief Small layout tag mask - * \note The most significant bit is used to indicate the small map layout. - */ - static constexpr uint64_t kSmallTagMask = static_cast(1) << 63; - /*! - * \brief Check if the map is a small map - * \return True if the map is a small map - */ - bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; } - /*! - * \brief Optional data deleter when data is allocated separately - * and its deletion is not managed by MapObj::deleter_. - */ - void (*data_deleter_)(void*) = nullptr; - // Reference class - template - friend class Map; -}; - -/*! \brief A specialization of small-sized hash map */ -class SmallMapObj : public MapObj, - public details::InplaceArrayBase { - private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - - public: - using MapObj::iterator; - using MapObj::KVType; - - // Return the number of usable slots for Small layout (mask off tag). - /*! - * \brief Return the number of usable slots for Small layout (mask off tag). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; } - - ~SmallMapObj() { - KVType* begin = static_cast(data_); - for (uint64_t index = 0; index < size_; ++index) { - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - } - if (data_deleter_ != nullptr) { - data_deleter_(data_); - } - } - /*! - * \brief Count the number of times a key exists in the SmallMapObj - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; - } - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; - } - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - KVType* ptr = static_cast(data_); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (AnyEqual()(ptr->first, key)) { - return iterator(i, this); - } - } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { Erase(position.index); } - - private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; } - /*! - * \brief Remove a position in SmallMapObj - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; - } - KVType* begin = static_cast(data_); - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - // IMPORTANT: We do direct raw memmove to bring later items to the current position - // to preserve the order of insertion. - // This works because direct memory copy preserves the Any's move semantics. - if (index + 1 < size_) { - std::memmove(reinterpret_cast(begin + index), - reinterpret_cast(begin + index + 1), - (size_ - index - 1) * sizeof(KVType)); - } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::ffi::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->data_ = p->AddressOf(0); - p->size_ = 0; - p->SetSlotsAndSmallLayoutTag(n); - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType* ptr = static_cast(p->data_); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapObj* from) { - KVType* first = static_cast(from->data_); - KVType* last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - SmallMapObj* map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; - } - if (map_node->size_ < map_node->NumSlots()) { - KVType* ptr = static_cast(map_node->data_) + map_node->size_; - new (ptr) KVType(std::move(kv)); - ++map_node->size_; - return; - } - uint64_t next_size = std::max(map_node->NumSlots() * 2, uint64_t(kInitSize)); - next_size = std::min(next_size, uint64_t(kMaxSize)); - TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots()); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return static_cast(data_) + index; } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - - protected: - friend class MapObj; - friend class DenseMapObj; - friend class details::InplaceArrayBase; -}; - -/*! \brief A specialization of hash map that implements the idea of array-based hash map. - * Another reference implementation can be found [1]. - * - * A. Overview - * - * DenseMapObj did several improvements over traditional separate chaining hash, - * in terms of cache locality, memory footprints and data organization. - * - * A1. Implicit linked list. For better cache locality, instead of using linked list - * explicitly for each bucket, we store list data into a single array that spans contiguously - * in memory, and then carefully design access patterns to make sure most of them fall into - * a single cache line. - * - * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and - * traversal. This can be divided in 3 parts. - * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, - * which means the slot is empty but not allowed to be written. - * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is - * head of a linked list. - * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit - * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when - * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are - * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to - * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, - * then x must be one of the 126 pre-defined values. - * - * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. - * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. - * 16 key-value pairs. - * - * B. Implementation details - * - * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid - * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, - * we use the Fibonacci Hashing [2] trick. - * - * B2. Traverse a linked list in the array. - * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i - * indicates that it is list head, then we found the head; otherwise the list is empty. No probing - * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we - * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of - * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). - * - * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this - * element is in the linked list, and if not, we put it at the end by probing the next empty - * position in one of the 126 candidate positions. If the linked list does not even exist, but the - * slot for list head has been occupied by another linked list, we should find this intruder another - * place. - * - * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing - * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the - * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list - * head. - * - * [1] https://github.com/skarupke/flat_hash_map - * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ - * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - */ -class DenseMapObj : public MapObj { - private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Index indicator to indicate an invalid index */ - static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief item type of the dense map, including a kv data and prev/next pointer */ - struct ItemType { - KVType data; - uint64_t prev = kInvalidIndex; - uint64_t next = kInvalidIndex; - - explicit ItemType(KVType&& data) : data(std::move(data)) {} - explicit ItemType(key_type key, mapped_type value) : data(key, value) {} - }; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(ItemType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(ItemType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout::value, "Block is not standard layout"); - - /*! - * \brief Deleter for the Block - * \param data The pointer to the Block - */ - static void BlockDeleter(void* data) { delete[] static_cast(data); } - - public: - using MapObj::iterator; - - /*! - * \brief Return the number of usable slots for Dense layout (MSB clear => identity). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_; } - - /*! - * \brief Destroy the DenseMapObj - */ - ~DenseMapObj() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type& key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->NumSlots()) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { return iterator(iter_list_head_, this); } - /*! \return end iterator */ - iterator end() const { return iterator(kInvalidIndex, this); } - - private: - Block* GetBlock(size_t index) const { return static_cast(data_) + index; } - /*! - * \brief Unlink the entry from iterator list - * \param node The node to be unlinked - * \note This function is usually used before deletion, - * and it does not change data content of the node. - */ - void IterListUnlink(ListNode node) { - // update head and tail of iterator list if needed - if (node.Item().prev == kInvalidIndex) { - iter_list_head_ = node.Item().next; - } else { - ListNode prev_node(node.Item().prev, this); - prev_node.Item().next = node.Item().next; - } - if (node.Item().next == kInvalidIndex) { - iter_list_tail_ = node.Item().prev; - } else { - ListNode next_node(node.Item().next, this); - next_node.Item().prev = node.Item().prev; - } - } - /*! - * \brief Insert the entry into tail of iterator list - * \param node The node to be inserted - * \note this function does not change data content of the node. - */ - void IterListPushBack(ListNode node) { - node.Item().prev = iter_list_tail_; - node.Item().next = kInvalidIndex; - if (iter_list_tail_ != kInvalidIndex) { - ListNode prev_node(iter_list_tail_, this); - prev_node.Item().next = node.index; - } - if (iter_list_head_ == kInvalidIndex) { - iter_list_head_ = node.index; - } - iter_list_tail_ = node.index; - } - /*! - * \brief Replace node src by dst in the iter list - * \param src The source node - * \param dst The destination node, must be empty - * \note This function does not change data content of the nodes, - * which needs to be updated by the caller. - */ - void IterListReplaceNodeBy(ListNode src, ListNode dst) { - // set link correctly on the dst - dst.Item().prev = src.Item().prev; - dst.Item().next = src.Item().next; - // update prev and next of dst - if (dst.Item().prev == kInvalidIndex) { - iter_list_head_ = dst.index; - } else { - ListNode prev_node(dst.Item().prev, this); - prev_node.Item().next = dst.index; - } - if (dst.Item().next == kInvalidIndex) { - iter_list_tail_ = dst.index; - } else { - ListNode next_node(dst.Item().next, this); - next_node.Item().prev = dst.index; - } - } - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type& key) const { - if (this->size_ == 0) { - return ListNode(); - } - for (ListNode iter = GetListHead(AnyHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (AnyEqual()(key, iter.Key())) { - return iter; - } - } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type& At(const key_type& key) const { - ListNode iter = Search(key); - if (iter.IsNone()) { - TVM_FFI_THROW(IndexError) << "key is not in Map"; - } - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type& key, ListNode* result) { - if (slots_ == 0) { - return false; - } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(AnyHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = iter; - return true; - } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); - } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (AnyEqual()(key, next.Key())) { - // we plan to take next, so we need to unlink it from iterator list - IterListUnlink(next); - *result = next; - return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; - } - result->NewTail(ItemType(key, Any(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - // first move the data over - empty.NewTail(ItemType(std::move(r.Data()))); - // then move link list chain of r to empty - // this needs to happen after NewTail so empty's prev/next get updated - IterListReplaceNodeBy(r, empty); - // explicit call destructor to destroy the item in `r` - r.DestructData(); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode& iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - // unlink the node from iterator list - IterListUnlink(iter); - // IMPORTANT: must explicit call destructor `iter` to avoid memory leak - // This is because we need to recycle iter's data - iter.DestructData(); - // set the meta data to be empty - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - // needs to first unlink iter from the list - IterListUnlink(iter); - // move data from last to iter - iter.Data() = std::move(last.Data()); - // Move link chain of iter to last as we stores last node to the new iter loc. - IterListReplaceNodeBy(last, iter); - // IMPORTANT: must explicit call destructor `last` to avoid memory leak - // likely we don't need this in this particular case because Any move behavior - // keep it here to be safe so code do not depend on specific move behavior of KVType - last.DestructData(); - // set the meta data to be empty - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->NumSlots()); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = GetBlock(bi)->bytes; - ItemType* data_ptr = reinterpret_cast(GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - data_ptr->ItemType::~ItemType(); - } - } - } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - if (data_ != nullptr) { - TVM_FFI_ICHECK(data_deleter_ != nullptr); - data_deleter_(data_); - } - data_ = nullptr; - data_deleter_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); - // Ensure even slot count (power-of-two expected by callers; this guard - // makes the method robust if a non-even value slips through). - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots); - Block* block = new Block[n_blocks]; - p->data_ = block; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(n_slots); - p->size_ = 0; - p->fib_shift_ = fib_shift; - p->iter_list_head_ = kInvalidIndex; - p->iter_list_tail_ = kInvalidIndex; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapObj* from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->NumSlots()); - p->data_ = new Block[n_blocks]; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(from->NumSlots()); - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - p->iter_list_head_ = from->iter_list_head_; - p->iter_list_tail_ = from->iter_list_tail_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr_from = from->GetBlock(bi)->bytes; - ItemType* data_ptr_from = reinterpret_cast(from->GetBlock(bi)->bytes + kBlockCap); - uint8_t* meta_ptr_to = p->GetBlock(bi)->bytes; - ItemType* data_ptr_to = reinterpret_cast(p->GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t& meta = *meta_ptr_to = *meta_ptr_from; - TVM_FFI_ICHECK(meta != kProtectedSlot); - if (meta != uint8_t(kEmptySlot)) { - new (data_ptr_to) ItemType(*data_ptr_from); - } - } - } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - DenseMapObj* map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = std::move(kv.second); - // update the iter list relation - map_node->IterListPushBack(iter); - return; - } - TVM_FFI_ICHECK(!map_node->IsSmallMap()); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2); - - // need to insert in the same order as the original map - for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { - ListNode node(index, map_node); - // now try move src_data into the new map, note that src may still not - // be fully consumed into the call, but destructor will be called. - InsertMaybeReHash(std::move(node.Data()), &p); - // Important, needs to explicit call destructor in case move did remove - // node's internal item - index = node.Item().next; - // IMPORTANT: must explicit call destructor `node` to avoid memory leak - // We must call node.DestructData() here. - // This is because std::move() arguments in IterMaybeReHash may or may not - // explicitly move out the node.Data() - // Remove this call will cause memory leak very likely. - node.DestructData(); - } - InsertMaybeReHash(std::move(kv), &p); - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { return size_ + 1 > NumSlots() * kMaxLoadFactor; } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - // keep at the end of iterator - if (index == kInvalidIndex) { - return index; - } - ListNode node(index, this); - return node.Item().next; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - // this is the end iterator, we need to return tail. - if (index == kInvalidIndex) { - return iter_list_tail_; - } - // circle around the iterator list, which is OK - ListNode node(index, this); - return node.Item().prev; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots) { return (n_slots + kBlockCap - 1) / kBlockCap; } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; - } - TVM_FFI_ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; - } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapObj* self) - : index(index), block(self->GetBlock(index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - ItemType& Item() const { - return *(reinterpret_cast(block->bytes + kBlockCap + - (index % kBlockCap) * sizeof(ItemType))); - } - /*! \brief Data on the entry */ - KVType& Data() const { return Item().data; } - /*! \brief Key on the entry */ - key_type& Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type& Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } - /*! \brief Destruct the item in the entry */ - void DestructData() const { - // explicit call destructor to destroy the item - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (&Data())->first.Any::~Any(); - (&Data())->second.Any::~Any(); - } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(ItemType v) const { - Meta() = 0b00000000; - new (&Item()) ItemType(std::move(v)); - } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(ItemType v) const { - Meta() = 0b10000000; - new (&Item()) ItemType(std::move(v)); - } - - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj* self, uint8_t meta) { - uint64_t offset = NextProbeLocation(meta & 0b01111111); - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - // the probing will go to next position and round back to stay within the - // correct range of the slots - index = (index + offset) % self->NumSlots(); - block = self->GetBlock(index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj* self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapObj* self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(AnyHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapObj* self, uint8_t* jump, ListNode* result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - // the probing will go to next position and round back to stay within the - // correct range of the slots - ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block* block; - }; - - protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief the head of iterator list */ - uint64_t iter_list_head_ = kInvalidIndex; - /*! \brief the tail of iterator list */ - uint64_t iter_list_tail_ = kInvalidIndex; - - static uint64_t NextProbeLocation(size_t index) { - /* clang-format off */ - /*! \brief Candidates of probing distance */ - static const uint64_t kNextProbeLocation[kNumJumpDists] { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - // Quadratic probing with triangle numbers. See also: - // 1) https://en.wikipedia.org/wiki/Quadratic_probing - // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - // 3) https://github.com/skarupke/flat_hash_map - 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, - 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, - 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, - 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, - 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, - 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, - 2211, 2278, 2346, 2415, 2485, 2556, 2628, - // larger triangle numbers - 8515, 19110, 42778, 96141, 216153, - 486591, 1092981, 2458653, 5532801, 12442566, - 27993903, 62983476, 141717030, 318844378, 717352503, - 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, - 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, - 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, - 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, - 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, - 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435, - }; - /* clang-format on */ - return kNextProbeLocation[index]; - } - friend class MapObj; - - private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndDenseLayoutTag(uint64_t n) { - TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear"; - slots_ = n; - } -}; - -/// \cond -#define TVM_FFI_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapObj*; \ - using TDense = DenseMapObj*; \ - if (base->IsSmallMap()) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -#define TVM_FFI_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapObj*; \ - using TDense = const DenseMapObj*; \ - if (base->IsSmallMap()) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -inline MapObj::iterator::pointer MapObj::iterator::operator->() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); -} - -inline MapObj::iterator& MapObj::iterator::operator++() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); -} - -inline MapObj::iterator& MapObj::iterator::operator--() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); -} - -inline size_t MapObj::count(const key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); -} - -inline const MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); -} - -inline MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->at(key); }); -} - -inline MapObj::iterator MapObj::begin() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); -} - -inline MapObj::iterator MapObj::end() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->end(); }); -} - -inline MapObj::iterator MapObj::find(const MapObj::key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); -} - -inline void MapObj::erase(const MapObj::iterator& position) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->erase(position); }); -} -/// \endcond - -#undef TVM_FFI_DISPATCH_MAP -#undef TVM_FFI_DISPATCH_MAP_CONST - -inline ObjectPtr MapObj::Empty() { return SmallMapObj::Empty(); } - -inline ObjectPtr MapObj::CopyFrom(MapObj* from) { - if (from->IsSmallMap()) { - return SmallMapObj::CopyFrom(static_cast(from)); - } else { - return DenseMapObj::CopyFrom(static_cast(from)); - } -} - -template -inline ObjectPtr MapObj::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapObj::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapObj::kMaxSize) { - if (cap < 2) { - return SmallMapObj::CreateFromRange(cap, first, last); - } - // need to insert to avoid duplicate keys - ObjectPtr obj = SmallMapObj::Empty(cap); - for (; first != last; ++first) { - KVType kv(*first); - SmallMapObj::InsertMaybeReHash(std::move(kv), &obj); - } - return obj; - } else { - uint32_t fib_shift; - uint64_t n_slots; - DenseMapObj::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapObj::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapObj::InsertMaybeReHash(std::move(kv), &obj); - } - return obj; - } -} - -inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - MapObj* base = static_cast(map->get()); -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - base->state_marker++; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - if (base->IsSmallMap()) { - SmallMapObj* sm = static_cast(base); - if (sm->NumSlots() < SmallMapObj::kMaxSize) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else if (sm->NumSlots() == SmallMapObj::kMaxSize) { - if (base->size_ < sm->NumSlots()) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else { - ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); - DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } - } - } else { - DenseMapObj::InsertMaybeReHash(std::move(kv), map); - } -} - -template <> -inline ObjectPtr make_object<>() = delete; - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template && - details::storage_enabled_v>> -class Map : public ObjectRef { - public: - /*! \brief The key type of the map */ - using key_type = K; - /*! \brief The mapped type of the map */ - using mapped_type = V; - /*! \brief The iterator type of the map */ - class iterator; - /*! - * \brief Construct an Map with UnsafeInit - */ - explicit Map(UnsafeInit tag) : ObjectRef(tag) {} - /*! - * \brief default constructor - */ - Map() { data_ = MapObj::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map&& other) : ObjectRef(std::move(other.data_)) {} - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map& other) : ObjectRef(other.data_) {} - - /*! - * \brief Move constructor - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map(Map&& other) : ObjectRef(std::move(other.data_)) {} - - /*! - * \brief Copy constructor - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map(const Map& other) : ObjectRef(other.data_) {} - - /*! - * \brief Move assignment - * \param other The other map - */ - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Copy assignment - * \param other The other map - */ - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Move assignment - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Copy assignment - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapObj::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K& key) const { - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(GetMapObj()->at(key)); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K& key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapObj* n = GetMapObj(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K& key) const { - MapObj* n = GetMapObj(); - return n == nullptr ? 0 : GetMapObj()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! \brief Release reference to all the elements */ - void clear() { - MapObj* n = GetMapObj(); - if (n != nullptr) { - data_ = MapObj::Empty(); - } - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K& key, const V& value) { - CopyOnWrite(); - MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapObj()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapObj()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); } - /*! \return The value associated with the key, std::nullopt if not found */ - std::optional Get(const K& key) const { - MapObj::iterator iter = GetMapObj()->find(key); - if (iter == GetMapObj()->end()) { - return std::nullopt; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(iter->second); - } - - /*! - * \brief Erase the entry associated with the key - * \param key The key - */ - void erase(const K& key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which guarantees to be unique) - */ - MapObj* CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapObj::Empty(); - } else if (!data_.unique()) { - data_ = MapObj::CopyFrom(GetMapObj()); - } - return GetMapObj(); - } - /*! \brief specify container node */ - using ContainerType = MapObj; - - /// \cond Doxygen_Suppress - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto& kv = *itr; - return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.first), - details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.second)); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++() { - ++itr; - return *this; - } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--() { - --itr; - return *this; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - iterator copy = *this; - --(*this); - return copy; - } - - private: - iterator(const MapObj::iterator& itr) // NOLINT(*) - : itr(itr) {} - - template - friend class Map; - - MapObj::iterator itr; - }; - /// \endcond - - private: - /*! \brief Return data_ as type of pointer of MapObj */ - MapObj* GetMapObj() const { return static_cast(data_.get()); } - - template - friend class Map; -}; - -/*! - * \brief Merge two Maps. - * \param lhs the first Map to merge. - * \param rhs the second Map to merge. - * @return The merged Array. Original Maps are kept unchanged. - */ -template && - details::storage_enabled_v>> -inline Map Merge(Map lhs, const Map& rhs) { - for (const auto& p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); -} - -// Traits for Map -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj* n = reinterpret_cast(src->v_obj); - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first) && - !kv.first.try_cast().has_value()) { - return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + - ", V]"; - } - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second) && - !kv.second.try_cast().has_value()) { - return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + - "]"; - } - } - } - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) return false; - if constexpr (std::is_same_v && std::is_same_v) { - return true; - } else { - const MapObj* n = reinterpret_cast(src->v_obj); - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; - } - } - return true; - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt; - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj* n = reinterpret_cast(src->v_obj); - bool storage_check = [&]() { - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; - } - } - return true; - }(); - // fast path, if storage check passes, we can return the array directly. - if (storage_check) return CopyFromAnyViewAfterCheck(src); - // slow path, we need to create a new map and convert to the target type. - Map ret; - for (const auto& kv : *n) { - auto k = kv.first.try_cast(); - auto v = kv.second.try_cast(); - if (!k.has_value() || !v.has_value()) return std::nullopt; - ret.Set(*std::move(k), *std::move(v)); - } - return ret; - } else { - return CopyFromAnyViewAfterCheck(src); - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "Map<" + details::Type2Str::v() + ", " + details::Type2Str::v() + ">"; - } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Map> = - type_contains_v && type_contains_v; -} // namespace details - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h deleted file mode 100644 index de24a44ded06..000000000000 --- a/ffi/include/tvm/ffi/container/shape.h +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/shape.h - * \brief Container to store shape of an Tensor. - */ -#ifndef TVM_FFI_CONTAINER_SHAPE_H_ -#define TVM_FFI_CONTAINER_SHAPE_H_ - -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! \brief An object representing a shape tuple. */ -class ShapeObj : public Object, public TVMFFIShapeCell { - public: - /*! \brief The type of shape index element. */ - using index_type = int64_t; - - /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ - int64_t Product() const { - int64_t product = 1; - for (size_t i = 0; i < this->size; ++i) { - product *= this->data[i]; - } - return product; - } - - /// \cond Doxygen_Suppress - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIShape, ShapeObj, Object); - /// \endcond -}; - -namespace details { - -class ShapeObjStdImpl : public ShapeObj { - public: - explicit ShapeObjStdImpl(std::vector other) : data_{other} { - this->data = data_.data(); - this->size = static_cast(data_.size()); - } - - private: - std::vector data_; -}; - -TVM_FFI_INLINE ObjectPtr MakeEmptyShape(size_t length, int64_t** mutable_data) { - ObjectPtr p = make_inplace_array_object(length); - static_assert(alignof(ShapeObj) % alignof(int64_t) == 0); - static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0); - int64_t* data = reinterpret_cast(reinterpret_cast(p.get()) + sizeof(ShapeObj)); - if (mutable_data) { - *mutable_data = data; - } - p->data = data; - p->size = length; - return p; -} - -// inplace shape allocation -template -TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end) { - size_t length = std::distance(begin, end); - int64_t* mutable_data; - ObjectPtr p = MakeEmptyShape(length, &mutable_data); - std::copy(begin, end, mutable_data); - return p; -} - -TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(const int64_t* data, int64_t ndim) { - int64_t* strides_data; - ObjectPtr strides = details::MakeEmptyShape(ndim, &strides_data); - int64_t stride = 1; - for (int i = ndim - 1; i >= 0; --i) { - strides_data[i] = stride; - stride *= data[i]; - } - return strides; -} - -} // namespace details - -/*! - * \brief Reference to shape object. - */ -class Shape : public ObjectRef { - public: - /*! \brief The type of shape index element. */ - using index_type = ShapeObj::index_type; - - /*! \brief Default constructor */ - Shape() : ObjectRef(details::MakeEmptyShape(0, nullptr)) {} - - /*! - * \brief Constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {} - - /** - * \brief Constructor from Array - * \param shape The Array - * - * \note This constructor will copy the data content. - */ - Shape(Array shape) // NOLINT(*) - : Shape(shape.begin(), shape.end()) {} - - /*! - * \brief constructor from initializer list - * \param shape The initializer list - */ - Shape(std::initializer_list shape) : Shape(shape.begin(), shape.end()) {} - - /*! - * \brief constructor from int64_t [N] - * - * \param other a int64_t array. - */ - Shape(std::vector other) // NOLINT(*) - : ObjectRef(make_object(std::move(other))) {} - - /*! - * \brief Create a strides from a shape. - * \param data The shape data. - * \param ndim The number of dimensions. - * \return The strides. - */ - static Shape StridesFromShape(const int64_t* data, int64_t ndim) { - return Shape(details::MakeStridesFromShape(data, ndim)); - } - - /*! - * \brief Return the data pointer - * - * \return const index_type* data pointer - */ - const int64_t* data() const { return get()->data; } - - /*! - * \brief Return the size of the shape tuple - * - * \return size_t shape tuple size - */ - size_t size() const { return get()->size; } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - int64_t operator[](size_t idx) const { - if (idx >= this->size()) { - TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); - } - return this->data()[idx]; - } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - int64_t at(size_t idx) const { return this->operator[](idx); } - - /*! \return Whether shape tuple is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the shape tuple */ - int64_t front() const { return this->at(0); } - - /*! \return The last element of the shape tuple */ - int64_t back() const { return this->at(this->size() - 1); } - - /*! \return begin iterator */ - const int64_t* begin() const { return get()->data; } - - /*! \return end iterator */ - const int64_t* end() const { return (get()->data + size()); } - - /*! \return The product of the shape tuple */ - int64_t Product() const { return get()->Product(); } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Shape, ObjectRef, ShapeObj); - /// \endcond - - private: - explicit Shape(ObjectPtr ptr) : ObjectRef(ptr) {} -}; - -inline std::ostream& operator<<(std::ostream& os, const Shape& shape) { - os << '['; - for (size_t i = 0; i < shape.size(); ++i) { - if (i != 0) { - os << ", "; - } - os << shape[i]; - } - os << ']'; - return os; -} - -// Shape -template <> -inline constexpr bool use_default_type_traits_v = false; - -// Allow auto conversion from Array to Shape, but not from Shape to Array -template <> -struct TypeTraits : public ObjectRefWithFallbackTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape; - TVM_FFI_INLINE static Shape ConvertFallbackValue(Array src) { return Shape(src); } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_CONTAINER_SHAPE_H_ diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h deleted file mode 100644 index 59dc7739ea63..000000000000 --- a/ffi/include/tvm/ffi/container/tensor.h +++ /dev/null @@ -1,468 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/tensor.h - * \brief Container to store a Tensor. - */ -#ifndef TVM_FFI_CONTAINER_TENSOR_H_ -#define TVM_FFI_CONTAINER_TENSOR_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Check if the device uses direct address, where address of data indicate alignment. - * \param device The input device. - * \return True if the device uses direct address, false otherwise. - */ -inline bool IsDirectAddressDevice(const DLDevice& device) { - return device.device_type <= kDLCUDAHost || device.device_type == kDLCUDAManaged || - device.device_type == kDLROCM || device.device_type == kDLROCMHost; -} - -/*! - * \brief check if a DLTensor is contiguous. - * \param arr The input DLTensor. - * \return The check result. - */ -inline bool IsContiguous(const DLTensor& arr) { - if (arr.strides == nullptr) return true; - int64_t expected_stride = 1; - for (int32_t i = arr.ndim; i != 0; --i) { - int32_t k = i - 1; - if (arr.shape[k] == 1) { - // Skip stride check if shape[k] is 1, where the dimension is contiguous - // regardless of the value of stride. - // - // For example, PyTorch will normalize stride to 1 if shape is 1 when exporting - // to DLPack. - // More context: https://github.com/pytorch/pytorch/pull/83158 - continue; - } - if (arr.strides[k] != expected_stride) return false; - expected_stride *= arr.shape[k]; - } - return true; -} - -/** - * \brief Check if the data in the DLTensor is aligned to the given alignment. - * \param arr The input DLTensor. - * \param alignment The alignment to check. - * \return True if the data is aligned to the given alignment, false otherwise. - */ -inline bool IsAligned(const DLTensor& arr, size_t alignment) { - if (IsDirectAddressDevice(arr.device)) { - return (reinterpret_cast(static_cast(arr.data) + arr.byte_offset) % alignment == - 0); - } else { - return arr.byte_offset % alignment == 0; - } -} - -/*! - * \brief return the total number of bytes needed to store packed data - * - * \param numel the number of elements in the array - * \param dtype the data type of the array - * \return the total number of bytes needed to store packed data - */ -inline size_t GetDataSize(int64_t numel, DLDataType dtype) { - // compatible handling sub-byte uint1(bool), which usually stored as uint8_t - // TODO(tqchen): revisit and switch to kDLBool - if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) { - return numel; - } - // for other sub-byte types, packing is preferred - return (numel * dtype.bits * dtype.lanes + 7) / 8; -} - -/*! - * \brief return the size of data the DLTensor holds, in terms of number of bytes - * - * \param arr the input DLTensor - * \return number of bytes of data in the DLTensor. - */ -inline size_t GetDataSize(const DLTensor& arr) { - size_t size = 1; - for (int i = 0; i < arr.ndim; ++i) { - size *= static_cast(arr.shape[i]); - } - return GetDataSize(size, arr.dtype); -} - -/*! \brief An object representing a Tensor. */ -class TensorObj : public Object, public DLTensor { - public: - /// \cond Doxygen_Suppress - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object); - /// \endcond - ~TensorObj() { - // deleting the cached dl managed tensor versioned - // need to acquire the value in case it is released by another thread - DLManagedTensorVersioned* cached = - cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire); - if (cached != nullptr) { - delete cached; - } - } - /*! - * \brief Move a Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensor* ToDLPack() const { - TensorObj* self = const_cast(this); - DLManagedTensor* ret = new DLManagedTensor(); - ret->dl_tensor = *static_cast(self); - ret->manager_ctx = self; - ret->deleter = DLManagedTensorDeleter; - details::ObjectUnsafe::IncRefObjectHandle(self); - return ret; - } - - /*! - * \brief Move a Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensorVersioned* ToDLPackVersioned() const { - TensorObj* from = const_cast(this); - // if cache is set, directly return it - // we need to use acquire to ensure that write to DLManagedTensorVersioned - // from another thread is visible to this thread. - DLManagedTensorVersioned* cached = - cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire); - // if cache is not set, create a new one - if (cached == nullptr) { - DLManagedTensorVersioned* ret = new DLManagedTensorVersioned(); - ret->version.major = DLPACK_MAJOR_VERSION; - ret->version.minor = DLPACK_MINOR_VERSION; - ret->dl_tensor = *static_cast(from); - ret->manager_ctx = from; - ret->deleter = EmbeddedDLManagedTensorVersionedDeleter; - ret->flags = 0; - DLManagedTensorVersioned* expected = nullptr; - // success set must release the new value to all other threads - // failure set must acquire, since the expected value is now coming - // from another thread that released this value - if (std::atomic_compare_exchange_strong_explicit(&cached_dl_managed_tensor_versioned_, - &expected, ret, std::memory_order_release, - std::memory_order_acquire)) { - // set is succes - cached = ret; - } else { - // delete the ret value as another thread raced to set this one first - delete ret; - cached = expected; - } - // at this point, cached is the value that officially set to the field - } - // inc the ref count of the from object - details::ObjectUnsafe::IncRefObjectHandle(from); - return cached; - } - - protected: - /*! \brief Internal data to back returning shape. */ - Optional shape_data_; - /*! \brief Internal data to back returning strides. */ - Optional strides_data_; - /*! \brief cached data to back returning DLManagedTensorVersioned. */ - mutable std::atomic cached_dl_managed_tensor_versioned_ = nullptr; - - /*! - * \brief Deleter for DLManagedTensor. - * \param tensor The DLManagedTensor to be deleted. - */ - static void DLManagedTensorDeleter(DLManagedTensor* tensor) { - TensorObj* obj = static_cast(tensor->manager_ctx); - details::ObjectUnsafe::DecRefObjectHandle(obj); - delete tensor; - } - - /*! - * \brief Deleter for DLManagedTensorVersioned. - * \param tensor The DLManagedTensorVersioned to be deleted. - */ - static void EmbeddedDLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) { - TensorObj* obj = static_cast(tensor->manager_ctx); - details::ObjectUnsafe::DecRefObjectHandle(obj); - } - - friend class Tensor; - /// \endcond -}; - -namespace details { -/*! - *\brief Helper class to create an TensorObj from an NDAllocator - * - * The underlying allocator needs to be implemented by user. - */ -template -class TensorObjFromNDAlloc : public TensorObj { - public: - template - TensorObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, - ExtraArgs&&... extra_args) - : alloc_(alloc) { - this->device = device; - this->ndim = static_cast(shape.size()); - this->dtype = dtype; - this->shape = const_cast(shape.data()); - Shape strides = Shape::StridesFromShape(this->shape, this->ndim); - this->strides = const_cast(strides.data()); - this->byte_offset = 0; - this->shape_data_ = std::move(shape); - this->strides_data_ = std::move(strides); - alloc_.AllocData(static_cast(this), std::forward(extra_args)...); - } - - ~TensorObjFromNDAlloc() { alloc_.FreeData(static_cast(this)); } - - private: - TNDAlloc alloc_; -}; - -/*! \brief helper class to import from DLPack legacy DLManagedTensor */ -template -class TensorObjFromDLPack : public TensorObj { - public: - explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { - *static_cast(this) = tensor_->dl_tensor; - if (tensor_->dl_tensor.strides == nullptr) { - Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim); - this->strides = const_cast(strides.data()); - this->strides_data_ = std::move(strides); - } - } - - ~TensorObjFromDLPack() { - // run DLPack deleter if needed. - if (tensor_->deleter != nullptr) { - (*tensor_->deleter)(tensor_); - } - } - - private: - TDLPackManagedTensor* tensor_; -}; -} // namespace details - -/*! - * \brief Managed Tensor (n-dimensional array). - * The tensor is backed by reference counted blocks. - * - * \note This class can be subclassed to implement downstream customized - * Tensor types that are backed by the same TensorObj storage type. - */ -class Tensor : public ObjectRef { - public: - /*! - * \brief Get the shape of the Tensor. - * \return The shape of the Tensor. - */ - tvm::ffi::Shape shape() const { - TensorObj* obj = get_mutable(); - if (!obj->shape_data_.has_value()) { - obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim); - } - return *(obj->shape_data_); - } - /*! - * \brief Get the strides of the Tensor. - * \return The strides of the Tensor. - */ - tvm::ffi::Shape strides() const { - TensorObj* obj = get_mutable(); - TVM_FFI_ICHECK(obj->strides != nullptr); - if (!obj->strides_data_.has_value()) { - obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides + obj->ndim); - } - return *(obj->strides_data_); - } - /*! - * \brief Get the data type of the Tensor. - * \return The data type of the Tensor. - */ - DLDataType dtype() const { return (*this)->dtype; } - /*! - * \brief Check if the Tensor is contiguous. - * \return True if the Tensor is contiguous, false otherwise. - */ - bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } - /*! - * \brief Check if the Tensor data is aligned to the given alignment. - * \param alignment The alignment to check. - * \return True if the Tensor data is aligned to the given alignment, false otherwise. - */ - bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); } - /*! - * \brief Create a Tensor from a NDAllocator. - * \param alloc The NDAllocator. - * \param shape The shape of the Tensor. - * \param dtype The data type of the Tensor. - * \param device The device of the Tensor. - * \param extra_args Extra arguments to be forwarded to TNDAlloc. - * \return The created Tensor. - * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free. - * \tparam ExtraArgs Extra arguments to be passed to Alloc. - */ - template - static Tensor FromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, - ExtraArgs&&... extra_args) { - return Tensor(make_object>( - alloc, shape, dtype, device, std::forward(extra_args)...)); - } - /*! - * \brief Create a Tensor from a DLPackTensorAllocator - * - * This function can be used together with TVMFFIEnvSetTensorAllocator - * in the extra/c_env_api.h to create Tensor from the thread-local - * environment allocator. - * - * \code - * - * ffi::Tensor tensor = ffi::Tensor::FromDLPackAlloc( - * TVMFFIEnvGetTensorAllocator(), shape, dtype, device - * ); - * \endcode - * - * \param allocator The DLPack allocator. - * \param shape The shape of the Tensor. - * \param dtype The data type of the Tensor. - * \param device The device of the Tensor. - * \return The created Tensor. - */ - static Tensor FromDLPackAlloc(DLPackTensorAllocator allocator, ffi::Shape shape, DLDataType dtype, - DLDevice device) { - if (allocator == nullptr) { - TVM_FFI_THROW(RuntimeError) - << "FromDLPackAlloc: allocator is nullptr, " - << "likely because TVMFFIEnvSetTensorAllocator has not been called."; - } - DLTensor prototype; - prototype.device = device; - prototype.dtype = dtype; - prototype.shape = const_cast(shape.data()); - prototype.ndim = static_cast(shape.size()); - prototype.strides = nullptr; - prototype.byte_offset = 0; - prototype.data = nullptr; - DLManagedTensorVersioned* tensor = nullptr; - // error context to be used to propagate error - struct ErrorContext { - std::string kind; - std::string message; - static void SetError(void* error_ctx, const char* kind, const char* message) { - ErrorContext* error_context = static_cast(error_ctx); - error_context->kind = kind; - error_context->message = message; - } - }; - ErrorContext error_context; - int ret = (*allocator)(&prototype, &tensor, &error_context, ErrorContext::SetError); - if (ret != 0) { - throw ffi::Error(error_context.kind, error_context.message, - TVMFFITraceback(__FILE__, __LINE__, __func__, 0)); - } - return Tensor(make_object>(tensor)); - } - /*! - * \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API. - * \param tensor The input DLPack managed tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \note This function will not run any checks on flags. - * \return The created Tensor. - */ - static Tensor FromDLPack(DLManagedTensor* tensor, size_t require_alignment = 0, - bool require_contiguous = false) { - if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment - << " bytes."; - } - if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; - } - return Tensor(make_object>(tensor)); - } - - /*! - * \brief Create a Tensor from a DLPack managed tensor, post v1.0 API. - * \param tensor The input DLPack managed tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \return The created Tensor. - */ - static Tensor FromDLPackVersioned(DLManagedTensorVersioned* tensor, size_t require_alignment = 0, - bool require_contiguous = false) { - if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment - << " bytes."; - } - if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; - } - if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) { - TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet supported"; - } - return Tensor(make_object>(tensor)); - } - - /*! - * \brief Convert the Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensor* ToDLPack() const { return get_mutable()->ToDLPack(); } - - /*! - * \brief Convert the Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensorVersioned* ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, ObjectRef, TensorObj); - /// \endcond - - protected: - /*! - * \brief Get mutable internal container pointer. - * \return a mutable container pointer. - */ - TensorObj* get_mutable() const { return const_cast(get()); } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_CONTAINER_TENSOR_H_ diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h deleted file mode 100644 index 75342409eabb..000000000000 --- a/ffi/include/tvm/ffi/container/tuple.h +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/tuple.h - * \brief Typed tuple like std::tuple backed by ArrayObj container. - */ -#ifndef TVM_FFI_CONTAINER_TUPLE_H_ -#define TVM_FFI_CONTAINER_TUPLE_H_ - -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Typed tuple like std::tuple backed by ArrayObj container. - * - * Tuple implements in-place copy-on-write semantics. - * - * \tparam Types The types of the tuple elements - */ -template -class Tuple : public ObjectRef { - public: - static_assert(details::all_storage_enabled_v, - "All types used in Tuple<...> must be compatible with Any"); - /*! \brief Default constructor */ - Tuple() : ObjectRef(MakeDefaultTupleNode()) {} - /*! - * \brief Constructor with UnsafeInit - */ - explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {} - /*! \brief Copy constructor */ - Tuple(const Tuple& other) : ObjectRef(other) {} - /*! \brief Move constructor */ - Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} - /*! - * \brief Constructor from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...), int>> - Tuple(const Tuple& other) : ObjectRef(other) {} - - /*! - * \brief Constructor from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...), int>> - Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} - - /*! - * \brief Constructor from arguments - * \param args The arguments - * \tparam UTypes The types of the other tuple - */ - template , Tuple> && ...))>> - explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam The enable_if_t type - */ - TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam The enable_if_t type - */ - TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...)>> - TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...)>> - TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Get I-th element of the tuple - * - * \tparam I The index of the element to get - * \return The I-th element of the tuple - * \note We use stl style since get usually is like a getter. - */ - template - auto get() const { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using ReturnType = std::tuple_element_t>; - const Any* ptr = GetArrayObj()->begin() + I; - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*ptr); - } - - /*! - * \brief Set I-th element of the tuple - * - * \param item The item to set - * \tparam I The index of the element to set - * \tparam U The type of the item - * - * \note This function will perform copy on write if underlying - * container is not uniquely owned. - * We use CamelCase since Set can cause copy on write - * and is more complicated than simple field setter. - */ - template - void Set(U&& item) { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using T = std::tuple_element_t>; - this->CopyIfNotUnique(); - Any* ptr = GetArrayObj()->MutableBegin() + I; - *ptr = T(std::forward(item)); - } - - /*! \brief specify container node */ - using ContainerType = ArrayObj; - - private: - static ObjectPtr MakeDefaultTupleNode() { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types()), p->size_++), ...); - return p; - } - - template - static ObjectPtr MakeTupleNode(UTypes&&... args) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types(std::forward(args))), p->size_++), ...); - return p; - } - - /*! \brief Copy on write */ - void CopyIfNotUnique() { - if (!data_.unique()) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - const Any* read = GetArrayObj()->begin(); - // increase size after each new to ensure exception safety - for (size_t i = 0; i < sizeof...(Types); ++i) { - new (itr++) Any(*read++); - p->size_++; - } - data_ = std::move(p); - } - } - - /*! \return The underlying ArrayObj */ - ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } - - template - friend class Tuple; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) { - return "Array[size=" + std::to_string(n->size()) + "]"; - } - return GetMismatchTypeInfoHelper<0, Types...>(n->begin()); - } - - template - TVM_FFI_INLINE static std::string GetMismatchTypeInfoHelper(const Any* arr) { - if constexpr (!std::is_same_v) { - const Any& any_v = arr[I]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v) && !(any_v.try_cast().has_value())) { - // now report the accurate mismatch information - return "Array[index " + std::to_string(I) + ": " + - details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; - } - } - if constexpr (sizeof...(Rest) > 0) { - return GetMismatchTypeInfoHelper(arr); - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) return false; - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) return false; - const TVMFFIAny* ffi_any_arr = reinterpret_cast(n->begin()); - return CheckAnyStrictHelper<0, Types...>(ffi_any_arr); - } - - template - TVM_FFI_INLINE static bool CheckAnyStrictHelper(const TVMFFIAny* src_arr) { - if constexpr (!std::is_same_v) { - if (!TypeTraits::CheckAnyStrict(src_arr + I)) { - return false; - } - } - if constexpr (sizeof...(Rest) > 0) { - return CheckAnyStrictHelper(src_arr); - } - return true; - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src // - ) { - if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) return std::nullopt; - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); - } - // slow path, try to convert to each type to match the tuple storage need. - Array arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); - Any* ptr = arr.CopyOnWrite()->MutableBegin(); - if (TryConvertElements<0, Types...>(ptr)) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr>( - details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); - } - return std::nullopt; - } - - template - TVM_FFI_INLINE static bool TryConvertElements(Any* arr) { - if constexpr (!std::is_same_v) { - if (auto opt_convert = arr[I].try_cast()) { - arr[I] = *std::move(opt_convert); - } else { - return false; - } - } - if constexpr (sizeof...(Rest) > 0) { - return TryConvertElements(std::move(arr)); - } else { - return true; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return details::ContainerTypeStr("Tuple"); - } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Tuple> = (type_contains_v && ...); -} // namespace details - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h deleted file mode 100644 index cae5a673b8ce..000000000000 --- a/ffi/include/tvm/ffi/container/variant.h +++ /dev/null @@ -1,302 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/variant.h - * \brief Runtime variant container types. - */ -#ifndef TVM_FFI_CONTAINER_VARIANT_H_ -#define TVM_FFI_CONTAINER_VARIANT_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base class for Variant. - * - * \tparam all_storage_object Whether all types are derived from ObjectRef. - */ -template -class VariantBase { - public: - TVM_FFI_INLINE bool same_as(const VariantBase& other) const { - return data_.same_as(other.data_); - } - - protected: - template - explicit VariantBase(T other) : data_(std::move(other)) {} - - TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } - - TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } - - TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } - - Any data_; -}; - -// Specialization for all object ref case, backed by ObjectRef. -template <> -class VariantBase : public ObjectRef { - protected: - template - explicit VariantBase(const T& other) : ObjectRef(other) {} - template - explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {} - explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {} - explicit VariantBase(Any other) - : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} - - TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } - - TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } - - TVM_FFI_INLINE AnyView ToAnyView() const { - TVMFFIAny any_data; - if (data_ == nullptr) { - any_data.type_index = TypeIndex::kTVMFFINone; - any_data.zero_padding = 0; - any_data.v_int64 = 0; - } else { - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); - any_data.type_index = data_->type_index(); - any_data.zero_padding = 0; - any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); - } - return AnyView::CopyFromTVMFFIAny(any_data); - } -}; -} // namespace details - -/*! - * \brief A typed variant container. - * - * When all values are ObjectRef, Variant is backed by ObjectRef, - * otherwise it is backed by Any. - */ -template -class Variant : public details::VariantBase> { - public: - /// \cond Doxygen_Suppress - using TParent = details::VariantBase>; - static_assert(details::all_storage_enabled_v, - "All types used in Variant<...> must be compatible with Any"); - /* - * \brief Helper utility to check if the type can be contained in the variant - */ - template - static constexpr bool variant_contains_v = (details::type_contains_v || ...); - /* \brief Helper utility for SFINAE if the type is part of the variant */ - template - using enable_if_variant_contains_t = std::enable_if_t>; - /// \endcond - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - Variant(const Variant& other) : TParent(other.data_) {} - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - Variant(Variant&& other) : TParent(std::move(other.data_)) {} - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - TVM_FFI_INLINE Variant& operator=(const Variant& other) { - this->SetData(other.data_); - return *this; - } - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - TVM_FFI_INLINE Variant& operator=(Variant&& other) { - this->SetData(std::move(other.data_)); - return *this; - } - - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - template > - Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - template > - TVM_FFI_INLINE Variant& operator=(T other) { - return operator=(Variant(std::move(other))); - } - - /*! - * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. - * \return The casted value, or std::nullopt if the cast is not possible. - * \tparam T The type to cast to. - */ - template > - TVM_FFI_INLINE std::optional as() const { - return this->TParent::ToAnyView().template as(); - } - - /*! - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T* as() const { - return this->TParent::ToAnyView().template as().value_or(nullptr); - } - - /*! - * \brief Get the value of the variant in type T, throws an exception if cast fails. - * \return The value of the variant - * \tparam T The type to get. - */ - template > - TVM_FFI_INLINE T get() const& { - return this->TParent::ToAnyView().template cast(); - } - - /*! - * \brief Get the value of the variant in type T, throws an exception if cast fails. - * \return The value of the variant - * \tparam T The type to get. - */ - template > - TVM_FFI_INLINE T get() && { - return std::move(*this).TParent::MoveToAny().template cast(); - } - - /*! - * \brief Get the type key of the variant - * \return The type key of the variant - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } - - private: - friend struct TypeTraits>; - friend struct ObjectPtrHash; - friend struct ObjectPtrEqual; - // constructor from any - explicit Variant(Any data) : TParent(std::move(data)) {} - /*! - * \brief Get the object pointer from the variant - * \note This function is only available if all types used in Variant<...> are derived from - * ObjectRef - */ - TVM_FFI_INLINE Object* GetObjectPtrForHashEqual() const { - constexpr bool all_object_v = (std::is_base_of_v && ...); - static_assert(all_object_v, - "All types used in Variant<...> must be derived from ObjectRef " - "to enable ObjectPtrHash/ObjectPtrEqual"); - return this->data_.get(); - } - // rexpose to friend class - using TParent::MoveToAny; - using TParent::ToAnyView; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const Variant& src, TVMFFIAny* result) { - *result = src.ToAnyView().CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(Variant src, TVMFFIAny* result) { - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return (TypeTraits::CheckAnyStrict(src) || ...); - } - - TVM_FFI_INLINE static Variant CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); - } - - TVM_FFI_INLINE static Variant MoveFromAnyAfterCheck(TVMFFIAny* src) { - return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src))); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); - } - // More expensive path, try to convert to each type, in order of declaration - return TryVariantTypes(src); - } - - template - TVM_FFI_INLINE static std::optional> TryVariantTypes(const TVMFFIAny* src) { - if (auto opt_convert = TypeTraits::TryCastFromAnyView(src)) { - return Variant(*std::move(opt_convert)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryVariantTypes(src); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return details::ContainerTypeStr("Variant"); } -}; - -template -TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant& a) const { - return std::hash()(a.GetObjectPtrForHashEqual()); -} - -template -TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant& a, - const Variant& b) const { - return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual(); -} - -namespace details { -template -inline constexpr bool type_contains_v, T> = (type_contains_v || ...); -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h deleted file mode 100644 index a9e09d229372..000000000000 --- a/ffi/include/tvm/ffi/dtype.h +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/dtype.h - * \brief Data type handling. - */ -#ifndef TVM_FFI_DTYPE_H_ -#define TVM_FFI_DTYPE_H_ - -#include -#include -#include -#include -#include - -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Extension code beyond the DLDataType. - * - * This class is always consistent with the DLPack. - */ -enum DLExtDataTypeCode { kDLExtCustomBegin = 129 }; - -namespace details { - -/* - * \brief Convert a DLDataTypeCode to a string. - * \param os The output stream. - * \param type_code The DLDataTypeCode to convert. - */ -inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(*) - switch (static_cast(type_code)) { - case kDLInt: { - return "int"; - } - case kDLUInt: { - return "uint"; - } - case kDLFloat: { - return "float"; - } - case kDLOpaqueHandle: { - return "handle"; - } - case kDLBfloat: { - return "bfloat"; - } - case kDLFloat8_e3m4: { - return "float8_e3m4"; - } - case kDLFloat8_e4m3: { - return "float8_e4m3"; - } - case kDLFloat8_e4m3b11fnuz: { - return "float8_e4m3b11fnuz"; - } - case kDLFloat8_e4m3fn: { - return "float8_e4m3fn"; - } - case kDLFloat8_e4m3fnuz: { - return "float8_e4m3fnuz"; - } - case kDLFloat8_e5m2: { - return "float8_e5m2"; - } - case kDLFloat8_e5m2fnuz: { - return "float8_e5m2fnuz"; - } - case kDLFloat8_e8m0fnu: { - return "float8_e8m0fnu"; - } - case kDLFloat6_e2m3fn: { - return "float6_e2m3fn"; - } - case kDLFloat6_e3m2fn: { - return "float6_e3m2fn"; - } - case kDLFloat4_e2m1fn: { - return "float4_e2m1fn"; - } - default: { - if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { - return "custom"; - } else { - TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" - << static_cast(type_code); - } - TVM_FFI_UNREACHABLE(); - } - } -} -} // namespace details - -/*! - * \brief Convert a string to a DLDataType. - * \param str The string to convert. - * \return The DLDataType. - */ -inline DLDataType StringToDLDataType(const String& str) { - DLDataType out; - TVMFFIByteArray data{str.data(), str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out)); - return out; -} - -/*! - * \brief Convert a DLDataType to a string. - * \param dtype The DLDataType to convert. - * \return The string. - */ -inline String DLDataTypeToString(DLDataType dtype) { - TVMFFIAny out; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out)); - return TypeTraits::MoveFromAnyAfterCheck(&out); -} - -// DLDataType -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; - - TVM_FFI_INLINE static void CopyToAnyView(const DLDataType& src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->zero_padding = 0; - result->v_dtype = src; - } - - TVM_FFI_INLINE static void MoveToAny(DLDataType src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->zero_padding = 0; - result->v_dtype = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIDataType; - } - - TVM_FFI_INLINE static DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return src->v_dtype; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIDataType) { - return src->v_dtype; - } - // enable string to dtype auto conversion - if (auto opt_str = TypeTraits::TryCastFromAnyView(src)) { - return StringToDLDataType(*opt_str); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } -}; -} // namespace ffi -} // namespace tvm - -// define DLDataType comparison and printing in root namespace -inline std::ostream& operator<<(std::ostream& os, DLDataType dtype) { // NOLINT(*) - return os << tvm::ffi::DLDataTypeToString(dtype); -} - -inline bool operator==(const DLDataType& lhs, const DLDataType& rhs) { - return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; -} - -inline bool operator!=(const DLDataType& lhs, const DLDataType& rhs) { return !(lhs == rhs); } -#endif // TVM_FFI_DTYPE_H_ diff --git a/ffi/include/tvm/ffi/endian.h b/ffi/include/tvm/ffi/endian.h deleted file mode 100644 index 4a73b82e6c30..000000000000 --- a/ffi/include/tvm/ffi/endian.h +++ /dev/null @@ -1,89 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/endian.h - * \brief Endian detection and handling - */ -#ifndef TVM_FFI_ENDIAN_H_ -#define TVM_FFI_ENDIAN_H_ - -#include -#include - -#ifndef TVM_FFI_IO_USE_LITTLE_ENDIAN -#define TVM_FFI_IO_USE_LITTLE_ENDIAN 1 -#endif - -#ifdef TVM_FFI_CMAKE_LITTLE_ENDIAN -// If compiled with CMake, use CMake's endian detection logic -#define TVM_FFI_LITTLE_ENDIAN TVM_FFI_CMAKE_LITTLE_ENDIAN -#else -#if defined(__APPLE__) || defined(_WIN32) -#define TVM_FFI_LITTLE_ENDIAN 1 -#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || defined(__RISCV__) -#include -#define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) -#elif defined(__FreeBSD__) || defined(__OpenBSD__) -#include -#define TVM_FFI_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN) -#elif defined(__QNX__) -#include -#define TVM_FFI_LITTLE_ENDIAN (BYTE_ORDER == LITTLE_ENDIAN) -#elif defined(__EMSCRIPTEN__) || defined(__hexagon__) -#define TVM_FFI_LITTLE_ENDIAN 1 -#elif defined(__sun) || defined(sun) -#include -#if defined(_LITTLE_ENDIAN) -#define TVM_FFI_LITTLE_ENDIAN 1 -#else -#define TVM_FFI_LITTLE_ENDIAN 0 -#endif -#else -#error "Unable to determine endianness of your machine; use CMake to compile" -#endif -#endif - -/*! \brief whether serialize using little endian */ -#define TVM_FFI_IO_NO_ENDIAN_SWAP (TVM_FFI_LITTLE_ENDIAN == TVM_FFI_IO_USE_LITTLE_ENDIAN) - -namespace tvm { -namespace ffi { -/*! - * \brief A generic inplace byte swapping function. - * \param data The data pointer. - * \param elem_bytes The number of bytes of the data elements - * \param num_elems Number of elements in the data. - * \note Always try pass in constant elem_bytes to enable - * compiler optimization - */ -inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { - for (size_t i = 0; i < num_elems; ++i) { - uint8_t* bptr = reinterpret_cast(data) + elem_bytes * i; - for (size_t j = 0; j < elem_bytes / 2; ++j) { - uint8_t v = bptr[elem_bytes - 1 - j]; - bptr[elem_bytes - 1 - j] = bptr[j]; - bptr[j] = v; - } - } -} -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_ENDIAN_H_ diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h deleted file mode 100644 index 261b69e71b5d..000000000000 --- a/ffi/include/tvm/ffi/error.h +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/error.h - * \brief Error handling component. - */ -#ifndef TVM_FFI_ERROR_H_ -#define TVM_FFI_ERROR_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -/*! - * \brief Macro defines whether we enable libbacktrace - */ -#ifndef TVM_FFI_USE_LIBBACKTRACE -#define TVM_FFI_USE_LIBBACKTRACE 1 -#endif - -/*! - * \brief Macro defines whether to install signal handler - * and print backtrace during segfault - */ -#ifndef TVM_FFI_BACKTRACE_ON_SEGFAULT -#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1 -#endif - -#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW -#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0 -#endif - -namespace tvm { -namespace ffi { - -/*! - * \brief Error already set in frontend env. - * - * This error can be thrown by EnvCheckSignals to indicate - * that there is an error set in the frontend environment(e.g. - * python interpreter). The TVM FFI should catch this error - * and return a proper code to tell the frontend caller about - * this fact. - * - * \code - * - * void ExampleLongRunningFunction() { - * if (TVMFFIEnvCheckSignals() != 0) { - * throw ::tvm::ffi::EnvErrorAlreadySet(); - * } - * // do work here - * } - * - * \endcode - */ -struct EnvErrorAlreadySet : public std::exception {}; - -/*! - * \brief Error object class. - */ -class ErrorObj : public Object, public TVMFFIErrorCell { - public: - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC("ffi.Error", ErrorObj, Object); - /// \endcond -}; - -namespace details { -class ErrorObjFromStd : public ErrorObj { - public: - ErrorObjFromStd(std::string kind, std::string message, std::string traceback) - : kind_data_(kind), message_data_(message), traceback_data_(traceback) { - this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()}; - this->message = TVMFFIByteArray{message_data_.data(), message_data_.length()}; - this->traceback = TVMFFIByteArray{traceback_data_.data(), traceback_data_.length()}; - this->update_traceback = UpdateTraceback; - } - - private: - /*! - * \brief Update the traceback of the error object. - * \param traceback The traceback to update. - */ - static void UpdateTraceback(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback_str) { - ErrorObjFromStd* obj = static_cast(self); - obj->traceback_data_ = std::string(traceback_str->data, traceback_str->size); - obj->traceback = TVMFFIByteArray{obj->traceback_data_.data(), obj->traceback_data_.length()}; - } - - std::string kind_data_; - std::string message_data_; - std::string traceback_data_; -}; -} // namespace details - -/*! - * \brief Managed reference to ErrorObj - * \sa Error Object - */ -class Error : public ObjectRef, public std::exception { - public: - /*! - * \brief Constructor - * \param kind The kind of the error. - * \param message The message of the error. - * \param traceback The traceback of the error. - */ - Error(std::string kind, std::string message, std::string traceback) { - data_ = make_object(kind, message, traceback); - } - - /*! - * \brief Constructor - * \param kind The kind of the error. - * \param message The message of the error. - * \param traceback The traceback of the error. - */ - Error(std::string kind, std::string message, const TVMFFIByteArray* traceback) - : Error(kind, message, std::string(traceback->data, traceback->size)) {} - - /*! - * \brief Get the kind of the error object. - * \return The kind of the error object. - */ - std::string kind() const { - ErrorObj* obj = static_cast(data_.get()); - return std::string(obj->kind.data, obj->kind.size); - } - - /*! - * \brief Get the message of the error object. - * \return The message of the error object. - */ - std::string message() const { - ErrorObj* obj = static_cast(data_.get()); - return std::string(obj->message.data, obj->message.size); - } - - /*! - * \brief Get the traceback of the error object. - * \return The traceback of the error object. - */ - std::string traceback() const { - ErrorObj* obj = static_cast(data_.get()); - return std::string(obj->traceback.data, obj->traceback.size); - } - - /*! - * \brief Update the traceback of the error object. - * \param traceback_str The traceback to update. - */ - void UpdateTraceback(const TVMFFIByteArray* traceback_str) { - ErrorObj* obj = static_cast(data_.get()); - obj->update_traceback(obj, traceback_str); - } - - /*! - * \brief Get the error message - * \return The error message - */ - const char* what() const noexcept(true) override { - thread_local std::string what_data; - ErrorObj* obj = static_cast(data_.get()); - what_data = (std::string("Traceback (most recent call last):\n") + - std::string(obj->traceback.data, obj->traceback.size) + - std::string(obj->kind.data, obj->kind.size) + std::string(": ") + - std::string(obj->message.data, obj->message.size) + '\n'); - return what_data.c_str(); - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Error, ObjectRef, ErrorObj); - /// \endcond -}; - -namespace details { - -class ErrorBuilder { - public: - explicit ErrorBuilder(std::string kind, std::string traceback, bool log_before_throw) - : kind_(kind), traceback_(traceback), log_before_throw_(log_before_throw) {} - - explicit ErrorBuilder(std::string kind, const TVMFFIByteArray* traceback, bool log_before_throw) - : ErrorBuilder(kind, std::string(traceback->data, traceback->size), log_before_throw) {} - -// MSVC disable warning in error builder as it is exepected -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4722) -#endif - // avoid inline to reduce binary size, error throw path do not need to be fast - [[noreturn]] ~ErrorBuilder() noexcept(false) { - ::tvm::ffi::Error error(std::move(kind_), stream_.str(), std::move(traceback_)); - if (log_before_throw_) { - std::cerr << error.what(); - } - throw error; - } -#ifdef _MSC_VER -#pragma warning(pop) -#endif - - std::ostringstream& stream() { return stream_; } - - protected: - std::string kind_; - std::ostringstream stream_; - std::string traceback_; - bool log_before_throw_; -}; - -} // namespace details - -/*! - * \brief Helper macro to throw an error with traceback and message - * - * \code - * - * void ThrowError() { - * TVM_FFI_THROW(RuntimeError) << "error message"; - * } - * - * \endcode - */ -#define TVM_FFI_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder(#ErrorKind, \ - TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0), \ - TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \ - .stream() - -/*! - * \brief Explicitly log error in stderr and then throw the error. - * - * \note This is only necessary on startup functions where we know error - * cannot be caught, and it is better to have a clear log message. - * In most cases, we should use use TVM_FFI_THROW. - */ -#define TVM_FFI_LOG_AND_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder( \ - #ErrorKind, TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0), true) \ - .stream() - -// Glog style checks with TVM_FFI prefix -// NOTE: we explicitly avoid glog style generic macros (LOG/CHECK) in tvm ffi -// to avoid potential conflict of downstream users who might have their own GLOG style macros -namespace details { - -template -TVM_FFI_INLINE std::unique_ptr LogCheckFormat(const X& x, const Y& y) { - std::ostringstream os; - os << " (" << x << " vs. " << y << ") "; // CHECK_XX(x, y) requires x and y can be serialized to - // string. Use CHECK(x OP y) otherwise. - return std::make_unique(os.str()); -} - -#define TVM_FFI_CHECK_FUNC(name, op) \ - template \ - TVM_FFI_INLINE std::unique_ptr LogCheck##name(const X& x, const Y& y) { \ - if (x op y) return nullptr; \ - return LogCheckFormat(x, y); \ - } \ - TVM_FFI_INLINE std::unique_ptr LogCheck##name(int x, int y) { \ - return LogCheck##name(x, y); \ - } - -// Inline _Pragma in macros does not work reliably on old version of MSVC and -// GCC. We wrap all comparisons in a function so that we can use #pragma to -// silence bad comparison warnings. -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsign-compare" -#elif defined(_MSC_VER) // MSVC -#pragma warning(push) -#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch -#endif - -TVM_FFI_CHECK_FUNC(_LT, <) -TVM_FFI_CHECK_FUNC(_GT, >) -TVM_FFI_CHECK_FUNC(_LE, <=) -TVM_FFI_CHECK_FUNC(_GE, >=) -TVM_FFI_CHECK_FUNC(_EQ, ==) -TVM_FFI_CHECK_FUNC(_NE, !=) - -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) // MSVC -#pragma warning(pop) -#endif -} // namespace details - -#define TVM_FFI_ICHECK_BINARY_OP(name, op, x, y) \ - if (auto __tvm__log__err = ::tvm::ffi::details::LogCheck##name(x, y)) \ - TVM_FFI_THROW(InternalError) << "Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": " - -#define TVM_FFI_ICHECK(x) \ - if (!(x)) TVM_FFI_THROW(InternalError) << "Check failed: (" #x << ") is false: " - -#define TVM_FFI_ICHECK_LT(x, y) TVM_FFI_ICHECK_BINARY_OP(_LT, <, x, y) -#define TVM_FFI_ICHECK_GT(x, y) TVM_FFI_ICHECK_BINARY_OP(_GT, >, x, y) -#define TVM_FFI_ICHECK_LE(x, y) TVM_FFI_ICHECK_BINARY_OP(_LE, <=, x, y) -#define TVM_FFI_ICHECK_GE(x, y) TVM_FFI_ICHECK_BINARY_OP(_GE, >=, x, y) -#define TVM_FFI_ICHECK_EQ(x, y) TVM_FFI_ICHECK_BINARY_OP(_EQ, ==, x, y) -#define TVM_FFI_ICHECK_NE(x, y) TVM_FFI_ICHECK_BINARY_OP(_NE, !=, x, y) -#define TVM_FFI_ICHECK_NOTNULL(x) \ - ((x) == nullptr ? TVM_FFI_THROW(InternalError) << "Check not null: " #x << ' ', \ - (x) : (x)) // NOLINT(*) -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_ERROR_H_ diff --git a/ffi/include/tvm/ffi/extra/base.h b/ffi/include/tvm/ffi/extra/base.h deleted file mode 100644 index b09b3540a83e..000000000000 --- a/ffi/include/tvm/ffi/extra/base.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/base.h - * \brief Base header for Extra API. - * - * The extra APIs contains a minmal set of extra APIs that are not - * required to support essential core functionality. - */ -#ifndef TVM_FFI_EXTRA_BASE_H_ -#define TVM_FFI_EXTRA_BASE_H_ - -#include - -/*! - * \brief Marks the API as extra c++ api that is defined in cc files. - * - * They are implemented in cc files to reduce compile-time overhead. - * The input/output only uses POD/Any/ObjectRef for ABI stability. - * However, these extra APIs may have an issue across MSVC/Itanium ABI, - * - * Related features are also available through reflection based function - * that is fully based on C API - * - * The project aims to minimize the number of extra C++ APIs to keep things - * lightweight and restrict the use to non-core functionalities. - */ -#ifndef TVM_FFI_EXTRA_CXX_API -#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL -#endif - -#endif // TVM_FFI_EXTRA_BASE_H_ diff --git a/ffi/include/tvm/ffi/extra/base64.h b/ffi/include/tvm/ffi/extra/base64.h deleted file mode 100644 index da763cfe3a03..000000000000 --- a/ffi/include/tvm/ffi/extra/base64.h +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * - * \file tvm/ffi/extra/base64.h - * \brief Base64 encoding and decoding utilities - */ -#ifndef TVM_FFI_EXTRA_BASE64_H_ -#define TVM_FFI_EXTRA_BASE64_H_ - -#include - -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Encode a byte array into a base64 string - * \param bytes The byte array to encode - * \return The base64 encoded string - */ -inline String Base64Encode(TVMFFIByteArray bytes) { - // encoding every 3 bytes into 4 characters - constexpr const char kEncodeTable[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string encoded; - encoded.reserve(4 * (bytes.size + 2) / 3); - - for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { - int32_t buf[3]; - buf[0] = static_cast(bytes.data[i]); - buf[1] = static_cast(bytes.data[i + 1]); - buf[2] = static_cast(bytes.data[i + 2]); - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); - encoded.push_back(kEncodeTable[buf[2] & 0x3F]); - } - if (bytes.size % 3 == 1) { - int32_t buf[1] = {static_cast(bytes.data[bytes.size - 1])}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); - encoded.push_back('='); - encoded.push_back('='); - } else if (bytes.size % 3 == 2) { - int32_t buf[2] = {static_cast(bytes.data[bytes.size - 2]), - static_cast(bytes.data[bytes.size - 1])}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); - encoded.push_back('='); - } - return String(encoded); -} - -/*! - * \brief Encode a bytes object into a base64 string - * \param data The bytes object to encode - * \return The base64 encoded string - */ -inline String Base64Encode(const Bytes& data) { - return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); -} - -/*! - * \brief Decode a base64 string into a byte array - * \param bytes The bytes to be decoded - * \return The decoded byte array - */ -inline Bytes Base64Decode(TVMFFIByteArray bytes) { - constexpr const char kDecodeTable[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' - }; - std::string decoded; - decoded.reserve(bytes.size * 3 / 4); - if (bytes.size == 0) return Bytes(); - TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; - // leverage this property to simplify decoding - static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); - // base64 is always multiple of 4 bytes - for (size_t i = 0; i < bytes.size; i += 4) { - // decode every 4 characters into 24bits, each character contains 6 bits - // note that = is also decoded as 0, which is safe to skip - int32_t buf[4] = { - static_cast(bytes.data[i]), - static_cast(bytes.data[i + 1]), - static_cast(bytes.data[i + 2]), - static_cast(bytes.data[i + 3]), - }; - int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | - (static_cast(kDecodeTable[buf[1]]) << 12) | - (static_cast(kDecodeTable[buf[2]]) << 6) | - static_cast(kDecodeTable[buf[3]]); - // unpack 24bits into 3 bytes, each contains 8 bits - decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); - if (buf[2] != '=') { - decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); - } - if (buf[3] != '=') { - decoded.push_back(static_cast(value_i24 & 0xFF)); - } - } - return Bytes(decoded); -} - -/*! - * \brief Decode a base64 string into a byte array - * \param data The base64 encoded string to decode - * \return The decoded byte array - */ -inline Bytes Base64Decode(const String& data) { - return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); -} - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_BASE64_H_ diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h deleted file mode 100644 index 3c49d79d3071..000000000000 --- a/ffi/include/tvm/ffi/extra/c_env_api.h +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/c_env_api.h - * \brief Extra environment API. - */ -#ifndef TVM_FFI_EXTRA_C_ENV_API_H_ -#define TVM_FFI_EXTRA_C_ENV_API_H_ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// ---------------------------------------------------------------------------- -// Stream context -// Focusing on minimalistic thread-local context recording stream being used. -// We explicitly not handle allocation/de-allocation of stream here. -// ---------------------------------------------------------------------------- -/*! - * \brief The type of the stream handle. - */ -typedef void* TVMFFIStreamHandle; - -/*! - * \brief FFI function to set the current stream for a device - * - * \param device_type The type of the device. - * \param device_id The id of the device. - * \param stream The stream to set. - * \param opt_out_original_stream Output original stream if the address is not nullptr. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream); - -/*! - * \brief FFI function to get the current stream for a device - * - * \param device_type The type of the device. - * \param device_id The id of the device. - * \return The current stream of the device. - */ -TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id); - -/*! - * \brief FFI function to set the current DLPack allocator in thread-local(TLS) context - * - * \param allocator The allocator to set. - * \param write_to_global_context Whether to also set the allocator to the global context. - * \param opt_out_original_allocator Output original TLS allocator if the address is not nullptr. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvSetTensorAllocator(DLPackTensorAllocator allocator, - int write_to_global_context, - DLPackTensorAllocator* opt_out_original_allocator); - -/*! - * \brief FFI function get the current DLPack allocator stored in context. - * - * This function first queries the global context, and if not found, - * queries the thread-local context. - * - * \return The current DLPack allocator. - */ -TVM_FFI_DLL DLPackTensorAllocator TVMFFIEnvGetTensorAllocator(); - -/*! - * \brief Check if there are any signals raised in the surrounding env. - * \return 0 when success, nonzero when failure happens - * \note Under python this function redirects to PyErr_CheckSignals - */ -TVM_FFI_DLL int TVMFFIEnvCheckSignals(); - -/*! - * \brief Register a symbol into the from the surrounding env such as python - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const char* name, void* symbol); - -// ---------------------------------------------------------------------------- -// Module symbol management in callee side -// ---------------------------------------------------------------------------- -/*! - * \brief FFI function to lookup a function from a module's imports. - * - * This is a helper function that is used by generated code. - * - * \param library_ctx The library context module handle. - * \param func_name The name of the function. - * \param out The result function. - * \note The returned function is a weak reference that is cached/owned by the module. - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, - TVMFFIObjectHandle* out); - -/*! - * \brief Register a symbol value that will be initialized when a library with the symbol is loaded. - * - * This function can be used to make context functions to be available in the library - * module that wants to avoid an explicit link dependency - * - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol); - -/*! - * \brief Register a symbol that will be initialized when a system library is loaded. - * - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* symbol); - -#ifdef __cplusplus -} // extern "C" -#endif -#endif // TVM_FFI_EXTRA_C_ENV_API_H_ diff --git a/ffi/include/tvm/ffi/extra/json.h b/ffi/include/tvm/ffi/extra/json.h deleted file mode 100644 index 24ab2f0d8970..000000000000 --- a/ffi/include/tvm/ffi/extra/json.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/json.h - * \brief Minimal lightweight JSON parsing and serialization utilities - */ -#ifndef TVM_FFI_EXTRA_JSON_H_ -#define TVM_FFI_EXTRA_JSON_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace json { - -/*! - * \brief alias Any as json Value. - * - * To keep things lightweight, we simply reuse the ffi::Any system. - */ -using Value = Any; - -/*! - * \brief alias Map as json Object. - * \note We use Map instead of Map to avoid - * the overhead of key checking when doing as conversion, - * the check will be performed at runtime when we read each key - */ -using Object = ffi::Map; - -/*! \brief alias Array as json Array. */ -using Array = ffi::Array; - -/*! - * \brief Parse a JSON string into an Any value. - * - * Besides the standard JSON syntax, this function also supports: - * - Infinity/NaN as JavaScript syntax - * - int64 integer value - * - * If error_msg is not nullptr, the error message will be written to it - * and no exception will be thrown when parsing fails. - * - * \param json_str The JSON string to parse. - * \param error_msg The output error message, can be nullptr. - * - * \return The parsed Any value. - */ -TVM_FFI_EXTRA_CXX_API json::Value Parse(const String& json_str, String* error_msg = nullptr); - -/*! - * \brief Serialize an Any value into a JSON string. - * - * \param value The Any value to serialize. - * \param indent The number of spaces to indent the output. - * If not specified, the output will be compact. - * \return The output JSON string. - */ -TVM_FFI_EXTRA_CXX_API String Stringify(const json::Value& value, - Optional indent = std::nullopt); - -} // namespace json -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_JSON_H_ diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h deleted file mode 100644 index fd6bf199f010..000000000000 --- a/ffi/include/tvm/ffi/extra/module.h +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/module.h - * \brief A managed dynamic module in the TVM FFI. - */ -#ifndef TVM_FFI_EXTRA_MODULE_H_ -#define TVM_FFI_EXTRA_MODULE_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -// forward declare Module -class Module; - -/*! - * \brief A module that can dynamically load ffi::Functions or exportable source code. - * \sa Module - */ -class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { - public: - /*! - * \return The per module type key. - * \note This key is used to for serializing custom modules. - */ - virtual const char* kind() const = 0; - /*! - * \brief Get the property mask of the module. - * \return The property mask of the module. - * - * \sa Module::ModulePropertyMask - */ - virtual int GetPropertyMask() const { return 0b000; } - /*! - * \brief Get a ffi::Function from the module. - * \param name The name of the function. - * \return The function. - */ - virtual Optional GetFunction(const String& name) = 0; - /*! - * \brief Returns true if this module has a definition for a function of \p name. - * - * Note that even if this function returns true the corresponding \p GetFunction result - * may be nullptr if the function is not yet callable without further compilation. - * - * The default implementation just checks if \p GetFunction is non-null. - * \param name The name of the function. - * \return True if the module implements the function, false otherwise. - */ - virtual bool ImplementsFunction(const String& name) { return GetFunction(name).defined(); } - /*! - * \brief Get the metadata of the function, if available. - * \param name The name of the function. - * \return The metadata stored in json string format. - */ - virtual Optional GetFunctionMetadata(const String& name) { return std::nullopt; } - /*! - * \brief Write the current module to file with given format (for further compilation). - * - * \param file_name The file to be saved to. - * \param format The format of the file. - * - * \note This function is mainly used by modules that - */ - virtual void WriteToFile(const String& file_name, const String& format) const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; - } - /*! - * \brief Get the possible write formats of the module, when available. - * \return Possible write formats when available. - */ - virtual Array GetWriteFormats() const { return Array(); } - /*! - * \brief Serialize the the module to bytes. - * \return The serialized module. - */ - virtual Bytes SaveToBytes() const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; - TVM_FFI_UNREACHABLE(); - } - /*! - * \brief Get the source code of module, when available. - * \param format Format of the source code, can be empty by default. - * \return Possible source code when available, or empty string if not available. - */ - virtual String InspectSource(const String& format = "") const { return String(); } - /*! - * \brief Import another module. - * \param other The module to import. - */ - virtual void ImportModule(const Module& other); - /*! - * \brief Clear all imported modules. - */ - virtual void ClearImports(); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return The function. - */ - Optional GetFunction(const String& name, bool query_imports); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return True if the module implements the function, false otherwise. - */ - bool ImplementsFunction(const String& name, bool query_imports); - /*! - * \brief Get the function metadata of the function if available. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return The function metadata of the function in json format. - */ - Optional GetFunctionMetadata(const String& name, bool query_imports); - /*! - * \brief Get the imports of the module. - * \return The imports of the module. - * \note Note the signature is not part of the public API. - */ - const Array& imports() const { return this->imports_; } - - struct InternalUnsafe; - - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; - static constexpr const bool _type_mutable = true; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIModule, ModuleObj, Object); - /// \endcond - - protected: - friend struct InternalUnsafe; - - /*! - * \brief The modules that this module depends on. - * \note Use ObjectRef to avoid circular dep on Module. - */ - Array imports_; - - private: - /*! - * \brief cache used by TVMFFIModuleLookupFromImports - */ - Map import_lookup_cache_; -}; - -/*! - * \brief Reference to module object. - * - * When invoking a function on a ModuleObj, such as GetFunction, - * use operator-> to get the ModuleObj pointer and invoke the member functions. - * - * \code - * ffi::Module mod = ffi::Module::LoadFromFile("path/to/module.so"); - * ffi::Function func = mod->GetFunction(name); - * \endcode - * - * \sa ModuleObj which contains most of the function implementations. - */ -class Module : public ObjectRef { - public: - /*! - * \brief Property of ffi::Module - */ - enum ModulePropertyMask : int { - /*! - * \brief The module can be serialized to bytes. - * - * This prooperty indicates that module implements SaveToBytes. - * The system also registers a GlobalDef function - * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. - */ - kBinarySerializable = 0b001, - /*! - * \brief The module can directly get runnable functions. - * - * This property indicates that module implements GetFunction that returns - * runnable ffi::Functions. - */ - kRunnable = 0b010, - /*! - * \brief The module can be exported to a object file or source file that then be compiled. - * - * This property indicates that module implements WriteToFile with a given format - * that can be queried by GetLibExportFormat. - * - * Examples include modules that can be exported to .o, .cc, .cu files. - * - * Such modules can be exported, compiled and loaded back as a dynamic library module. - */ - kCompilationExportable = 0b100 - }; - /*! - * \brief Constructor from ObjectPtr. - * \param ptr The object pointer. - */ - explicit Module(ObjectPtr ptr) : ObjectRef(ptr) { TVM_FFI_ICHECK(ptr != nullptr); } - /*! - * \brief Load a module from file. - * \param file_name The name of the host function module. - * \note This function won't load the import relationship. - * Re-create import relationship by calling Import. - */ - TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String& file_name); - /*! - * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. - * \param callback The callback to be called with the symbol name and address. - * \note This helper can be used to implement custom Module that needs to access context symbols. - */ - TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( - const ffi::TypedFunction& callback); - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Module, ObjectRef, ModuleObj); - /// \endcond -}; - -/* - * \brief Symbols for library module. - */ -namespace symbol { -/*!\ brief symbol prefix for tvm ffi related function symbols */ -constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_"; -// Special symbols have one extra _ prefix to avoid conflict with user symbols -/*! - * \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main" - */ -constexpr const char* tvm_ffi_main = "__tvm_ffi_main"; -/*! \brief Global variable to store context pointer for a library module. */ -constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; -/*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin"; -/*! \brief Optional metadata prefix of a symbol. */ -constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; -} // namespace symbol -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_MODULE_H_ diff --git a/ffi/include/tvm/ffi/extra/serialization.h b/ffi/include/tvm/ffi/extra/serialization.h deleted file mode 100644 index b5aa2891ac40..000000000000 --- a/ffi/include/tvm/ffi/extra/serialization.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/serialization.h - * \brief Reflection-based serialization utilities - */ -#ifndef TVM_FFI_EXTRA_SERIALIZATION_H_ -#define TVM_FFI_EXTRA_SERIALIZATION_H_ - -#include -#include - -namespace tvm { -namespace ffi { - -/** - * \brief Serialize ffi::Any to a JSON that stores the object graph. - * - * The JSON graph structure is stored as follows: - * - * ``` - * { - * "root_index": , // Index of root node in nodes array - * "nodes": [, ...], // Array of serialized nodes - * "metadata": // Optional metadata - * } - * ``` - * - * Each node has the format: `{"type": "", "data": }` - * For object types and strings, the data may contain indices to other nodes. - * For object fields whose static type is known as a primitive type, it is stored directly, - * otherwise, it is stored as a reference to the nodes array by an index. - * - * This function preserves the type and multiple references to the same object, - * which is useful for debugging and serialization. - * - * \param value The ffi::Any value to serialize. - * \param metadata Extra metadata attached to "metadata" field of the JSON object. - * \return The serialized JSON value. - */ -TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any& metadata = Any(nullptr)); - -/** - * \brief Deserialize a JSON that stores the object graph to an ffi::Any value. - * - * This function can be used to implement deserialization - * and debugging. - * - * \param value The JSON value to deserialize. - * \return The deserialized object graph. - */ -TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value& value); - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ diff --git a/ffi/include/tvm/ffi/extra/structural_equal.h b/ffi/include/tvm/ffi/extra/structural_equal.h deleted file mode 100644 index ec960a85e611..000000000000 --- a/ffi/include/tvm/ffi/extra/structural_equal.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/structural_equal.h - * \brief Structural equal implementation - */ -#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ -#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Structural equality comparators - */ -class StructuralEqual { - public: - /** - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_EXTRA_CXX_API static bool Equal(const Any& lhs, const Any& rhs, - bool map_free_vars = false, - bool skip_tensor_content = false); - /** - * \brief Get the first mismatch AccessPath pair when running - * structural equal comparison between two Any values. - * - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparing tensor data content, - * useful for cases where we don't care about parameters content - * \return If comparison fails, return the first mismatch AccessPath pair, - * otherwise return std::nullopt. - */ - TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( - const Any& lhs, const Any& rhs, bool map_free_vars = false, bool skip_tensor_content = false); - - /* - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_INLINE bool operator()(const Any& lhs, const Any& rhs) const { - return Equal(lhs, rhs, false, true); - } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ diff --git a/ffi/include/tvm/ffi/extra/structural_hash.h b/ffi/include/tvm/ffi/extra/structural_hash.h deleted file mode 100644 index bfe023c382a7..000000000000 --- a/ffi/include/tvm/ffi/extra/structural_hash.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/structural_hash.h - * \brief Structural hash - */ -#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ -#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ - -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Structural hash - */ -class StructuralHash { - public: - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content. - * \return The hash value. - */ - TVM_FFI_EXTRA_CXX_API static uint64_t Hash(const Any& value, bool map_free_vars = false, - bool skip_tensor_content = false); - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \return The hash value. - */ - TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h deleted file mode 100644 index 0706fdc0eccc..000000000000 --- a/ffi/include/tvm/ffi/function.h +++ /dev/null @@ -1,880 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/function.h - * \brief A managed function in the TVM FFI. - */ -#ifndef TVM_FFI_FUNCTION_H_ -#define TVM_FFI_FUNCTION_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/** - * Helper macro to construct a safe call - * - * \brief Marks the beginning of the safe call that catches exception explicitly - * \sa TVM_FFI_SAFE_CALL_END - * - * \code - * int TVMFFICStyleFunction() { - * TVM_FFI_SAFE_CALL_BEGIN(); - * // c++ code region here - * TVM_FFI_SAFE_CALL_END(); - * } - * \endcode - */ -#define TVM_FFI_SAFE_CALL_BEGIN() \ - try { \ - (void)0 - -/*! - * \brief Marks the end of safe call. - */ -#define TVM_FFI_SAFE_CALL_END() \ - return 0; \ - } \ - catch (const ::tvm::ffi::Error& err) { \ - ::tvm::ffi::details::SetSafeCallRaised(err); \ - return -1; \ - } \ - catch (const ::tvm::ffi::EnvErrorAlreadySet&) { \ - return -2; \ - } \ - catch (const std::exception& ex) { \ - ::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \ - return -1; \ - } \ - TVM_FFI_UNREACHABLE() - -/*! - * \brief Macro to check a call to TVMFFISafeCallType and raise exception if error happens. - * \param func The function to check. - * - * \code - * // calls TVMFFIFunctionCall and raises exception if error happens - * TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); - * \endcode - */ -#define TVM_FFI_CHECK_SAFE_CALL(func) \ - { \ - int ret_code = (func); \ - if (ret_code != 0) { \ - if (ret_code == -2) { \ - throw ::tvm::ffi::EnvErrorAlreadySet(); \ - } \ - throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \ - } \ - } - -/*! - * \brief Object container class that backs ffi::Function - * \note Do not use this class directly, use ffi::Function - */ -class FunctionObj : public Object, public TVMFFIFunctionCell { - public: - /*! \brief Typedef for C++ style calling signature that comes with exception propagation */ - typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*); - using TVMFFIFunctionCell::safe_call; - /*! \brief A C++ style call implementation, with exception propagation in C++ style. */ - FCall call; - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param num_args The number of arguments - * \param result The return value. - */ - TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { - this->call(this, args, num_args, result); - } - /// \cond Doxygen_Suppress - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIFunction, FunctionObj, Object); - /// \endcond - - protected: - /*! \brief Make default constructor protected. */ - FunctionObj() {} - /// \cond Doxygen_Suppress - // Implementing safe call style - static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin); - FunctionObj* self = static_cast(func); - self->call(self, reinterpret_cast(args), num_args, - reinterpret_cast(result)); - TVM_FFI_SAFE_CALL_END(); - } - /// \endcond - friend class Function; -}; - -namespace details { -/*! - * \brief Derived object class for constructing FunctionObj backed by a TCallable - * - * This is a helper class that implements the function call interface. - */ -template -class FunctionObjImpl : public FunctionObj { - public: - using TStorage = typename std::remove_cv::type>::type; - /*! \brief The type of derived object class */ - using TSelf = FunctionObjImpl; - /*! - * \brief Derived object class for constructing ffi::FunctionObj. - * \param callable The type-erased callable object. - */ - explicit FunctionObjImpl(TCallable callable) : callable_(callable) { - this->safe_call = SafeCall; - this->call = Call; - } - - private: - // implementation of call - static void Call(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) { - (static_cast(func))->callable_(args, num_args, result); - } - - /*! \brief Type-erased filed for storing callable object*/ - mutable TStorage callable_; -}; - -/*! - * \brief Base class to provide a common implementation to redirect call to safecall - * \tparam Derived The derived class in CRTP-idiom - */ -template -struct RedirectCallToSafeCall { - static void Call(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* rv) { - Derived* self = static_cast(const_cast(func)); - TVM_FFI_CHECK_SAFE_CALL(self->RedirectSafeCall(reinterpret_cast(args), - num_args, reinterpret_cast(rv))); - } - - static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* rv) { - Derived* self = reinterpret_cast(func); - return self->RedirectSafeCall(args, num_args, rv); - } -}; - -/*! - * \brief FunctionObj specialization that leverages C-style callback definitions. - */ -class ExternCFunctionObjImpl : public FunctionObj, - public RedirectCallToSafeCall { - public: - using RedirectCallToSafeCall::SafeCall; - - ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) - : self_(self), safe_call_(safe_call), deleter_(deleter) { - this->call = RedirectCallToSafeCall::Call; - this->safe_call = RedirectCallToSafeCall::SafeCall; - } - - ~ExternCFunctionObjImpl() { deleter_(self_); } - - TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* rv) const { - return safe_call_(self_, args, num_args, rv); - } - - private: - void* self_; - TVMFFISafeCallType safe_call_; - void (*deleter_)(void* self); -}; - -/*! - * \brief FunctionObj specialization that wraps an external function. - */ -class ImportedFunctionObjImpl : public FunctionObj, - public RedirectCallToSafeCall { - public: - using RedirectCallToSafeCall::SafeCall; - - explicit ImportedFunctionObjImpl(ObjectPtr data) : data_(data) { - this->call = RedirectCallToSafeCall::Call; - this->safe_call = RedirectCallToSafeCall::SafeCall; - } - - TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* rv) const { - FunctionObj* func = const_cast(static_cast(data_.get())); - return func->safe_call(func, args, num_args, rv); - } - - private: - ObjectPtr data_; -}; - -// Helper class to set packed arguments -class PackedArgsSetter { - public: - explicit PackedArgsSetter(AnyView* args) : args_(args) {} - - // NOTE: setter needs to be very carefully designed - // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) - // that is why we need T&& and std::forward here - template - TVM_FFI_INLINE void operator()(size_t i, T&& value) const { - args_[i].operator=(std::forward(value)); - } - - private: - AnyView* args_; -}; -} // namespace details - -/*! - * \brief Represents arguments packed in AnyView array - * \note This class represent packed arguments to ffi::Function - */ -class PackedArgs { - public: - /*! - * \brief Constructor - * \param data The arguments - * \param size The number of arguments - */ - PackedArgs(const AnyView* data, int32_t size) : data_(data), size_(size) {} - - /*! \return size of the arguments */ - int size() const { return size_; } - - /*! \return The arguments */ - const AnyView* data() const { return data_; } - - /*! - * \brief Slice the arguments - * \param begin The begin index - * \param end The end index - * \return The sliced arguments - */ - PackedArgs Slice(int begin, int end = -1) const { - if (end == -1) { - end = size_; - } - return PackedArgs(data_ + begin, end - begin); - } - - /*! - * \brief Get i-th argument - * \param i the index. - * \return the ith argument. - */ - AnyView operator[](int i) const { return data_[i]; } - - /*! - * \brief Fill the arguments into the AnyView array - * \param data The AnyView array to store the packed arguments - * \param args The arguments to be packed - * \note Caller must ensure all args are alive during lifetime of data. - * A common pitfall is to pass in local variables that are immediately - * destroyed after calling Fill. - */ - template - TVM_FFI_INLINE static void Fill(AnyView* data, Args&&... args) { - details::for_each(details::PackedArgsSetter(data), std::forward(args)...); - } - - private: - /*! \brief The arguments */ - const AnyView* data_; - /*! \brief The number of arguments */ - int32_t size_; -}; - -/*! - * \brief ffi::Function is a type-erased function. - * The arguments are passed by "packed format" via AnyView - */ -class Function : public ObjectRef { - public: - /*! \brief Constructor from null */ - Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - * \note legacy purpose, should change to Function::FromPacked for mostfuture use. - */ - template - explicit Function(TCallable packed_call) { - *this = FromPacked(packed_call); - } - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - */ - template - static Function FromPacked(TCallable packed_call) { - static_assert( - std::is_convertible_v> || - std::is_convertible_v>, - "tvm::ffi::Function::FromPacked requires input function signature to match packed func " - "format"); - if constexpr (std::is_convertible_v>) { - auto wrapped_call = [packed_call](const AnyView* args, int32_t num_args, - Any* rv) mutable -> void { - PackedArgs args_pack(args, num_args); - packed_call(args_pack, rv); - }; - return FromPackedInternal(wrapped_call); - } else { - return FromPackedInternal(packed_call); - } - } - /*! - * \brief Import a possibly externally defined function to this dll - * \param other Function defined in another dynamic library. - * - * \note This function will redirect the call to safe_call in other. - * It will try to detect if the function is already from the same DLL - * and directly return the original function if so. - * - * \return The imported function. - */ - static Function ImportFromExternDLL(Function other) { - const FunctionObj* other_func = static_cast(other.get()); - // the other function comes from the same dll, no action needed - if (other_func->safe_call == &(FunctionObj::SafeCall) || - other_func->safe_call == &(details::ImportedFunctionObjImpl::SafeCall) || - other_func->safe_call == &(details::ExternCFunctionObjImpl::SafeCall)) { - return other; - } - // the other function coems from a different library - Function func; - func.data_ = make_object(std::move(other.data_)); - return func; - } - /*! - * \brief Create ffi::Function from a C style callbacks. - * \param self Resource handle to the function - * \param safe_call The safe_call definition in C. - * \param deleter The deleter to release the resource of self. - * \return The created function. - */ - static Function FromExternC(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void* self)) { - // the other function coems from a different library - Function func; - func.data_ = make_object(self, safe_call, deleter); - return func; - } - /*! - * \brief Get global function by name - * \param name The function name - * \return The global function. - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(std::string_view name) { - TVMFFIObjectHandle handle; - TVMFFIByteArray name_arr{name.data(), name.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); - if (handle != nullptr) { - return Function( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); - } else { - return std::nullopt; - } - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(const std::string& name) { - return GetGlobal(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(const String& name) { - return GetGlobal(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(const char* name) { - return GetGlobal(std::string_view(name)); - } - /*! - * \brief Get global function by name and throw an error if it is not found. - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(std::string_view name) { - std::optional res = GetGlobal(name); - if (!res.has_value()) { - TVM_FFI_THROW(ValueError) << "Function " << name << " not found"; - } - return *res; - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(const std::string& name) { - return GetGlobalRequired(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(const String& name) { - return GetGlobalRequired(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(const char* name) { - return GetGlobalRequired(std::string_view(name)); - } - /*! - * \brief Set global function by name - * \param name The name of the function - * \param func The function - * \param override Whether to override when there is duplication. - */ - static void SetGlobal(std::string_view name, Function func, bool override = false) { - TVMFFIByteArray name_arr{name.data(), name.size()}; - TVM_FFI_CHECK_SAFE_CALL( - TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override)); - } - /*! - * \brief List all global names - * \return A vector of all global names - * \note This function do not depend on Array so core do not have container dep. - */ - static std::vector ListGlobalNames() { - Function fname_functor = - GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast(); - std::vector names; - int len = fname_functor(-1).cast(); - for (int i = 0; i < len; ++i) { - names.push_back(fname_functor(i).cast()); - } - return names; - } - /** - * \brief Remove a global function by name - * \param name The name of the function - */ - static void RemoveGlobal(const String& name) { - static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal"); - fremove(name); - } - /*! - * \brief Constructing a packed function from a normal function. - * - * \param callable the internal container of packed function. - */ - template - static Function FromTyped(TCallable callable) { - using FuncInfo = details::FunctionInfo; - auto call_packed = [callable](const AnyView* args, int32_t num_args, Any* rv) mutable -> void { - details::unpack_call( - std::make_index_sequence{}, nullptr, callable, args, num_args, rv); - }; - return FromPackedInternal(call_packed); - } - /*! - * \brief Constructing a packed function from a normal function. - * - * \param callable the internal container of packed function. - * \param name optional name attacked to the function. - */ - template - static Function FromTyped(TCallable callable, std::string name) { - using FuncInfo = details::FunctionInfo; - auto call_packed = [callable, name](const AnyView* args, int32_t num_args, - Any* rv) mutable -> void { - details::unpack_call( - std::make_index_sequence{}, &name, callable, args, num_args, rv); - }; - return FromPackedInternal(call_packed); - } - /*! - * \brief Call function by directly passing in unpacked arguments. - * - * \param args Arguments to be passed. - * \tparam Args arguments to be passed. - * - * \code - * // Example code on how to call packed function - * void CallFFIFunction(tvm::ffi::Function f) { - * // call like normal functions by pass in arguments - * // return value is automatically converted back - * int rvalue = f(1, 2.0); - * } - * \endcode - */ - template - TVM_FFI_INLINE Any operator()(Args&&... args) const { - const int kNumArgs = sizeof...(Args); - const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; - AnyView args_pack[kArraySize]; - PackedArgs::Fill(args_pack, std::forward(args)...); - Any result; - static_cast(data_.get())->CallPacked(args_pack, kNumArgs, &result); - return result; - } - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param num_args The number of arguments - * \param result The return value. - */ - TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { - static_cast(data_.get())->CallPacked(args, num_args, result); - } - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param result The return value. - */ - TVM_FFI_INLINE void CallPacked(PackedArgs args, Any* result) const { - static_cast(data_.get())->CallPacked(args.data(), args.size(), result); - } - - /*! \return Whether the packed function is nullptr */ - TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, ObjectRef, FunctionObj); - /// \endcond - - class Registry; - - private: - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - */ - template - static Function FromPackedInternal(TCallable packed_call) { - using ObjType = typename details::FunctionObjImpl; - Function func; - func.data_ = make_object(std::forward(packed_call)); - return func; - } -}; - -/*! - * \brief Please refer to \ref TypedFunctionAnchor "TypedFunction" - */ -template -class TypedFunction; - -/*! - * \anchor TypedFunctionAnchor - * \brief A ffi::Function wrapper to provide typed function signature. - * It is backed by a ffi::Function internally. - * - * TypedFunction enables compile time type checking. - * TypedFunction works with the runtime system: - * - It can be passed as an argument of ffi::Function. - * - It can be assigned to ffi::Any. - * - It can be directly converted to a type-erased ffi::Function. - * - * Developers should prefer TypedFunction over ffi::Function in C++ code - * as it enables compile time checking. - * We can construct a TypedFunction from a lambda function - * with the same signature. - * - * \code - * // user defined lambda function. - * auto addone = [](int x)->int { - * return x + 1; - * }; - * // We can directly convert - * // lambda function to TypedFunction - * TypedFunction ftyped(addone); - * // invoke the function. - * int y = ftyped(1); - * // Can be directly converted to ffi::Function - * ffi::Function packed = ftype; - * \endcode - * \tparam R The return value of the function. - * \tparam Args The argument signature of the function. - */ -template -class TypedFunction { - public: - /*! \brief short hand for this function type */ - using TSelf = TypedFunction; - /*! \brief default constructor */ - TypedFunction() {} - /*! \brief constructor from null */ - TypedFunction(std::nullptr_t null) {} // NOLINT(*) - /*! - * \brief constructor from a function - * \param packed The function - */ - TypedFunction(Function packed) : packed_(packed) {} // NOLINT(*) - /*! - * \brief construct from a lambda function with the same signature. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedFunction ftyped(typed_lambda, "add_one"); - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \param name the name of the lambda function. - * \tparam FLambda the type of the lambda function. - */ - template >::value>::type> - TypedFunction(FLambda typed_lambda, std::string name) { // NOLINT(*) - packed_ = Function::FromTyped(typed_lambda, name); - } - /*! - * \brief construct from a lambda function with the same signature. - * - * This version does not take a name. It is highly recommend you use the - * version that takes a name for the lambda. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedFunction ftyped(typed_lambda); - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - */ - template >::value>::type> - TypedFunction(const FLambda& typed_lambda) { // NOLINT(*) - packed_ = Function::FromTyped(typed_lambda); - } - /*! - * \brief copy assignment operator from typed lambda - * - * Example usage: - * \code - * // construct from packed function - * TypedFunction ftyped; - * ftyped = [](int x) { return x + 1; } - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - * \returns reference to self. - */ - template >::value>::type> - TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) - packed_ = Function::FromTyped(typed_lambda); - return *this; - } - /*! - * \brief copy assignment operator from ffi::Function. - * \param packed The packed function. - * \returns reference to self. - */ - TSelf& operator=(Function packed) { - packed_ = std::move(packed); - return *this; - } - /*! - * \brief Invoke the operator. - * \param args The arguments - * \returns The return value. - */ - TVM_FFI_INLINE R operator()(Args... args) const { - if constexpr (std::is_same_v) { - packed_(std::forward(args)...); - } else { - Any res = packed_(std::forward(args)...); - if constexpr (std::is_same_v) { - return res; - } else { - return std::move(res).cast(); - } - } - } - /*! - * \brief convert to ffi::Function - * \return the internal ffi::Function - */ - operator Function() const { return packed(); } - /*! - * \return reference the internal ffi::Function - */ - const Function& packed() const& { return packed_; } - /*! - * \return r-value reference the internal ffi::Function - */ - constexpr Function&& packed() && { return std::move(packed_); } - /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } - - private: - /*! \brief The internal packed function */ - Function packed_; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction; - - TVM_FFI_INLINE static void CopyToAnyView(const TypedFunction& src, TVMFFIAny* result) { - TypeTraits::CopyToAnyView(src.packed(), result); - } - - TVM_FFI_INLINE static void MoveToAny(TypedFunction src, TVMFFIAny* result) { - TypeTraits::MoveToAny(std::move(src.packed()), result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIFunction; - } - - TVM_FFI_INLINE static TypedFunction CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return TypedFunction(TypeTraits::CopyFromAnyViewAfterCheck(src)); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView( - const TVMFFIAny* src) { - std::optional opt = TypeTraits::TryCastFromAnyView(src); - if (opt.has_value()) { - return TypedFunction(*std::move(opt)); - } else { - return std::nullopt; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo::Sig(); } -}; - -/*! - * \brief helper function to get type index from key - */ -inline int32_t TypeKeyToIndex(std::string_view type_key) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - return type_index; -} - -/*! - * \brief Export typed function as a SafeCallType symbol. - * - * \param ExportName The symbol name to be exported. - * \param Function The typed function. - * \note ExportName and Function must be different, - * see code examples below. - * - * \sa ffi::TypedFunction - * - * \code - * - * int AddOne_(int x) { - * return x + 1; - * } - * - * // Expose the function as "AddOne" - * TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); - * - * // Expose the function as "SubOne" - * TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { - * return x - 1; - * }); - * - * // The following code will cause compilation error. - * // Because the same Function and ExportName - * // TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); - * - * // The following code is OK, assuming the macro - * // is in a different namespace from xyz - * // TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); - * - * \endcode - */ -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, TVMFFIAny* args, int32_t num_args, \ - TVMFFIAny* result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ - static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call( \ - std::make_index_sequence{}, &name, Function, \ - reinterpret_cast(args), num_args, \ - reinterpret_cast<::tvm::ffi::Any*>(result)); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ - } -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_FUNCTION_H_ diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h deleted file mode 100644 index 20ca44cbcb72..000000000000 --- a/ffi/include/tvm/ffi/function_details.h +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/function_details.h - * \brief Implements the funciton signature reflection - */ -#ifndef TVM_FFI_FUNCTION_DETAILS_H_ -#define TVM_FFI_FUNCTION_DETAILS_H_ - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { - -template -struct Arg2Str { - template - TVM_FFI_INLINE static void Apply(std::ostream& os) { - using Arg = std::tuple_element_t; - if constexpr (i != 0) { - os << ", "; - } - os << i << ": " << Type2Str::v(); - } - template - TVM_FFI_INLINE static void Run(std::ostream& os, std::index_sequence) { - using TExpander = int[]; - (void)TExpander{0, (Apply(os), 0)...}; - } -}; - -template -static constexpr bool ArgSupported = - (std::is_same_v>, Any> || - std::is_same_v>, AnyView> || - TypeTraitsNoCR::convert_enabled); - -// NOTE: return type can only support non-reference managed returns -template -static constexpr bool RetSupported = - (std::is_same_v || std::is_void_v || TypeTraits::convert_enabled); - -template -struct FuncFunctorImpl { - using FType = R(Args...); - using ArgType = std::tuple; - using RetType = R; - /*! \brief total number of arguments*/ - static constexpr size_t num_args = sizeof...(Args); - // MSVC is not that friendly to in-template nested bool evaluation -#ifndef _MSC_VER - /*! \brief Whether this function can be converted to ffi::Function via FromTyped */ - static constexpr bool unpacked_supported = (ArgSupported && ...) && (RetSupported); -#endif - - TVM_FFI_INLINE static std::string Sig() { - using IdxSeq = std::make_index_sequence; - std::ostringstream ss; - ss << "("; - Arg2Str>::Run(ss, IdxSeq{}); - ss << ") -> " << Type2Str::v(); - return ss.str(); - } -}; - -template -struct FunctionInfoHelper; - -template -struct FunctionInfoHelper : FuncFunctorImpl {}; -template -struct FunctionInfoHelper : FuncFunctorImpl {}; - -/*! - * \brief Template class to get function signature of a function or functor. - * \tparam T The function/functor type. - * \note We need a decltype redirection because this helps lambda types. - */ -template -struct FunctionInfo : FunctionInfoHelper {}; - -template -struct FunctionInfo : FuncFunctorImpl {}; -template -struct FunctionInfo : FuncFunctorImpl {}; - -/*! \brief Using static function to output typed function signature */ -typedef std::string (*FGetFuncSignature)(); - -/*! - * \brief Auxilary argument value with context for error reporting - */ -class ArgValueWithContext { - public: - /*! - * \brief move constructor from another return value. - * \param args The argument list - * \param arg_index In a function call, this argument is at index arg_index (0-indexed). - * \param optional_name Name of the function being called. Can be nullptr if the function is not. - * \param f_sig Pointer to static function outputting signature of the function being called. - * named. - */ - TVM_FFI_INLINE ArgValueWithContext(const AnyView* args, int32_t arg_index, - const std::string* optional_name, FGetFuncSignature f_sig) - : args_(args), arg_index_(arg_index), optional_name_(optional_name), f_sig_(f_sig) {} - - template - TVM_FFI_INLINE operator Type() { - using TypeWithoutCR = std::remove_const_t>; - - if constexpr (std::is_same_v) { - return args_[arg_index_]; - } else if constexpr (std::is_same_v) { - return Any(args_[arg_index_]); - } else { - std::optional opt = args_[arg_index_].try_cast(); - if (!opt.has_value()) { - TVMFFIAny any_data = args_[arg_index_].CopyToTVMFFIAny(); - TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" << arg_index_ - << " when calling: `" - << (optional_name_ == nullptr ? "" : *optional_name_) - << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" - << Type2Str::v() << "` but got `" - << TypeTraits::GetMismatchTypeInfo(&any_data) - << '`'; - } - return *std::move(opt); - } - } - - private: - const AnyView* args_; - int32_t arg_index_; - const std::string* optional_name_; - FGetFuncSignature f_sig_; -}; - -template -TVM_FFI_INLINE void unpack_call(std::index_sequence, const std::string* optional_name, - const F& f, [[maybe_unused]] const AnyView* args, - [[maybe_unused]] int32_t num_args, [[maybe_unused]] Any* rv) { - using FuncInfo = FunctionInfo; - FGetFuncSignature f_sig = FuncInfo::Sig; - - // somehow MSVC does not support the static constexpr member in this case, function is fine -#ifndef _MSC_VER - static_assert(FuncInfo::unpacked_supported, "The function signature do not support unpacked"); -#endif - constexpr size_t nargs = sizeof...(Is); - if (nargs != num_args) { - TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: `" - << (optional_name == nullptr ? "" : *optional_name) - << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << nargs - << " but got " << num_args << " arguments"; - } - // use index sequence to do recursive-less unpacking - if constexpr (std::is_same_v) { - f(ArgValueWithContext(args, Is, optional_name, f_sig)...); - } else { - *rv = R(f(ArgValueWithContext(args, Is, optional_name, f_sig)...)); - } -} - -/*! - * \brief Move the safe call raised error to the caller - * \return The error - */ -TVM_FFI_INLINE static Error MoveFromSafeCallRaised() { - TVMFFIObjectHandle handle; - TVMFFIErrorMoveFromRaised(&handle); - // handle is owned by caller - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); -} - -/*! - * \brief Set the safe call raised error - * \param error The error - */ -TVM_FFI_INLINE static void SetSafeCallRaised(const Error& error) { - TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error)); -} -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_FUNCTION_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h deleted file mode 100644 index 1fa9d6539079..000000000000 --- a/ffi/include/tvm/ffi/memory.h +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/memory.h - * \brief Runtime memory management to allocate on heap object. - */ -#ifndef TVM_FFI_MEMORY_H_ -#define TVM_FFI_MEMORY_H_ - -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! \brief Deleter function for obeject */ -typedef void (*FObjectDeleter)(void* obj, int flags); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter = nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. -namespace details { -/*! - * \brief Base class of object allocators that implements make. - * Use curiously recurring template pattern. - * - * \tparam Derived The derived class. - */ -template -class ObjAllocatorBase { - public: - /*! - * \brief Make a new object using the allocator. - * \tparam T The type to be allocated. - * \tparam Args The constructor signature. - * \param args The arguments. - */ - template - ObjectPtr make_object(Args&&... args) { - using Handler = typename Derived::template Handler; - static_assert(std::is_base_of::value, "make can only be used to create Object"); - T* ptr = Handler::New(static_cast(this), std::forward(args)...); - TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->strong_ref_count = 1; - ffi_ptr->weak_ref_count = 1; - ffi_ptr->type_index = T::RuntimeTypeIndex(); - ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); - } - - /*! - * \tparam ArrayType The type to be allocated. - * \tparam ElemType The type of array element. - * \tparam Args The constructor signature. - * \param num_elems The number of array elements. - * \param args The arguments. - */ - template - ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { - using Handler = typename Derived::template ArrayHandler; - static_assert(std::is_base_of::value, - "make_inplace_array can only be used to create Object"); - ArrayType* ptr = - Handler::New(static_cast(this), num_elems, std::forward(args)...); - TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->strong_ref_count = 1; - ffi_ptr->weak_ref_count = 1; - ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); - ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); - } -}; - -// Simple allocator that uses new/delete. -class SimpleObjAllocator : public ObjAllocatorBase { - public: - template - class Handler { - public: - struct alignas(T) StorageType { - char data[sizeof(T)]; - }; - - template - static T* New(SimpleObjAllocator*, Args&&... args) { - // NOTE: the first argument is not needed for SimpleObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - StorageType* data = new StorageType(); - new (data) T(std::forward(args)...); - return reinterpret_cast(data); - } - - static FObjectDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(void* objptr, int flags) { - T* tptr = - details::ObjectUnsafe::RawObjectPtrFromUnowned(static_cast(objptr)); - if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { - // It is important to do tptr->T::~T(), - // so that we explicitly call the specific destructor - // instead of tptr->~T(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->T::~T(); - } - if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { - delete reinterpret_cast(tptr); - } - } - }; - - // Array handler that uses new/delete. - template - class ArrayHandler { - public: - using StorageType = typename std::aligned_storage::type; - // for now only support elements that aligns with array header. - static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, - "element alignment constraint"); - - template - static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { - // NOTE: the first argument is not needed for ArrayObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - size_t unit = sizeof(StorageType); - size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType); - size_t num_storage_slots = (requested_size + unit - 1) / unit; - StorageType* data = new StorageType[num_storage_slots]; - new (data) ArrayType(std::forward(args)...); - return reinterpret_cast(data); - } - - static FObjectDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(void* objptr, int flags) { - ArrayType* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned( - static_cast(objptr)); - if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { - // It is important to do tptr->ArrayType::~ArrayType(), - // so that we explicitly call the specific destructor - // instead of tptr->~ArrayType(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->ArrayType::~ArrayType(); - } - if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { - StorageType* p = reinterpret_cast(tptr); - delete[] p; - } - } - }; -}; -} // namespace details - -/*! - * \brief Allocate an object - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The ObjectPtr to the allocated object. - */ -template -inline ObjectPtr make_object(Args&&... args) { - return details::SimpleObjAllocator().make_object(std::forward(args)...); -} - -/*! - * \brief Allocate an Object with additional ElemType[num_elems] that are stored right after. - * \param num_elems The number of elements in the array. - * \param args arguments to the constructor. - * \tparam ArrayType the array type. - * \tparam ElemType the element type. - * \return The ObjectPtr to the allocated array. - */ -template -inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { - return details::SimpleObjAllocator().make_inplace_array( - num_elems, std::forward(args)...); -} - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_MEMORY_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h deleted file mode 100644 index 6dcc30e808da..000000000000 --- a/ffi/include/tvm/ffi/object.h +++ /dev/null @@ -1,1142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/object.h - * \brief A managed object in the TVM FFI. - */ -#ifndef TVM_FFI_OBJECT_H_ -#define TVM_FFI_OBJECT_H_ - -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief TypeIndex enum, alias of TVMFFITypeIndex. - */ -using TypeIndex = TVMFFITypeIndex; - -/*! - * \brief TypeInfo, alias of TVMFFITypeInfo. - */ -using TypeInfo = TVMFFITypeInfo; - -/*! - * \brief Helper tag to explicitly request unsafe initialization. - * - * Constructing an ObjectRefType with UnsafeInit{} will set the data_ member to nullptr. - * - * When initializing Object fields, ObjectRef fields can be set to UnsafeInit. - * This enables the "construct with UnsafeInit then set all fields" pattern - * when the object does not have a default constructor. - * - * Used for initialization in controlled scenarios where such unsafe - * initialization is known to be safe. - * - * Each ObjectRefType should have a constructor that takes an UnsafeInit tag. - * - * \note As the name suggests, do not use it in normal code paths. - */ -struct UnsafeInit {}; - -/*! - * \brief Known type keys for pre-defined types. - */ -struct StaticTypeKey { - /*! \brief The type key for Any */ - static constexpr const char* kTVMFFIAny = "Any"; - /*! \brief The type key for None */ - static constexpr const char* kTVMFFINone = "None"; - /*! \brief The type key for bool */ - static constexpr const char* kTVMFFIBool = "bool"; - /*! \brief The type key for int */ - static constexpr const char* kTVMFFIInt = "int"; - /*! \brief The type key for float */ - static constexpr const char* kTVMFFIFloat = "float"; - /*! \brief The type key for void* */ - static constexpr const char* kTVMFFIOpaquePtr = "void*"; - /*! \brief The type key for DataType */ - static constexpr const char* kTVMFFIDataType = "DataType"; - /*! \brief The type key for Device */ - static constexpr const char* kTVMFFIDevice = "Device"; - /*! \brief The type key for const char* */ - static constexpr const char* kTVMFFIRawStr = "const char*"; - /*! \brief The type key for TVMFFIByteArray* */ - static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; - /*! \brief The type key for ObjectRValueRef */ - static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef"; - /*! \brief The type key for SmallStr */ - static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; - /*! \brief The type key for SmallBytes */ - static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes"; - /*! \brief The type key for Bytes */ - static constexpr const char* kTVMFFIBytes = "ffi.Bytes"; - /*! \brief The type key for String */ - static constexpr const char* kTVMFFIStr = "ffi.String"; - /*! \brief The type key for Shape */ - static constexpr const char* kTVMFFIShape = "ffi.Shape"; - /*! \brief The type key for Tensor */ - static constexpr const char* kTVMFFITensor = "ffi.Tensor"; - /*! \brief The type key for Object */ - static constexpr const char* kTVMFFIObject = "ffi.Object"; - /*! \brief The type key for Function */ - static constexpr const char* kTVMFFIFunction = "ffi.Function"; - /*! \brief The type key for Array */ - static constexpr const char* kTVMFFIArray = "ffi.Array"; - /*! \brief The type key for Map */ - static constexpr const char* kTVMFFIMap = "ffi.Map"; - /*! \brief The type key for Module */ - static constexpr const char* kTVMFFIModule = "ffi.Module"; -}; - -/*! - * \brief Get type key from type index - * \param type_index The input type index - * \return the type key - */ -inline std::string TypeIndexToTypeKey(int32_t type_index) { - const TypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - return std::string(type_info->type_key.data, type_info->type_key.size); -} - -namespace details { -// Helper to perform -// unsafe operations related to object -struct ObjectUnsafe; - -/*! - * Check if the type_index is an instance of TargetObjectType. - * - * \tparam TargetType The target object type to be checked. - * - * \param object_type_index The type index to be checked, caller - * ensures that the index is already within the object index range. - * - * \return Whether the target type is true. - */ -template -TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index); -} // namespace details - -/*! - * \brief Base class of all object containers. - * - * Sub-class of objects should declare the following static constexpr fields: - * - * - _type_index: - * Static type index of the object, if assigned to TypeIndex::kTVMFFIDynObject - * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::TypeIndex(); - * - _type_key: - * The unique string identifier of the type. - * - _type_final: - * Whether the type is terminal type(there is no subclass of the type in the object system). - * This field is automatically set by macro TVM_FFI_DECLARE_OBJECT_INFO_FINAL - * It is still OK to sub-class a terminal object type T and construct it using make_object. - * But IsInstance check will only show that the object type is T(instead of the sub-class). - * - _type_mutable: - * Whether we would like to expose cast to non-constant pointer - * ObjectType* from Any/AnyView. By default, we set to false so it is not exposed. - * - * The following two fields are necessary for base classes that can be sub-classed. - * - * - _type_child_slots: - * Number of reserved type index slots for child classes. - * Used for runtime optimization for type checking in IsInstance. - * If an object's type_index is within range of [type_index, type_index + _type_child_slots] - * Then the object can be quickly decided as sub-class of the current object class. - * If not, a fallback mechanism is used to check the global type table. - * Recommendation: set to estimate number of children needed. - * - * - _type_child_slots_can_overflow: - * Whether we can add additional child classes even if the number of child classes - * exceeds the _type_child_slots. A fallback mechanism to check type table will be used. - * Recommendation: set to false for optimal runtime speed if we know exact number of children. - * - * Two macros are used to declare helper functions in the object: - * - Use TVM_FFI_DECLARE_OBJECT_INFO for object classes that can be sub-classed. - * - Use TVM_FFI_DECLARE_OBJECT_INFO_FINAL for object classes that cannot be sub-classed. - * - * New objects can be created using make_object function. - * Which will automatically populate the type_index and deleter of the object. - */ -class Object { - protected: - /*! \brief header field that is the common prefix of all objects */ - TVMFFIObject header_; - - public: - Object() { - header_.strong_ref_count = 0; - header_.weak_ref_count = 0; - header_.deleter = nullptr; - } - /*! - * Check if the object is an instance of TargetType. - * \tparam TargetType The target type to be checked. - * \return Whether the target type is true. - */ - template - bool IsInstance() const { - return details::IsObjectInstance(header_.type_index); - } - - /*! \return The internal runtime type index of the object. */ - int32_t type_index() const { return header_.type_index; } - - /*! - * \return the type key of the object. - * \note this operation is expensive, can be used for error reporting. - */ - std::string GetTypeKey() const { - // the function checks that the info exists - const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index); - return std::string(type_info->type_key.data, type_info->type_key.size); - } - - /*! - * \return A hash value of the return of GetTypeKey. - */ - uint64_t GetTypeKeyHash() const { - // the function checks that the info exists - const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index); - return type_info->type_key_hash; - } - - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - * \return the result. - */ - static std::string TypeIndex2Key(int32_t tindex) { - const TypeInfo* type_info = TVMFFIGetTypeInfo(tindex); - return std::string(type_info->type_key.data, type_info->type_key.size); - } - - /*! - * \return Whether the object.use_count() == 1. - */ - bool unique() const { return use_count() == 1; } - - /*! - * \return The usage count of the cell. - * \note We use STL style naming to be consistent with known API in shared_ptr. - */ - int32_t use_count() const { - // only need relaxed load of counters -#ifdef _MSC_VER - return (reinterpret_cast(&header_.strong_ref_count))[0]; // NOLINT(*) -#else - return __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED); -#endif - } - - //---------------------------------------------------------------------------- - // The following fields are configuration flags for subclasses of object - //---------------------------------------------------------------------------- - /*! \brief The type key of the class */ - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIObject; - /*! \brief Whether the class is final */ - static constexpr bool _type_final = false; - /*! \brief Whether allow mutable access to fields */ - static constexpr bool _type_mutable = false; - /*! \brief The number of child slots of the class to pre-allocate to this type */ - static constexpr uint32_t _type_child_slots = 0; - /*! - * \brief Whether allow additional children beyond pre-specified by _type_child_slots - */ - static constexpr bool _type_child_slots_can_overflow = true; - /*! \brief The static type index of the class */ - static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject; - /*! \brief The static depth of the class in the object hierarchy */ - static constexpr int32_t _type_depth = 0; - /*! \brief The structural equality and hash kind of the type */ - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported; - // The following functions are provided by macro - // TVM_FFI_DECLARE_OBJECT_INFO and TVM_FFI_DECLARE_OBJECT_INFO_FINAL - /*! - * \brief Get the runtime allocated type index of the type - * \note Getting this information may need dynamic calls into a global table. - */ - static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } - /*! - * \brief Internal function to get or allocate a runtime index. - */ - static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } - - private: - /*! \brief increase strong reference count, the caller must already hold a strong reference */ - void IncRef() { -#ifdef _MSC_VER - _InterlockedIncrement64( - reinterpret_cast(&header_.strong_ref_count)); // NOLINT(*) -#else - __atomic_fetch_add(&(header_.strong_ref_count), 1, __ATOMIC_RELAXED); -#endif - } - /*! - * \brief Try to lock the object to increase the strong reference count, - * the caller must already hold a strong reference. - * \return whether the lock call is successful and object is still alive. - */ - bool TryPromoteWeakPtr() { -#ifdef _MSC_VER - uint64_t old_count = - (reinterpret_cast(&header_.strong_ref_count))[0]; // NOLINT(*) - while (old_count > 0) { - uint64_t new_count = old_count + 1; - uint64_t old_count_loaded = _InterlockedCompareExchange64( - reinterpret_cast(&header_.strong_ref_count), new_count, old_count); - if (old_count == old_count_loaded) { - return true; - } - old_count = old_count_loaded; - } - return false; -#else - uint64_t old_count = __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED); - while (old_count > 0) { - // must do CAS to ensure that we are the only one that increases the reference count - // avoid condition when two threads tries to promote weak to strong at same time - // or when strong deletion happens between the load and the CAS - uint64_t new_count = old_count + 1; - if (__atomic_compare_exchange_n(&(header_.strong_ref_count), &old_count, new_count, true, - __ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) { - return true; - } - } - return false; -#endif - } - - /*! \brief increase weak reference count */ - void IncWeakRef() { -#ifdef _MSC_VER - _InterlockedIncrement(reinterpret_cast(&header_.weak_ref_count)); // NOLINT(*) -#else - __atomic_fetch_add(&(header_.weak_ref_count), 1, __ATOMIC_RELAXED); -#endif - } - - /*! \brief decrease strong reference count and delete the object */ - void DecRef() { -#ifdef _MSC_VER - // use simpler impl in windows to ensure correctness - if (_InterlockedDecrement64( // - reinterpret_cast(&header_.strong_ref_count)) == 0) { // NOLINT(*) - // full barrrier is implicit in InterlockedDecrement - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); - } - if (_InterlockedDecrement( // - reinterpret_cast(&header_.weak_ref_count)) == 0) { // NOLINT(*) - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } - } -#else - // first do a release, note we only need to acquire for deleter - if (__atomic_fetch_sub(&(header_.strong_ref_count), 1, __ATOMIC_RELEASE) == 1) { - if (__atomic_load_n(&(header_.weak_ref_count), __ATOMIC_RELAXED) == 1) { - // common case, we need to delete both the object and the memory block - // only acquire when we need to call deleter - __atomic_thread_fence(__ATOMIC_ACQUIRE); - if (header_.deleter != nullptr) { - // call deleter once - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth); - } - } else { - // Slower path: there is still a weak reference left - __atomic_thread_fence(__ATOMIC_ACQUIRE); - // call destructor first, then decrease weak reference count - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); - } - // now decrease weak reference count - if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) == 1) { - __atomic_thread_fence(__ATOMIC_ACQUIRE); - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } - } - } -#endif - } - - /*! \brief decrease weak reference count */ - void DecWeakRef() { -#ifdef _MSC_VER - if (_InterlockedDecrement( // - reinterpret_cast(&header_.weak_ref_count)) == 0) { // NOLINT(*) - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } -#else - // now decrease weak reference count - if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) == 1) { - __atomic_thread_fence(__ATOMIC_ACQUIRE); - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } -#endif - } - - // friend classes - template - friend class ObjectPtr; - template - friend class WeakObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief A custom smart pointer for Object. - * \tparam T the content data type. - * \sa make_object - */ -template -class ObjectPtr { - public: - /*! \brief default constructor */ - ObjectPtr() {} - /*! \brief default constructor */ - ObjectPtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - ObjectPtr(const ObjectPtr& other) // NOLINT(*) - : ObjectPtr(other.data_) {} - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - ObjectPtr(const ObjectPtr& other) // NOLINT(*) - : ObjectPtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - ObjectPtr(ObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - ObjectPtr(ObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~ObjectPtr() { this->reset(); } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - void swap(ObjectPtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T* get() const { return static_cast(data_); } - /*! - * \return The pointer - */ - T* operator->() const { return get(); } - /*! - * \return The reference - */ - T& operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr& operator=(const ObjectPtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - ObjectPtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr& operator=(ObjectPtr&& other) { // NOLINT(*) - // copy-and-swap idiom - ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief nullptr check - * \return result of comparison of internal pointer with nullptr. - */ - explicit operator bool() const { return get() != nullptr; } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } - /*! \return whether the reference is unique */ - bool unique() const { return data_ != nullptr && data_->use_count() == 1; } - /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } - /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t) const { return data_ == nullptr; } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - - private: - /*! \brief internal pointer field */ - Object* data_{nullptr}; - /*! - * \brief constructor from Object - * \param data The data pointer - */ - explicit ObjectPtr(Object* data) : data_(data) { - if (data_ != nullptr) { - data_->IncRef(); - } - } - // friend classes - friend class Object; - friend class ObjectRef; - friend struct ObjectPtrHash; - template - friend class ObjectPtr; - template - friend class WeakObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief A custom smart pointer for Object. - * \tparam T the content data type. - * \sa make_object - */ -template -class WeakObjectPtr { - public: - /*! \brief default constructor */ - WeakObjectPtr() {} - /*! \brief default constructor */ - WeakObjectPtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - WeakObjectPtr(const WeakObjectPtr& other) // NOLINT(*) - : WeakObjectPtr(other.data_) {} - - /*! - * \brief copy constructor - * \param other The value to be moved - */ - WeakObjectPtr(const ObjectPtr& other) // NOLINT(*) - : WeakObjectPtr(other.get()) {} - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - WeakObjectPtr(const WeakObjectPtr& other) // NOLINT(*) - : WeakObjectPtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - WeakObjectPtr(const ObjectPtr& other) // NOLINT(*) - : WeakObjectPtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - WeakObjectPtr(WeakObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - WeakObjectPtr(WeakObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~WeakObjectPtr() { this->reset(); } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - void swap(WeakObjectPtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - /*! - * \brief copy assignment - * \param other The value to be assigned. - * \return reference to self. - */ - WeakObjectPtr& operator=(const WeakObjectPtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - WeakObjectPtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignment - * \param other The value to be assigned. - * \return reference to self. - */ - WeakObjectPtr& operator=(WeakObjectPtr&& other) { // NOLINT(*) - // copy-and-swap idiom - WeakObjectPtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - /*! \return The internal object pointer if the object is still alive, otherwise nullptr */ - ObjectPtr lock() const { - if (data_ != nullptr && data_->TryPromoteWeakPtr()) { - ObjectPtr ret; - // we already increase the reference count, so we don't need to do it again - ret.data_ = data_; - return ret; - } - return nullptr; - } - - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecWeakRef(); - data_ = nullptr; - } - } - - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } - - /*! \return whether the pointer is nullptr */ - bool expired() const { return data_ == nullptr || data_->use_count() == 0; } - - private: - /*! \brief internal pointer field */ - Object* data_{nullptr}; - - /*! - * \brief constructor from Object - * \param data The data pointer - */ - explicit WeakObjectPtr(Object* data) : data_(data) { - if (data_ != nullptr) { - data_->IncWeakRef(); - } - } - - template - friend class WeakObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief Optional data type in FFI. - * \tparam T The underlying type of the optional. - * - * \note Compared to std::optional, Optional - * akes less storage as it used nullptr to represent nullopt. - */ -template -class Optional; - -/*! \brief Base class of all object reference */ -class ObjectRef { - public: - /*! \brief default constructor */ - ObjectRef() = default; - /*! \brief copy constructor */ - ObjectRef(const ObjectRef& other) = default; - /*! \brief move constructor */ - ObjectRef(ObjectRef&& other) = default; - /*! \brief copy assignment */ - ObjectRef& operator=(const ObjectRef& other) = default; - /*! \brief move assignment */ - ObjectRef& operator=(ObjectRef&& other) = default; - /*! \brief Constructor from existing object ptr */ - explicit ObjectRef(ObjectPtr data) : data_(data) {} - /*! \brief Constructor from UnsafeInit */ - explicit ObjectRef(UnsafeInit) : data_(nullptr) {} - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool same_as(const ObjectRef& other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator==(const ObjectRef& other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } - /*! - * \brief Comparator - * \param other Another object ref by address. - * \return the compare result. - */ - bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } - /*! - * \return whether the object is defined. - */ - bool defined() const { return data_ != nullptr; } - /*! \return the internal object pointer */ - const Object* get() const { return data_.get(); } - /*! \return the internal object pointer */ - const Object* operator->() const { return get(); } - /*! \return whether the reference is unique */ - bool unique() const { return data_.unique(); } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_.use_count(); } - - /*! - * \brief Try to downcast the internal Object to a - * raw pointer of a corresponding type. - * - * The function will return a nullptr if the cast failed. - * - * if (const AddNode *ptr = node_ref.as()) { - * // This is an add node - * } - * - * \tparam ObjectType the target type, must be a subtype of Object - * \return The pointer to the requested type. - */ - template >> - const ObjectType* as() const { - if (data_ != nullptr && data_->IsInstance()) { - return static_cast(data_.get()); - } else { - return nullptr; - } - } - - /*! - * \brief Try to downcast the ObjectRef to Optional of the requested type. - * - * The function will return a std::nullopt if the cast or if the pointer is nullptr. - * - * \tparam ObjectRefType the target type, must be a subtype of ObjectRef' - * \return The optional value of the requested type. - */ - template >> - TVM_FFI_INLINE std::optional as() const { - if (data_ != nullptr) { - if (data_->IsInstance()) { - ObjectRefType ref(UnsafeInit{}); - ref.data_ = data_; - return ref; - } else { - return std::nullopt; - } - } else { - return std::nullopt; - } - } - - /*! - * \brief Get the type index of the ObjectRef - * \return The type index of the ObjectRef - */ - int32_t type_index() const { - return data_ != nullptr ? data_->type_index() : TypeIndex::kTVMFFINone; - } - - /*! - * \brief Get the type key of the ObjectRef - * \return The type key of the ObjectRef - */ - std::string GetTypeKey() const { - return data_ != nullptr ? data_->GetTypeKey() : StaticTypeKey::kTVMFFINone; - } - - /*! \brief type indicate the container type. */ - using ContainerType = Object; - /*! \brief Whether the reference can point to nullptr */ - static constexpr bool _type_is_nullable = true; - - protected: - /*! \brief Internal pointer that backs the reference. */ - ObjectPtr data_; - /*! \return return a mutable internal ptr, can be used by sub-classes. */ - Object* get_mutable() const { return data_.get(); } - // friend classes. - friend struct ObjectPtrHash; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -// forward delcare variant -template -class Variant; - -/*! \brief ObjectRef hash functor */ -struct ObjectPtrHash { - size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } - - template - size_t operator()(const ObjectPtr& a) const { - return std::hash()(a.get()); - } - - template - TVM_FFI_INLINE size_t operator()(const Variant& a) const; -}; - -/*! \brief ObjectRef equal functor */ -struct ObjectPtrEqual { - bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } - - template - bool operator()(const ObjectPtr& a, const ObjectPtr& b) const { - return a == b; - } - - template - TVM_FFI_INLINE bool operator()(const Variant& a, const Variant& b) const; -}; - -/// \cond Doxygen_Suppress -#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) \ - static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - static_assert(!ParentType::_type_final, "ParentType marked as final"); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - TVMFFIByteArray type_key{TypeName::_type_key, \ - std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ - &type_key, TypeName::_type_index, TypeName::_type_depth, TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ - return tindex; \ - } \ - static inline int32_t _register_type_index = _GetOrAllocRuntimeTypeIndex() -/// \endcond - -/*! - * \brief Helper macro to declare object information with static type index. - * - * \param TypeKey The type key of the current type. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_OBJECT_INFO_STATIC(TypeKey, TypeName, ParentType) \ - static constexpr const char* _type_key = TypeKey; \ - static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ - TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) - -/*! - * \brief Helper macro to declare object information with type key already defined in class. - * - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) \ - static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - static_assert(!ParentType::_type_final, "ParentType marked as final"); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - TVMFFIByteArray type_key{TypeName::_type_key, \ - std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ - &type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ - return tindex; \ - } \ - static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } \ - static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex() - -/*! - * \brief Helper macro to declare object information with dynamic type index. - * - * \param TypeKey The type key of the current type. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType) \ - static constexpr const char* _type_key = TypeKey; \ - TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) - -/*! - * \brief Helper macro to declare object information with dynamic type index and is final. - * - * \param TypeKey The type key of the current type. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_OBJECT_INFO_FINAL(TypeKey, TypeName, ParentType) \ - static const constexpr int _type_child_slots [[maybe_unused]] = 0; \ - static const constexpr bool _type_final [[maybe_unused]] = true; \ - TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType) - -/*! - * \brief Define object reference methods. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - * - * \note This macro also defines the default constructor that puts the ObjectRef - * in undefined state initially. - */ -#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ - explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - using __PtrType = std::conditional_t; \ - __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \ - __PtrType get() const { return static_cast<__PtrType>(data_.get()); } \ - static constexpr bool _type_is_nullable = true; \ - using ContainerType = ObjectName - -/*! - * \brief Define object reference methods do not have undefined state. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - using __PtrType = std::conditional_t; \ - __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \ - __PtrType get() const { return static_cast<__PtrType>(data_.get()); } \ - static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName - -namespace details { - -template -TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { - static_assert(std::is_base_of_v); - // Everything is a subclass of object. - if constexpr (std::is_same::value) { - return true; - } else if constexpr (TargetType::_type_final) { - // if the target type is a final type - // then we only need to check the equivalence. - return object_type_index == TargetType::RuntimeTypeIndex(); - } else { - // Explicitly enclose in else to eliminate this branch early in compilation. - // if target type is a non-leaf type - // Check if type index falls into the range of reserved slots. - int32_t target_type_index = TargetType::RuntimeTypeIndex(); - int32_t begin = target_type_index; - // The condition will be optimized by constant-folding. - if constexpr (TargetType::_type_child_slots != 0) { - // total_slots = child_slots + 1 (including self) - int32_t end = begin + TargetType::_type_child_slots + 1; - if (object_type_index >= begin && object_type_index < end) return true; - } else { - if (object_type_index == begin) return true; - } - if constexpr (TargetType::_type_child_slots_can_overflow) { - // Invariance: parent index is always smaller than the child. - if (object_type_index < target_type_index) return false; - // Do a runtime lookup of type information - // the function checks that the info exists - const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index); - return (type_info->type_depth > TargetType::_type_depth && - type_info->type_acenstors[TargetType::_type_depth]->type_index == target_type_index); - } else { - return false; - } - } -} - -/*! - * \brief Namespace to internally manipulate object class. - * \note These functions are only supposed to be used by internal - * implementations and not external users of the tvm::ffi - */ -struct ObjectUnsafe { - // NOTE: get ffi header from an object - TVM_FFI_INLINE static TVMFFIObject* GetHeader(const Object* src) { - return const_cast(&(src->header_)); - } - - template - TVM_FFI_INLINE static int64_t GetObjectOffsetToSubclass() { - return (reinterpret_cast(&(static_cast(nullptr)->header_)) - - reinterpret_cast(&(static_cast(nullptr)->header_))); - } - - template - TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr& ptr) { - T ref(UnsafeInit{}); - ref.data_ = ptr; - return ref; - } - - template - TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr&& ptr) { - T ref(UnsafeInit{}); - ref.data_ = std::move(ptr); - return ref; - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(const ObjectRef& ref) { - if constexpr (std::is_same_v) { - return ref.data_; - } else { - return tvm::ffi::ObjectPtr(ref.data_.data_); - } - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(ObjectRef&& ref) { - if constexpr (std::is_same_v) { - return std::move(ref.data_); - } else { - ObjectPtr result; - result.data_ = std::move(ref.data_.data_); - ref.data_.data_ = nullptr; - return result; - } - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(Object* raw_ptr) { - tvm::ffi::ObjectPtr ptr; - ptr.data_ = raw_ptr; - return ptr; - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(TVMFFIObject* obj_ptr) { - return ObjectPtrFromOwned(reinterpret_cast(obj_ptr)); - } - - template - TVM_FFI_INLINE static T* RawObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { - // NOTE: this is important to first cast to Object* - // then cast back to T* because objptr and tptr may not be the same - // depending on how sub-class allocates the space. - return static_cast(reinterpret_cast(obj_ptr)); - } - - // Create ObjectPtr from unowned ptr - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(Object* raw_ptr) { - return tvm::ffi::ObjectPtr(raw_ptr); - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { - return tvm::ffi::ObjectPtr(reinterpret_cast(obj_ptr)); - } - - TVM_FFI_INLINE static void DecRefObjectHandle(TVMFFIObjectHandle handle) { - reinterpret_cast(handle)->DecRef(); - } - - TVM_FFI_INLINE static void IncRefObjectHandle(TVMFFIObjectHandle handle) { - reinterpret_cast(handle)->IncRef(); - } - - TVM_FFI_INLINE static Object* RawObjectPtrFromObjectRef(const ObjectRef& src) { - return src.data_.data_; - } - - TVM_FFI_INLINE static TVMFFIObject* TVMFFIObjectPtrFromObjectRef(const ObjectRef& src) { - return GetHeader(src.data_.data_); - } - - template - TVM_FFI_INLINE static TVMFFIObject* TVMFFIObjectPtrFromObjectPtr(const ObjectPtr& src) { - return GetHeader(src.data_); - } - - template - TVM_FFI_INLINE static TVMFFIObject* MoveObjectPtrToTVMFFIObjectPtr(ObjectPtr&& src) { - Object* obj_ptr = src.data_; - src.data_ = nullptr; - return GetHeader(obj_ptr); - } - - TVM_FFI_INLINE static TVMFFIObject* MoveObjectRefToTVMFFIObjectPtr(ObjectRef&& src) { - Object* obj_ptr = src.data_.data_; - src.data_.data_ = nullptr; - return GetHeader(obj_ptr); - } -}; -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_OBJECT_H_ diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h deleted file mode 100644 index f370a178502e..000000000000 --- a/ffi/include/tvm/ffi/optional.h +++ /dev/null @@ -1,419 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/optional.h - * \brief Runtime Optional container types. - * \note Optional specializes for T is ObjectRef and used nullptr to indicate nullopt. - */ -#ifndef TVM_FFI_OPTIONAL_H_ -#define TVM_FFI_OPTIONAL_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -// Note: We place optional in tvm/ffi instead of tvm/ffi/container -// because optional itself is an inherent core component of the FFI system. -/// \cond Doxygen_Suppress -template -inline constexpr bool is_optional_type_v = false; - -template -inline constexpr bool is_optional_type_v> = true; - -// we can safely used ptr based optional for ObjectRef types -// that do not have additional data members and virtual functions. -template -inline constexpr bool use_ptr_based_optional_v = - (std::is_base_of_v && !is_optional_type_v); -/// \endcond - -// Specialization for non-ObjectRef types. -// simply fallback to std::optional -template -class Optional && !std::is_same_v && - !std::is_same_v>> { - public: - // default constructors. - Optional() = default; - Optional(const Optional& other) : data_(other.data_) {} - Optional(Optional&& other) : data_(std::move(other.data_)) {} - Optional(std::optional other) : data_(std::move(other)) {} // NOLINT(*) - Optional(std::nullopt_t) {} // NOLINT(*) - // normal value handling. - Optional(T other) // NOLINT(*) - : data_(std::move(other)) {} - - TVM_FFI_INLINE Optional& operator=(const Optional& other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(Optional&& other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(T other) { - data_ = std::move(other); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(std::nullopt_t) { - data_ = std::nullopt; - return *this; - } - - TVM_FFI_INLINE const T& value() const& { - if (!data_.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return *data_; - } - - TVM_FFI_INLINE T&& value() && { - if (!data_.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return *std::move(data_); - } - - template > - TVM_FFI_INLINE T value_or(U&& default_value) const { - return data_.value_or(std::forward(default_value)); - } - - TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.has_value(); } - - TVM_FFI_INLINE bool has_value() const noexcept { return data_.has_value(); } - - TVM_FFI_INLINE bool operator==(const Optional& other) const { return data_ == other.data_; } - - TVM_FFI_INLINE bool operator!=(const Optional& other) const { return data_ != other.data_; } - - template - TVM_FFI_INLINE bool operator==(const U& other) const { - return data_ == other; - } - template - TVM_FFI_INLINE bool operator!=(const U& other) const { - return data_ != other; - } - - /*! - * \brief Direct access to the value. - * \return the xvalue reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T&& operator*() && noexcept { return *std::move(data_); } - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE const T& operator*() const& noexcept { return *data_; } - - private: - std::optional data_; -}; - -// Specialization for String type, use nullptr to indicate nullopt -template -class Optional || std::is_same_v>> { - public: - // default constructors. - Optional() = default; - Optional(const Optional& other) : data_(other.data_) {} - Optional(Optional&& other) : data_(std::move(other.data_)) {} - Optional(std::nullopt_t) {} // NOLINT(*) - // normal value handling. - Optional(T other) // NOLINT(*) - : data_(std::move(other)) {} - - TVM_FFI_INLINE Optional& operator=(const Optional& other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(Optional&& other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(T other) { - data_ = std::move(other); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(std::nullopt_t) { - T(details::BytesBaseCell(std::nullopt)).swap(data_); - return *this; - } - - TVM_FFI_INLINE const T& value() const& { - if (data_.data_ == std::nullopt) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return data_; - } - - TVM_FFI_INLINE String&& value() && { - if (data_.data_ == std::nullopt) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return std::move(data_); - } - - template - TVM_FFI_INLINE T value_or(U&& default_value) const { - if (data_.data_ == std::nullopt) { - return std::forward(default_value); - } - return data_; - } - - TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.data_ != std::nullopt; } - - TVM_FFI_INLINE bool has_value() const noexcept { return data_.data_ != std::nullopt; } - - TVM_FFI_INLINE bool operator==(const Optional& other) const { - if (data_.data_ == std::nullopt) { - return other.data_.data_ == std::nullopt; - } - if (other.data_.data_ == std::nullopt) { - return false; - } - return data_ == other.data_; - } - - TVM_FFI_INLINE bool operator!=(const Optional& other) const { return !(*this == other); } - - template - TVM_FFI_INLINE bool operator==(const U& other) const { - if constexpr (std::is_same_v) { - return data_.data_ == std::nullopt; - } else { - if (data_.data_ == std::nullopt) { - return false; - } - return data_ == other; - } - } - template - TVM_FFI_INLINE bool operator!=(const U& other) const { - if constexpr (std::is_same_v) { - return data_.data_ != std::nullopt; - } else { - if (data_.data_ == std::nullopt) { - return true; - } - return data_ != other; - } - } - - /*! - * \brief Direct access to the value. - * \return the xvalue reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T&& operator*() && noexcept { return std::move(data_); } - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE const T& operator*() const& noexcept { return data_; } - - private: - // this is a private initializer - T data_{details::BytesBaseCell(std::nullopt)}; -}; - -// Specialization for ObjectRef types. -// nullptr is treated as std::nullopt. -template -class Optional>> : public ObjectRef { - public: - using ContainerType = typename T::ContainerType; - Optional() = default; - Optional(const Optional& other) : ObjectRef(other.data_) {} - Optional(Optional&& other) : ObjectRef(std::move(other.data_)) {} - explicit Optional(ffi::UnsafeInit tag) : ObjectRef(tag) {} - // nullopt hanlding - Optional(std::nullopt_t) {} // NOLINT(*) - - // handle conversion from std::optional - Optional(std::optional other) { // NOLINT(*) - if (other.has_value()) { - *this = *std::move(other); - } - } - // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) {} - - TVM_FFI_INLINE Optional& operator=(T other) { - ObjectRef::operator=(std::move(other)); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(const Optional& other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(std::nullptr_t) { - data_ = nullptr; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(Optional&& other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE T value() const& { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); - } - - TVM_FFI_INLINE T value() && { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); - } - - template > - TVM_FFI_INLINE T value_or(U&& default_value) const { - return data_ != nullptr ? details::ObjectUnsafe::ObjectRefFromObjectPtr(data_) - : T(std::forward(default_value)); - } - - TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; } - - TVM_FFI_INLINE bool has_value() const { return data_ != nullptr; } - - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T operator*() const& noexcept { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); - } - - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T operator*() && noexcept { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); - } - - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); } - - // operator overloadings - TVM_FFI_INLINE auto operator==(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - return EQToOptional(other); - } - TVM_FFI_INLINE auto operator!=(const Optional& other) const { return NEToOptional(other); } - - TVM_FFI_INLINE auto operator==(const std::optional& other) const { - // support case where sub-class returns a symbolic ref type. - return EQToOptional(other); - } - TVM_FFI_INLINE auto operator!=(const std::optional& other) const { - return NEToOptional(other); - } - - TVM_FFI_INLINE auto operator==(const T& other) const { - using RetType = decltype(value() == other); - if (same_as(other)) return RetType(true); - if (has_value()) return operator*() == other; - return RetType(false); - } - - TVM_FFI_INLINE auto operator!=(const T& other) const { return !(*this == other); } - - template - TVM_FFI_INLINE auto operator==(const U& other) const { - using RetType = decltype(value() == other); - if (!has_value()) return RetType(false); - return operator*() == other; - } - - template - TVM_FFI_INLINE auto operator!=(const U& other) const { - using RetType = decltype(value() != other); - if (!has_value()) return RetType(true); - return operator*() != other; - } - - /*! - * \return The internal object pointer with container type of T. - * \note This function do not perform not-null checking. - */ - TVM_FFI_INLINE const ContainerType* get() const { - return static_cast(data_.get()); - } - - private: - template - TVM_FFI_INLINE auto EQToOptional(const U& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(operator*() == *other); - if (same_as(other)) return RetType(true); - if (has_value() && other.has_value()) { - return operator*() == *other; - } else { - // one of them is nullptr. - return RetType(false); - } - } - - template - TVM_FFI_INLINE auto NEToOptional(const U& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(operator*() != *other); - if (same_as(other)) return RetType(false); - if (has_value() && other.has_value()) { - return operator*() != *other; - } else { - // one of them is nullptr. - return RetType(true); - } - } -}; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_OPTIONAL_H_ diff --git a/ffi/include/tvm/ffi/reflection/access_path.h b/ffi/include/tvm/ffi/reflection/access_path.h deleted file mode 100644 index ea102e144ab3..000000000000 --- a/ffi/include/tvm/ffi/reflection/access_path.h +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/registry.h - * \brief Registry of reflection metadata. - */ -#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_ -#define TVM_FFI_REFLECTION_ACCESS_PATH_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -/*! - * \brief The kind of the access pattern. - */ -enum class AccessKind : int32_t { - /*! \brief Object attribute access. */ - kAttr = 0, - /*! \brief Array item access. */ - kArrayItem = 1, - /*! \brief Map item access. */ - kMapItem = 2, - // the following two are used for error reporting when - // the supposed access field is not available - /*! \brief Object attribute missing access. */ - kAttrMissing = 3, - /*! \brief Array item missing access. */ - kArrayItemMissing = 4, - /*! \brief Map item missing access. */ - kMapItemMissing = 5, -}; - -class AccessStep; - -/*! - * \brief Represent a single step in object field, map key, array index access. - */ -class AccessStepObj : public Object { - public: - /*! - * \brief The kind of the access pattern. - */ - AccessKind kind; - /*! - * \brief The access key - * \note for array access, it will always be integer - * for field access, it will be string - */ - Any key; - - // default constructor to enable auto-serialization - AccessStepObj() = default; - /*! - * \brief Constructor - * \param kind The kind of the access step. - * \param key The key of the access step. - */ - AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {} - - /*! - * \brief Deep check if two steps are equal. - * \param other The other step to compare with. - * \return True if the two steps are equal, false otherwise. - */ - inline bool StepEqual(const AccessStep& other) const; - - /// \cond Doxygen_Suppress - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessStep", AccessStepObj, Object); - /// \endcond -}; - -/*! - * \brief ObjectRef class of AccessStepObj. - * - * \sa AccessStepObj - */ -class AccessStep : public ObjectRef { - public: - /*! - * \brief Constructor - * \param kind The kind of the access step. - * \param key The key of the access step. - * \return The access step. - */ - AccessStep(AccessKind kind, Any key) : ObjectRef(make_object(kind, key)) {} - - /*! - * \brief Create an access step for a object attribute access. - * \param field_name The name of the field to access. - * \return The access step. - */ - static AccessStep Attr(String field_name) { return AccessStep(AccessKind::kAttr, field_name); } - - /*! - * \brief Create an access step for a object attribute missing access. - * \param field_name The name of the field to access. - * \return The access step. - */ - static AccessStep AttrMissing(String field_name) { - return AccessStep(AccessKind::kAttrMissing, field_name); - } - - /*! - * \brief Create an access step for a array item access. - * \param index The index of the array item to access. - * \return The access step. - */ - static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } - - /*! - * \brief Create an access step for a array item missing access. - * \param index The index of the array item to access. - * \return The access step. - */ - static AccessStep ArrayItemMissing(int64_t index) { - return AccessStep(AccessKind::kArrayItemMissing, index); - } - - /*! - * \brief Create an access step for a map item access. - * \param key The key of the map item to access. - * \return The access step. - */ - static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); } - - /*! - * \brief Create an access step for a map item missing access. - * \param key The key of the map item to access. - * \return The access step. - */ - static AccessStep MapItemMissing(Any key = nullptr) { - return AccessStep(AccessKind::kMapItemMissing, key); - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessStep, ObjectRef, AccessStepObj); - /// \endcond -}; - -inline bool AccessStepObj::StepEqual(const AccessStep& other) const { - return this->kind == other->kind && AnyEqual()(this->key, other->key); -} - -// forward declaration -class AccessPath; - -/*! - * \brief ObjectRef class of AccessPathObj. - * - * \sa AccessPathObj - */ -class AccessPathObj : public Object { - public: - /*! - * \brief The parent of the access path. - * - * This parent-pointing tree structure is more space efficient when - * representing multiple paths that share a common prefix. - * - * \note Empty for root. - */ - Optional parent; - /*! - * \brief The current of the access path. - * \note Empty for root. - */ - Optional step; - /*! - * \brief The current depth of the access path, 0 for root - */ - int32_t depth; - - // default constructor to enable auto-serialization - AccessPathObj() = default; - /*! - * \brief Constructor for the access path. - * \param parent The parent of the access path. - * \param step The current step of the access path. - * \param depth The current depth of the access path. - */ - AccessPathObj(Optional parent, Optional step, int32_t depth) - : parent(parent), step(step), depth(depth) {} - - /*! - * \brief Get the parent of the access path. - * \return The parent of the access path. - */ - inline Optional GetParent() const; - - /*! - * \brief Extend the access path with a new step. - * \param step The step to extend the access path with. - * \return The extended access path. - */ - inline AccessPath Extend(AccessStep step) const; - - /*! - * \brief Extend the access path with an object attribute access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath Attr(String field_name) const; - - /*! - * \brief Extend the access path with an object attribute missing access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath AttrMissing(String field_name) const; - - /*! - * \brief Extend the access path with an array item access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItem(int64_t index) const; - - /*! - * \brief Extend the access path with an array item missing access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItemMissing(int64_t index) const; - - /*! - * \brief Extend the access path with a map item access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItem(Any key) const; - - /*! - * \brief Extend the access path with a map item missing access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItemMissing(Any key) const; - - /*! - * \brief Get the array of steps that corresponds to the access path. - * \return The array of steps that corresponds to the access path. - */ - inline Array ToSteps() const; - - /*! - * \brief Check if two paths are equal by deep comparing the steps. - * \param other The other path to compare with. - * \return True if the two paths are equal, false otherwise. - */ - inline bool PathEqual(const AccessPath& other) const; - - /*! - * \brief Check if this path is a prefix of another path. - * \param other The other path to compare with. - * \return True if this path is a prefix of the other path, false otherwise. - */ - inline bool IsPrefixOf(const AccessPath& other) const; - - /// \cond Doxygen_Suppress - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessPath", AccessPathObj, Object); - /// \endcond - - private: - static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) { - // fast path for same pointer - if (lhs == rhs) return true; - if (lhs->depth != rhs->depth) return false; - // do deep equality checks - while (lhs->parent.has_value()) { - TVM_FFI_ICHECK(rhs->parent.has_value()); - TVM_FFI_ICHECK(lhs->step.has_value()); - TVM_FFI_ICHECK(rhs->step.has_value()); - if (!(*lhs->step)->StepEqual(*(rhs->step))) { - return false; - } - lhs = static_cast(lhs->parent.get()); - rhs = static_cast(rhs->parent.get()); - // fast path for same pointer - if (lhs == rhs) return true; - TVM_FFI_ICHECK(lhs != nullptr); - TVM_FFI_ICHECK(rhs != nullptr); - } - return true; - } -}; - -/*! - * \brief ObjectRef class of AccessPath. - * - * \sa AccessPathObj - */ -class AccessPath : public ObjectRef { - public: - /*! - * \brief Create an access path from an iterator range of steps. - * \param begin The beginning of the iterator range. - * \param end The end of the iterator range. - * \return The access path. - */ - template - static AccessPath FromSteps(Iter begin, Iter end) { - AccessPath path = AccessPath::Root(); - for (Iter it = begin; it != end; ++it) { - path = path->Extend(*it); - } - return path; - } - /*! - * \brief Create an access path from an array of steps. - * \param steps The array of steps. - * \return The access path. - */ - static AccessPath FromSteps(Array steps) { - AccessPath path = AccessPath::Root(); - for (AccessStep step : steps) { - path = path->Extend(step); - } - return path; - } - - /*! - * \brief Create a root access path. - * \return The root access path. - */ - static AccessPath Root() { - return AccessPath(make_object(std::nullopt, std::nullopt, 0)); - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessPath, ObjectRef, AccessPathObj); - /// \endcond - - private: - friend class AccessPathObj; - explicit AccessPath(ObjectPtr ptr) : ObjectRef(ptr) {} -}; - -/*! - * \brief The pair of access paths. - */ -using AccessPathPair = Tuple; - -inline Optional AccessPathObj::GetParent() const { - if (auto opt_parent = this->parent.as()) { - return opt_parent; - } - return std::nullopt; -} - -inline AccessPath AccessPathObj::Extend(AccessStep step) const { - return AccessPath(make_object(GetRef(this), step, this->depth + 1)); -} - -inline AccessPath AccessPathObj::Attr(String field_name) const { - return this->Extend(AccessStep::Attr(field_name)); -} - -inline AccessPath AccessPathObj::AttrMissing(String field_name) const { - return this->Extend(AccessStep::AttrMissing(field_name)); -} - -inline AccessPath AccessPathObj::ArrayItem(int64_t index) const { - return this->Extend(AccessStep::ArrayItem(index)); -} - -inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const { - return this->Extend(AccessStep::ArrayItemMissing(index)); -} - -inline AccessPath AccessPathObj::MapItem(Any key) const { - return this->Extend(AccessStep::MapItem(key)); -} - -inline AccessPath AccessPathObj::MapItemMissing(Any key) const { - return this->Extend(AccessStep::MapItemMissing(key)); -} - -inline Array AccessPathObj::ToSteps() const { - std::vector reverse_steps; - reverse_steps.reserve(this->depth); - const AccessPathObj* current = this; - while (current->parent.has_value()) { - TVM_FFI_ICHECK(current->step.has_value()); - reverse_steps.push_back(*(current->step)); - current = static_cast(current->parent.get()); - TVM_FFI_ICHECK(current != nullptr); - } - return Array(reverse_steps.rbegin(), reverse_steps.rend()); -} - -inline bool AccessPathObj::PathEqual(const AccessPath& other) const { - return PathEqual(this, other.get()); -} - -inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const { - if (this->depth > other->depth) { - return false; - } - const AccessPathObj* rhs_path = other.get(); - while (rhs_path->depth > this->depth) { - TVM_FFI_ICHECK(rhs_path->parent.has_value()); - rhs_path = static_cast(rhs_path->parent.get()); - } - return PathEqual(this, rhs_path); -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ diff --git a/ffi/include/tvm/ffi/reflection/accessor.h b/ffi/include/tvm/ffi/reflection/accessor.h deleted file mode 100644 index 5fadd0985daf..000000000000 --- a/ffi/include/tvm/ffi/reflection/accessor.h +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/accessor.h - * \brief Reflection-based accessor for object fields and methods. - */ -#ifndef TVM_FFI_REFLECTION_ACCESSOR_H_ -#define TVM_FFI_REFLECTION_ACCESSOR_H_ - -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -/*! - * \brief helper function to get reflection field info by type key and field name - */ -inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char* field_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo* info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_fields; ++i) { - if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) { - return &(info->fields[i]); - } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in " << type_key; - TVM_FFI_UNREACHABLE(); -} - -/*! - * \brief helper wrapper class to obtain a getter. - */ -class FieldGetter { - public: - /*! - * \brief Constructor - * \param field_info The field info. - */ - explicit FieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} - - /*! - * \brief Constructor - * \param type_key The type key. - * \param field_name The name of the field. - */ - explicit FieldGetter(std::string_view type_key, const char* field_name) - : FieldGetter(GetFieldInfo(type_key, field_name)) {} - - /*! - * \brief Get the value of the field - * \param obj_ptr The object pointer. - * \return The value of the field. - */ - Any operator()(const Object* obj_ptr) const { - Any result; - const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->getter(const_cast(addr), reinterpret_cast(&result))); - return result; - } - - Any operator()(const ObjectPtr& obj_ptr) const { return operator()(obj_ptr.get()); } - - Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); } - - private: - const TVMFFIFieldInfo* field_info_; -}; - -/*! - * \brief helper wrapper class to obtain a setter. - */ -class FieldSetter { - public: - /*! - * \brief Constructor - * \param field_info The field info. - */ - explicit FieldSetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} - - /*! - * \brief Constructor - * \param type_key The type key. - * \param field_name The name of the field. - */ - explicit FieldSetter(std::string_view type_key, const char* field_name) - : FieldSetter(GetFieldInfo(type_key, field_name)) {} - - /*! - * \brief Set the value of the field - * \param obj_ptr The object pointer. - * \param value The value to be set. - */ - void operator()(const Object* obj_ptr, AnyView value) const { - const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->setter(const_cast(addr), reinterpret_cast(&value))); - } - - void operator()(const ObjectPtr& obj_ptr, AnyView value) const { - operator()(obj_ptr.get(), value); - } - - void operator()(const ObjectRef& obj, AnyView value) const { operator()(obj.get(), value); } - - private: - const TVMFFIFieldInfo* field_info_; -}; - -/*! - * \brief Helper class to get type attribute column. - */ -class TypeAttrColumn { - public: - /*! - * \brief Constructor - * \param attr_name The name of the type attribute. - */ - explicit TypeAttrColumn(std::string_view attr_name) { - TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()}; - column_ = TVMFFIGetTypeAttrColumn(&attr_name_array); - if (column_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name; - } - } - /*! - * \brief Get the type attribute column by type index. - * \param type_index The type index. - * \return The type attribute column. - */ - AnyView operator[](int32_t type_index) const { - size_t tindex = static_cast(type_index); - if (tindex >= column_->size) { - return AnyView(); - } - const AnyView* any_view_data = reinterpret_cast(column_->data); - return any_view_data[tindex]; - } - - private: - const TVMFFITypeAttrColumn* column_; -}; - -/*! - * \brief helper function to get reflection method info by type key and method name - * - * \param type_key The type key. - * \param method_name The name of the method. - * \return The method info. - */ -inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const char* method_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo* info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_methods; ++i) { - if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) { - return &(info->methods[i]); - } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key; - TVM_FFI_UNREACHABLE(); -} - -/*! - * \brief helper function to get reflection method function by method info - * - * \param type_key The type key. - * \param method_name The name of the method. - * \return The method function. - */ -inline Function GetMethod(std::string_view type_key, const char* method_name) { - const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name); - return AnyView::CopyFromTVMFFIAny(info->method).cast(); -} - -/*! - * \brief Visit each field info of the type info and run callback. - * - * \tparam Callback The callback function type. - * - * \param type_info The type info. - * \param callback The callback function. - * - * \note This function calls both the child and parent type info. - */ -template -inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) { - using ResultType = decltype(callback(type_info->fields)); - static_assert(std::is_same_v, "Callback must return void"); - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - callback(parent_info->fields + j); - } - } - for (int i = 0; i < type_info->num_fields; ++i) { - callback(type_info->fields + i); - } -} - -/*! - * \brief Visit each field info of the type info and run callback which returns bool for early stop. - * - * \tparam Callback The callback function type, which returns bool for early stop. - * - * \param type_info The type info. - * \param callback_with_early_stop The callback function. - * \return true if any of early stop is triggered. - * - * \note This function calls both the child and parent type info and can be used for searching. - */ -template -inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info, - Callback callback_with_early_stop) { - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - if (callback_with_early_stop(parent_info->fields + j)) return true; - } - } - for (int i = 0; i < type_info->num_fields; ++i) { - if (callback_with_early_stop(type_info->fields + i)) return true; - } - return false; -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_ACCESSOR_H_ diff --git a/ffi/include/tvm/ffi/reflection/creator.h b/ffi/include/tvm/ffi/reflection/creator.h deleted file mode 100644 index 774eb8b0b4a9..000000000000 --- a/ffi/include/tvm/ffi/reflection/creator.h +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/creator.h - * \brief Reflection-based creator to create objects from type key and fields. - */ -#ifndef TVM_FFI_REFLECTION_CREATOR_H_ -#define TVM_FFI_REFLECTION_CREATOR_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { -/*! - * \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection. - */ -class ObjectCreator { - public: - /*! - * \brief Constructor - * \param type_key The type key. - */ - explicit ObjectCreator(std::string_view type_key) - : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {} - - /*! - * \brief Constructor - * \param type_info The type info. - */ - explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) { - int32_t type_index = type_info->type_index; - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have reflection registered"; - } - if (type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor, " - << "as a result cannot be created via reflection"; - } - } - - /** - * \brief Create an object from a map of fields. - * \param fields The fields of the object. - * \return The created object. - */ - Any operator()(const Map& fields) const { - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - size_t match_field_count = 0; - ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) { - String field_name(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (fields.count(field_name) != 0) { - Any field_value = fields[field_name]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); - ++match_field_count; - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "`"; - } - }); - if (match_field_count == fields.size()) return ObjectRef(ptr); - // report error that checks if contains extra fields that are not in the type - auto check_field_name = [&](const String& field_name) { - bool found = false; - ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) { - if (field_name.compare(field_info->name) == 0) { - found = true; - return true; - } - return false; - }); - return found; - }; - for (const auto& [field_name, _] : fields) { - if (!check_field_name(field_name)) { - TVM_FFI_THROW(TypeError) << "Type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "` does not have field `" << field_name << "`"; - } - } - TVM_FFI_UNREACHABLE(); - } - - private: - const TVMFFITypeInfo* type_info_; -}; -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_CREATOR_H_ diff --git a/ffi/include/tvm/ffi/reflection/registry.h b/ffi/include/tvm/ffi/reflection/registry.h deleted file mode 100644 index 6a1a9b55d2b0..000000000000 --- a/ffi/include/tvm/ffi/reflection/registry.h +++ /dev/null @@ -1,564 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/registry.h - * \brief Registry of reflection metadata. - */ -#ifndef TVM_FFI_REFLECTION_REGISTRY_H_ -#define TVM_FFI_REFLECTION_REGISTRY_H_ - -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -/*! \brief Reflection namespace */ -namespace reflection { - -/*! - * \brief Trait that can be used to set field info - * \sa DefaultValue, AttachFieldFlag - */ -struct FieldInfoTrait {}; - -/*! - * \brief Trait that can be used to set field default value - */ -class DefaultValue : public FieldInfoTrait { - public: - /*! - * \brief Constructor - * \param value The value to be set - */ - explicit DefaultValue(Any value) : value_(value) {} - - /*! - * \brief Apply the default value to the field info - * \param info The field info. - */ - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { - info->default_value = AnyView(value_).CopyToTVMFFIAny(); - info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; - } - - private: - Any value_; -}; - -/*! - * \brief Trait that can be used to attach field flag - */ -class AttachFieldFlag : public FieldInfoTrait { - public: - /*! - * \brief Attach a field flag to the field - * - * \param flag The flag to be set - * - * \return The trait object. - */ - explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} - - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); - } - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore); - } - - /*! - * \brief Apply the field flag to the field info - * \param info The field info. - */ - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; } - - private: - int32_t flag_; -}; - -/*! - * \brief Get the byte offset of a class member field. - * - * \tparam The original class. - * \tparam T the field type. - * - * \param field_ptr A class member pointer - * \returns The byteoffset - */ -template -TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { - int64_t field_offset_to_class = - reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); - return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); -} - -/// \cond Doxygen_Suppress -class ReflectionDefBase { - protected: - template - static int FieldGetter(void* field, TVMFFIAny* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int FieldSetter(void* field, const TVMFFIAny* value) { - TVM_FFI_SAFE_CALL_BEGIN(); - if constexpr (std::is_same_v) { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value); - } else { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); - } - TVM_FFI_SAFE_CALL_END(); - } - - template - static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - ObjectPtr obj = make_object(); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - ObjectPtr obj = make_object(UnsafeInit{}); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); - TVM_FFI_SAFE_CALL_END(); - } - - template - TVM_FFI_INLINE static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); - } - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) { - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) { - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](Class target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { - // call method pointer - return (const_cast(target)->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { - // call method pointer - return (target->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) { - return ffi::Function::FromTyped(std::forward(func), name); - } -}; -/// \endcond - -/*! - * \brief GlobalDef helper to register a global function. - * - * \code - * namespace refl = tvm::ffi::reflection; - * refl::GlobalDef().def("my_ffi_extension.my_function", MyFunction); - * \endcode - */ -class GlobalDef : public ReflectionDefBase { - public: - /*! - * \brief Define a global function. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring or subclass of FieldInfoTrait. - * - * \return The reflection definition. - */ - template - GlobalDef& def(const char* name, Func&& func, Extra&&... extra) { - RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), - std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a global function in ffi::PackedArgs format. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring or subclass of FieldInfoTrait. - * - * \return The reflection definition. - */ - template - GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) { - RegisterFunc(name, ffi::Function::FromPacked(func), std::forward(extra)...); - return *this; - } - - /*! - * \brief Expose a class method as a global function. - * - * An argument will be added to the first position if the function is not static. - * - * \tparam Class The class type. - * \tparam Func The function type. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { - RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), - std::forward(extra)...); - return *this; - } - - private: - template - void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) { - TVMFFIMethodInfo info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.type_schema = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - // obtain the method function - info.method = AnyView(func).CopyToTVMFFIAny(); - // apply method info traits - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0)); - } -}; - -/*! - * \brief Helper to register Object's reflection metadata. - * \tparam Class The class type. - * - * \code - * namespace refl = tvm::ffi::reflection; - * refl::ObjectDef().def_ro("my_field", &MyClass::my_field); - * \endcode - */ -template -class ObjectDef : public ReflectionDefBase { - public: - /*! - * \brief Constructor - * \tparam ExtraArgs The extra arguments. - * \param extra_args The extra arguments. - */ - template - explicit ObjectDef(ExtraArgs&&... extra_args) - : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { - RegisterExtraInfo(std::forward(extra_args)...); - } - - /*! - * \brief Define a readonly field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { - RegisterField(name, field_ptr, false, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a read-write field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { - static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); - RegisterField(name, field_ptr, true, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { - RegisterMethod(name, false, std::forward(func), std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a static method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { - RegisterMethod(name, true, std::forward(func), std::forward(extra)...); - return *this; - } - - private: - template - void RegisterExtraInfo(ExtraArgs&&... extra_args) { - TVMFFITypeMetadata info; - info.total_size = sizeof(Class); - info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; - info.creator = nullptr; - info.doc = TVMFFIByteArray{nullptr, 0}; - if constexpr (std::is_default_constructible_v) { - info.creator = ObjectCreatorDefault; - } else if constexpr (std::is_constructible_v) { - info.creator = ObjectCreatorUnsafeInit; - } - // apply extra info traits - ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info)); - } - - template - void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable, - ExtraArgs&&... extra_args) { - static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); - TVMFFIFieldInfo info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.field_static_type_index = TypeToFieldStaticTypeIndex::value; - // store byte offset and setter, getter - // so the same setter can be reused for all the same type - info.offset = GetFieldByteOffsetToObject(field_ptr); - info.size = sizeof(T); - info.alignment = alignof(T); - info.flags = 0; - if (writable) { - info.flags |= kTVMFFIFieldFlagBitMaskWritable; - } - info.getter = FieldGetter; - info.setter = FieldSetter; - // initialize default value to nullptr - info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); - info.doc = TVMFFIByteArray{nullptr, 0}; - info.type_schema = TVMFFIByteArray{nullptr, 0}; - // apply field info traits - ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); - // call register - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); - } - - // register a method - template - void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { - TVMFFIMethodInfo info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.type_schema = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - if (is_static) { - info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; - } - // obtain the method function - Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - info.method = AnyView(method).CopyToTVMFFIAny(); - // apply method info traits - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); - } - - int32_t type_index_; - const char* type_key_; -}; - -/*! - * \brief Helper to register type attribute. - * \tparam Class The class type. - * \tparam ExtraArgs The extra arguments. - * - * \code - * namespace refl = tvm::ffi::reflection; - * refl::TypeAttrDef().def("func_attr", MyFunc); - * \endcode - * - */ -template >> -class TypeAttrDef : public ReflectionDefBase { - public: - /*! - * \brief Constructor - * \tparam ExtraArgs The extra arguments. - * \param extra_args The extra arguments. - */ - template - explicit TypeAttrDef(ExtraArgs&&... extra_args) - : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {} - - /*! - * \brief Define a function-valued type attribute. - * - * \tparam Func The function type. - * - * \param name The name of the function. - * \param func The function to be registered. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef& def(const char* name, Func&& func) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - ffi::Function ffi_func = - GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - - /*! - * \brief Define a constant-valued type attribute. - * - * \tparam T The type of the value. - * - * \param name The name of the attribute. - * \param value The value of the attribute. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef& attr(const char* name, T value) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - - private: - int32_t type_index_; - const char* type_key_; -}; - -/*! - * \brief Ensure the type attribute column is presented in the system. - * - * \param name The name of the type attribute. - */ -inline void EnsureTypeAttrColumn(std::string_view name) { - TVMFFIByteArray name_array = {name.data(), name.size()}; - AnyView any_view(nullptr); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array, - reinterpret_cast(&any_view))); -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_REGISTRY_H_ diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h deleted file mode 100644 index ebbec582e62a..000000000000 --- a/ffi/include/tvm/ffi/rvalue_ref.h +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/rvalue_ref.h - * \brief Helper class to define rvalue reference type. - */ -#ifndef TVM_FFI_RVALUE_REF_H_ -#define TVM_FFI_RVALUE_REF_H_ - -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Helper class to define rvalue reference type. - * - * By default, FFI pass all values by lvalue reference. - * - * However, we do allow users to intentionally mark a function parameter - * as RValueRef. In such cases, the caller can choose to pass parameter - * wrapped by RValueRef to the function. In which case the parameter - * can be directly moved by the callee. The caller can also choose to pass - * a normal lvalue to the function, in such case a copy will be triggered. - * - * To keep FFI checking overhead minimal, we do not handle case when rvalue - * is passed, but the callee did not declare the parameter as RValueRef. - * - * This design allows us to still leverage move semantics for parameters that - * need copy on write scenarios (and requires an unique copy). - * - * \code - * - * void Example() { - * auto append = Function::FromTyped([](RValueRef> ref, int val) -> Array { - * Array arr = *std::move(ref); - * assert(arr.unique()); - * arr.push_back(val); - * return arr; - * }); - * Array a = Array({1, 2}); - * // as we use rvalue ref to move a into append - * // we keep a single copy of the Array without creating new copies during copy-on-write - * a = append(RvalueRef(std::move(a)), 3); - * assert(a.size() == 3); - * } - * - * \endcode - */ -template >> -class RValueRef { - public: - /*! \brief the container type of the rvalue ref */ - using ContainerType = typename TObjRef::ContainerType; - /*! \brief only allow move constructor from rvalue of T */ - explicit RValueRef(TObjRef&& data) - : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} - - /*! \brief return the data as rvalue */ - TObjRef operator*() && { return TObjRef(std::move(data_)); } - - private: - mutable ObjectPtr data_; - - template - friend struct TypeTraits; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const RValueRef& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIObjectRValueRef; - result->zero_padding = 0; - // store the address of the ObjectPtr, which allows us to move the value - // and set the original ObjectPtr to nullptr - result->v_ptr = &(src.data_); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - return "RValueRef<" + TypeTraits::GetMismatchTypeInfo(&tmp_any) + ">"; - } else { - return TypeTraits::GetMismatchTypeInfo(src); - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // first try rvalue conversion - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - // fast path, storage type matches, direct move the rvalue ref - if (TypeTraits::CheckAnyStrict(&tmp_any)) { - return RValueRef( - details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(*rvalue_ref))); - } - if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - return RValueRef(*std::move(opt)); - } - return std::nullopt; - } - // try lvalue conversion - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { - return RValueRef(*std::move(opt)); - } else { - return std::nullopt; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "RValueRef<" + TypeTraits::TypeStr() + ">"; - } -}; -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_RVALUE_REF_H_ diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h deleted file mode 100644 index a1529d749fca..000000000000 --- a/ffi/include/tvm/ffi/string.h +++ /dev/null @@ -1,1014 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/string.h - * \brief Runtime Bytes and String types. - */ -#ifndef TVM_FFI_STRING_H_ -#define TVM_FFI_STRING_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -// Note: We place string in tvm/ffi instead of tvm/ffi/container -// because string itself needs special handling and is an inherent -// core component for return string handling. -// The following dependency relation holds -// any -> string -> object - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base class for bytes and string objects. - */ -class BytesObjBase : public Object, public TVMFFIByteArray {}; - -/*! - * \brief An object representing bytes. - * \note We use a separate object for bytes to follow Python convention - * and indicate passing of raw bytes. - * Bytes can be converted from/to string. - */ -class BytesObj : public BytesObjBase { - public: - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIBytes; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIBytes, BytesObj, Object); -}; - -/*! \brief An object representing string. This is a POD type. */ -class StringObj : public BytesObjBase { - public: - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIStr, StringObj, Object); -}; - -// String moved from std::string -// without having to trigger a copy -template -class BytesObjStdImpl : public Base { - public: - explicit BytesObjStdImpl(std::string other) : data_{other} { - this->data = data_.data(); - this->size = data_.size(); - } - - private: - std::string data_; -}; - -/*! - * \brief Helper cell class that can be used to back small string - * \note Do not use directly, use String or Bytes instead - */ -class BytesBaseCell { - public: - BytesBaseCell() { - // initialize to none - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - explicit BytesBaseCell(std::nullopt_t) { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - BytesBaseCell(const BytesBaseCell& other) : data_(other.data_) { // NOLINT(*) - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); - } - } - - BytesBaseCell(BytesBaseCell&& other) : data_(other.data_) { // NOLINT(*) - other.data_.type_index = TypeIndex::kTVMFFINone; - } - - BytesBaseCell& operator=(const BytesBaseCell& other) { - BytesBaseCell(other).swap(*this); // NOLINT(*) - return *this; - } - - BytesBaseCell& operator=(BytesBaseCell&& other) { - BytesBaseCell(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - ~BytesBaseCell() { - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); - } - } - - /*! - * \brief Check if the cell is null - * \return true if the cell is null, false otherwise - */ - bool operator==(std::nullopt_t) const { return data_.type_index == TypeIndex::kTVMFFINone; } - - /*! - * \brief Check if the cell is not null - * \return true if the cell is not null, false otherwise - */ - bool operator!=(std::nullopt_t) const { return data_.type_index != TypeIndex::kTVMFFINone; } - - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(BytesBaseCell& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - const char* data() const noexcept { - if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - return data_.v_bytes; - } else { - return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->data; - } - } - - size_t size() const noexcept { - if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - return data_.small_str_len; - } else { - return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->size; - } - } - - template - void InitFromStd(std::string&& other, int32_t large_type_index) { - // needs to be reset to none first for exception safety - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); - ObjectPtr ptr = make_object>(std::move(other)); - data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); - data_.type_index = large_type_index; - } - - /*! - * \brief Create a new empty space for a string - * \param size The size of the string - * \param small_type_index The type index for the small string - * \param large_type_index The type index for the large string - * \note always reserve one byte for \0 compactibility - * \return A pointer to the empty space - */ - template - char* InitSpaceForSize(size_t size, int32_t small_type_index, int32_t large_type_index) { - size_t kMaxSmallBytesLen = sizeof(int64_t) - 1; - // first zero the content, this is important for exception safety - data_.type_index = small_type_index; - data_.zero_padding = 0; - if (size <= kMaxSmallBytesLen) { - // set up the size accordingly - data_.small_str_len = static_cast(size); - return data_.v_bytes; - } else { - // allocate from heap - ObjectPtr ptr = make_inplace_array_object(size + 1); - char* dest_data = reinterpret_cast(ptr.get()) + sizeof(LargeObj); - ptr->data = dest_data; - ptr->size = size; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); - data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); - // now reset the type index to str - data_.type_index = large_type_index; - return dest_data; - } - } - - void InitTypeIndex(int32_t type_index) { data_.type_index = type_index; } - - void MoveToAny(TVMFFIAny* result) { - *result = data_; - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - TVMFFIAny CopyToTVMFFIAny() const { return data_; } - - static BytesBaseCell CopyFromAnyView(const TVMFFIAny* src) { - BytesBaseCell result(*src); - if (result.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(result.data_.v_obj); - } - return result; - } - - static BytesBaseCell MoveFromAny(TVMFFIAny* src) { - BytesBaseCell result(*src); - src->type_index = TypeIndex::kTVMFFINone; - src->zero_padding = 0; - src->v_int64 = 0; - return result; - } - - private: - explicit BytesBaseCell(TVMFFIAny data) : data_(data) {} - /*! \brief internal backing data */ - TVMFFIAny data_; -}; -} // namespace details - -/*! - * \brief Managed reference of byte array. - */ -class Bytes { - public: - /*! \brief default constructor */ - Bytes() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallBytes); } - /*! - * \brief constructor from size - * - * \param data The data pointer. - * \param size The size of the char array. - */ - Bytes(const char* data, size_t size) { this->InitData(data, size); } - /*! - * \brief constructor from TVMFFIByteArray - * - * \param bytes a char array. - */ - Bytes(TVMFFIByteArray bytes) { // NOLINT(*) - this->InitData(bytes.data, bytes.size); - } - /*! - * \brief constructor from std::string - * - * \param other a char array. - */ - Bytes(const std::string& other) { // NOLINT(*) - this->InitData(other.data(), other.size()); - } - /*! - * \brief constructor from std::string - * - * \param other a char array. - */ - Bytes(std::string&& other) { // NOLINT(*) - data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIBytes); - } - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(Bytes& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - template - Bytes& operator=(T&& other) { - // copy-and-swap idiom - Bytes(std::forward(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const { return data_.size(); } - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const { return data_.data(); } - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { return std::string{data(), size()}; } - - /*! - * \brief Compare two char sequence - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * \return int zero if both char sequences compare equal. negative if this - * appear before other, positive otherwise. - */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { - if (lhs == rhs && lhs_count == rhs_count) return 0; - - for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { - if (lhs[i] < rhs[i]) return -1; - if (lhs[i] > rhs[i]) return 1; - } - if (lhs_count < rhs_count) { - return -1; - } else if (lhs_count > rhs_count) { - return 1; - } else { - return 0; - } - } - /*! - * \brief Compare two char sequence for equality - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * - * \return true if the two char sequences are equal, false otherwise. - */ - static bool memequal(const void* lhs, const void* rhs, size_t lhs_count, size_t rhs_count) { - return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0); - } - - private: - template - friend struct TypeTraits; - template - friend class Optional; - // internal backing cell - details::BytesBaseCell data_; - // create a new String from TVMFFIAny, must keep private - explicit Bytes(details::BytesBaseCell data) : data_(data) {} - char* InitSpaceForSize(size_t size) { - return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallBytes, - TypeIndex::kTVMFFIBytes); - } - void InitData(const char* data, size_t size) { - char* dest_data = InitSpaceForSize(size); - std::memcpy(dest_data, data, size); - // mainly to be compat with string - dest_data[size] = '\0'; - } -}; - -/*! - * \brief String container class. - */ -class String { - public: - /*! - * \brief avoid misuse of nullptr - */ - String(std::nullptr_t) = delete; // NOLINT(*) - /*! - * \brief constructor - */ - String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); } - // constructors from Any - /*! - * \brief Copy constructor - * \param other The other string - */ - String(const String& other) = default; // NOLINT(*) - /*! - * \brief Move constructor - * \param other The other string - */ - String(String&& other) = default; // NOLINT(*) - /*! - * \brief Copy assignment operator - * \param other The other string - */ - String& operator=(const String& other) = default; // NOLINT(*) - /*! - * \brief Move assignment operator - * \param other The other string - */ - String& operator=(String&& other) = default; // NOLINT(*) - - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(String& other) noexcept { // NOLINT(*) - std::swap(data_, other.data_); - } - - /*! - * \brief Copy assignment operator - * \param other The other string - */ - String& operator=(const std::string& other) { - String(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Move assignment operator - * \param other The other string - */ - String& operator=(std::string&& other) { - String(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief Copy assignment operator - * \param other The other string - */ - String& operator=(const char* other) { - String(other).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief constructor from raw string - * - * \param data The data pointer. - * \param size The size of the char array. - */ - String(const char* data, size_t size) { this->InitData(data, size); } - - /*! - * \brief constructor from raw string - * - * \param other a char array. - * \note This constructor is marked as explicit to avoid implicit conversion - * of nullptr value here to string, which then was used in comparison - */ - String(const char* other) { // NOLINT(*) - this->InitData(other, std::char_traits::length(other)); - } - /*! - * \brief Construct a new string object - * \param other The std::string object to be copied - */ - String(const std::string& other) { // NOLINT(*) - this->InitData(other.data(), other.size()); - } - - /*! - * \brief Construct a new string object - * \param other The std::string object to be moved - */ - String(std::string&& other) { // NOLINT(*) - // exception safety, first set to none so if exception is thrown - // destructor works correctly - data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIStr); - } - - /*! - * \brief constructor from TVMFFIByteArray - * - * \param other a TVMFFIByteArray. - */ - explicit String(TVMFFIByteArray other) { this->InitData(other.data, other.size); } - - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const noexcept { return data_.data(); } - - /*! - * \brief Returns a pointer to the char array in the string. - * - * \return const char* - */ - const char* c_str() const noexcept { return data(); } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const noexcept { return data_.size(); } - - /*! - * \brief Compares this String object to other - * - * \param other The String to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const String& other) const { - return Bytes::memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this String object to other - * - * \param other The string to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const std::string& other) const { - return Bytes::memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this to other - * - * \param other The character array to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const char* other) const { - const char* this_data = data(); - size_t this_size = size(); - for (size_t i = 0; i < this_size; ++i) { - // other is shorter than this - if (other[i] == '\0') return 1; - if (this_data[i] < other[i]) return -1; - if (this_data[i] > other[i]) return 1; - } - // other equals this - if (other[this_size] == '\0') return 0; - // other longer than this - return -1; - } - - /*! - * \brief Compares this to other - * - * \param other The TVMFFIByteArray to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const TVMFFIByteArray& other) const { - return Bytes::memncmp(data(), other.data, size(), other.size); - } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t length() const { return size(); } - - /*! - * \brief Retun if the string is empty - * - * \return true if empty, false otherwise. - */ - bool empty() const { return size() == 0; } - - /*! - * \brief Read an element. - * \param pos The position at which to read the character. - * - * \return The char at position - */ - char at(size_t pos) const { - if (pos < size()) { - return data()[pos]; - } else { - throw std::out_of_range("tvm::String index out of bounds"); - } - } - - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { return std::string{data(), size()}; } - - private: - template - friend struct TypeTraits; - template - friend class Optional; - // internal backing cell - details::BytesBaseCell data_; - // create a new String from TVMFFIAny, must keep private - explicit String(details::BytesBaseCell data) : data_(data) {} - /*! - * \brief Create a new empty space for a string - * \param size The size of the string - * \return A pointer to the empty space - */ - char* InitSpaceForSize(size_t size) { - return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallStr, - TypeIndex::kTVMFFIStr); - } - void InitData(const char* data, size_t size) { - char* dest_data = InitSpaceForSize(size); - std::memcpy(dest_data, data, size); - dest_data[size] = '\0'; - } - /*! - * \brief Concatenate two char sequences - * - * \param lhs Pointers to the lhs char array - * \param lhs_size The size of the lhs char array - * \param rhs Pointers to the rhs char array - * \param rhs_size The size of the rhs char array - * - * \return The concatenated char sequence - */ - static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { - String ret; - // disable stringop-overflow and restrict warnings - // gcc may produce false positive when we enable dest_data returned from small string path - // Because compiler is not able to detect the condition that the path is only triggered via - // size < kMaxSmallStrLen and can report it as a overflow case. -#if (__GNUC__) && !(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstringop-overflow" -#pragma GCC diagnostic ignored "-Wrestrict" -#endif - char* dest_data = ret.InitSpaceForSize(lhs_size + rhs_size); - std::memcpy(dest_data, lhs, lhs_size); - std::memcpy(dest_data + lhs_size, rhs, rhs_size); - dest_data[lhs_size + rhs_size] = '\0'; -#if (__GNUC__) && !(__clang__) -#pragma GCC diagnostic pop -#endif - return ret; - } - // Overload + operator - friend String operator+(const String& lhs, const String& rhs); - friend String operator+(const String& lhs, const std::string& rhs); - friend String operator+(const std::string& lhs, const String& rhs); - friend String operator+(const String& lhs, const char* rhs); - friend String operator+(const char* lhs, const String& rhs); -}; - -/*! \brief Convert TVMFFIByteArray to std::string_view */ -TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { - return std::string_view(str.data, str.size); -} -/// \cond Doxygen_Suppress - -template <> -inline constexpr bool use_default_type_traits_v = false; - -// specialize to enable implicit conversion from TVMFFIByteArray* -template <> -struct TypeTraits : public TypeTraitsBase { - // bytes can be union type of small bytes and object, so keep it as any - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - - TVM_FFI_INLINE static void CopyToAnyView(const Bytes& src, TVMFFIAny* result) { - *result = src.data_.CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(Bytes src, TVMFFIAny* result) { - src.data_.MoveToAny(result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFISmallBytes || - src->type_index == TypeIndex::kTVMFFIBytes; - } - - TVM_FFI_INLINE static Bytes CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); - } - - TVM_FFI_INLINE static Bytes MoveFromAnyAfterCheck(TVMFFIAny* src) { - return Bytes(details::BytesBaseCell::MoveFromAny(src)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - return Bytes(*static_cast(src->v_ptr)); - } - if (src->type_index == TypeIndex::kTVMFFISmallBytes || - src->type_index == TypeIndex::kTVMFFIBytes) { - return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "bytes"; } -}; - -template <> -inline constexpr bool use_default_type_traits_v = false; - -// specialize to enable implicit conversion from const char* -template <> -struct TypeTraits : public TypeTraitsBase { - // string can be union type of small string and object, so keep it as any - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - - TVM_FFI_INLINE static void CopyToAnyView(const String& src, TVMFFIAny* result) { - *result = src.data_.CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(String src, TVMFFIAny* result) { - src.data_.MoveToAny(result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFISmallStr || - src->type_index == TypeIndex::kTVMFFIStr; - } - - TVM_FFI_INLINE static String CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return String(details::BytesBaseCell::CopyFromAnyView(src)); - } - - TVM_FFI_INLINE static String MoveFromAnyAfterCheck(TVMFFIAny* src) { - return String(details::BytesBaseCell::MoveFromAny(src)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIRawStr) { - return String(src->v_c_str); - } - if (src->type_index == TypeIndex::kTVMFFISmallStr || src->type_index == TypeIndex::kTVMFFIStr) { - return String(details::BytesBaseCell::CopyFromAnyView(src)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "str"; } -}; - -// const char*, requirement: not nullable, do not retain ownership -template -struct TypeTraits : public TypeTraitsBase { - // NOTE: only enable implicit conversion into AnyView - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const char src[N], TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src; - } - - TVM_FFI_INLINE static void MoveToAny(const char src[N], TVMFFIAny* result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(src), result); - } -}; - -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const char* src, TVMFFIAny* result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src; - } - - TVM_FFI_INLINE static void MoveToAny(const char* src, TVMFFIAny* result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(src), result); - } - // Do not allow const char* in a container, so we do not need CheckAnyStrict - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIRawStr) { - return static_cast(src->v_c_str); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "const char*"; } -}; - -// TVMFFIByteArray, requirement: not nullable, do not retain ownership -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIByteArrayPtr; - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(TVMFFIByteArray* src, TVMFFIAny* result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIByteArrayPtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static void MoveToAny(TVMFFIByteArray* src, TVMFFIAny* result) { - TypeTraits::MoveToAny(Bytes(*src), result); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - return static_cast(src->v_ptr); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIByteArrayPtr; } -}; - -template <> -inline constexpr bool use_default_type_traits_v = false; - -template <> -struct TypeTraits - : public FallbackOnlyTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const std::string& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src.c_str(); - } - - TVM_FFI_INLINE static void MoveToAny(std::string src, TVMFFIAny* result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(std::move(src)), result); - } - - TVM_FFI_INLINE static std::string TypeStr() { return "std::string"; } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(const char* src) { - return std::string(src); - } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(TVMFFIByteArray* src) { - return std::string(src->data, src->size); - } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(Bytes src) { - return src.operator std::string(); - } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(String src) { - return src.operator std::string(); - } -}; - -inline String operator+(const String& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const std::string& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const std::string& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const char* lhs, const String& rhs) { - size_t lhs_size = std::strlen(lhs); - size_t rhs_size = rhs.size(); - return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const char* rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = std::strlen(rhs); - return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); -} - -// Overload < operator -inline bool operator<(std::nullptr_t, const String& rhs) = delete; -inline bool operator<(const String& lhs, std::nullptr_t) = delete; - -inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -// Overload > operator -inline bool operator>(std::nullptr_t, const String& rhs) = delete; -inline bool operator>(const String& lhs, std::nullptr_t) = delete; - -inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -// Overload <= operator -inline bool operator<=(std::nullptr_t, const String& rhs) = delete; -inline bool operator<=(const String& lhs, std::nullptr_t) = delete; - -inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -// Overload >= operator -inline bool operator>=(std::nullptr_t, const String& rhs) = delete; -inline bool operator>=(const String& lhs, std::nullptr_t) = delete; - -inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } - -inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } - -// delete Overload == operator for nullptr -inline bool operator==(const String& lhs, std::nullptr_t) = delete; -inline bool operator==(std::nullptr_t, const String& rhs) = delete; - -inline bool operator==(const String& lhs, const std::string& rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const std::string& lhs, const String& rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const String& lhs, const String& rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -// Overload != operator -inline bool operator!=(const String& lhs, std::nullptr_t) = delete; -inline bool operator!=(std::nullptr_t, const String& rhs) = delete; - -inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline std::ostream& operator<<(std::ostream& out, const String& input) { - out.write(input.data(), input.size()); - return out; -} -/// \endcond -} // namespace ffi -} // namespace tvm - -/// \cond Doxygen_Suppress -namespace std { - -template <> -struct hash<::tvm::ffi::Bytes> { - std::size_t operator()(const ::tvm::ffi::Bytes& bytes) const { - return std::hash()(std::string_view(bytes.data(), bytes.size())); - } -}; - -template <> -struct hash<::tvm::ffi::String> { - std::size_t operator()(const ::tvm::ffi::String& str) const { - return std::hash()(std::string_view(str.data(), str.size())); - } -}; -} // namespace std -/// \endcond -#endif // TVM_FFI_STRING_H_ diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h deleted file mode 100644 index 0f1971945a4b..000000000000 --- a/ffi/include/tvm/ffi/type_traits.h +++ /dev/null @@ -1,781 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/object.h - * \brief A managed object in the TVM FFI. - */ -#ifndef TVM_FFI_TYPE_TRAITS_H_ -#define TVM_FFI_TYPE_TRAITS_H_ - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. - * - * The function specifications of TypeTraits - * - * - CopyToAnyView: Convert a value T to AnyView - * - MoveToAny: Move a value to Any - * - CheckAnyStrict: Check if a Any stores a result of CopyToAnyView of current T. - * - CopyFromAnyViewAfterCheck: Copy a value T from Any view after we pass CheckAnyStrict. - * - MoveFromAnyAfterCheck: Move a value T from Any storage after we pass CheckAnyStrict. - * - TryCastFromAnyView: Convert a AnyView to a T, we may apply type conversion. - * - GetMismatchTypeInfo: Get the type key of a type when TryCastFromAnyView fails. - * - TypeStr: Get the type key of a type - * - * It is possible that CheckAnyStrict is false but TryCastFromAnyView still works. - * - * For example, when Any x stores int, TypeTraits::CheckAnyStrict(x) will be false, - * but TypeTraits::TryCastFromAnyView(x) will return a corresponding float value - * via type conversion. - * - * CheckAnyStrict is mainly used in recursive container such as Array to - * decide if a new Array needed to be created via recursive conversion, - * or we can use the current container as is when converting to Array. - * - * A container array: Array satisfies the following invariant: - * - `all(TypeTraits::CheckAnyStrict(x) for x in the array)`. - */ -template -struct TypeTraits { - /*! \brief Whether the type is enabled in FFI. */ - static constexpr bool convert_enabled = false; - /*! \brief Whether the type can appear as a storage type in Container */ - static constexpr bool storage_enabled = false; -}; - -/*! - * \brief TypeTraits that removes const and reference keywords. - * \tparam T the original type - */ -template -using TypeTraitsNoCR = TypeTraits>>; - -template -inline constexpr bool use_default_type_traits_v = true; - -struct TypeTraitsBase { - static constexpr bool convert_enabled = true; - static constexpr bool storage_enabled = true; - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - // get mismatched type when result mismatches the trait. - // this function is called after TryCastFromAnyView fails - // to get more detailed type information in runtime - // especially when the error involves nested container type - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* source) { - return TypeIndexToTypeKey(source->type_index); - } -}; - -/*! - * \brief Trait that maps a type to its field static type index - * \tparam T the type - * \return the field static type index - */ -template -struct TypeToFieldStaticTypeIndex { - /*! \brief The field static type index of the type */ - static constexpr int32_t value = TypeIndex::kTVMFFIAny; -}; - -template -struct TypeToFieldStaticTypeIndex::convert_enabled>> { - static constexpr int32_t value = TypeTraits::field_static_type_index; -}; - -/*! - * \brief Trait that maps a type to its runtime type index - * \tparam T the type - * \return the runtime type index - */ -template -struct TypeToRuntimeTypeIndex { - /*! - * \brief Get the runtime type index of the type - * \return the runtime type index - */ - static int32_t v() { return TypeToFieldStaticTypeIndex::value; } -}; - -template -struct TypeToRuntimeTypeIndex>> { - static int32_t v() { return T::ContainerType::RuntimeTypeIndex(); } -}; - -// None -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone; - - TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFINone; - result->zero_padding = 0; - // invariant: the pointer field also equals nullptr - // this will simplify same_as comparisons and hash - result->v_int64 = 0; - } - - TVM_FFI_INLINE static void MoveToAny(std::nullptr_t, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFINone; - result->zero_padding = 0; - // invariant: the pointer field also equals nullptr - // this will simplify same_as comparisons and hash - result->v_int64 = 0; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFINone; - } - - TVM_FFI_INLINE static std::nullptr_t CopyFromAnyViewAfterCheck(const TVMFFIAny*) { - return nullptr; - } - - TVM_FFI_INLINE static std::nullptr_t MoveFromAnyAfterCheck(TVMFFIAny*) { return nullptr; } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return nullptr; - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFINone; } -}; - -/** - * \brief A type that forbids implicit conversion from int to bool - * - * This type is used to prevent implicit conversion from int to bool. - */ -class StrictBool { - public: - /*! - * \brief Constructor - * \param value The value of the strict bool. - */ - StrictBool(bool value) : value_(value) {} // NOLINT(*) - /*! - *\brief Convert the strict bool to bool. - * \return The value of the strict bool. - */ - operator bool() const { return value_; } - - private: - bool value_; -}; - -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; - - TVM_FFI_INLINE static void CopyToAnyView(const StrictBool& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIBool; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(StrictBool src, TVMFFIAny* result) { - CopyToAnyView(src, result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIBool; - } - - TVM_FFI_INLINE static StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static StrictBool MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIBool) { - return StrictBool(static_cast(src->v_int64)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } -}; - -// Bool type, allow implicit casting from int -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; - - TVM_FFI_INLINE static void CopyToAnyView(const bool& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIBool; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(bool src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIBool; - } - - TVM_FFI_INLINE static bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static bool MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return static_cast(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } -}; - -// Integer POD values -template -struct TypeTraits>> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; - - TVM_FFI_INLINE static void CopyToAnyView(const Int& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIInt; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(Int src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIInt; - } - - TVM_FFI_INLINE static Int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static Int MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return Int(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } -}; - -// Enum Integer POD values -template -struct TypeTraits && - std::is_integral_v>>> - : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; - - TVM_FFI_INLINE static void CopyToAnyView(const IntEnum& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIInt; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(IntEnum src, TVMFFIAny* result) { - CopyToAnyView(src, result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIInt; - } - - TVM_FFI_INLINE static IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static IntEnum MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return static_cast(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } -}; - -// Float POD values -template -struct TypeTraits>> - : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFloat; - - TVM_FFI_INLINE static void CopyToAnyView(const Float& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIFloat; - result->zero_padding = 0; - result->v_float64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(Float src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIFloat; - } - - TVM_FFI_INLINE static Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_float64); - } - - TVM_FFI_INLINE static Float MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIFloat) { - return Float(src->v_float64); - } else if (src->type_index == TypeIndex::kTVMFFIInt || - src->type_index == TypeIndex::kTVMFFIBool) { - return Float(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIFloat; } -}; - -// void* -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIOpaquePtr; - - TVM_FFI_INLINE static void CopyToAnyView(void* src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIOpaquePtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static void MoveToAny(void* src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIOpaquePtr; - } - - TVM_FFI_INLINE static void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return src->v_ptr; } - - TVM_FFI_INLINE static void* MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { - return static_cast(src->v_ptr); - } - if (src->type_index == TypeIndex::kTVMFFINone) { - return static_cast(nullptr); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIOpaquePtr; } -}; - -// Device -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDevice; - - TVM_FFI_INLINE static void CopyToAnyView(const DLDevice& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIDevice; - result->zero_padding = 0; - result->v_device = src; - } - - TVM_FFI_INLINE static void MoveToAny(DLDevice src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIDevice; - result->zero_padding = 0; - result->v_device = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIDevice; - } - - TVM_FFI_INLINE static DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return src->v_device; - } - - TVM_FFI_INLINE static DLDevice MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIDevice) { - return src->v_device; - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIDevice; } -}; - -// DLTensor*, requirement: not nullable, do not retain ownership -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; - - TVM_FFI_INLINE static void CopyToAnyView(DLTensor* src, TVMFFIAny* result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIDLTensorPtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; - } - - TVM_FFI_INLINE static DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_ptr); - } - - TVM_FFI_INLINE static void MoveToAny(DLTensor*, TVMFFIAny*) { - TVM_FFI_THROW(RuntimeError) - << "DLTensor* cannot be held in Any as it does not retain ownership, use Tensor instead"; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { - return static_cast(src->v_ptr); - } else if (src->type_index == TypeIndex::kTVMFFITensor) { - // Conversion from Tensor pointer to DLTensor - // based on the assumption that Tensor always follows the TVMFFIObject header - static_assert(sizeof(TVMFFIObject) == 24); - return reinterpret_cast(reinterpret_cast(src->v_obj) + - sizeof(TVMFFIObject)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "DLTensor*"; } -}; - -// Traits for ObjectRef, None to ObjectRef will always fail. -// use std::optional instead for nullable references. -template -struct ObjectRefTypeTraitsBase : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject; - using ContainerType = typename TObjRef::ContainerType; - - TVM_FFI_INLINE static void CopyToAnyView(const TObjRef& src, TVMFFIAny* result) { - if constexpr (TObjRef::_type_is_nullable) { - if (!src.defined()) { - TypeTraits::CopyToAnyView(nullptr, result); - return; - } - } - TVMFFIObject* obj_ptr = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static void MoveToAny(TObjRef src, TVMFFIAny* result) { - if constexpr (TObjRef::_type_is_nullable) { - if (!src.defined()) { - TypeTraits::CopyToAnyView(nullptr, result); - return; - } - } - TVMFFIObject* obj_ptr = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(src)); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) return true; - } - return (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && - details::IsObjectInstance(src->type_index)); - } - - TVM_FFI_INLINE static TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); - } - - TVM_FFI_INLINE static TObjRef MoveFromAnyAfterCheck(TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } - // move out the object pointer - ObjectPtr obj_ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); - // reset the src to nullptr - TypeTraits::MoveToAny(nullptr, src); - return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(obj_ptr)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } - if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - if (details::IsObjectInstance(src->type_index)) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); - } - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return ContainerType::_type_key; } -}; - -template -struct TypeTraits && - use_default_type_traits_v>> - : public ObjectRefTypeTraitsBase {}; - -/*! - * \brief Helper class that convert to T only via the FallbackTypes - * - * The conversion will go through the FallbackTypes in the order - * specified in the template parameter. - * \tparam T The type of the target value. - * \tparam FallbackTypes The type of the fallback value. - * \note TypeTraits must be derived from this class and define - * ConvertFallbackValue(FallbackType)->T for each FallbackType - */ -template -struct FallbackOnlyTraitsBase : public TypeTraitsBase { - // disable container for FallbackOnlyTraitsBase - /// \cond Doxygen_Suppress - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - return TryFallbackTypes(src); - } - - template - TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny* src) { - static_assert(!std::is_same_v, - "Using bool as FallbackType can cause bug because int will be detected as bool, " - "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { - return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryFallbackTypes(src); - } - return std::nullopt; - } - /// \endcond -}; - -/*! - * \brief Helper class to define ObjectRef that can be auto-converted from a - * fallback type, the Traits must be derived from it - * and define a static methods named ConvertFallbackValue for each - * FallbackType - * - * The conversion will go through the FallbackTypes in the order - * specified in the template parameter. - * \tparam ObjectRefType The type of the ObjectRef. - * \tparam FallbackTypes The type of the fallback value. - */ -template -struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase { - /// \cond Doxygen_Suppress - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (auto opt_obj = ObjectRefTypeTraitsBase::TryCastFromAnyView(src)) { - return *opt_obj; - } - // apply fallback types in TryCastFromAnyView - return TryFallbackTypes(src); - } - - template - TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny* src) { - static_assert(!std::is_same_v, - "Using bool as FallbackType can cause bug because int will be detected as bool, " - "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { - return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryFallbackTypes(src); - } - return std::nullopt; - } - /// \endcond -}; - -// Traits for weak pointer of object -// NOTE: we require the weak pointer cast from - -template -struct TypeTraits>> - : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(TObject* src, TVMFFIAny* result) { - TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static void MoveToAny(TObject* src, TVMFFIAny* result) { - TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - // needs to increase ref because original weak ptr do not own the code - details::ObjectUnsafe::IncRefObjectHandle(result->v_obj); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && - details::IsObjectInstance(src->type_index); - } - - TVM_FFI_INLINE static TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - if constexpr (!std::is_const_v) { - static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); - } - return details::ObjectUnsafe::RawObjectPtrFromUnowned(src->v_obj); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if constexpr (!std::is_const_v) { - static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); - } - if (CheckAnyStrict(src)) return CopyFromAnyViewAfterCheck(src); - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return TObject::_type_key; } -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const Optional& src, TVMFFIAny* result) { - if (src.has_value()) { - TypeTraits::CopyToAnyView(*src, result); - } else { - TypeTraits::CopyToAnyView(nullptr, result); - } - } - - TVM_FFI_INLINE static void MoveToAny(Optional src, TVMFFIAny* result) { - if (src.has_value()) { - TypeTraits::MoveToAny(*std::move(src), result); - } else { - TypeTraits::CopyToAnyView(nullptr, result); - } - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) return true; - return TypeTraits::CheckAnyStrict(src); - } - - TVM_FFI_INLINE static Optional CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return Optional(std::nullopt); - } - return TypeTraits::CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static Optional MoveFromAnyAfterCheck(TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return Optional(std::nullopt); - } - return TypeTraits::MoveFromAnyAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) return Optional(std::nullopt); - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { - return Optional(*std::move(opt)); - } else { - // important to be explicit here - // because nullopt can convert to std::optional(nullopt) which indicate success - // return std::optional>(std::nullopt) to indicate failure - return std::optional>(std::nullopt); - } - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - return TypeTraits::GetMismatchTypeInfo(src); - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "Optional<" + TypeTraits::TypeStr() + ">"; - } -}; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_TYPE_TRAITS_H_ diff --git a/ffi/licenses/LICENSE.dlpack.txt b/ffi/licenses/LICENSE.dlpack.txt deleted file mode 100644 index 20a9c8a7b4dc..000000000000 --- a/ffi/licenses/LICENSE.dlpack.txt +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2017 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/ffi/licenses/LICENSE.libbacktrace.txt b/ffi/licenses/LICENSE.libbacktrace.txt deleted file mode 100644 index e9e256244d69..000000000000 --- a/ffi/licenses/LICENSE.libbacktrace.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (C) 2012-2016 Free Software Foundation, Inc. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: - -# (1) Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. - -# (2) Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in -# the documentation and/or other materials provided with the -# distribution. - -# (3) The name of the author may not be used to -# endorse or promote products derived from this software without -# specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR -# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING -# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. diff --git a/ffi/licenses/LICENSE.pytorch.txt b/ffi/licenses/LICENSE.pytorch.txt deleted file mode 100644 index 966a609b61e5..000000000000 --- a/ffi/licenses/LICENSE.pytorch.txt +++ /dev/null @@ -1,84 +0,0 @@ -From PyTorch: - -Copyright (c) 2016- Facebook, Inc (Adam Paszke) -Copyright (c) 2014- Facebook, Inc (Soumith Chintala) -Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) -Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) -Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) -Copyright (c) 2011-2013 NYU (Clement Farabet) -Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) -Copyright (c) 2006 Idiap Research Institute (Samy Bengio) -Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - -From Caffe2: - -Copyright (c) 2016-present, Facebook Inc. All rights reserved. - -All contributions by Facebook: -Copyright (c) 2016 Facebook Inc. - -All contributions by Google: -Copyright (c) 2015 Google Inc. -All rights reserved. - -All contributions by Yangqing Jia: -Copyright (c) 2015 Yangqing Jia -All rights reserved. - -All contributions by Kakao Brain: -Copyright 2019-2020 Kakao Brain - -All contributions by Cruise LLC: -Copyright (c) 2022 Cruise LLC. -All rights reserved. - -All contributions by Tri Dao: -Copyright (c) 2024 Tri Dao. -All rights reserved. - -All contributions by Arm: -Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates - -All contributions from Caffe: -Copyright(c) 2013, 2014, 2015, the respective contributors -All rights reserved. - -All other contributions: -Copyright(c) 2015, 2016 the respective contributors -All rights reserved. - -Caffe2 uses a copyright model similar to Caffe: each contributor holds -copyright over their contributions to Caffe2. The project versioning records -all such contribution and copyright details. If a contributor wants to further -mark their specific copyright on a particular contribution, they should -indicate their copyright solely in the commit message of the change when it is -committed. - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America - and IDIAP Research Institute nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/ffi/licenses/NOTICE.pytorch.txt b/ffi/licenses/NOTICE.pytorch.txt deleted file mode 100644 index 6effb8b5d707..000000000000 --- a/ffi/licenses/NOTICE.pytorch.txt +++ /dev/null @@ -1,456 +0,0 @@ -======================================================================= -Software under third_party -======================================================================= -Software libraries under third_party are provided as github submodule -links, and their content is not part of the Caffe2 codebase. Their -licences can be found under the respective software repositories. - -======================================================================= -Earlier BSD License -======================================================================= -Early development of Caffe2 in 2015 and early 2016 is licensed under the -BSD license. The license is attached below: - -All contributions by Facebook: -Copyright (c) 2016 Facebook Inc. - -All contributions by Google: -Copyright (c) 2015 Google Inc. -All rights reserved. - -All contributions by Yangqing Jia: -Copyright (c) 2015 Yangqing Jia -All rights reserved. - -All contributions by Kakao Brain: -Copyright 2019-2020 Kakao Brain - -All other contributions: -Copyright(c) 2015, 2016 the respective contributors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -======================================================================= -Caffe's BSD License -======================================================================= -Some parts of the caffe2 code is derived from the original Caffe code, which is -created by Yangqing Jia and is now a BSD-licensed open-source project. The Caffe -license is as follows: - -COPYRIGHT - -All contributions by the University of California: -Copyright (c) 2014, The Regents of the University of California (Regents) -All rights reserved. - -All other contributions: -Copyright (c) 2014, the respective contributors -All rights reserved. - -Caffe uses a shared copyright model: each contributor holds copyright over -their contributions to Caffe. The project versioning records all such -contribution and copyright details. If a contributor wants to further mark -their specific copyright on a particular contribution, they should indicate -their copyright solely in the commit message of the change when it is -committed. - -LICENSE - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -CONTRIBUTION AGREEMENT - -By contributing to the BVLC/caffe repository through pull-request, comment, -or otherwise, the contributor releases their content to the -license and copyright terms herein. - -======================================================================= -Caffe2's Apache License -======================================================================= - -This repo contains Caffe2 code, which was previously licensed under -Apache License Version 2.0: - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -======================================================================= -Cephes's 3-Clause BSD License -======================================================================= - -Code derived from implementations in the Cephes Math Library should mention -its derivation and reference the following license: - - 3-Clause BSD License for the Cephes Math Library - Copyright (c) 2018, Steven Moshier - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - * Neither the name of the nor the - names of its contributors may be used to endorse or promote products - derived from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY - DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -======================================================================= -SciPy's 3-Clause BSD License -======================================================================= - -Code derived from implementations in SciPy should mention its derivation -and reference the following license: - - Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -======================================================================= -Boost's 1.0 Software License -======================================================================= - -Code derived from implementations in Boost 1.0 should mention its -derivation and reference the following license: - - Boost Software License - Version 1.0 - August 17th, 2003 - - Permission is hereby granted, free of charge, to any person or organization - obtaining a copy of the software and accompanying documentation covered by - this license (the "Software") to use, reproduce, display, distribute, - execute, and transmit the Software, and to prepare derivative works of the - Software, and to permit third-parties to whom the Software is furnished to - do so, all subject to the following: - - The copyright notices in the Software and this entire statement, including - the above license grant, this restriction and the following disclaimer, - must be included in all copies of the Software, in whole or in part, and - all derivative works of the Software, unless such copies or derivative - works are solely in the form of machine-executable object code generated by - a source language processor. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT - SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE - FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, - ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - DEALINGS IN THE SOFTWARE. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -======================================================================= -PILLOW-SIMD Software License -======================================================================= - -Code derived from implementations in PILLOW-SIMD should mention its derivation -and reference the following license: - - The Python Imaging Library (PIL) is - - Copyright © 1997-2011 by Secret Labs AB - Copyright © 1995-2011 by Fredrik Lundh - - Pillow is the friendly PIL fork. It is - - Copyright © 2010-2022 by Alex Clark and contributors - - Like PIL, Pillow is licensed under the open source HPND License: - - By obtaining, using, and/or copying this software and/or its associated - documentation, you agree that you have read, understood, and will comply - with the following terms and conditions: - - Permission to use, copy, modify, and distribute this software and its - associated documentation for any purpose and without fee is hereby granted, - provided that the above copyright notice appears in all copies, and that - both that copyright notice and this permission notice appear in supporting - documentation, and that the name of Secret Labs AB or the author not be - used in advertising or publicity pertaining to distribution of the software - without specific, written prior permission. - - SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS - SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. - IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, - INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM - LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE - OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR - PERFORMANCE OF THIS SOFTWARE. diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml deleted file mode 100644 index cc2df03f0a6b..000000000000 --- a/ffi/pyproject.toml +++ /dev/null @@ -1,159 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[project] -name = "apache-tvm-ffi" -version = "0.1.0a13" -description = "tvm ffi" - -authors = [{ name = "TVM FFI team" }] -readme = "README.md" -license = { text = "Apache 2.0" } -classifiers = [ - "License :: OSI Approved :: Apache Software License", - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", -] -keywords = ["machine learning", "inference"] -requires-python = ">=3.9" - -dependencies = [] - - -[project.urls] -Homepage = "https://github.com/apache/tvm/ffi" -GitHub = "https://github.com/apache/tvm/ffi" - -[project.optional-dependencies] -# setup tools is needed by torch jit for best perf -torch = ["torch", "setuptools", "ninja"] -cpp = ["ninja"] -test = ["pytest", "numpy", "torch", "ninja"] - -[project.scripts] -tvm-ffi-config = "tvm_ffi.config:__main__" - -[build-system] -requires = ["scikit-build-core>=0.10.0", "cython"] -build-backend = "scikit_build_core.build" - -[tool.scikit-build] -wheel.py-api = "cp312" -minimum-version = "build-system.requires" - -# Build configuration -build-dir = "build" -build.verbose = true - -# CMake configuration -cmake.version = "CMakeLists.txt" -cmake.build-type = "Release" -cmake.args = [ - "-DTVM_FFI_ATTACH_DEBUG_SYMBOLS=ON", - "-DTVM_FFI_BUILD_TESTS=OFF", - "-DTVM_FFI_BUILD_PYTHON_MODULE=ON" -] - -# Logging -logging.level = "INFO" - -# Wheel configuration -wheel.packages = ["python/tvm_ffi"] -wheel.install-dir = "tvm_ffi" - -# Source distribution configuration -sdist.include = [ - # Build files - "/CMakeLists.txt", - "/pyproject.toml", - "/cmake/**/*", - # Source code - "/src/**/*.cc", - "/include/**/*", - - # python and cython - "/python/tvm_ffi/**/*.py", - "/python/tvm_ffi/**/*.pyx", - "/python/tvm_ffi/**/*.pyi", - - # Third party files - "/3rdparty/libbacktrace/**/*", - "/3rdparty/dlpack/include/*/*", - - # Documentation and metadata - "/docs/**/*", - "/LICENSE", - "/README.md", - "/NOTICE", - - # Tests - "/tests/**/*", -] - -sdist.exclude = ["**/.git", "**/.github", "**/__pycache__", "**/*.pyc", "build", "dist"] - -[tool.pytest.ini_options] -testpaths = ["tests"] - -[tool.black] -exclude = "3rdparty/*" -line-length = 100 -skip-magic-trailing-comma = true - -[tool.isort] -profile = "black" -src_paths = ["python", "tests"] -extend_skip = ["3rdparty"] -line_length = 100 -skip_gitignore = true - -[tool.cibuildwheel] -build-verbosity = 1 - -# only build up to cp312, cp312 -# will be abi3 and can be used in future versions -build = [ - "cp39-*", - "cp310-*", - "cp311-*", - "cp312-*", -] -skip = [ - "*musllinux*" -] -# we only need to test on cp312 -test-skip = [ - "cp39-*", - "cp310-*", - "cp311-*", -] -# focus on testing abi3 wheel -build-frontend = "build[uv]" -test-command = "pytest {package}/tests/python -vvs" -test-extras = ["test"] - -[tool.cibuildwheel.linux] -archs = ["x86_64", "aarch64"] - -[tool.cibuildwheel.macos] -archs = ["x86_64", "arm64"] -environment = { MACOSX_DEPLOYMENT_TARGET = "10.14" } - -[tool.cibuildwheel.windows] -archs = ["AMD64"] diff --git a/ffi/python/tvm_ffi/.gitignore b/ffi/python/tvm_ffi/.gitignore deleted file mode 100644 index eeb15feab328..000000000000 --- a/ffi/python/tvm_ffi/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -core.cpp -core.cpython* diff --git a/ffi/python/tvm_ffi/__init__.py b/ffi/python/tvm_ffi/__init__.py deleted file mode 100644 index c23e8b59fee7..000000000000 --- a/ffi/python/tvm_ffi/__init__.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM FFI Python package.""" -# base always go first to load the libtvm_ffi -from . import base -from . import libinfo - -# package init part -from .registry import ( - register_object, - register_global_func, - get_global_func, - remove_global_func, - init_ffi_api, -) -from ._dtype import dtype -from .core import Object, ObjectConvertible, Function -from ._convert import convert -from .error import register_error -from ._tensor import Device, device, DLDeviceType -from ._tensor import from_dlpack, Tensor, Shape -from .container import Array, Map -from .module import Module, system_lib, load_module -from . import serialization -from . import access_path -from . import testing - -# optional module to speedup dlpack conversion -from . import _optional_torch_c_dlpack - -__all__ = [ - "dtype", - "Device", - "Object", - "register_object", - "register_global_func", - "get_global_func", - "remove_global_func", - "init_ffi_api", - "Object", - "ObjectConvertible", - "Function", - "convert", - "register_error", - "Device", - "device", - "DLDeviceType", - "from_dlpack", - "Tensor", - "Shape", - "Array", - "Map", - "testing", - "access_path", - "serialization", - "Module", - "system_lib", - "load_module", -] diff --git a/ffi/python/tvm_ffi/_convert.py b/ffi/python/tvm_ffi/_convert.py deleted file mode 100644 index a0b6c1b117e5..000000000000 --- a/ffi/python/tvm_ffi/_convert.py +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Conversion utilities to bring python objects into ffi values.""" -from numbers import Number -from typing import Any -from . import core -from . import container - - -def convert(value: Any) -> Any: - """Convert a python object to ffi values. - - Parameters - ---------- - value : Any - The python object to be converted. - - Returns - ------- - ffi_obj : Any - The converted TVM FFI object. - - Note - ---- - Function arguments to ffi function calls are - automatically converted. So this function is mainly - only used in internal or testing scenarios. - """ - if isinstance(value, (core.Object, core.PyNativeObject, bool, Number)): - return value - elif isinstance(value, (tuple, list)): - return container.Array(value) - elif isinstance(value, dict): - return container.Map(value) - elif isinstance(value, str): - return core.String(value) - elif isinstance(value, (bytes, bytearray)): - return core.Bytes(value) - elif isinstance(value, core.ObjectConvertible): - return value.asobject() - elif callable(value): - return core._convert_to_ffi_func(value) - elif value is None: - return None - elif hasattr(value, "__dlpack__"): - return core.from_dlpack(value) - elif isinstance(value, Exception): - return core._convert_to_ffi_error(value) - else: - # in this case, it is an opaque python object - return core._convert_to_opaque_object(value) diff --git a/ffi/python/tvm_ffi/_dtype.py b/ffi/python/tvm_ffi/_dtype.py deleted file mode 100644 index 30409e41d1cf..000000000000 --- a/ffi/python/tvm_ffi/_dtype.py +++ /dev/null @@ -1,141 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""dtype class.""" -# pylint: disable=invalid-name -from enum import IntEnum - -from . import core - - -class DataTypeCode(IntEnum): - """DLDataTypeCode code in DLTensor.""" - - INT = 0 - UINT = 1 - FLOAT = 2 - HANDLE = 3 - BFLOAT = 4 - Float8E3M4 = 7 - Float8E4M3 = 8 - Float8E4M3B11FNUZ = 9 - Float8E4M3FN = 10 - Float8E4M3FNUZ = 11 - Float8E5M2 = 12 - Float8E5M2FNUZ = 13 - Float8E8M0FNU = 14 - Float6E2M3FN = 15 - Float6E3M2FN = 16 - Float4E2M1FN = 17 - - -class dtype(str): - """TVM FFI dtype class. - - Parameters - ---------- - dtype_str : str - - Note - ---- - This class subclasses str so it can be directly passed - into other array api's dtype arguments. - """ - - __slots__ = ["__tvm_ffi_dtype__"] - - _NUMPY_DTYPE_TO_STR = {} - - def __new__(cls, content): - content = str(content) - val = str.__new__(cls, content) - val.__tvm_ffi_dtype__ = core.DataType(content) - return val - - def __repr__(self): - return f"dtype('{self}')" - - def with_lanes(self, lanes): - """ - Create a new dtype with the given number of lanes. - - Parameters - ---------- - lanes : int - The number of lanes. - - Returns - ------- - dtype - The new dtype with the given number of lanes. - """ - cdtype = core._create_dtype_from_tuple( - core.DataType, self.__tvm_ffi_dtype__.type_code, self.__tvm_ffi_dtype__.bits, lanes - ) - val = str.__new__(dtype, str(cdtype)) - val.__tvm_ffi_dtype__ = cdtype - return val - - @property - def itemsize(self): - return self.__tvm_ffi_dtype__.itemsize - - @property - def type_code(self): - return self.__tvm_ffi_dtype__.type_code - - @property - def bits(self): - return self.__tvm_ffi_dtype__.bits - - @property - def lanes(self): - return self.__tvm_ffi_dtype__.lanes - - -try: - # this helps to make numpy as optional - # although almost in all cases we want numpy - import numpy as np - - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" - if hasattr(np, "float_"): - dtype._NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64" -except ImportError: - pass - -try: - import ml_dtypes - - dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2" - dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" -except ImportError: - pass - -core._set_class_dtype(dtype) diff --git a/ffi/python/tvm_ffi/_ffi_api.py b/ffi/python/tvm_ffi/_ffi_api.py deleted file mode 100644 index 1c2326c0fefd..000000000000 --- a/ffi/python/tvm_ffi/_ffi_api.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""FFI API.""" -from . import registry - -registry.init_ffi_api("ffi", __name__) diff --git a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py deleted file mode 100644 index f44855247abe..000000000000 --- a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py +++ /dev/null @@ -1,404 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Optional module to support faster DLPack conversion. - -This is an optional module to support faster DLPack conversion for torch. -Some of the changes are merged but not yet released, so it is used -as a stop gap to support faster DLPack conversion. - -This file contains source code from PyTorch: -License: licenses/LICENSE.pytorch.txt - -This module only serves as temp measure and will -likely be phased away and deleted after changes landed and released in pytorch. - -This module will load slowly at first time due to JITing, -subsequent calls will be much faster. -""" -import warnings -from . import libinfo - - -def load_torch_c_dlpack_extension(): - """Load the torch c dlpack extension.""" - cpp_source = """ -#include -#include -#include -#include - -using namespace std; -namespace at { -namespace { - -DLDataType getDLDataTypeForDLPackv1(const Tensor& t) { - DLDataType dtype; - dtype.lanes = 1; - dtype.bits = t.element_size() * 8; - switch (t.scalar_type()) { - case ScalarType::UInt1: - case ScalarType::UInt2: - case ScalarType::UInt3: - case ScalarType::UInt4: - case ScalarType::UInt5: - case ScalarType::UInt6: - case ScalarType::UInt7: - case ScalarType::Byte: - case ScalarType::UInt16: - case ScalarType::UInt32: - case ScalarType::UInt64: - dtype.code = DLDataTypeCode::kDLUInt; - break; - case ScalarType::Int1: - case ScalarType::Int2: - case ScalarType::Int3: - case ScalarType::Int4: - case ScalarType::Int5: - case ScalarType::Int6: - case ScalarType::Int7: - case ScalarType::Char: - dtype.code = DLDataTypeCode::kDLInt; - break; - case ScalarType::Double: - dtype.code = DLDataTypeCode::kDLFloat; - break; - case ScalarType::Float: - dtype.code = DLDataTypeCode::kDLFloat; - break; - case ScalarType::Int: - dtype.code = DLDataTypeCode::kDLInt; - break; - case ScalarType::Long: - dtype.code = DLDataTypeCode::kDLInt; - break; - case ScalarType::Short: - dtype.code = DLDataTypeCode::kDLInt; - break; - case ScalarType::Half: - dtype.code = DLDataTypeCode::kDLFloat; - break; - case ScalarType::Bool: - dtype.code = DLDataTypeCode::kDLBool; - break; - case ScalarType::ComplexHalf: - case ScalarType::ComplexFloat: - case ScalarType::ComplexDouble: - dtype.code = DLDataTypeCode::kDLComplex; - break; - case ScalarType::BFloat16: - dtype.code = DLDataTypeCode::kDLBfloat; - break; - case ScalarType::Float8_e5m2: - dtype.code = DLDataTypeCode::kDLFloat8_e5m2; - break; - case ScalarType::Float8_e5m2fnuz: - dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz; - break; - case ScalarType::Float8_e4m3fn: - dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn; - break; - case ScalarType::Float8_e4m3fnuz: - dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz; - break; - case ScalarType::Float8_e8m0fnu: - dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu; - break; -#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8 - case ScalarType::Float4_e2m1fn_x2: - dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn; - break; -#endif - default: - TORCH_CHECK(false, "Unsupported scalar type: "); - } - return dtype; -} - -DLDevice torchDeviceToDLDeviceForDLPackv1(at::Device device) { - DLDevice ctx; - - ctx.device_id = (device.is_cuda() || device.is_privateuseone()) - ? static_cast(static_cast(device.index())) - : 0; - - switch (device.type()) { - case DeviceType::CPU: - ctx.device_type = DLDeviceType::kDLCPU; - break; - case DeviceType::CUDA: -#ifdef USE_ROCM - ctx.device_type = DLDeviceType::kDLROCM; -#else - ctx.device_type = DLDeviceType::kDLCUDA; -#endif - break; - case DeviceType::OPENCL: - ctx.device_type = DLDeviceType::kDLOpenCL; - break; - case DeviceType::HIP: - ctx.device_type = DLDeviceType::kDLROCM; - break; - case DeviceType::XPU: - ctx.device_type = DLDeviceType::kDLOneAPI; - ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device); - break; - case DeviceType::MAIA: - ctx.device_type = DLDeviceType::kDLMAIA; - break; - case DeviceType::PrivateUse1: - ctx.device_type = DLDeviceType::kDLExtDev; - break; - case DeviceType::MPS: - ctx.device_type = DLDeviceType::kDLMetal; - break; - default: - TORCH_CHECK(false, "Cannot pack tensors on " + device.str()); - } - - return ctx; -} - -template -struct ATenDLMTensor { - Tensor handle; - T tensor{}; -}; - -template -void deleter(T* arg) { - delete static_cast*>(arg->manager_ctx); -} - -// Adds version information for DLManagedTensorVersioned. -// This is a no-op for the other types. -template -void fillVersion(T* tensor) {} - -template <> -void fillVersion( - DLManagedTensorVersioned* tensor) { - tensor->flags = 0; - tensor->version.major = DLPACK_MAJOR_VERSION; - tensor->version.minor = DLPACK_MINOR_VERSION; -} - -// This function returns a shared_ptr to memory managed DLpack tensor -// constructed out of ATen tensor -template -T* toDLPackImpl(const Tensor& src) { - auto view = src; - - bool need_normalize_strides = false; - int64_t expected_stride = 1; - for (int i = src.dim() - 1; i >= 0; i--) { - // detect if we do not meet continuous pattern - // and the size is 1, so there is opportunity to normalize - if (src.stride(i) != expected_stride && src.size(i) == 1) { - need_normalize_strides = true; - break; - } - expected_stride *= src.size(i); - } - - // less common case, try normalizing the strides - if (need_normalize_strides) { - // create a new tensor with possibly normalized strides - // gh-83069 - auto shape = src.sizes(); - auto strides = src.strides().vec(); - for (int i = 0; i < src.dim(); i++) { - if (shape[i] < 2) { - strides[i] = 1; - } - } - view = src.as_strided(shape, strides, src.storage_offset()); - } - - ATenDLMTensor* atDLMTensor(new ATenDLMTensor); - atDLMTensor->handle = view; - atDLMTensor->tensor.manager_ctx = atDLMTensor; - atDLMTensor->tensor.deleter = &deleter; - atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); - atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDeviceForDLPackv1(src.device()); - atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); - atDLMTensor->tensor.dl_tensor.dtype = getDLDataTypeForDLPackv1(src); - atDLMTensor->tensor.dl_tensor.shape = const_cast(view.sizes().data()); - atDLMTensor->tensor.dl_tensor.strides = const_cast(view.strides().data()); - atDLMTensor->tensor.dl_tensor.byte_offset = 0; - fillVersion(&atDLMTensor->tensor); - return &(atDLMTensor->tensor); -} - -static Device getATenDeviceForDLPackv1(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) { - switch (type) { - case DLDeviceType::kDLCPU: - return at::Device(DeviceType::CPU); -#ifndef USE_ROCM - // if we are compiled under HIP, we cannot do cuda - case DLDeviceType::kDLCUDA: - return at::Device(DeviceType::CUDA, index); -#endif - case DLDeviceType::kDLOpenCL: - return at::Device(DeviceType::OPENCL, index); - case DLDeviceType::kDLROCM: -#ifdef USE_ROCM - // this looks funny, we need to return CUDA here to masquerade - return at::Device(DeviceType::CUDA, index); -#else - return at::Device(DeviceType::HIP, index); -#endif - case DLDeviceType::kDLOneAPI: - TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data."); - return at::detail::getXPUHooks().getDeviceFromPtr(data); - case DLDeviceType::kDLMAIA: - return at::Device(DeviceType::MAIA, index); - case DLDeviceType::kDLExtDev: - return at::Device(DeviceType::PrivateUse1, index); - case DLDeviceType::kDLMetal: - return at::Device(DeviceType::MPS, index); - default: - TORCH_CHECK( - false, "Unsupported device_type: ", std::to_string(type)); - } -} - - -// This function constructs a Tensor from a memory managed DLPack which -// may be represented as either: DLManagedTensor and DLManagedTensorVersioned. -template -at::Tensor fromDLPackImpl(T* src, std::function deleter) { - if (!deleter) { - deleter = [src](void* self [[maybe_unused]]) { - if (src->deleter) { - src->deleter(src); - } - }; - } - - DLTensor& dl_tensor = src->dl_tensor; - Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data); - ScalarType stype = toScalarType(dl_tensor.dtype); - - if (!dl_tensor.strides) { - return at::from_blob( - dl_tensor.data, - IntArrayRef(dl_tensor.shape, dl_tensor.ndim), - std::move(deleter), - at::device(device).dtype(stype), - {device}); - } - return at::from_blob( - dl_tensor.data, - IntArrayRef(dl_tensor.shape, dl_tensor.ndim), - IntArrayRef(dl_tensor.strides, dl_tensor.ndim), - deleter, - at::device(device).dtype(stype), - {device}); -} - -} // namespace -} // namespace at - -int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out, void** env_stream) { - try { - py::handle handle(static_cast(py_obj)); - at::Tensor tensor = handle.cast(); - if (env_stream != nullptr && tensor.is_cuda()) { - *env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream(); - } - *out = at::toDLPackImpl(tensor); - return 0; - } catch (const std::exception& e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - return -1; - } -} - -int TorchDLPackToPyObject(DLManagedTensorVersioned* src, void** py_obj_out) { - try { - at::Tensor tensor = at::fromDLPackImpl(src, nullptr); - *py_obj_out = THPVariable_Wrap(tensor); - return 0; - } catch (const std::exception& e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - return -1; - } -} - -int TorchDLPackTensorAllocator( - DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, - void (*SetError)(void* error_ctx, const char* kind, const char* message) -) { - try { - at::IntArrayRef shape(prototype->shape, prototype->shape + prototype->ndim); - at::TensorOptions options = at::TensorOptions() - .dtype(at::toScalarType(prototype->dtype)) - .device(at::getATenDeviceForDLPackv1(prototype->device.device_type, prototype->device.device_id)); - at::Tensor tensor = at::empty(shape, options); - *out = at::toDLPackImpl(tensor); - return 0; - } catch (const std::exception& e) { - SetError(error_ctx, "TorchDLPackTensorAllocator", e.what()); - return -1; - } -} - -int64_t TorchDLPackFromPyObjectPtr() { - return reinterpret_cast(TorchDLPackFromPyObject); -} - -int64_t TorchDLPackToPyObjectPtr() { - return reinterpret_cast(TorchDLPackToPyObject); -} - -int64_t TorchDLPackTensorAllocatorPtr() { - return reinterpret_cast(TorchDLPackTensorAllocator); -} - """ - try: - # optionally import torch - import torch - from torch.utils import cpp_extension - - mod = cpp_extension.load_inline( - name="to_dlpack", - cpp_sources=cpp_source, - functions=[ - "TorchDLPackFromPyObjectPtr", - "TorchDLPackToPyObjectPtr", - "TorchDLPackTensorAllocatorPtr", - ], - extra_cflags=["-O3"], - extra_include_paths=libinfo.include_paths() + cpp_extension.include_paths("cuda"), - ) - # set the dlpack related flags - torch.Tensor.__c_dlpack_from_pyobject__ = mod.TorchDLPackFromPyObjectPtr() - torch.Tensor.__c_dlpack_to_pyobject__ = mod.TorchDLPackToPyObjectPtr() - torch.Tensor.__c_dlpack_tensor_allocator__ = mod.TorchDLPackTensorAllocatorPtr() - return mod - except ImportError: - pass - except Exception as e: - warnings.warn( - f"Failed to load torch c dlpack extension: {e}," - "EnvTensorAllocator will not be enabled." - ) - return None - - -# keep alive -_mod = load_torch_c_dlpack_extension() diff --git a/ffi/python/tvm_ffi/_tensor.py b/ffi/python/tvm_ffi/_tensor.py deleted file mode 100644 index c0c9a20731f4..000000000000 --- a/ffi/python/tvm_ffi/_tensor.py +++ /dev/null @@ -1,88 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Tensor related objects and functions.""" -# we name it as _tensor.py to avoid potential future case -# if we also want to expose a tensor function in the root namespace - -from numbers import Integral -from . import core -from .core import Device, DLDeviceType, Tensor, from_dlpack -from . import registry -from . import _ffi_api - - -@registry.register_object("ffi.Shape") -class Shape(tuple, core.PyNativeObject): - """Shape tuple that represents `ffi::Shape` returned by a ffi call. - - Note - ---- - This class subclasses `tuple` so it can be used in most places where - tuple is used in python array apis. - """ - - def __new__(cls, content): - if any(not isinstance(x, Integral) for x in content): - raise ValueError("Shape must be a tuple of integers") - val = tuple.__new__(cls, content) - val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content) - return val - - # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj): - """Construct from a given tvm object.""" - content = core._shape_obj_get_py_tuple(obj) - val = tuple.__new__(cls, content) - val.__tvm_ffi_object__ = obj - return val - - -def device(device_type, index=None): - """Construct a TVM FFI device with given device type and index - - Parameters - ---------- - device_type: str or int - The device type or name. - - index: int, optional - The device index. - - Returns - ------- - device: tvm_ffi.Device - - Examples - -------- - Device can be used to create reflection of device by - string representation of the device type. - - .. code-block:: python - - assert tvm_ffi.device("cuda:0") == tvm_ffi.device("cuda", 0) - assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0) - """ - return core._CLASS_DEVICE(device_type, index) - - -__all__ = [ - "from_dlpack", - "Tensor", - "device", - "Device", - "DLDeviceType", -] diff --git a/ffi/python/tvm_ffi/access_path.py b/ffi/python/tvm_ffi/access_path.py deleted file mode 100644 index fb8ab1b2edea..000000000000 --- a/ffi/python/tvm_ffi/access_path.py +++ /dev/null @@ -1,181 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Access path classes.""" - -from enum import IntEnum -from typing import List, Any -from . import core -from .registry import register_object - - -class AccessKind(IntEnum): - ATTR = 0 - ARRAY_ITEM = 1 - MAP_ITEM = 2 - ATTR_MISSING = 3 - ARRAY_ITEM_MISSING = 4 - MAP_ITEM_MISSING = 5 - - -@register_object("ffi.reflection.AccessStep") -class AccessStep(core.Object): - """Access step container""" - - -@register_object("ffi.reflection.AccessPath") -class AccessPath(core.Object): - """Access path container""" - - def __init__(self) -> None: - super().__init__() - raise ValueError( - "AccessPath can't be initialized directly. " - "Use AccessPath.root() to create a path to the root object" - ) - - @staticmethod - def root() -> "AccessPath": - """Create a root access path""" - return AccessPath._root() - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, AccessPath): - return False - return self._path_equal(other) - - def __ne__(self, other: Any) -> bool: - if not isinstance(other, AccessPath): - return True - return not self._path_equal(other) - - def is_prefix_of(self, other: "AccessPath") -> bool: - """Check if this access path is a prefix of another access path - - Parameters - ---------- - other : AccessPath - The access path to check if it is a prefix of this access path - - Returns - ------- - bool - True if this access path is a prefix of the other access path, False otherwise - """ - return self._is_prefix_of(other) - - def attr(self, attr_key: str) -> "AccessPath": - """Create an access path to the attribute of the current object - - Parameters - ---------- - attr_key : str - The key of the attribute to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._attr(attr_key) - - def attr_missing(self, attr_key: str) -> "AccessPath": - """Create an access path that indicate an attribute is missing - - Parameters - ---------- - attr_key : str - The key of the attribute to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._attr_missing(attr_key) - - def array_item(self, index: int) -> "AccessPath": - """Create an access path to the item of the current array - - Parameters - ---------- - index : int - The index of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._array_item(index) - - def array_item_missing(self, index: int) -> "AccessPath": - """Create an access path that indicate an array item is missing - - Parameters - ---------- - index : int - The index of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._array_item_missing(index) - - def map_item(self, key: Any) -> "AccessPath": - """Create an access path to the item of the current map - - Parameters - ---------- - key : Any - The key of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._map_item(key) - - def map_item_missing(self, key: Any) -> "AccessPath": - """Create an access path that indicate a map item is missing - - Parameters - ---------- - key : Any - The key of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._map_item_missing(key) - - def to_steps(self) -> List["AccessStep"]: - """Convert the access path to a list of access steps - - Returns - ------- - List[AccessStep] - The list of access steps - """ - return self._to_steps() - - __hash__ = core.Object.__hash__ diff --git a/ffi/python/tvm_ffi/base.py b/ffi/python/tvm_ffi/base.py deleted file mode 100644 index 2fcd70b54183..000000000000 --- a/ffi/python/tvm_ffi/base.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# coding: utf-8 -"""Base library for TVM FFI.""" -import ctypes -import os -import sys -import subprocess -import logging -from . import libinfo - -logger = logging.getLogger(__name__) - -# ---------------------------- -# Python3 version. -# ---------------------------- -if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 9): - PY3STATEMENT = "The minimal Python requirement is Python 3.9" - raise Exception(PY3STATEMENT) - -# ---------------------------- -# library loading -# ---------------------------- - - -def _load_lib(): - """Load libary by searching possible path.""" - lib_path = libinfo.find_libtvm_ffi() - # The dll search path need to be added explicitly in windows - if sys.platform.startswith("win32"): - for path in libinfo.get_dll_directories(): - os.add_dll_directory(path) - - lib = ctypes.CDLL(lib_path, ctypes.RTLD_GLOBAL) - return lib - - -# library instance -_LIB = _load_lib() diff --git a/ffi/python/tvm_ffi/config.py b/ffi/python/tvm_ffi/config.py deleted file mode 100644 index b81ecdec3dc2..000000000000 --- a/ffi/python/tvm_ffi/config.py +++ /dev/null @@ -1,92 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Config utilities for finding paths to lib and headers""" - -import argparse -import sys -import os -from . import libinfo - - -def find_windows_implib(): - libdir = os.path.dirname(libinfo.find_libtvm_ffi()) - implib = os.path.join(libdir, "tvm_ffi.lib") - if not os.path.isfile(implib): - raise RuntimeError(f"Cannot find imp lib {implib}") - return implib - - -def __main__(): - """Main function""" - parser = argparse.ArgumentParser( - description="Get various configuration information needed to compile with tvm-ffi", - ) - - parser.add_argument("--includedir", action="store_true", help="Print include directory") - parser.add_argument( - "--dlpack-includedir", action="store_true", help="Print dlpack include directory" - ) - parser.add_argument("--cmakedir", action="store_true", help="Print library directory") - parser.add_argument("--sourcedir", action="store_true", help="Print source directory") - parser.add_argument("--libfiles", action="store_true", help="Fully qualified library filenames") - parser.add_argument("--libdir", action="store_true", help="Print library directory") - parser.add_argument("--libs", action="store_true", help="Libraries to be linked") - parser.add_argument("--cython-lib-path", action="store_true", help="Print cython path") - parser.add_argument("--cxxflags", action="store_true", help="Print cxx flags") - parser.add_argument("--ldflags", action="store_true", help="Print ld flags") - - args = parser.parse_args() - - # print help when no arguments are provided - if len(sys.argv) == 1: - parser.print_help() - return - - if args.includedir: - print(libinfo.find_include_path()) - if args.dlpack_includedir: - print(libinfo.find_dlpack_include_path()) - if args.cmakedir: - print(libinfo.find_cmake_path()) - if args.libdir: - print(os.path.dirname(libinfo.find_libtvm_ffi())) - if args.libfiles: - if sys.platform.startswith("win32"): - print(find_windows_implib()) - else: - print(libinfo.find_libtvm_ffi()) - if args.sourcedir: - print(libinfo.find_source_path()) - if args.cython_lib_path: - print(libinfo.find_cython_lib()) - if args.cxxflags: - include_dir = libinfo.find_include_path() - dlpack_include_dir = libinfo.find_dlpack_include_path() - print(f"-I{include_dir} -I{dlpack_include_dir} -std=c++17") - if args.libs: - if sys.platform.startswith("win32"): - print(find_windows_implib()) - else: - print("-ltvm_ffi") - - if args.ldflags: - if not sys.platform.startswith("win32"): - print(f"-L{os.path.dirname(libinfo.find_libtvm_ffi())}") - - -if __name__ == "__main__": - __main__() diff --git a/ffi/python/tvm_ffi/container.py b/ffi/python/tvm_ffi/container.py deleted file mode 100644 index fedc0a281ba8..000000000000 --- a/ffi/python/tvm_ffi/container.py +++ /dev/null @@ -1,252 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Container classes.""" -import collections.abc - -from typing import Any, Mapping, Sequence -from . import core -from . import _ffi_api -from .registry import register_object - -__all__ = ["Array", "Map"] - - -def getitem_helper(obj, elem_getter, length, idx): - """Helper function to implement a pythonic getitem function. - - Parameters - ---------- - obj: object - The original object - - elem_getter : function - A simple function that takes index and return a single element. - - length : int - The size of the array - - idx : int or slice - The argument passed to getitem - - Returns - ------- - result : object - The result of getitem - """ - if isinstance(idx, slice): - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else length - step = idx.step if idx.step is not None else 1 - if start < 0: - start += length - if stop < 0: - stop += length - return [elem_getter(obj, i) for i in range(start, stop, step)] - - if idx < -length or idx >= length: - raise IndexError(f"Index out of range. size: {length}, got index {idx}") - if idx < 0: - idx += length - return elem_getter(obj, idx) - - -@register_object("ffi.Array") -class Array(core.Object, collections.abc.Sequence): - """Array container that represents a sequence of values in ffi. - - {py:func}`tvm_ffi.convert` will map python list/tuple to this class. - - Parameters - ---------- - input_list : Sequence[Any] - The list of values to be stored in the array. - - See Also - -------- - {py:func}`tvm_ffi.convert` - - Examples - -------- - .. code-block:: python - - import tvm_ffi - - a = tvm_ffi.convert([1, 2, 3]) - assert isinstance(a, tvm_ffi.Array) - assert len(a) == 3 - """ - - def __init__(self, input_list: Sequence[Any]): - self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) - - def __getitem__(self, idx): - return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx) - - def __len__(self): - return _ffi_api.ArraySize(self) - - def __repr__(self): - # exception safety handling for chandle=None - if self.__chandle__() == 0: - return type(self).__name__ + "(chandle=None)" - return "[" + ", ".join([x.__repr__() for x in self]) + "]" - - -class KeysView(collections.abc.KeysView): - """Helper class to return keys view""" - - def __init__(self, backend_map): - self._backend_map = backend_map - - def __len__(self): - return len(self._backend_map) - - def __iter__(self): - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self._backend_map) - while True: - k = functor(0) - yield k - if not functor(2): - break - - def __contains__(self, k): - return self._backend_map.__contains__(k) - - -class ValuesView(collections.abc.ValuesView): - """Helper class to return values view""" - - def __init__(self, backend_map): - self._backend_map = backend_map - - def __len__(self): - return len(self._backend_map) - - def __iter__(self): - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self._backend_map) - while True: - v = functor(1) - yield v - if not functor(2): - break - - -class ItemsView(collections.abc.ItemsView): - """Helper class to return items view""" - - def __init__(self, backend_map): - self.backend_map = backend_map - - def __len__(self): - return len(self.backend_map) - - def __iter__(self): - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self.backend_map) - while True: - k = functor(0) - v = functor(1) - yield (k, v) - if not functor(2): - break - - -@register_object("ffi.Map") -class Map(core.Object, collections.abc.Mapping): - """Map container. - - {py:func}`tvm_ffi.convert` will map python dict to this class. - - Parameters - ---------- - input_dict : Mapping[Any, Any] - The dictionary of values to be stored in the map. - - See Also - -------- - {py:func}`tvm_ffi.convert` - - Examples - -------- - .. code-block:: python - - import tvm_ffi - - amap = tvm_ffi.convert({"a": 1, "b": 2}) - assert isinstance(amap, tvm_ffi.Map) - assert len(amap) == 2 - assert amap["a"] == 1 - assert amap["b"] == 2 - """ - - def __init__(self, input_dict: Mapping[Any, Any]): - list_kvs = [] - for k, v in input_dict.items(): - list_kvs.append(k) - list_kvs.append(v) - self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs) - - def __getitem__(self, k): - return _ffi_api.MapGetItem(self, k) - - def __contains__(self, k): - return _ffi_api.MapCount(self, k) != 0 - - def keys(self): - return KeysView(self) - - def values(self): - return ValuesView(self) - - def items(self): - """Get the items from the map""" - return ItemsView(self) - - def __len__(self): - return _ffi_api.MapSize(self) - - def __iter__(self): - return iter(self.keys()) - - def get(self, key, default=None): - """Get an element with a default value. - - Parameters - ---------- - key : object - The attribute key. - - default : object - The default object. - - Returns - ------- - value: object - The result value. - """ - return self[key] if key in self else default - - def __repr__(self): - # exception safety handling for chandle=None - if self.__chandle__() == 0: - return type(self).__name__ + "(chandle=None)" - return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}" diff --git a/ffi/python/tvm_ffi/cpp/__init__.py b/ffi/python/tvm_ffi/cpp/__init__.py deleted file mode 100644 index 632698f4431a..000000000000 --- a/ffi/python/tvm_ffi/cpp/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from .load_inline import load_inline diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py b/ffi/python/tvm_ffi/cpp/load_inline.py deleted file mode 100644 index 3bc0fc4cbc73..000000000000 --- a/ffi/python/tvm_ffi/cpp/load_inline.py +++ /dev/null @@ -1,437 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Sequence, Optional, Mapping -import os -import sys -import glob -import hashlib -import shutil -import subprocess -import functools - -from tvm_ffi.module import Module, load_module -from tvm_ffi.utils import FileLock -from tvm_ffi.libinfo import find_include_path, find_dlpack_include_path, find_libtvm_ffi - -IS_WINDOWS = sys.platform == "win32" - - -def _hash_sources( - cpp_source: str, - cuda_source: str, - functions: Sequence[str] | Mapping[str, str], - extra_cflags: Sequence[str], - extra_cuda_cflags: Sequence[str], - extra_ldflags: Sequence[str], - extra_include_paths: Sequence[str], -) -> str: - """Generate a unique hash for the given sources and functions.""" - m = hashlib.sha256() - m.update(cpp_source.encode("utf-8")) - m.update(cuda_source.encode("utf-8")) - if isinstance(functions, Mapping): - for name in sorted(functions): - m.update(name.encode("utf-8")) - m.update(functions[name].encode("utf-8")) - else: - for name in sorted(functions): - m.update(name.encode("utf-8")) - for flag in extra_cflags: - m.update(flag.encode("utf-8")) - for flag in extra_cuda_cflags: - m.update(flag.encode("utf-8")) - for flag in extra_ldflags: - m.update(flag.encode("utf-8")) - for path in extra_include_paths: - m.update(path.encode("utf-8")) - return m.hexdigest()[:16] - - -def _maybe_write(path: str, content: str) -> None: - """Write content to path if it does not already exist with the same content.""" - if os.path.exists(path): - with open(path, "r") as f: - existing_content = f.read() - if existing_content == content: - return - with open(path, "w") as f: - f.write(content) - - -@functools.lru_cache -def _find_cuda_home() -> Optional[str]: - """Find the CUDA install path.""" - # Guess #1 - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home is None: - # Guess #2 - nvcc_path = shutil.which("nvcc") - if nvcc_path is not None: - cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) - else: - # Guess #3 - if IS_WINDOWS: - cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*") - if len(cuda_homes) == 0: - cuda_home = "" - else: - cuda_home = cuda_homes[0] - else: - cuda_home = "/usr/local/cuda" - if not os.path.exists(cuda_home): - raise RuntimeError( - "Could not find CUDA installation. " - "Please set CUDA_HOME environment variable." - ) - return cuda_home - - -def _get_cuda_target() -> str: - """Get the CUDA target architecture flag.""" - if "TVM_FFI_CUDA_ARCH_LIST" in os.environ: - arch_list = os.environ["TVM_FFI_CUDA_ARCH_LIST"].split() # e.g., "8.9 9.0a" - flags = [] - for arch in arch_list: - if len(arch.split(".")) != 2: - raise ValueError(f"Invalid CUDA architecture: {arch}") - major, minor = arch.split(".") - flags.append(f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}") - return " ".join(flags) - else: - # - try: - status = subprocess.run( - args=["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], - capture_output=True, - check=True, - ) - compute_cap = status.stdout.decode("utf-8").strip().split("\n")[0] - major, minor = compute_cap.split(".") - return f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}" - except Exception: - # fallback to a reasonable default - return "-gencode=arch=compute_70,code=sm_70" - - -def _generate_ninja_build( - name: str, - build_dir: str, - with_cuda: bool, - extra_cflags: Sequence[str], - extra_cuda_cflags: Sequence[str], - extra_ldflags: Sequence[str], - extra_include_paths: Sequence[str], -) -> str: - """Generate the content of build.ninja for building the module.""" - default_include_paths = [find_include_path(), find_dlpack_include_path()] - - tvm_ffi_lib = find_libtvm_ffi() - tvm_ffi_lib_path = os.path.dirname(tvm_ffi_lib) - tvm_ffi_lib_name = os.path.splitext(os.path.basename(tvm_ffi_lib))[0] - if IS_WINDOWS: - default_cflags = [ - "/std:c++17", - "/MD", - "/wd4819", - "/wd4251", - "/wd4244", - "/wd4267", - "/wd4275", - "/wd4018", - "/wd4190", - "/wd4624", - "/wd4067", - "/wd4068", - "/EHsc", - ] - default_cuda_cflags = ["-Xcompiler", "/std:c++17", "/O2"] - default_ldflags = ["/DLL", f"/LIBPATH:{tvm_ffi_lib_path}", f"{tvm_ffi_lib_name}.lib"] - else: - default_cflags = ["-std=c++17", "-fPIC", "-O2"] - default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"] - default_ldflags = ["-shared", "-L{}".format(tvm_ffi_lib_path), "-ltvm_ffi"] - - if with_cuda: - # determine the compute capability of the current GPU - default_cuda_cflags += [_get_cuda_target()] - default_ldflags += ["-L{}".format(os.path.join(_find_cuda_home(), "lib64")), "-lcudart"] - - cflags = default_cflags + [flag.strip() for flag in extra_cflags] - cuda_cflags = default_cuda_cflags + [flag.strip() for flag in extra_cuda_cflags] - ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags] - include_paths = default_include_paths + [os.path.abspath(path) for path in extra_include_paths] - - # append include paths - for path in include_paths: - cflags.append("-I{}".format(path.replace(":", "$:"))) - cuda_cflags.append("-I{}".format(path.replace(":", "$:"))) - - # flags - ninja = [] - ninja.append("ninja_required_version = 1.3") - ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS else "c++"))) - ninja.append("cflags = {}".format(" ".join(cflags))) - if with_cuda: - ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin", "nvcc"))) - ninja.append("cuda_cflags = {}".format(" ".join(cuda_cflags))) - ninja.append("ldflags = {}".format(" ".join(ldflags))) - - # rules - ninja.append("") - ninja.append("rule compile") - if IS_WINDOWS: - ninja.append(" command = $cxx /showIncludes $cflags -c $in /Fo$out") - ninja.append(" deps = msvc") - else: - ninja.append(" depfile = $out.d") - ninja.append(" deps = gcc") - ninja.append(" command = $cxx -MMD -MF $out.d $cflags -c $in -o $out") - ninja.append("") - - if with_cuda: - ninja.append("rule compile_cuda") - ninja.append(" depfile = $out.d") - ninja.append(" deps = gcc") - ninja.append( - " command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out" - ) - ninja.append("") - - ninja.append("rule link") - if IS_WINDOWS: - ninja.append(" command = $cxx $in /link $ldflags /out:$out") - else: - ninja.append(" command = $cxx $in $ldflags -o $out") - ninja.append("") - - # build targets - ninja.append( - "build main.o: compile {}".format( - os.path.abspath(os.path.join(build_dir, "main.cpp")).replace(":", "$:") - ) - ) - if with_cuda: - ninja.append( - "build cuda.o: compile_cuda {}".format( - os.path.abspath(os.path.join(build_dir, "cuda.cu")).replace(":", "$:") - ) - ) - # Use appropriate extension based on platform - ext = ".dll" if IS_WINDOWS else ".so" - ninja.append("build {}{}: link main.o{}".format(name, ext, " cuda.o" if with_cuda else "")) - ninja.append("") - - # default target - ninja.append("default {}{}".format(name, ext)) - ninja.append("") - return "\n".join(ninja) - - -def _build_ninja(build_dir: str) -> None: - """Build the module in the given build directory using ninja.""" - command = ["ninja", "-v"] - num_workers = os.environ.get("MAX_JOBS", None) - if num_workers is not None: - command += ["-j", num_workers] - status = subprocess.run(args=command, cwd=build_dir, capture_output=True) - if status.returncode != 0: - msg = ["ninja exited with status {}".format(status.returncode)] - encoding = "oem" if IS_WINDOWS else "utf-8" - if status.stdout: - msg.append("stdout:\n{}".format(status.stdout.decode(encoding))) - if status.stderr: - msg.append("stderr:\n{}".format(status.stderr.decode(encoding))) - - raise RuntimeError("\n".join(msg)) - - -def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str: - """Decorate the given source code with TVM FFI export macros.""" - sources = [ - "#include ", - "#include ", - "#include ", - "#include ", - "", - source, - ] - - for func_name, func_doc in functions.items(): - sources.append(f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({func_name}, {func_name});") - _ = func_doc # todo: add support to embed function docstring to the tvm ffi functions. - - sources.append("") - - return "\n".join(sources) - - -def load_inline( - name: str, - *, - cpp_sources: str | None = None, - cuda_sources: str | None = None, - functions: Sequence[str] | None = None, - extra_cflags: Sequence[str] | None = None, - extra_cuda_cflags: Sequence[str] | None = None, - extra_ldflags: Sequence[str] | None = None, - extra_include_paths: Sequence[str] | None = None, - build_directory: Optional[str] = None, -) -> Module: - """Compile and load a C++/CUDA tvm ffi module from inline source code. - - This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_sources and - cuda_sources are compiled to an object file, and then linked together into a shared library. It's possible to only - provide cpp_sources or cuda_sources. - - The `functions` parameter is used to specify which functions in the source code should be exported to the tvm ffi module. - It can be a mapping, a sequence, or a single string. When a mapping is given, the keys are the names of the exported - functions, and the values are docstrings for the functions. When a sequence or a single string is given, they are the - functions needed to be exported, and the docstrings are set to empty strings. A single function name can also be given - as a string, indicating that only one function is to be exported. - - Extra compiler and linker flags can be provided via the `extra_cflags`, `extra_cuda_cflags`, and `extra_ldflags` - parameters. The default flags are generally sufficient for most use cases, but you may need to provide additional - flags for your specific use case. - - The include dir of tvm ffi and dlpack are used by default for linker to find the headers. Thus, you can include - any header from tvm ffi and dlpack in your source code. You can also provide additional include paths via the - `extra_include_paths` parameter and include custom headers in your source code. - - The compiled shared library is cached in a cache directory to avoid recompilation. The `build_directory` parameter - is provided to specify the build directory. If not specified, a default tvm ffi cache directory will be used. - The default cache directory can be specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified, - the default cache directory is `~/.cache/tvm-ffi`. - - Parameters - ---------- - name: str - The name of the tvm ffi module. - cpp_sources: Sequence[str] | str, optional - The C++ source code. It can be a list of sources or a single source. - cuda_sources: Sequence[str] | str, optional - The CUDA source code. It can be a list of sources or a single source. - functions: Mapping[str, str] | Sequence[str] | str, optional - The functions in cpp_sources or cuda_source that will be exported to the tvm ffi module. When a mapping is - given, the keys are the names of the exported functions, and the values are docstrings for the functions. When - a sequence or a single string is given, they are the functions needed to be exported, and the docstrings are set - to empty strings. A single function name can also be given as a string. When cpp_sources is given, the functions - must be declared (not necessarily defined) in the cpp_sources. When cpp_sources is not given, the functions - must be defined in the cuda_sources. If not specified, no function will be exported. - extra_cflags: Sequence[str], optional - The extra compiler flags for C++ compilation. - The default flags are: - - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2'] - - On Windows: ['/std:c++17'] - extra_cuda_cflags: - The extra compiler flags for CUDA compilation. - The default flags are: - - On Linux/macOS: ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] - - On Windows: ['-Xcompiler', '/std:c++17', '/O2'] - extra_ldflags: Sequence[str], optional - The extra linker flags. - The default flags are: - - On Linux/macOS: ['-shared'] - - On Windows: ['/DLL'] - extra_include_paths: Sequence[str], optional - The extra include paths. - The default include paths are: - - The include path of tvm ffi - build_directory: str, optional - The build directory. If not specified, a default tvm ffi cache directory will be used. By default, the - cache directory is `~/.cache/tvm-ffi`. You can also set the `TVM_FFI_CACHE_DIR` environment variable to - specify the cache directory. - - Returns - ------- - mod: Module - The loaded tvm ffi module. - """ - if cpp_sources is None: - cpp_sources = [] - elif isinstance(cpp_sources, str): - cpp_sources = [cpp_sources] - cpp_source = "\n".join(cpp_sources) - if cuda_sources is None: - cuda_sources = [] - elif isinstance(cuda_sources, str): - cuda_sources = [cuda_sources] - cuda_source = "\n".join(cuda_sources) - with_cpp = len(cpp_sources) > 0 - with_cuda = len(cuda_sources) > 0 - - extra_ldflags = extra_ldflags or [] - extra_cflags = extra_cflags or [] - extra_cuda_cflags = extra_cuda_cflags or [] - extra_include_paths = extra_include_paths or [] - - # add function registration code to sources - if isinstance(functions, str): - functions = {functions: ""} - elif isinstance(functions, Sequence): - functions = {name: "" for name in functions} - - if with_cpp: - cpp_source = _decorate_with_tvm_ffi(cpp_source, functions) - cuda_source = _decorate_with_tvm_ffi(cuda_source, {}) - else: - cpp_source = _decorate_with_tvm_ffi(cpp_source, {}) - cuda_source = _decorate_with_tvm_ffi(cuda_source, functions) - - # determine the cache dir for the built module - if build_directory is None: - build_directory = os.environ.get( - "TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi") - ) - source_hash: str = _hash_sources( - cpp_source, - cuda_source, - functions, - extra_cflags, - extra_cuda_cflags, - extra_ldflags, - extra_include_paths, - ) - build_dir: str = os.path.join(build_directory, "{}_{}".format(name, source_hash)) - else: - build_dir = os.path.abspath(build_directory) - os.makedirs(build_dir, exist_ok=True) - - # generate build.ninja - ninja_source = _generate_ninja_build( - name=name, - build_dir=build_dir, - with_cuda=with_cuda, - extra_cflags=extra_cflags, - extra_cuda_cflags=extra_cuda_cflags, - extra_ldflags=extra_ldflags, - extra_include_paths=extra_include_paths, - ) - - with FileLock(os.path.join(build_dir, "lock")): - # write source files and build.ninja if they do not already exist - _maybe_write(os.path.join(build_dir, "main.cpp"), cpp_source) - if with_cuda: - _maybe_write(os.path.join(build_dir, "cuda.cu"), cuda_source) - _maybe_write(os.path.join(build_dir, "build.ninja"), ninja_source) - - # build the module - _build_ninja(build_dir) - - # Use appropriate extension based on platform - ext = ".dll" if IS_WINDOWS else ".so" - return load_module(os.path.abspath(os.path.join(build_dir, "{}{}".format(name, ext)))) diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi deleted file mode 100644 index ef583c752908..000000000000 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ /dev/null @@ -1,393 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import ctypes -from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, int16_t -from libc.string cimport memcpy -from libcpp.vector cimport vector -from cpython.bytes cimport PyBytes_AsStringAndSize, PyBytes_FromStringAndSize, PyBytes_AsString -from cpython cimport Py_INCREF, Py_DECREF -from cpython cimport PyErr_CheckSignals, PyGILState_Ensure, PyGILState_Release, PyObject -from cpython cimport pycapsule, PyCapsule_Destructor -from cpython cimport PyErr_SetNone - -cdef extern from "dlpack/dlpack.h": - cdef enum: - kDLCPU = 1, - kDLCUDA = 2, - kDLCUDAHost = 3, - kDLOpenCL = 4, - kDLVulkan = 7, - kDLMetal = 8, - kDLVPI = 9, - kDLROCM = 10, - kDLROCMHost = 11, - kDLExtDev = 12, - kDLCUDAManaged = 13, - kDLOneAPI = 14, - kDLWebGPU = 15, - kDLHexagon = 16, - kDLMAIA = 17 - kDLTrn = 18 - - ctypedef struct DLDataType: - uint8_t code - uint8_t bits - int16_t lanes - - ctypedef struct DLDevice: - int device_type - int device_id - - ctypedef struct DLTensor: - void* data - DLDevice device - int ndim - DLDataType dtype - int64_t* shape - int64_t* strides - uint64_t byte_offset - - ctypedef struct DLPackVersion: - uint32_t major - uint32_t minor - - ctypedef struct DLManagedTensor: - DLTensor dl_tensor - void* manager_ctx - void (*deleter)(DLManagedTensor* self) - - ctypedef struct DLManagedTensorVersioned: - DLPackVersion version - DLTensor dl_tensor - void* manager_ctx - void (*deleter)(DLManagedTensorVersioned* self) - uint64_t flags - - -# Cython binding for TVM FFI C API -cdef extern from "tvm/ffi/c_api.h": - cdef enum TVMFFITypeIndex: - kTVMFFIAny = -1 - kTVMFFINone = 0 - kTVMFFIInt = 1 - kTVMFFIBool = 2 - kTVMFFIFloat = 3 - kTVMFFIOpaquePtr = 4 - kTVMFFIDataType = 5 - kTVMFFIDevice = 6 - kTVMFFIDLTensorPtr = 7 - kTVMFFIRawStr = 8 - kTVMFFIByteArrayPtr = 9 - kTVMFFIObjectRValueRef = 10 - kTVMFFISmallStr = 11 - kTVMFFISmallBytes = 12 - kTVMFFIStaticObjectBegin = 64 - kTVMFFIObject = 64 - kTVMFFIStr = 65 - kTVMFFIBytes = 66 - kTVMFFIError = 67 - kTVMFFIFunction = 68 - kTVMFFIShape = 69 - kTVMFFITensor = 70 - kTVMFFIArray = 71 - kTVMFFIMap = 72 - kTVMFFIModule = 73 - kTVMFFIOpaquePyObject = 74 - - - ctypedef void* TVMFFIObjectHandle - - ctypedef struct TVMFFIObject: - int32_t type_index - int32_t ref_counter - void (*deleter)(TVMFFIObject* self) - - ctypedef struct TVMFFIAny: - int32_t type_index - int32_t zero_padding - int64_t v_int64 - double v_float64 - void* v_ptr - TVMFFIObject* v_obj - const char* v_c_str - DLDataType v_dtype - DLDevice v_device - - ctypedef struct TVMFFIByteArray: - const char* data - size_t size - - ctypedef struct TVMFFIOpaqueObjectCell: - void* handle - - ctypedef struct TVMFFIShapeCell: - const int64_t* data - size_t size - - ctypedef struct TVMFFIErrorCell: - TVMFFIByteArray kind - TVMFFIByteArray message - TVMFFIByteArray traceback - void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback) - - ctypedef int (*TVMFFISafeCallType)( - void* handle, const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result) noexcept - - cdef enum TVMFFIFieldFlagBitMask: - kTVMFFIFieldFlagBitMaskWritable = 1 << 0 - kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1 - kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2 - - ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept; - ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept; - ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept; - - ctypedef struct TVMFFIFieldInfo: - TVMFFIByteArray name - TVMFFIByteArray doc - TVMFFIByteArray type_schema - int64_t flags - int64_t size - int64_t alignment - int64_t offset - TVMFFIFieldGetter getter - TVMFFIFieldSetter setter - TVMFFIAny default_value - int32_t field_static_type_index - - ctypedef struct TVMFFIMethodInfo: - TVMFFIByteArray name - TVMFFIByteArray doc - TVMFFIByteArray type_schema - int64_t flags - TVMFFIAny method - - ctypedef struct TVMFFITypeMetadata: - TVMFFIByteArray doc - TVMFFIObjectCreator creator - int64_t total_size - - ctypedef struct TVMFFITypeInfo: - int32_t type_index - int32_t type_depth - TVMFFIByteArray type_key - const int32_t* type_acenstors - uint64_t type_key_hash - int32_t num_fields - int32_t num_methods - const TVMFFIFieldInfo* fields - const TVMFFIMethodInfo* methods - const TVMFFITypeMetadata* metadata - - int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil - int TVMFFIObjectIncRef(TVMFFIObjectHandle obj) nogil - int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, - void (*deleter)(void*), TVMFFIObjectHandle* out) nogil - int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil - int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result) nogil - int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void*), TVMFFIObjectHandle* out) nogil - int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) nogil - int TVMFFIFunctionSetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle f, int override) nogil - int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle* out) nogil - void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) nogil - void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil - TVMFFIObjectHandle TVMFFIErrorCreate(TVMFFIByteArray* kind, TVMFFIByteArray* message, - TVMFFIByteArray* traceback) nogil - - int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil - int TVMFFIStringFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil - int TVMFFIBytesFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil - int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil - int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil - const TVMFFIByteArray* TVMFFITraceback( - const char* filename, int lineno, const char* func, int cross_ffi_boundary) nogil; - int TVMFFITensorFromDLPack(DLManagedTensor* src, int32_t require_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out) nogil - int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* src, - int32_t require_alignment, - int32_t require_contiguous, - TVMFFIObjectHandle* out) nogil - int TVMFFITensorToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out) nogil - int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle src, - DLManagedTensorVersioned** out) nogil - const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil - TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil - TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil - TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil - TVMFFIOpaqueObjectCell* TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) nogil - TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil - DLTensor* TVMFFITensorGetDLTensorPtr(TVMFFIObjectHandle obj) nogil - DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil - -cdef extern from "tvm/ffi/extra/c_env_api.h": - ctypedef void* TVMFFIStreamHandle - - int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil - void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) nogil - int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream) nogil - - -cdef extern from "tvm_ffi_python_helpers.h": - # no need to expose fields of the call context - # setter data structure - ctypedef int (*DLPackFromPyObject)( - void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream - ) except -1 - - ctypedef int (*DLPackToPyObject)( - DLManagedTensorVersioned* tensor, void** py_obj_out - ) except -1 - ctypedef int (*DLPackTensorAllocator)( - DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, - void (*SetError)(void* error_ctx, const char* kind, const char* message) - ) except -1 - - ctypedef struct TVMFFIPyCallContext: - int device_type - int device_id - TVMFFIStreamHandle stream - DLPackToPyObject c_dlpack_to_pyobject - DLPackTensorAllocator c_dlpack_tensor_allocator - - ctypedef struct TVMFFIPyArgSetter: - int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1 - DLPackFromPyObject c_dlpack_from_pyobject - DLPackToPyObject c_dlpack_to_pyobject - DLPackTensorAllocator c_dlpack_tensor_allocator - - ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1 - # The main call function - int TVMFFIPyFuncCall( - TVMFFIPyArgSetterFactory setter_factory, - void* chandle, - PyObject* py_arg_tuple, - TVMFFIAny* result, - int* c_api_ret_code, - int release_gil, - DLPackToPyObject* out_dlpack_importer - ) except -1 - - int TVMFFIPyConstructorCall( - TVMFFIPyArgSetterFactory setter_factory, - void* chandle, - PyObject* py_arg_tuple, - TVMFFIAny* result, - int* c_api_ret_code, - TVMFFIPyCallContext* parent_ctx - ) except -1 - - int TVMFFIPyCallFieldSetter( - TVMFFIPyArgSetterFactory setter_factory, - TVMFFIFieldSetter field_setter, - void* field_ptr, - PyObject* py_arg, - int* c_api_ret_code - ) except -1 - - int TVMFFIPyPyObjectToFFIAny( - TVMFFIPyArgSetterFactory setter_factory, - PyObject* py_arg, - TVMFFIAny* out, - int* c_api_ret_code - ) except -1 - - size_t TVMFFIPyGetDispatchMapSize() noexcept - - void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, TVMFFIObjectHandle arg) noexcept - void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) noexcept - # the predefined setters for common POD types - int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 - int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 - int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 - int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 - - -cdef class ByteArrayArg: - cdef TVMFFIByteArray cdata - cdef object py_data - - def __cinit__(self, py_data): - if isinstance(py_data, bytearray): - py_data = bytes(py_data) - cdef char* data - cdef Py_ssize_t size - self.py_data = py_data - PyBytes_AsStringAndSize(py_data, &data, &size) - self.cdata.data = data - self.cdata.size = size - - cdef inline TVMFFIByteArray* cptr(self): - return &self.cdata - - -cdef inline py_str(const char* x): - """Convert a c_char_p to a python string - - Parameters - ---------- - x : c_char_p - A char pointer that can be passed to C API - """ - return x.decode("utf-8") - - -cdef inline str bytearray_to_str(const TVMFFIByteArray* x): - return PyBytes_FromStringAndSize(x.data, x.size).decode("utf-8") - - -cdef inline c_str(pystr): - """Create ctypes char * from a python string - - Parameters - ---------- - string : string type - python string - - Returns - ------- - str : c_char_p - A char pointer that can be passed to C API - """ - return pystr.encode("utf-8") - - -cdef inline object ctypes_handle(void* chandle): - """Cast C handle to ctypes handle.""" - return ctypes.cast(chandle, ctypes.c_void_p) - - -cdef inline void* c_handle(object handle): - """Cast C types handle to c handle.""" - cdef unsigned long long v_ptr - v_ptr = handle.value - return (v_ptr) - - -cdef _init_env_api(): - # Initialize env api for signal handling - # Also registers the gil state release and ensure as PyErr_CheckSignals - # function is called with gil released and we need to regrab the gil - CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyErr_CheckSignals"), PyErr_CheckSignals)) - CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Ensure"), PyGILState_Ensure)) - CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Release"), PyGILState_Release)) - -_init_env_api() diff --git a/ffi/python/tvm_ffi/cython/core.pyx b/ffi/python/tvm_ffi/cython/core.pyx deleted file mode 100644 index b24a83da7c1d..000000000000 --- a/ffi/python/tvm_ffi/cython/core.pyx +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -include "./base.pxi" -include "./dtype.pxi" -include "./device.pxi" -include "./object.pxi" -include "./error.pxi" -include "./string.pxi" -include "./tensor.pxi" -include "./function.pxi" diff --git a/ffi/python/tvm_ffi/cython/device.pxi b/ffi/python/tvm_ffi/cython/device.pxi deleted file mode 100644 index 85740a067a63..000000000000 --- a/ffi/python/tvm_ffi/cython/device.pxi +++ /dev/null @@ -1,191 +0,0 @@ - - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from enum import IntEnum - -_CLASS_DEVICE = None - -def _set_class_device(cls): - global _CLASS_DEVICE - _CLASS_DEVICE = cls - - -def _create_device_from_tuple(cls, device_type, device_id): - cdef DLDevice cdevice = TVMFFIDLDeviceFromIntPair(device_type, device_id) - ret = cls.__new__(cls) - (ret).cdevice = cdevice - return ret - - -class DLDeviceType(IntEnum): - """The enum that maps to DLDeviceType.""" - kDLCPU = 1 - kDLCUDA = 2 - kDLCUDAHost = 3 - kDLOpenCL = 4 - kDLVulkan = 7 - kDLMetal = 8 - kDLVPI = 9 - kDLROCM = 10 - kDLROCMHost = 11 - kDLExtDev = 12 - kDLCUDAManaged = 13 - kDLOneAPI = 14 - kDLWebGPU = 15 - kDLHexagon = 16 - - -cdef class Device: - """Device represents a device in the ffi system. - - Device is a thin wrapper around DLDevice in DLPack standard. - - Parameters - ---------- - device_type : Union[str, int] - The string representation of the device type - - index : int - The device id - - Examples - -------- - You can use `tvm_ffi.device` function to create a `Device`. - - .. code-block:: python - - assert tvm_ffi.device("cuda:0") == tvm_ffi.device("cuda", 0) - assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0) - """ - cdef DLDevice cdevice - - _DEVICE_TYPE_TO_NAME = { - DLDeviceType.kDLCPU: "cpu", - DLDeviceType.kDLCUDA: "cuda", - DLDeviceType.kDLCUDAHost: "cuda_host", - DLDeviceType.kDLCUDAManaged: "cuda_managed", - DLDeviceType.kDLOpenCL: "opencl", - DLDeviceType.kDLVulkan: "vulkan", - DLDeviceType.kDLMetal: "metal", - DLDeviceType.kDLVPI: "vpi", - DLDeviceType.kDLROCM: "rocm", - DLDeviceType.kDLROCMHost: "rocm_host", - DLDeviceType.kDLExtDev: "ext_dev", - DLDeviceType.kDLOneAPI: "oneapi", - DLDeviceType.kDLWebGPU: "webgpu", - DLDeviceType.kDLHexagon: "hexagon", - } - - _DEVICE_NAME_TO_TYPE = { - "llvm": DLDeviceType.kDLCPU, - "cpu": DLDeviceType.kDLCPU, - "c": DLDeviceType.kDLCPU, - "test": DLDeviceType.kDLCPU, - "cuda": DLDeviceType.kDLCUDA, - "nvptx": DLDeviceType.kDLCUDA, - "cl": DLDeviceType.kDLOpenCL, - "opencl": DLDeviceType.kDLOpenCL, - "vulkan": DLDeviceType.kDLVulkan, - "metal": DLDeviceType.kDLMetal, - "vpi": DLDeviceType.kDLVPI, - "rocm": DLDeviceType.kDLROCM, - "ext_dev": DLDeviceType.kDLExtDev, - "hexagon": DLDeviceType.kDLHexagon, - "webgpu": DLDeviceType.kDLWebGPU, - } - - def __init__(self, device_type, index = None): - device_type_or_name = device_type - index = index if index is not None else 0 - if isinstance(device_type_or_name, str): - # skip suffix annotations - device_type_or_name = device_type_or_name.split(" ")[0] - parts = device_type_or_name.split(":") - if len(parts) < 1 or len(parts) > 2: - raise ValueError(f"Invalid device: {device_type_or_name}") - if parts[0] not in self._DEVICE_NAME_TO_TYPE: - raise ValueError(f"Unknown device: {parts[0]}") - device_type = self._DEVICE_NAME_TO_TYPE[parts[0]] - if len(parts) == 2: - try: - index = int(parts[1]) - except ValueError: - raise ValueError(f"Invalid device index: {parts[1]}") - else: - device_type = device_type_or_name - if not isinstance(index, int): - raise TypeError(f"Invalid device index: {index}") - self.cdevice = TVMFFIDLDeviceFromIntPair(device_type, index) - - def __reduce__(self): - cls = type(self) - return (_create_device_from_tuple, (cls, self.cdevice.device_type, self.cdevice.device_id)) - - def __eq__(self, other): - if not isinstance(other, Device): - return False - return ( - self.cdevice.device_type == (other).cdevice.device_type - and self.cdevice.device_id == (other).cdevice.device_id - ) - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - cdef int dev_type = self.cdevice.device_type - name = self.__device_type_name__() - index = self.cdevice.device_id - return f"{name}:{index}" - - def __repr__(self): - cdef int dev_type = self.cdevice.device_type - name = self.__device_type_name__() - index = self.cdevice.device_id - return f"device(type='{name}', index={index})" - - def __hash__(self): - return hash((self.cdevice.device_type, self.cdevice.device_id)) - - - def __device_type_name__(self): - return self._DEVICE_TYPE_TO_NAME[self.cdevice.device_type] - - @property - def type(self): - """String representation of the device type.""" - return self.__device_type_name__() - - @property - def index(self): - """The device index.""" - return self.cdevice.device_id - - def dlpack_device_type(self): - """The device type int code used in the DLPack specification. - """ - return self.cdevice.device_type - - -cdef inline object make_ret_device(TVMFFIAny result): - ret = _CLASS_DEVICE.__new__(_CLASS_DEVICE) - (ret).cdevice = result.v_device - return ret - - -_set_class_device(Device) diff --git a/ffi/python/tvm_ffi/cython/dtype.pxi b/ffi/python/tvm_ffi/cython/dtype.pxi deleted file mode 100644 index d9e20b77f3a8..000000000000 --- a/ffi/python/tvm_ffi/cython/dtype.pxi +++ /dev/null @@ -1,116 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -_CLASS_DTYPE = None - -def _set_class_dtype(cls): - global _CLASS_DTYPE - _CLASS_DTYPE = cls - - -def _create_dtype_from_tuple(cls, code, bits, lanes): - cdef DLDataType cdtype - cdtype.code = code - cdtype.bits = bits - cdtype.lanes = lanes - ret = cls.__new__(cls, str(cdtype)) - (ret).cdtype = cdtype - return ret - - -cdef class DataType: - """DataType is a wrapper around DLDataType. - - Parameters - ---------- - dtype_str : str - The string representation of the data type - """ - cdef DLDataType cdtype - - def __init__(self, dtype_str): - cdef ByteArrayArg dtype_str_arg = ByteArrayArg(c_str(dtype_str)) - CHECK_CALL(TVMFFIDataTypeFromString(dtype_str_arg.cptr(), &(self.cdtype))) - - def __reduce__(self): - cls = type(self) - return (_create_dtype_from_tuple, - (cls, self.cdtype.code, self.cdtype.bits, self.cdtype.lanes)) - - def __eq__(self, other): - if not isinstance(other, DataType): - return False - return ( - self.cdtype.code == other.cdtype.code - and self.cdtype.bits == other.cdtype.bits - and self.cdtype.lanes == other.cdtype.lanes - ) - - def __ne__(self, other): - return not self.__eq__(other) - - @property - def type_code(self): - return self.cdtype.code - - @property - def bits(self): - return self.cdtype.bits - - @property - def lanes(self): - return self.cdtype.lanes - - @property - def itemsize(self): - """Get the number of bytes of a single element of this data type. When the number of lanes - is greater than 1, the itemsize is the size of the vector type. - - Returns - ------- - itemsize : int - The number of bytes of a single element of this data type - """ - lanes_as_int = self.cdtype.lanes - if lanes_as_int < 0: - raise ValueError("Cannot determine itemsize for scalable vector types") - return (self.cdtype.bits * self.cdtype.lanes + 7) // 8 - - def __str__(self): - cdef TVMFFIAny temp_any - cdef TVMFFIByteArray* bytes_ptr - cdef TVMFFIByteArray bytes - - CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &temp_any)) - if temp_any.type_index == kTVMFFISmallStr: - bytes = TVMFFISmallBytesGetContentByteArray(&temp_any) - res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) - return res - - bytes_ptr = TVMFFIBytesGetByteArrayPtr(temp_any.v_obj) - res = py_str(PyBytes_FromStringAndSize(bytes_ptr.data, bytes_ptr.size)) - CHECK_CALL(TVMFFIObjectDecRef(temp_any.v_obj)) - return res - - -cdef inline object make_ret_dtype(TVMFFIAny result): - cdtype = DataType.__new__(DataType) - (cdtype).cdtype = result.v_dtype - val = str.__new__(_CLASS_DTYPE, cdtype.__str__()) - val.__tvm_ffi_dtype__ = cdtype - return val diff --git a/ffi/python/tvm_ffi/cython/error.pxi b/ffi/python/tvm_ffi/cython/error.pxi deleted file mode 100644 index b7771000fd82..000000000000 --- a/ffi/python/tvm_ffi/cython/error.pxi +++ /dev/null @@ -1,134 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# error handling for FFI - -import types -import re - -ERROR_NAME_TO_TYPE = {} -ERROR_TYPE_TO_NAME = {} - -_WITH_APPEND_TRACEBACK = None -_TRACEBACK_TO_STR = None - - -cdef class Error(Object): - """Base class for all FFI errors, usually they are attached to errors - - Note - ---- - Do not directly raise this object, instead use the `py_error` method - to convert it to a python error then raise it. - """ - - def __init__(self, kind, message, traceback): - cdef ByteArrayArg kind_arg = ByteArrayArg(c_str(kind)) - cdef ByteArrayArg message_arg = ByteArrayArg(c_str(message)) - cdef ByteArrayArg traceback_arg = ByteArrayArg(c_str(traceback)) - (self).chandle = TVMFFIErrorCreate( - kind_arg.cptr(), message_arg.cptr(), traceback_arg.cptr() - ) - - def update_traceback(self, traceback): - """Update the traceback of the error - - Parameters - ---------- - traceback : str - The traceback to update. - """ - cdef ByteArrayArg traceback_arg = ByteArrayArg(c_str(traceback)) - TVMFFIErrorGetCellPtr(self.chandle).update_traceback(self.chandle, traceback_arg.cptr()) - - def py_error(self): - """ - Convert the FFI error to the python error - """ - error_cls = ERROR_NAME_TO_TYPE.get(self.kind, RuntimeError) - py_error = error_cls(self.message) - py_error = _WITH_APPEND_TRACEBACK(py_error, self.traceback) - py_error.__tvm_ffi_error__ = self - return py_error - - @property - def kind(self): - return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).kind)) - - @property - def message(self): - return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).message)) - - @property - def traceback(self): - return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).traceback)) - - -_register_object_by_index(kTVMFFIError, Error) - - -cdef inline Error move_from_last_error(): - # raise last error - error = Error.__new__(Error) - TVMFFIErrorMoveFromRaised(&(error).chandle) - return error - - -cdef inline int raise_existing_error() except -2: - return -2 - - -cdef inline int set_last_ffi_error(error) except -1: - """Set the last FFI error""" - cdef Error ffi_error - - kind = ERROR_TYPE_TO_NAME.get(type(error), "RuntimeError") - message = error.__str__() - py_traceback = _TRACEBACK_TO_STR(error.__traceback__) - c_traceback = bytearray_to_str(TVMFFITraceback(NULL, 0, NULL, 0)) - - # error comes from an exception thrown from C++ side - if hasattr(error, "__tvm_ffi_error__"): - # already have stack trace - ffi_error = error.__tvm_ffi_error__ - # attach the python traceback together with the C++ traceback to get full trace - ffi_error.update_traceback(c_traceback + py_traceback) - TVMFFIErrorSetRaised(ffi_error.chandle) - else: - ffi_error = Error(kind, message, c_traceback + py_traceback) - TVMFFIErrorSetRaised(ffi_error.chandle) - - -def _convert_to_ffi_error(error): - """Convert the python error to the FFI error""" - py_traceback = _TRACEBACK_TO_STR(error.__traceback__) - if hasattr(error, "__tvm_ffi_error__"): - error.__tvm_ffi_error__.update_traceback(py_traceback) - return error.__tvm_ffi_error__ - else: - kind = ERROR_TYPE_TO_NAME.get(type(error), "RuntimeError") - message = error.__str__() - return Error(kind, message, py_traceback) - - -cdef inline int CHECK_CALL(int ret) except -2: - """Check the return code of the C API function call""" - if ret == 0: - return 0 - # -2 brings exception - if ret == -2: - raise raise_existing_error() - raise move_from_last_error().py_error() diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi deleted file mode 100644 index 71c9522ddba4..000000000000 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ /dev/null @@ -1,853 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import ctypes -import os -from numbers import Real, Integral - - -if os.environ.get("TVM_FFI_BUILD_DOCS", "0") == "0": - try: - # optionally import torch and setup torch related utils - import torch - except ImportError: - torch = None -else: - torch = None - - -cdef int _RELEASE_GIL_BY_DEFAULT = int( - os.environ.get("TVM_FFI_RELEASE_GIL_BY_DEFAULT", "1") -) - -cdef inline object make_ret_small_str(TVMFFIAny result): - """convert small string to return value.""" - cdef TVMFFIByteArray bytes - bytes = TVMFFISmallBytesGetContentByteArray(&result) - return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) - - -cdef inline object make_ret_small_bytes(TVMFFIAny result): - """convert small bytes to return value.""" - cdef TVMFFIByteArray bytes - bytes = TVMFFISmallBytesGetContentByteArray(&result) - return PyBytes_FromStringAndSize(bytes.data, bytes.size) - - -cdef inline object make_ret(TVMFFIAny result, DLPackToPyObject c_dlpack_to_pyobject = NULL): - """convert result to return value.""" - cdef int32_t type_index - type_index = result.type_index - if type_index == kTVMFFITensor: - # specially handle Tensor as it needs a special dltensor field - return make_tensor_from_any(result, c_dlpack_to_pyobject) - elif type_index == kTVMFFIOpaquePyObject: - return make_ret_opaque_object(result) - elif type_index >= kTVMFFIStaticObjectBegin: - return make_ret_object(result) - # the following code should be optimized to switch case - if type_index == kTVMFFINone: - return None - elif type_index == kTVMFFIBool: - return bool(result.v_int64) - elif type_index == kTVMFFIInt: - return result.v_int64 - elif type_index == kTVMFFIFloat: - return result.v_float64 - elif type_index == kTVMFFISmallStr: - return make_ret_small_str(result) - elif type_index == kTVMFFISmallBytes: - return make_ret_small_bytes(result) - elif type_index == kTVMFFIOpaquePtr: - return ctypes_handle(result.v_ptr) - elif type_index == kTVMFFIDataType: - return make_ret_dtype(result) - elif type_index == kTVMFFIDevice: - return make_ret_device(result) - elif type_index == kTVMFFIDLTensorPtr: - return make_ret_dltensor(result) - elif type_index == kTVMFFIObjectRValueRef: - raise ValueError("Return value cannot be ObjectRValueRef") - elif type_index == kTVMFFIByteArrayPtr: - raise ValueError("Return value cannot be ByteArrayPtr") - elif type_index == kTVMFFIRawStr: - raise ValueError("Return value cannot be RawStr") - raise ValueError("Unhandled type index %d" % type_index) - - -##---------------------------------------------------------------------------- -## Helper to simplify calling constructor -##---------------------------------------------------------------------------- -cdef inline int ConstructorCall(void* constructor_handle, - PyObject* py_arg_tuple, - void** handle, - TVMFFIPyCallContext* parent_ctx) except -1: - """Call contructor of a handle function""" - cdef TVMFFIAny result - cdef int c_api_ret_code - # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone - result.type_index = kTVMFFINone - result.v_int64 = 0 - TVMFFIPyConstructorCall( - TVMFFIPyArgSetterFactory_, constructor_handle, py_arg_tuple, &result, &c_api_ret_code, - parent_ctx - ) - CHECK_CALL(c_api_ret_code) - handle[0] = result.v_ptr - return 0 - -##---------------------------------------------------------------------------- -## Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_ -##---------------------------------------------------------------------------- -cdef int TVMFFIPyArgSetterTensor_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* arg, TVMFFIAny* out -) except -1: - if (arg).chandle != NULL: - out.type_index = kTVMFFITensor - out.v_ptr = (arg).chandle - else: - out.type_index = kTVMFFIDLTensorPtr - out.v_ptr = (arg).cdltensor - return 0 - - -cdef int TVMFFIPyArgSetterObject_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* arg, TVMFFIAny* out -) except -1: - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - return 0 - - -cdef int TVMFFIPyArgSetterDLPackCExporter_( - TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx, - PyObject* arg, TVMFFIAny* out -) except -1: - cdef DLManagedTensorVersioned* temp_managed_tensor - cdef TVMFFIObjectHandle temp_chandle - cdef TVMFFIStreamHandle env_stream = NULL - - if this.c_dlpack_to_pyobject != NULL: - ctx.c_dlpack_to_pyobject = this.c_dlpack_to_pyobject - if this.c_dlpack_tensor_allocator != NULL: - ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator - - if ctx.device_id != -1: - # already queried device, do not do it again, pass NULL to stream - if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, NULL) != 0: - return -1 - else: - # query string on the envrionment stream - if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, &env_stream) != 0: - return -1 - # If device is not CPU, we should set the device type and id - if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU: - ctx.stream = env_stream - ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type - ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id - # run conversion - if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0, &temp_chandle) != 0: - raise BufferError("Failed to convert DLManagedTensorVersioned to ffi.Tensor") - out.type_index = kTVMFFITensor - out.v_ptr = temp_chandle - TVMFFIPyPushTempFFIObject(ctx, temp_chandle) - return 0 - - -cdef int TorchDLPackToPyObjectFallback_( - DLManagedTensorVersioned* dltensor, void** py_obj_out -) except -1: - # a bit convoluted but ok as a fallback - cdef TVMFFIObjectHandle temp_chandle - TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) - tensor = make_tensor_from_chandle(temp_chandle) - torch_tensor = torch.from_dlpack(tensor) - Py_INCREF(torch_tensor) - py_obj_out[0] = (torch_tensor) - return 0 - - -cdef int TVMFFIPyArgSetterTorchFallback_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Current setter for torch.Tensor, go through python and not as fast as c exporter""" - # TODO(tqchen): remove this once torch always support fast DLPack importer - cdef object arg = py_arg - is_cuda = arg.is_cuda - arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg)) - out.type_index = kTVMFFITensor - out.v_ptr = (arg).chandle - temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) - ctx.c_dlpack_to_pyobject = TorchDLPackToPyObjectFallback_ - # record the stream and device for torch context - if is_cuda and ctx.device_type != -1: - ctx.device_type = temp_dltensor.device.device_type - ctx.device_id = temp_dltensor.device.device_id - # This is an API that dynamo and other uses to get the raw stream from torch - temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id) - ctx.stream = temp_ptr - # push to temp and clear the handle - TVMFFIPyPushTempPyObject(ctx, arg) - return 0 - - -cdef int TVMFFIPyArgSetterDLPack_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for __dlpack__ mechanism through python, not as fast as c exporter""" - cdef TVMFFIObjectHandle temp_chandle - cdef object arg = py_arg - _from_dlpack_universal(arg, 0, 0, &temp_chandle) - out.type_index = kTVMFFITensor - out.v_ptr = temp_chandle - # record the stream from the source framework context when possible - temp_dltensor = TVMFFITensorGetDLTensorPtr(temp_chandle) - if (temp_dltensor.device.device_type != kDLCPU and - ctx.device_type != -1): - # __tvm_ffi_env_stream__ returns the expected stream that should be set - # through TVMFFIEnvSetStream when calling a TVM FFI function - if hasattr(arg, "__tvm_ffi_env_stream__"): - # Ideally projects should directly setup their stream context API - # write through by also calling TVMFFIEnvSetStream - # so we do not need this protocol to do exchange - ctx.device_type = temp_dltensor.device.device_type - ctx.device_id = temp_dltensor.device.device_id - temp_ptr= arg.__tvm_ffi_env_stream__() - ctx.stream = temp_ptr - TVMFFIPyPushTempFFIObject(ctx, temp_chandle) - return 0 - - -cdef int TVMFFIPyArgSetterDType_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for dtype""" - cdef object arg = py_arg - # dtype is a subclass of str, so this check occur before str - arg = arg.__tvm_ffi_dtype__ - out.type_index = kTVMFFIDataType - out.v_dtype = (arg).cdtype - return 0 - - -cdef int TVMFFIPyArgSetterDevice_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for device""" - cdef object arg = py_arg - out.type_index = kTVMFFIDevice - out.v_device = (arg).cdevice - return 0 - - -cdef int TVMFFIPyArgSetterStr_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for str""" - cdef object arg = py_arg - cdef bytes tstr = arg.encode("utf-8") - cdef char* data - cdef Py_ssize_t size - cdef TVMFFIByteArray cdata - - PyBytes_AsStringAndSize(tstr, &data, &size) - cdata.data = data - cdata.size = size - CHECK_CALL(TVMFFIStringFromByteArray(&cdata, out)) - if out.type_index >= kTVMFFIStaticObjectBegin: - TVMFFIPyPushTempFFIObject(ctx, out.v_ptr) - return 0 - - -cdef int TVMFFIPyArgSetterPyNativeObjectStr_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Specially handle String as its __tvm_ffi_object__ may be empty""" - cdef object arg = py_arg - # need to check if the arg is a large string returned from ffi - if arg.__tvm_ffi_object__ is not None: - arg = arg.__tvm_ffi_object__ - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - return 0 - return TVMFFIPyArgSetterStr_(handle, ctx, py_arg, out) - - -cdef int TVMFFIPyArgSetterBytes_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for bytes""" - cdef object arg = py_arg - - if isinstance(arg, bytearray): - arg = bytes(arg) - - cdef char* data - cdef Py_ssize_t size - cdef TVMFFIByteArray cdata - - PyBytes_AsStringAndSize(arg, &data, &size) - cdata.data = data - cdata.size = size - CHECK_CALL(TVMFFIBytesFromByteArray(&cdata, out)) - - if out.type_index >= kTVMFFIStaticObjectBegin: - TVMFFIPyPushTempFFIObject(ctx, out.v_ptr) - return 0 - - -cdef int TVMFFIPyArgSetterPyNativeObjectBytes_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Specially handle Bytes as its __tvm_ffi_object__ may be empty""" - cdef object arg = py_arg - # need to check if the arg is a large bytes returned from ffi - if arg.__tvm_ffi_object__ is not None: - arg = arg.__tvm_ffi_object__ - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - return 0 - return TVMFFIPyArgSetterBytes_(handle, ctx, py_arg, out) - - -cdef int TVMFFIPyArgSetterPyNativeObjectGeneral_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Specially handle Bytes as its __tvm_ffi_object__ may be empty""" - cdef object arg = py_arg - if arg.__tvm_ffi_object__ is None: - raise ValueError(f"__tvm_ffi_object__ is None for {type(arg)}") - assert arg.__tvm_ffi_object__ is not None - arg = arg.__tvm_ffi_object__ - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - return 0 - - -cdef int TVMFFIPyArgSetterCtypesVoidPtr_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for ctypes.c_void_p""" - out.type_index = kTVMFFIOpaquePtr - out.v_ptr = c_handle(py_arg) - return 0 - - -cdef int TVMFFIPyArgSetterObjectRValueRef_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for ObjectRValueRef""" - cdef object arg = py_arg - out.type_index = kTVMFFIObjectRValueRef - out.v_ptr = &(((arg.obj)).chandle) - return 0 - - -cdef int TVMFFIPyArgSetterCallable_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for Callable""" - cdef object arg = py_arg - cdef TVMFFIObjectHandle chandle - _convert_to_ffi_func_handle(arg, &chandle) - out.type_index = TVMFFIObjectGetTypeIndex(chandle) - out.v_ptr = chandle - TVMFFIPyPushTempFFIObject(ctx, chandle) - return 0 - - -cdef int TVMFFIPyArgSetterException_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for Exception""" - cdef object arg = py_arg - arg = _convert_to_ffi_error(arg) - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - TVMFFIPyPushTempPyObject(ctx, arg) - return 0 - - -cdef int TVMFFIPyArgSetterTuple_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for Tuple""" - # recursively construct a new tuple - cdef TVMFFIObjectHandle chandle - ConstructorCall(_CONSTRUCTOR_ARRAY.chandle, py_arg, &chandle, ctx) - out.type_index = TVMFFIObjectGetTypeIndex(chandle) - out.v_ptr = chandle - TVMFFIPyPushTempFFIObject(ctx, chandle) - return 0 - - -cdef int TVMFFIPyArgSetterTupleLike_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for TupleLike""" - # recursively construct a new tuple - cdef tuple tuple_arg = tuple(py_arg) - cdef TVMFFIObjectHandle chandle - ConstructorCall(_CONSTRUCTOR_ARRAY.chandle, tuple_arg, &chandle, ctx) - out.type_index = TVMFFIObjectGetTypeIndex(chandle) - out.v_ptr = chandle - TVMFFIPyPushTempFFIObject(ctx, chandle) - return 0 - - -cdef int TVMFFIPyArgSetterMap_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for Map""" - # recursively construct a new map - cdef dict dict_arg = py_arg - cdef list list_kvs = [] - for k, v in dict_arg.items(): - list_kvs.append(k) - list_kvs.append(v) - cdef tuple_arg_kvs = tuple(list_kvs) - cdef TVMFFIObjectHandle chandle - ConstructorCall(_CONSTRUCTOR_MAP.chandle, tuple_arg_kvs, &chandle, ctx) - out.type_index = TVMFFIObjectGetTypeIndex(chandle) - out.v_ptr = chandle - TVMFFIPyPushTempFFIObject(ctx, chandle) - return 0 - - -cdef int TVMFFIPyArgSetterObjectConvertible_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Setter for ObjectConvertible""" - # recursively construct a new map - cdef object arg = py_arg - arg = arg.asobject() - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - TVMFFIPyPushTempPyObject(ctx, arg) - - -cdef int TVMFFIPyArgSetterFallback_( - TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out -) except -1: - """Fallback setter for all other types""" - cdef object arg = py_arg - cdef TVMFFIObjectHandle chandle - _convert_to_opaque_object_handle(arg, &chandle) - out.type_index = kTVMFFIOpaquePyObject - out.v_ptr = chandle - TVMFFIPyPushTempFFIObject(ctx, chandle) - - -cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) except -1: - """ - Factory function that creates an argument setter for a given Python argument type. - """ - # NOTE: the order of checks matter here - # becase each argument may satisfy multiple checks - # priortize native types over external types - cdef object arg = value - cdef long long temp_ptr - if arg is None: - out.func = TVMFFIPyArgSetterNone_ - return 0 - if isinstance(arg, Tensor): - out.func = TVMFFIPyArgSetterTensor_ - return 0 - if isinstance(arg, Object): - out.func = TVMFFIPyArgSetterObject_ - return 0 - if isinstance(arg, ObjectRValueRef): - out.func = TVMFFIPyArgSetterObjectRValueRef_ - return 0 - if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1": - # external tensors - if hasattr(arg, "__c_dlpack_from_pyobject__"): - out.func = TVMFFIPyArgSetterDLPackCExporter_ - temp_ptr = arg.__c_dlpack_from_pyobject__ - out.c_dlpack_from_pyobject = temp_ptr - if hasattr(arg, "__c_dlpack_to_pyobject__"): - temp_ptr = arg.__c_dlpack_to_pyobject__ - out.c_dlpack_to_pyobject = temp_ptr - if hasattr(arg, "__c_dlpack_tensor_allocator__"): - temp_ptr = arg.__c_dlpack_tensor_allocator__ - out.c_dlpack_tensor_allocator = temp_ptr - return 0 - if torch is not None and isinstance(arg, torch.Tensor): - out.func = TVMFFIPyArgSetterTorchFallback_ - return 0 - if hasattr(arg, "__dlpack__"): - out.func = TVMFFIPyArgSetterDLPack_ - return 0 - if isinstance(arg, bool): - # A python `bool` is a subclass of `int`, so this check - # must occur before `Integral`. - out.func = TVMFFIPyArgSetterBool_ - return 0 - if isinstance(arg, Integral): - out.func = TVMFFIPyArgSetterInt_ - return 0 - if isinstance(arg, Real): - out.func = TVMFFIPyArgSetterFloat_ - return 0 - # dtype is a subclass of str, so this check must occur before str - if isinstance(arg, _CLASS_DTYPE): - out.func = TVMFFIPyArgSetterDType_ - return 0 - if isinstance(arg, _CLASS_DEVICE): - out.func = TVMFFIPyArgSetterDevice_ - return 0 - if isinstance(arg, PyNativeObject): - # check for PyNativeObject - # this check must happen before str/bytes/tuple - if isinstance(arg, str): - out.func = TVMFFIPyArgSetterPyNativeObjectStr_ - return 0 - if isinstance(arg, bytes): - out.func = TVMFFIPyArgSetterPyNativeObjectBytes_ - return 0 - out.func = TVMFFIPyArgSetterPyNativeObjectGeneral_ - return 0 - if isinstance(arg, str): - out.func = TVMFFIPyArgSetterStr_ - return 0 - if isinstance(arg, (bytes, bytearray)): - out.func = TVMFFIPyArgSetterBytes_ - return 0 - if isinstance(arg, tuple): - out.func = TVMFFIPyArgSetterTuple_ - return 0 - if isinstance(arg, list): - out.func = TVMFFIPyArgSetterTupleLike_ - return 0 - if isinstance(arg, dict): - out.func = TVMFFIPyArgSetterMap_ - return 0 - if isinstance(arg, ctypes.c_void_p): - out.func = TVMFFIPyArgSetterCtypesVoidPtr_ - return 0 - if callable(arg): - out.func = TVMFFIPyArgSetterCallable_ - return 0 - if isinstance(arg, Exception): - out.func = TVMFFIPyArgSetterException_ - return 0 - if isinstance(arg, ObjectConvertible): - out.func = TVMFFIPyArgSetterObjectConvertible_ - return 0 - # default to opaque object - out.func = TVMFFIPyArgSetterFallback_ - return 0 - -#--------------------------------------------------------------------------------------------- -## Implementation of function calling -#--------------------------------------------------------------------------------------------- -cdef class Function(Object): - """Python class that wraps a function with tvm-ffi ABI. - - See Also - -------- - tvm_ffi.register_global_func: How to register global function. - tvm_ffi.get_global_func: How to get global function. - """ - cdef int c_release_gil - cdef dict __dict__ - - def __cinit__(self): - self.c_release_gil = _RELEASE_GIL_BY_DEFAULT - - property release_gil: - def __get__(self): - return self.c_release_gil != 0 - def __set__(self, value): - self.c_release_gil = value - - def __call__(self, *args): - cdef TVMFFIAny result - cdef int c_api_ret_code - cdef DLPackToPyObject c_dlpack_to_pyobject = NULL - # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone - result.type_index = kTVMFFINone - result.v_int64 = 0 - TVMFFIPyFuncCall( - TVMFFIPyArgSetterFactory_, - (self).chandle, args, - &result, - &c_api_ret_code, - self.release_gil, - &c_dlpack_to_pyobject - ) - # NOTE: logic is same as check_call - # directly inline here to simplify traceback - if c_api_ret_code == 0: - return make_ret(result, c_dlpack_to_pyobject) - elif c_api_ret_code == -2: - raise_existing_error() - raise move_from_last_error().py_error() - -_register_object_by_index(kTVMFFIFunction, Function) - - -cdef class FieldGetter: - cdef TVMFFIFieldGetter getter - cdef int64_t offset - - def __call__(self, Object obj): - cdef TVMFFIAny result - cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset - result.type_index = kTVMFFINone - result.v_int64 = 0 - c_api_ret_code = self.getter(field_ptr, &result) - CHECK_CALL(c_api_ret_code) - return make_ret(result) - - -cdef class FieldSetter: - cdef TVMFFIFieldSetter setter - cdef int64_t offset - - def __call__(self, Object obj, value): - cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset - TVMFFIPyCallFieldSetter( - TVMFFIPyArgSetterFactory_, - self.setter, - field_ptr, - value, - &c_api_ret_code - ) - # NOTE: logic is same as check_call - # directly inline here to simplify traceback - if c_api_ret_code == 0: - return - elif c_api_ret_code == -2: - raise_existing_error() - raise move_from_last_error().py_error() - - -cdef _get_method_from_method_info(const TVMFFIMethodInfo* method): - cdef TVMFFIAny result - CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result)) - return make_ret(result) - - -def _member_method_wrapper(method_func): - def wrapper(self, *args): - return method_func(self, *args) - return wrapper - - -def _add_class_attrs_by_reflection(int type_index, object cls): - """Decorate the class attrs by reflection""" - cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index) - cdef const TVMFFIFieldInfo* field - cdef const TVMFFIMethodInfo* method - cdef int num_fields = info.num_fields - cdef int num_methods = info.num_methods - - for i in range(num_fields): - # attach fields to the class - field = &(info.fields[i]) - getter = FieldGetter.__new__(FieldGetter) - (getter).getter = field.getter - (getter).offset = field.offset - setter = FieldSetter.__new__(FieldSetter) - (setter).setter = field.setter - (setter).offset = field.offset - if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0: - setter = None - doc = ( - py_str(PyBytes_FromStringAndSize(field.doc.data, field.doc.size)) - if field.doc.size != 0 - else None - ) - name = py_str(PyBytes_FromStringAndSize(field.name.data, field.name.size)) - if hasattr(cls, name): - # skip already defined attributes - continue - setattr(cls, name, property(getter, setter, doc=doc)) - - for i in range(num_methods): - # attach methods to the class - method = &(info.methods[i]) - name = py_str(PyBytes_FromStringAndSize(method.name.data, method.name.size)) - doc = ( - py_str(PyBytes_FromStringAndSize(method.doc.data, method.doc.size)) - if method.doc.size != 0 - else None - ) - method_func = _get_method_from_method_info(method) - - if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod: - method_pyfunc = staticmethod(method_func) - else: - # must call into another method instead of direct capture - # to avoid the same method_func variable being used - # across multiple loop iterations - method_pyfunc = _member_method_wrapper(method_func) - - if doc is not None: - method_pyfunc.__doc__ = doc - method_pyfunc.__name__ = name - - if hasattr(cls, name): - # skip already defined attributes - continue - setattr(cls, name, method_pyfunc) - - return cls - - -def _register_global_func(name, pyfunc, override): - cdef TVMFFIObjectHandle chandle - cdef int c_api_ret_code - cdef int ioverride = override - cdef ByteArrayArg name_arg = ByteArrayArg(c_str(name)) - - if not isinstance(pyfunc, Function): - pyfunc = _convert_to_ffi_func(pyfunc) - - CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(), (pyfunc).chandle, ioverride)) - return pyfunc - - -def _get_global_func(name, allow_missing): - cdef TVMFFIObjectHandle chandle - cdef ByteArrayArg name_arg = ByteArrayArg(c_str(name)) - - CHECK_CALL(TVMFFIFunctionGetGlobal(name_arg.cptr(), &chandle)) - if chandle != NULL: - ret = Function.__new__(Function) - (ret).chandle = chandle - return ret - - if allow_missing: - return None - - raise ValueError("Cannot find global function %s" % name) - - -# handle callbacks -cdef void tvm_ffi_pyobject_deleter(void* fhandle) noexcept with gil: - local_pyobject = (fhandle) - Py_DECREF(local_pyobject) - - -cdef int tvm_ffi_callback(void* context, - const TVMFFIAny* packed_args, - int32_t num_args, - TVMFFIAny* result) noexcept with gil: - cdef list pyargs - cdef TVMFFIAny temp_result - cdef int c_api_ret_code - local_pyfunc = (context) - pyargs = [] - for i in range(num_args): - CHECK_CALL(TVMFFIAnyViewToOwnedAny(&packed_args[i], &temp_result)) - pyargs.append(make_ret(temp_result)) - - try: - rv = local_pyfunc(*pyargs) - TVMFFIPyPyObjectToFFIAny( - TVMFFIPyArgSetterFactory_, - rv, - result, - &c_api_ret_code - ) - if c_api_ret_code == 0: - return 0 - elif c_api_ret_code == -2: - raise_existing_error() - return -1 - except Exception as err: - set_last_ffi_error(err) - return -1 - - -cdef inline int _convert_to_ffi_func_handle( - object pyfunc, TVMFFIObjectHandle* out_handle -) except -1: - """Convert a python function to TVM FFI function handle""" - Py_INCREF(pyfunc) - CHECK_CALL(TVMFFIFunctionCreate( - (pyfunc), - tvm_ffi_callback, - tvm_ffi_pyobject_deleter, - out_handle)) - return 0 - - -def _convert_to_ffi_func(object pyfunc): - """Convert a python function to TVM FFI function""" - cdef TVMFFIObjectHandle chandle - _convert_to_ffi_func_handle(pyfunc, &chandle) - ret = Function.__new__(Function) - (ret).chandle = chandle - return ret - - -cdef inline int _convert_to_opaque_object_handle( - object pyobject, TVMFFIObjectHandle* out_handle -) except -1: - """Convert a python object to TVM FFI opaque object handle""" - Py_INCREF(pyobject) - CHECK_CALL(TVMFFIObjectCreateOpaque( - (pyobject), - kTVMFFIOpaquePyObject, - tvm_ffi_pyobject_deleter, - out_handle)) - return 0 - - -def _convert_to_opaque_object(object pyobject): - """Convert a python object to TVM FFI opaque object""" - cdef TVMFFIObjectHandle chandle - _convert_to_opaque_object_handle(pyobject, &chandle) - ret = OpaquePyObject.__new__(OpaquePyObject) - (ret).chandle = chandle - return ret - - -def _print_debug_info(): - """Get the size of the dispatch map""" - cdef size_t size = TVMFFIPyGetDispatchMapSize() - print(f"TVMFFIPyGetDispatchMapSize: {size}") - - -cdef Function _OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) -cdef Function _OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True) -cdef Function _CONSTRUCTOR_ARRAY = _get_global_func("ffi.Array", True) -cdef Function _CONSTRUCTOR_MAP = _get_global_func("ffi.Map", True) diff --git a/ffi/python/tvm_ffi/cython/object.pxi b/ffi/python/tvm_ffi/cython/object.pxi deleted file mode 100644 index 1d026b250fb7..000000000000 --- a/ffi/python/tvm_ffi/cython/object.pxi +++ /dev/null @@ -1,295 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import warnings - -_CLASS_OBJECT = None - - -def _set_class_object(cls): - global _CLASS_OBJECT - _CLASS_OBJECT = cls - - -def __object_repr__(obj): - """Object repr function that can be overridden by assigning to it""" - return type(obj).__name__ + "(" + str(obj.__ctypes_handle__().value) + ")" - - -def _new_object(cls): - """Helper function for pickle""" - return cls.__new__(cls) - - -class ObjectConvertible: - """Base class for all classes that can be converted to object.""" - - def asobject(self): - """Convert value to object""" - raise NotImplementedError() - - -class ObjectRValueRef: - """Represent an RValue ref to an object that can be moved. - - Parameters - ---------- - obj : tvm.runtime.Object - The object that this value refers to - """ - - __slots__ = ["obj"] - - def __init__(self, obj): - self.obj = obj - - -cdef class Object: - """Base class of all TVM FFI objects. - """ - cdef void* chandle - - def __cinit__(self): - # initialize chandle to NULL to avoid leak in - # case of error before chandle is set - self.chandle = NULL - - def __dealloc__(self): - if self.chandle != NULL: - CHECK_CALL(TVMFFIObjectDecRef(self.chandle)) - self.chandle = NULL - - def __ctypes_handle__(self): - return ctypes_handle(self.chandle) - - def __chandle__(self): - cdef uint64_t chandle = self.chandle - return chandle - - def __reduce__(self): - cls = type(self) - return (_new_object, (cls,), self.__getstate__()) - - def __getstate__(self): - if _OBJECT_TO_JSON_GRAPH_STR is None: - raise RuntimeError("ffi.ToJSONGraphString is not registered, make sure build project with extra API") - if not self.__chandle__() == 0: - # need to explicit convert to str in case String - # returned and triggered another infinite recursion in get state - return {"handle": str(_OBJECT_TO_JSON_GRAPH_STR(self, None))} - return {"handle": None} - - def __setstate__(self, state): - # pylint: disable=assigning-non-slot, assignment-from-no-return - if _OBJECT_FROM_JSON_GRAPH_STR is None: - raise RuntimeError("ffi.FromJSONGraphString is not registered, make sure build project with extra API") - handle = state["handle"] - if handle is not None: - self.__init_handle_by_constructor__(_OBJECT_FROM_JSON_GRAPH_STR, handle) - else: - self.chandle = NULL - - def __repr__(self): - # exception safety handling for chandle=None - if self.chandle == NULL: - return type(self).__name__ + "(chandle=None)" - return str(__object_repr__(self)) - - def __eq__(self, other): - return self.same_as(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # avoid error raised during construction. - self.chandle = NULL - cdef void* chandle - ConstructorCall( - (fconstructor).chandle, args, &chandle, NULL) - self.chandle = chandle - - def same_as(self, other): - """Check object identity. - - Parameters - ---------- - other : object - The other object to compare against. - - Returns - ------- - result : bool - The comparison result. - """ - if not isinstance(other, Object): - return False - return self.chandle == (other).chandle - - def __hash__(self): - cdef uint64_t hash_value = self.chandle - return hash_value - - def _move(self): - """Create an RValue reference to the object and mark the object as moved. - - This is a advanced developer API that can be useful when passing an - unique reference to an Object that you no longer needed to a function. - - A unique reference can trigger copy on write optimization that avoids - copy when we transform an object. - - Note - ---- - All the reference of the object becomes invalid after it is moved. - Be very careful when using this feature. - - Returns - ------- - rvalue : The rvalue reference. - """ - return ObjectRValueRef(self) - - def __move_handle_from__(self, other): - """Move the handle from other to self""" - self.chandle = (other).chandle - (other).chandle = NULL - - -cdef class OpaquePyObject(Object): - """Opaque PyObject container - - This is a helper class to store opaque python objects - that will be passed to the ffi functions. - - Users do not need to directly create this class. - """ - def pyobject(self): - """Get the underlying python object""" - cdef object obj - cdef PyObject* py_handle - py_handle = (TVMFFIOpaqueObjectGetCellPtr(self.chandle).handle) - obj = py_handle - return obj - - -class PyNativeObject: - """Base class of all TVM objects that also subclass python's builtin types.""" - __slots__ = [] - - def __init_tvm_ffi_object_by_constructor__(self, fconstructor, *args): - """Initialize the internal tvm_ffi_object by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return object is directly set into the object - """ - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - obj.__init_handle_by_constructor__(fconstructor, *args) - self.__tvm_ffi_object__ = obj - - -"""Maps object type index to its constructor""" -cdef list OBJECT_TYPE = [] -"""Maps object type to its type index""" -cdef dict OBJECT_INDEX = {} - - -def _register_object_by_index(int index, object cls): - """register object class""" - global OBJECT_TYPE - while len(OBJECT_TYPE) <= index: - OBJECT_TYPE.append(None) - OBJECT_TYPE[index] = cls - OBJECT_INDEX[cls] = index - - -def _object_type_key_to_index(str type_key): - """get the type index of object class""" - cdef int32_t tidx - type_key_arg = ByteArrayArg(c_str(type_key)) - if TVMFFITypeKeyToIndex(type_key_arg.cptr(), &tidx) == 0: - return tidx - return None - -cdef inline str _type_index_to_key(int32_t tindex): - """get the type key of object class""" - cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex) - cdef const TVMFFIByteArray* type_key - if info == NULL: - return "" - type_key = &(info.type_key) - return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size)) - - -cdef inline object make_ret_opaque_object(TVMFFIAny result): - obj = OpaquePyObject.__new__(OpaquePyObject) - (obj).chandle = result.v_obj - return obj.pyobject() - - -cdef inline object make_ret_object(TVMFFIAny result): - global OBJECT_TYPE - cdef int32_t tindex - cdef object cls - tindex = result.type_index - - if tindex < len(OBJECT_TYPE): - cls = OBJECT_TYPE[tindex] - if cls is not None: - if issubclass(cls, PyNativeObject): - obj = Object.__new__(Object) - (obj).chandle = result.v_obj - return cls.__from_tvm_ffi_object__(cls, obj) - obj = cls.__new__(cls) - (obj).chandle = result.v_obj - return obj - - # object is not found in registered entry - # in this case we need to report an warning - type_key = _type_index_to_key(tindex) - warnings.warn(f"Returning type `{type_key}` which is not registered via register_object, fallback to Object") - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - (obj).chandle = result.v_obj - return obj - - -_set_class_object(Object) diff --git a/ffi/python/tvm_ffi/cython/string.pxi b/ffi/python/tvm_ffi/cython/string.pxi deleted file mode 100644 index 0737259f22e2..000000000000 --- a/ffi/python/tvm_ffi/cython/string.pxi +++ /dev/null @@ -1,80 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# helper class for string/bytes handling - -cdef inline str _string_obj_get_py_str(obj): - cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) - return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) - - -cdef inline bytes _bytes_obj_get_py_bytes(obj): - cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) - return PyBytes_FromStringAndSize(bytes.data, bytes.size) - - - -class String(str, PyNativeObject): - __slots__ = ["__tvm_ffi_object__"] - """String object that is possibly returned by FFI call. - - Note - ---- - This class subclasses str so it can be directly treated as str. - There is no need to construct this object explicitly. - """ - def __new__(cls, value): - val = str.__new__(cls, value) - val.__tvm_ffi_object__ = None - return val - - # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj): - """Construct from a given tvm object.""" - content = _string_obj_get_py_str(obj) - val = str.__new__(cls, content) - val.__tvm_ffi_object__ = obj - return val - - -_register_object_by_index(kTVMFFIStr, String) - - -class Bytes(bytes, PyNativeObject): - """Bytes object that is possibly returned by FFI call. - - Note - ---- - This class subclasses bytes so it can be directly treated as bytes. - There is no need to construct this object explicitly. - """ - def __new__(cls, value): - val = bytes.__new__(cls, value) - val.__tvm_ffi_object__ = None - return val - - # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj): - """Construct from a given tvm object.""" - content = _bytes_obj_get_py_bytes(obj) - val = bytes.__new__(cls, content) - val.__tvm_ffi_object__ = obj - return val - - -_register_object_by_index(kTVMFFIBytes, Bytes) diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi deleted file mode 100644 index 1255f0b0c3ff..000000000000 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ /dev/null @@ -1,362 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -__dlpack_version__ = (1, 1) -_CLASS_TENSOR = None - - -def _set_class_tensor(cls): - global _CLASS_TENSOR - _CLASS_TENSOR = cls - - -cdef const char* _c_str_dltensor = "dltensor" -cdef const char* _c_str_used_dltensor = "used_dltensor" -cdef const char* _c_str_dltensor_versioned = "dltensor_versioned" -cdef const char* _c_str_used_dltensor_versioned = "used_dltensor_versioned" - -cdef void _c_dlpack_deleter(object pycaps): - cdef DLManagedTensor* dltensor - if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor): - dltensor = pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor) - dltensor.deleter(dltensor) - -cdef void _c_dlpack_versioned_deleter(object pycaps): - cdef DLManagedTensorVersioned* dltensor - if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor_versioned): - dltensor = pycapsule.PyCapsule_GetPointer( - pycaps, _c_str_dltensor_versioned) - dltensor.deleter(dltensor) - - -cdef inline object _from_dlpack_intptr( - void* dlpack -): - cdef TVMFFIObjectHandle chandle - cdef DLManagedTensor* ptr = dlpack - cdef int c_api_ret_code - cdef int c_req_alignment = 0 - cdef int c_req_contiguous = 0 - c_api_ret_code = TVMFFITensorFromDLPack( - ptr, c_req_alignment, c_req_contiguous, &chandle) - CHECK_CALL(c_api_ret_code) - return make_tensor_from_chandle(chandle) - - -cdef inline int _from_dlpack( - object dltensor, int require_alignment, - int require_contiguous, TVMFFIObjectHandle* out -) except -1: - cdef DLManagedTensor* ptr - cdef int c_api_ret_code - cdef int c_req_alignment = require_alignment - cdef int c_req_contiguous = require_contiguous - if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): - ptr = pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) - c_api_ret_code = TVMFFITensorFromDLPack( - ptr, c_req_alignment, c_req_contiguous, out) - CHECK_CALL(c_api_ret_code) - # set name and destructor to be empty - pycapsule.PyCapsule_SetDestructor(dltensor, NULL) - pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor) - return 0 - raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once") - - -cdef inline int _from_dlpack_versioned( - object dltensor, int require_alignment, - int require_contiguous, TVMFFIObjectHandle* out -) except -1: - cdef DLManagedTensorVersioned* ptr - cdef int c_api_ret_code - cdef int c_req_alignment = require_alignment - cdef int c_req_contiguous = require_contiguous - if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor_versioned): - ptr = pycapsule.PyCapsule_GetPointer( - dltensor, _c_str_dltensor_versioned) - c_api_ret_code = TVMFFITensorFromDLPackVersioned( - ptr, c_req_alignment, c_req_contiguous, out) - CHECK_CALL(c_api_ret_code) - # set name and destructor to be empty - pycapsule.PyCapsule_SetDestructor(dltensor, NULL) - pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor_versioned) - return 0 - raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be consumed once") - - -cdef inline int _from_dlpack_universal( - object ext_tensor, int require_alignment, - int require_contiguous, TVMFFIObjectHandle* out -) except -1: - # as of most frameworks do not yet support v1.1 - # move to false as most frameworks get upgraded. - cdef int favor_legacy_dlpack = True - - if hasattr(ext_tensor, '__dlpack__'): - if favor_legacy_dlpack: - _from_dlpack( - ext_tensor.__dlpack__(), - require_alignment, - require_contiguous, - out - ) - else: - try: - _from_dlpack_versioned( - ext_tensor.__dlpack__(max_version=__dlpack_version__), - require_alignment, - require_contiguous, - out - ) - except TypeError: - _from_dlpack( - ext_tensor.__dlpack__(), - require_alignment, - require_contiguous, - out - ) - else: - if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned): - _from_dlpack_versioned( - ext_tensor, - require_alignment, - require_contiguous, - out - ) - elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor): - _from_dlpack( - ext_tensor, - require_alignment, - require_contiguous, - out - ) - else: - raise TypeError("Expect from_dlpack to take either a compatible tensor or PyCapsule") - - -def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): - """ - Convert an external tensor to an Tensor. - - Parameters - ---------- - ext_tensor : object - The external tensor to convert. - - require_alignment : int - The minimum required alignment to check for the tensor. - - require_contiguous : bool - Whether to check for contiguous memory. - - Returns - ------- - tensor : :py:class:`tvm_ffi.Tensor` - The converted tensor. - """ - cdef TVMFFIObjectHandle chandle - _from_dlpack_universal(ext_tensor, require_alignment, require_contiguous, &chandle) - return make_tensor_from_chandle(chandle) - - -# helper class for shape handling -def _shape_obj_get_py_tuple(obj): - cdef TVMFFIShapeCell* shape = TVMFFIShapeGetCellPtr((obj).chandle) - return tuple(shape.data[i] for i in range(shape.size)) - - -cdef class Tensor(Object): - """Tensor object that represents a managed n-dimensional array. - """ - cdef DLTensor* cdltensor - - @property - def shape(self): - """Shape of this array""" - return tuple(self.cdltensor.shape[i] for i in range(self.cdltensor.ndim)) - - @property - def dtype(self): - """Data type of this array""" - cdef TVMFFIAny dtype_any - dtype_any.v_dtype = self.cdltensor.dtype - return make_ret_dtype(dtype_any) - - @property - def device(self): - """Device of this Tensor""" - cdef TVMFFIAny device_any - device_any.v_device = self.cdltensor.device - return make_ret_device(device_any) - - def _to_dlpack(self): - cdef DLManagedTensor* dltensor - cdef int c_api_ret_code - c_api_ret_code = TVMFFITensorToDLPack(self.chandle, &dltensor) - CHECK_CALL(c_api_ret_code) - return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) - - def _to_dlpack_versioned(self): - cdef DLManagedTensorVersioned* dltensor - cdef int c_api_ret_code - c_api_ret_code = TVMFFITensorToDLPackVersioned(self.chandle, &dltensor) - CHECK_CALL(c_api_ret_code) - return pycapsule.PyCapsule_New( - dltensor, _c_str_dltensor_versioned, _c_dlpack_versioned_deleter) - - def __dlpack_device__(self): - cdef int device_type = self.cdltensor.device.device_type - cdef int device_id = self.cdltensor.device.device_id - return (device_type, device_id) - - def __dlpack__(self, *, stream=None, max_version=None, dl_device=None, copy=None): - """Produce a DLPack tensor from this array - - Parameters - ---------- - stream : Optional[int] - The stream to use for the DLPack tensor - - max_version : int, optional - The maximum version of the DLPack tensor to produce - - dl_device : Optional[Tuple[int, int]] - The device to use for the DLPack tensor - - copy : Optional[bool] - Whether to copy the data to the new device - - Returns - ------- - dlpack : DLPack tensor - - Raises - ------ - BufferError - Export failed - """ - if max_version is None: - # Keep and use the DLPack 0.X implementation - # Note: from March 2025 onwards (but ideally as late as - # possible), it's okay to raise BufferError here - return self._to_dlpack() - else: - # We get to produce `DLManagedTensorVersioned` now. Note that - # our_own_dlpack_version is the max version that the *producer* - # supports and fills in the `DLManagedTensorVersioned::version` - # field - if max_version[0] >= __dlpack_version__[0]: - if dl_device is not None and dl_device != self.__dlpack_device__(): - raise BufferError("dl_device of different type not supported") - if copy is not None and copy: - raise BufferError("copy not yet supported") - return self._to_dlpack_versioned() - elif max_version[0] < 1: - return self.__ctypes_handle__to_dlpack() - else: - raise BufferError(f"Unsupported max_version {max_version}") - - -_set_class_tensor(Tensor) -_register_object_by_index(kTVMFFITensor, Tensor) - - -cdef int _dltensor_test_wrapper_c_dlpack_from_pyobject( - void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream -) except -1: - cdef PyObject* py_obj = obj - cdef DLTensorTestWrapper wrapper = py_obj - cdef TVMFFIStreamHandle current_stream - cdef DLManagedTensorVersioned* temp_managed_tensor - if env_stream != NULL: - env_stream[0] = TVMFFIEnvGetStream( - wrapper.tensor.cdltensor.device.device_type, - wrapper.tensor.cdltensor.device.device_id - ) - - return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out) - - -def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr(): - cdef DLPackFromPyObject converter_func = _dltensor_test_wrapper_c_dlpack_from_pyobject - cdef void* temp_ptr = converter_func - cdef long long temp_int_ptr = temp_ptr - return temp_int_ptr - - -cdef class DLTensorTestWrapper: - """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose. - """ - __c_dlpack_from_pyobject__ = _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() - - cdef Tensor tensor - cdef dict __dict__ - def __init__(self, tensor): - self.tensor = tensor - - def __tvm_ffi_env_stream__(self): - cdef TVMFFIStreamHandle stream - cdef long long stream_as_int - cdef int c_api_ret_code - stream = TVMFFIEnvGetStream( - self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id) - stream_as_int = stream - return stream_as_int - - def __dlpack_device__(self): - return self.tensor.__dlpack_device__() - - def __dlpack__(self, *, **kwargs): - return self.tensor.__dlpack__(**kwargs) - - -cdef inline object make_ret_dltensor(TVMFFIAny result): - cdef DLTensor* dltensor - dltensor = result.v_ptr - tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR) - (tensor).chandle = NULL - (tensor).cdltensor = dltensor - return tensor - - -cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle, DLPackToPyObject c_dlpack_to_pyobject = NULL): - # TODO: Implement - cdef Tensor tensor - cdef void* py_obj - cdef DLManagedTensorVersioned* dlpack - - if c_dlpack_to_pyobject != NULL: - # try convert and import into the environment array if possible - if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0: - try: - # note that py_obj already holds an extra reference to the tensor - # so we need to decref it after the conversion - c_dlpack_to_pyobject(dlpack, &py_obj) - tensor = (py_obj) - Py_DECREF(tensor) - return tensor - except Exception: - pass - # default return the tensor - tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR) - (tensor).chandle = chandle - (tensor).cdltensor = TVMFFITensorGetDLTensorPtr(chandle) - return tensor - - -cdef inline object make_tensor_from_any(TVMFFIAny any, DLPackToPyObject c_dlpack_to_pyobject): - return make_tensor_from_chandle(any.v_ptr, c_dlpack_to_pyobject) diff --git a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h deleted file mode 100644 index 325b878c4fc9..000000000000 --- a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h +++ /dev/null @@ -1,580 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file tvm_ffi_python_helpers.h - * \brief C++ based helpers for the Python FFI call to optimize performance. - */ -#ifndef TVM_FFI_PYTHON_HELPERS_H_ -#define TVM_FFI_PYTHON_HELPERS_H_ - -#include -#include -#include - -#include -#include -#include -#include - -//---------------------------------------------------------- -// Extra support for DLPack -//---------------------------------------------------------- -/*! - * \brief C-style function pointer to speed convert a PyObject Tensor to a DLManagedTensorVersioned. - * \param py_obj The Python object to convert, this should be PyObject* - * \param out The output DLManagedTensorVersioned. - * \param env_stream Outputs the current context stream of the device provided by the tensor. - * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. - * \note We use void* to avoid dependency on Python.h so this specific type is - * not dependent on Python.h and can be copied to dlpack.h - */ -typedef int (*DLPackFromPyObject)(void* py_obj, DLManagedTensorVersioned** out, void** env_stream); -/*! - * \brief C-style function pointer to speed convert a DLManagedTensorVersioned to a PyObject Tensor. - * \param tensor The DLManagedTensorVersioned to convert. - * \param py_obj_out The output Python object. - * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. - * \note We use void* to avoid dependency on Python.h so this specific type is - * not dependent on Python.h and can be copied to dlpack.h - */ -typedef int (*DLPackToPyObject)(DLManagedTensorVersioned* tensor, void** py_obj_out); - -///-------------------------------------------------------------------------------- -/// We deliberately designed the data structure and function to be C-style -// prefixed with TVMFFIPy so they can be easily invoked through Cython. -///-------------------------------------------------------------------------------- - -/*! - * \brief Context for each ffi call to track the stream, device and temporary arguments. - */ -struct TVMFFIPyCallContext { - /*! \brief The workspace for the packed arguments */ - TVMFFIAny* packed_args = nullptr; - /*! \brief Detected device type, if any */ - int device_type = -1; - /*! \brief Detected device id, if any */ - int device_id = 0; - /*! \brief Detected stream, if any */ - void* stream = nullptr; - /*! \brief the temporary arguments to be recycled */ - void** temp_ffi_objects = nullptr; - /*! \brief the number of temporary arguments */ - int num_temp_ffi_objects = 0; - /*! \brief the temporary arguments to be recycled */ - void** temp_py_objects = nullptr; - /*! \brief the number of temporary arguments */ - int num_temp_py_objects = 0; - /*! \brief the DLPack exporter, if any */ - DLPackToPyObject c_dlpack_to_pyobject{nullptr}; - /*! \brief the DLPack allocator, if any */ - DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; -}; - -/*! \brief Argument setter for a given python argument. */ -struct TVMFFIPyArgSetter { - /*! - * \brief Function pointer to invoke the setter. - * \param self Pointer to this, this should be TVMFFIPyArgSetter* - * \param call_ctx The call context. - * \param arg The python argument to be set - * \param out The output argument. - * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. - */ - int (*func)(TVMFFIPyArgSetter* self, TVMFFIPyCallContext* call_ctx, PyObject* arg, - TVMFFIAny* out); - /*! - * \brief Optional DLPack exporter for for setters that leverages DLPack protocol. - */ - DLPackFromPyObject c_dlpack_from_pyobject{nullptr}; - /*! - * \brief Optional DLPack importer for for setters that leverages DLPack protocol. - */ - DLPackToPyObject c_dlpack_to_pyobject{nullptr}; - /*! - * \brief Optional DLPack allocator for for setters that leverages DLPack protocol. - */ - DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; - /*! - * \brief Invoke the setter. - * \param call_ctx The call context. - * \param arg The python argument to be set - * \param out The output argument. - * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. - */ - int operator()(TVMFFIPyCallContext* call_ctx, PyObject* arg, TVMFFIAny* out) const { - return (*func)(const_cast(this), call_ctx, arg, out); - } -}; - -//--------------------------------------------------------------------------------------------- -// The following section contains predefined setters for common POD types -// They ar not meant to be used directly, but instead being registered to TVMFFIPyCallManager -//--------------------------------------------------------------------------------------------- -int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, - TVMFFIAny* out) noexcept { - out->type_index = kTVMFFIFloat; - // this function getsdispatched when type is already float, so no need to worry about error - out->v_float64 = PyFloat_AsDouble(arg); - return 0; -} - -int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, - TVMFFIAny* out) noexcept { - int overflow = 0; - out->type_index = kTVMFFIInt; - out->v_int64 = PyLong_AsLongLongAndOverflow(arg, &overflow); - - if (overflow != 0) { - PyErr_SetString(PyExc_OverflowError, "Python int too large to convert to int64_t"); - return -1; - } - return 0; -} - -int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, - TVMFFIAny* out) noexcept { - out->type_index = kTVMFFIBool; - // this function getsdispatched when type is already bool, so no need to worry about error - out->v_int64 = PyLong_AsLong(arg); - return 0; -} - -int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, - TVMFFIAny* out) noexcept { - out->type_index = kTVMFFINone; - out->v_int64 = 0; - return 0; -} - -//--------------------------------------------------------------------------------------------- -// The following section contains the dispatcher logic for function calling -//--------------------------------------------------------------------------------------------- -/*! - * \brief Factory function that creates an argument setter for a given Python argument type. - * - * This factory function analyzes a Python argument and creates an appropriate setter - * that can convert Python objects of the same type to C arguments for TVM FFI calls. - * The setter will be cached for future use for setting argument of the same type. - * - * \param arg The Python argument value used as a type example. - * \param out Output parameter that receives the created argument setter. - * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. - * - * \note This is a callback function supplied by the caller. The factory must satisfy - * the invariance that the same setter can be used for other arguments with - * the same type as the provided example argument. - */ -typedef int (*TVMFFIPyArgSetterFactory)(PyObject* arg, TVMFFIPyArgSetter* out); - -/*! - * \brief A manager class that handles python ffi calls. - */ -class TVMFFIPyCallManager { - public: - /*! - * \brief Get the thread local call manager. - * \return The thread local call manager. - */ - static TVMFFIPyCallManager* ThreadLocal() { - static thread_local TVMFFIPyCallManager inst; - return &inst; - } - /*! - * \brief auxiliary class that manages the call stack in RAII manner. - * - * In most cases, it will try to allocate from temp_stack, - * then allocate from heap if the request goes beyond the stack size. - */ - class CallStack : public TVMFFIPyCallContext { - public: - CallStack(TVMFFIPyCallManager* manager, int64_t num_args) : manager_ptr_(manager) { - static_assert(sizeof(TVMFFIAny) >= (sizeof(void*) * 2)); - static_assert(alignof(TVMFFIAny) % alignof(void*) == 0); - old_stack_top_ = manager->stack_top_; - int64_t requested_count = num_args * 2; - TVMFFIAny* stack_head = manager->temp_stack_.data() + manager->stack_top_; - if (manager->stack_top_ + requested_count > - static_cast(manager->temp_stack_.size())) { - // allocate from heap - heap_ptr_ = new TVMFFIAny[requested_count]; - stack_head = heap_ptr_; - } else { - manager->stack_top_ += requested_count; - } - this->packed_args = stack_head; - this->temp_ffi_objects = reinterpret_cast(stack_head + num_args); - this->temp_py_objects = this->temp_ffi_objects + num_args; - } - - ~CallStack() { - try { - // recycle the temporary arguments if any - for (int i = 0; i < this->num_temp_ffi_objects; ++i) { - TVMFFIObjectDecRef(this->temp_ffi_objects[i]); - } - for (int i = 0; i < this->num_temp_py_objects; ++i) { - Py_DecRef(static_cast(this->temp_py_objects[i])); - } - } catch (const std::exception& ex) { - // very rare, catch c++ exception and set python error - PyErr_SetString(PyExc_RuntimeError, ex.what()); - } - // now recycle the memory of the call stack - if (heap_ptr_ == nullptr) { - manager_ptr_->stack_top_ = old_stack_top_; - } else { - delete[] heap_ptr_; - } - } - - private: - /*! - *\brief The manager of the call stack - * If stored on stack, must set it to point to parent. - */ - TVMFFIPyCallManager* manager_ptr_ = nullptr; - /*! \brief The heap of the call stack */ - TVMFFIAny* heap_ptr_ = nullptr; - /*! \brief The old stack size */ - int64_t old_stack_top_ = 0; - }; - - /*! - * \brief Call a function with a variable number of arguments - * \param setter_factory The factory function to create the setter - * \param func_handle The handle of the function to call - * \param py_arg_tuple The arguments to the function - * \param result The result of the function - * \param c_api_ret_code The return code of the C-call - * \param release_gil Whether to release the GIL - * \param optional_out_dlpack_importer The DLPack importer to be used for the result - * \return 0 on when there is no python error, -1 on python error - * \note When an error happens on FFI side, we should return 0 and set c_api_ret_code - */ - int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, - TVMFFIAny* result, int* c_api_ret_code, bool release_gil, - DLPackToPyObject* optional_out_dlpack_importer) { - int64_t num_args = PyTuple_Size(py_arg_tuple); - if (num_args == -1) return -1; - try { - // allocate a call stack - CallStack ctx(this, num_args); - // Iterate over the arguments and set them - for (int64_t i = 0; i < num_args; ++i) { - PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i); - TVMFFIAny* c_arg = ctx.packed_args + i; - if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; - } - TVMFFIStreamHandle prev_stream = nullptr; - DLPackTensorAllocator prev_tensor_allocator = nullptr; - // setup stream context if needed - if (ctx.device_type != -1) { - c_api_ret_code[0] = - TVMFFIEnvSetStream(ctx.device_type, ctx.device_id, ctx.stream, &prev_stream); - // setting failed, directly return - if (c_api_ret_code[0] != 0) return 0; - } - if (ctx.c_dlpack_tensor_allocator != nullptr) { - c_api_ret_code[0] = - TVMFFIEnvSetTensorAllocator(ctx.c_dlpack_tensor_allocator, 0, &prev_tensor_allocator); - if (c_api_ret_code[0] != 0) return 0; - } - // call the function - if (release_gil) { - // release the GIL - Py_BEGIN_ALLOW_THREADS; - c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); - Py_END_ALLOW_THREADS; - } else { - c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); - } - // restore the original stream - if (ctx.device_type != -1 && prev_stream != ctx.stream) { - // always try recover first, even if error happens - if (TVMFFIEnvSetStream(ctx.device_type, ctx.device_id, prev_stream, nullptr) != 0) { - // recover failed, set python error - PyErr_SetString(PyExc_RuntimeError, "Failed to recover stream"); - return -1; - } - } - if (prev_tensor_allocator != ctx.c_dlpack_tensor_allocator) { - c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(prev_tensor_allocator, 0, nullptr); - if (c_api_ret_code[0] != 0) return 0; - } - if (optional_out_dlpack_importer != nullptr && ctx.c_dlpack_to_pyobject != nullptr) { - *optional_out_dlpack_importer = ctx.c_dlpack_to_pyobject; - } - return 0; - } catch (const std::exception& ex) { - // very rare, catch c++ exception and set python error - PyErr_SetString(PyExc_RuntimeError, ex.what()); - return -1; - } - } - - /* - * \brief Call a constructor with a variable number of arguments - * - * This function is similar to FuncCall, but it will not set the - * stream and tensor allocator, instead, it will synchronize the TVMFFIPyCallContext - * with the parent context. This behavior is needed for nested conversion of arguments - * where detected argument setting needs to be synchronized with final call. - * - * This function will also not release the GIL since constructor call is usually cheap. - * - * \param setter_factory The factory function to create the setter - * \param func_handle The handle of the constructor to call - * \param py_arg_tuple The arguments to the constructor - * \param result The result of the constructor - * \param c_api_ret_code The return code of the constructor - * \param parent_ctx The parent call context to - * \return 0 on success, -1 on failure - */ - int ConstructorCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, - PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, - TVMFFIPyCallContext* parent_ctx) { - int64_t num_args = PyTuple_Size(py_arg_tuple); - if (num_args == -1) return -1; - try { - // allocate a call stack - CallStack ctx(this, num_args); - // Iterate over the arguments and set them - for (int64_t i = 0; i < num_args; ++i) { - PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i); - TVMFFIAny* c_arg = ctx.packed_args + i; - if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; - } - c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); - // propagate the call context to the parent context - if (parent_ctx != nullptr) { - // stream and current device information - if (parent_ctx->device_type == -1) { - parent_ctx->device_type = ctx.device_type; - parent_ctx->device_id = ctx.device_id; - parent_ctx->stream = ctx.stream; - } - // DLPack allocator - if (parent_ctx->c_dlpack_tensor_allocator == nullptr) { - parent_ctx->c_dlpack_tensor_allocator = ctx.c_dlpack_tensor_allocator; - } - // DLPack importer - if (parent_ctx->c_dlpack_to_pyobject == nullptr) { - parent_ctx->c_dlpack_to_pyobject = ctx.c_dlpack_to_pyobject; - } - } - return 0; - } catch (const std::exception& ex) { - // very rare, catch c++ exception and set python error - PyErr_SetString(PyExc_RuntimeError, ex.what()); - return -1; - } - } - - int SetField(TVMFFIPyArgSetterFactory setter_factory, TVMFFIFieldSetter field_setter, - void* field_ptr, PyObject* py_arg, int* c_api_ret_code) { - try { - CallStack ctx(this, 1); - TVMFFIAny* c_arg = ctx.packed_args; - if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; - c_api_ret_code[0] = (*field_setter)(field_ptr, c_arg); - return 0; - } catch (const std::exception& ex) { - // very rare, catch c++ exception and set python error - PyErr_SetString(PyExc_RuntimeError, ex.what()); - return -1; - } - } - - int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, PyObject* py_arg, TVMFFIAny* out, - int* c_api_ret_code) { - try { - CallStack ctx(this, 1); - TVMFFIAny* c_arg = ctx.packed_args; - if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; - c_api_ret_code[0] = TVMFFIAnyViewToOwnedAny(c_arg, out); - return 0; - } catch (const std::exception& ex) { - // very rare, catch c++ exception and set python error - PyErr_SetString(PyExc_RuntimeError, ex.what()); - return -1; - } - } - /*! - * \brief Get the size of the dispatch map - * \return The size of the dispatch map - */ - size_t GetDispatchMapSize() const { return dispatch_map_.size(); } - - private: - TVMFFIPyCallManager() { - static constexpr size_t kDefaultDispatchCapacity = 32; - static constexpr size_t kDefaultStackSize = 32; - dispatch_map_.reserve(kDefaultDispatchCapacity); - temp_stack_.resize(kDefaultStackSize * 2); - } - /*! - * \brief Set an py_arg to out. - * \param setter_factory The factory function to create the setter - * \param ctx The call context - * \param py_arg The python argument to be set - * \param out The output argument - * \return 0 on success, -1 on failure - */ - int SetArgument(TVMFFIPyArgSetterFactory setter_factory, TVMFFIPyCallContext* ctx, - PyObject* py_arg, TVMFFIAny* out) { - PyTypeObject* py_type = Py_TYPE(py_arg); - // pre-zero the output argument, modulo the type index - out->type_index = kTVMFFINone; - out->zero_padding = 0; - out->v_int64 = 0; - // find the pre-cached setter - // This class is thread-local, so we don't need to worry about race condition - auto it = dispatch_map_.find(py_type); - if (it != dispatch_map_.end()) { - TVMFFIPyArgSetter setter = it->second; - // if error happens, propagate it back - if (setter(ctx, py_arg, out) != 0) return -1; - } else { - // no dispatch found, query and create a new one. - TVMFFIPyArgSetter setter; - // propagate python error back - if (setter_factory(py_arg, &setter) != 0) { - return -1; - } - // update dispatch table - dispatch_map_.emplace(py_type, setter); - if (setter(ctx, py_arg, out) != 0) return -1; - } - return 0; - } - // internal dispacher - std::unordered_map dispatch_map_; - // temp call stack - std::vector temp_stack_; - int64_t stack_top_ = 0; -}; - -/*! - * \brief Call a function with a variable number of arguments - * \param setter_factory The factory function to create the setter - * \param func_handle The handle of the function to call - * \param py_arg_tuple The arguments to the function - * \param result The result of the function - * \param c_api_ret_code The return code of the function - * \param release_gil Whether to release the GIL - * \param out_dlpack_exporter The DLPack exporter to be used for the result - * \return 0 on success, nonzero on failure - */ -inline int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, - PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, - bool release_gil = true, - DLPackToPyObject* out_dlpack_importer = nullptr) { - return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory, func_handle, py_arg_tuple, - result, c_api_ret_code, release_gil, - out_dlpack_importer); -} - -/*! - * \brief Call a constructor function with a variable number of arguments - * - * This function is similar to TVMFFIPyFuncCall, but it will not set the - * stream and tensor allocator. Instead, it will synchronize the TVMFFIPyCallContext - * with the parent context. This behavior is needed for nested conversion of arguments - * where detected argument settings need to be synchronized with the final call. - * - * This function will also not release the GIL since constructor call is usually cheap. - * - * \param setter_factory The factory function to create the setter - * \param func_handle The handle of the function to call - * \param py_arg_tuple The arguments to the constructor - * \param result The result of the constructor - * \param c_api_ret_code The return code of the constructor - * \param parent_ctx The parent call context - * \param release_gil Whether to release the GIL - * \param out_dlpack_exporter The DLPack exporter to be used for the result - * \return 0 on success, nonzero on failure - */ -inline int TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, - PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, - TVMFFIPyCallContext* parent_ctx) { - return TVMFFIPyCallManager::ThreadLocal()->ConstructorCall( - setter_factory, func_handle, py_arg_tuple, result, c_api_ret_code, parent_ctx); -} - -/*! - * \brief Set a field of a FFI object - * \param setter_factory The factory function to create the setter - * \param field_setter The field setter function - * \param field_ptr The pointer to the field - * \param py_arg The python argument to be set - * \param c_api_ret_code The return code of the function - * \return 0 on success, nonzero on failure - */ -inline int TVMFFIPyCallFieldSetter(TVMFFIPyArgSetterFactory setter_factory, - TVMFFIFieldSetter field_setter, void* field_ptr, - PyObject* py_arg, int* c_api_ret_code) { - return TVMFFIPyCallManager::ThreadLocal()->SetField(setter_factory, field_setter, field_ptr, - py_arg, c_api_ret_code); -} - -/*! - * \brief Convert a Python object to a FFI Any - * \param setter_factory The factory function to create the setter - * \param py_arg The python argument to be set - * \param out The output argument - * \param c_api_ret_code The return code of the function - * \return 0 on success, nonzero on failure - */ -inline int TVMFFIPyPyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, PyObject* py_arg, - TVMFFIAny* out, int* c_api_ret_code) { - return TVMFFIPyCallManager::ThreadLocal()->PyObjectToFFIAny(setter_factory, py_arg, out, - c_api_ret_code); -} - -/*! - * \brief Get the size of the dispatch map - * \return The size of the dispatch map - */ -inline size_t TVMFFIPyGetDispatchMapSize() { - return TVMFFIPyCallManager::ThreadLocal()->GetDispatchMapSize(); -} - -/*! - * \brief Push a temporary FFI object to the call context that will be recycled after the call - * \param ctx The call context - * \param arg The FFI object to push - */ -inline void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, TVMFFIObjectHandle arg) noexcept { - // invariance: each ArgSetter can have at most one temporary Python object - // so it ensures that we won't overflow the temporary Python object stack - ctx->temp_ffi_objects[ctx->num_temp_ffi_objects++] = arg; -} - -/*! - * \brief Push a temporary Python object to the call context that will be recycled after the call - * \param ctx The call context - * \param arg The Python object to push - */ -inline void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) noexcept { - // invariance: each ArgSetter can have at most one temporary Python object - // so it ensures that we won't overflow the temporary Python object stack - Py_IncRef(arg); - ctx->temp_py_objects[ctx->num_temp_py_objects++] = arg; -} -#endif // TVM_FFI_PYTHON_HELPERS_H_ diff --git a/ffi/python/tvm_ffi/error.py b/ffi/python/tvm_ffi/error.py deleted file mode 100644 index a7714cb58ffd..000000000000 --- a/ffi/python/tvm_ffi/error.py +++ /dev/null @@ -1,193 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Error handling.""" -import re -import types -import sys -import ast -from . import core - - -def _parse_traceback(traceback): - """Parse the traceback string into a list of (filename, lineno, func) - - Parameters - ---------- - traceback : str - The traceback string. - - Returns - ------- - result : List[Tuple[str, int, str]] - The list of (filename, lineno, func) - """ - pattern = r'File "(.+?)", line (\d+), in (.+)' - result = [] - for line in traceback.split("\n"): - match = re.match(pattern, line.strip()) - if match: - try: - filename = match.group(1) - lineno = int(match.group(2)) - func = match.group(3) - result.append((filename, lineno, func)) - except ValueError: - pass - return result - - -class TracebackManager: - """ - Helper to manage traceback generation - """ - - def __init__(self): - self._code_cache = {} - - def _get_cached_code_object(self, filename, lineno, func): - # Hack to create a code object that points to the correct - # line number and function name - key = (filename, lineno, func) - # cache the code object to avoid re-creating it - if key in self._code_cache: - return self._code_cache[key] - # Parse to AST and zero out column info - # since column info are not accurate in original trace - tree = ast.parse("_getframe()", filename=filename, mode="eval") - for node in ast.walk(tree): - if hasattr(node, "col_offset"): - node.col_offset = 0 - if hasattr(node, "end_col_offset"): - node.end_col_offset = 0 - # call into get frame, bt changes the context - code_object = compile(tree, filename, "eval") - # replace the function name and line number - code_object = code_object.replace(co_name=func, co_firstlineno=lineno) - self._code_cache[key] = code_object - return code_object - - def _create_frame(self, filename, lineno, func): - """Create a frame object from the filename, lineno, and func""" - code_object = self._get_cached_code_object(filename, lineno, func) - # call into get frame, but changes the context so the code - # points to the correct frame - context = {"_getframe": sys._getframe} - # pylint: disable=eval-used - return eval(code_object, context, context) - - def append_traceback(self, tb, filename, lineno, func): - """Append a traceback to the given traceback - - Parameters - ---------- - tb : types.TracebackType - The traceback to append to. - filename : str - The filename of the traceback - lineno : int - The line number of the traceback - func : str - The function name of the traceback - - Returns - ------- - new_tb : types.TracebackType - The new traceback with the appended frame. - """ - frame = self._create_frame(filename, lineno, func) - return types.TracebackType(tb, frame, frame.f_lasti, lineno) - - -_TRACEBACK_MANAGER = TracebackManager() - - -def _with_append_traceback(py_error, traceback): - """Append the traceback to the py_error and return it""" - tb = py_error.__traceback__ - for filename, lineno, func in reversed(_parse_traceback(traceback)): - tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func) - return py_error.with_traceback(tb) - - -def _traceback_to_str(tb): - """Convert the traceback to a string""" - lines = [] - while tb is not None: - frame = tb.tb_frame - lineno = tb.tb_lineno - filename = frame.f_code.co_filename - funcname = frame.f_code.co_name - lines.append(f' File "{filename}", line {lineno}, in {funcname}\n') - tb = tb.tb_next - return "".join(lines) - - -core._WITH_APPEND_TRACEBACK = _with_append_traceback -core._TRACEBACK_TO_STR = _traceback_to_str - - -def register_error(name_or_cls=None, cls=None): - """Register an error class so it can be recognized by the ffi error handler. - - Parameters - ---------- - name_or_cls : str or class - The name of the error class. - - cls : class - The class to register. - - Returns - ------- - fregister : function - Register function if f is not specified. - - Examples - -------- - .. code-block:: python - - @tvm.error.register_error - class MyError(RuntimeError): - pass - - err_inst = tvm.error.create_ffi_error("MyError: xyz") - assert isinstance(err_inst, MyError) - """ - if callable(name_or_cls): - cls = name_or_cls - name_or_cls = cls.__name__ - - def register(mycls): - """internal register function""" - err_name = name_or_cls if isinstance(name_or_cls, str) else mycls.__name__ - core.ERROR_NAME_TO_TYPE[err_name] = mycls - core.ERROR_TYPE_TO_NAME[mycls] = err_name - return mycls - - if cls is None: - return register - return register(cls) - - -register_error("RuntimeError", RuntimeError) -register_error("ValueError", ValueError) -register_error("TypeError", TypeError) -register_error("AttributeError", AttributeError) -register_error("KeyError", KeyError) -register_error("IndexError", IndexError) -register_error("AssertionError", AssertionError) diff --git a/ffi/python/tvm_ffi/libinfo.py b/ffi/python/tvm_ffi/libinfo.py deleted file mode 100644 index b02897f27917..000000000000 --- a/ffi/python/tvm_ffi/libinfo.py +++ /dev/null @@ -1,167 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import sys -import os -import glob - - -def split_env_var(env_var, split): - """Splits environment variable string. - - Parameters - ---------- - env_var : str - Name of environment variable. - - split : str - String to split env_var on. - - Returns - ------- - splits : list(string) - If env_var exists, split env_var. Otherwise, empty list. - """ - if os.environ.get(env_var, None): - return [p.strip() for p in os.environ[env_var].split(split)] - return [] - - -def get_dll_directories(): - """Get the possible dll directories""" - ffi_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - dll_path = [os.path.join(ffi_dir, "lib")] - dll_path += [os.path.join(ffi_dir, "..", "..", "build", "lib")] - # in source build from parent if needed - dll_path += [os.path.join(ffi_dir, "..", "..", "..", "build", "lib")] - - if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): - dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":")) - dll_path.extend(split_env_var("PATH", ":")) - elif sys.platform.startswith("darwin"): - dll_path.extend(split_env_var("DYLD_LIBRARY_PATH", ":")) - dll_path.extend(split_env_var("PATH", ":")) - elif sys.platform.startswith("win32"): - dll_path.extend(split_env_var("PATH", ";")) - return [os.path.abspath(x) for x in dll_path if os.path.isdir(x)] - - -def find_libtvm_ffi(): - """Find libtvm_ffi.""" - dll_path = get_dll_directories() - if sys.platform.startswith("win32"): - lib_dll_names = ["tvm_ffi.dll"] - elif sys.platform.startswith("darwin"): - lib_dll_names = ["libtvm_ffi.dylib", "libtvm_ffi.so"] - else: - lib_dll_names = ["libtvm_ffi.so"] - - name = lib_dll_names - lib_dll_path = [os.path.join(p, name) for name in lib_dll_names for p in dll_path] - lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)] - - if not lib_found: - raise RuntimeError(f"Cannot find library: {name}\nList of candidates:\n{lib_dll_path}") - - return lib_found[0] - - -def find_source_path(): - """Find packaged source home path.""" - candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__))), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", ".."), - ] - for candidate in candidates: - if os.path.isdir(os.path.join(candidate, "cmake")): - return candidate - raise RuntimeError("Cannot find home path.") - - -def find_cmake_path(): - """Find the preferred cmake path.""" - candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "cmake"), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "cmake"), - ] - for candidate in candidates: - if os.path.isdir(candidate): - return candidate - raise RuntimeError("Cannot find cmake path.") - - -def find_include_path(): - """Find header files for C compilation.""" - candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "include"), - ] - for candidate in candidates: - if os.path.isdir(candidate): - return candidate - raise RuntimeError("Cannot find include path.") - - -def find_python_helper_include_path(): - """Find header files for C compilation.""" - candidates = [ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "cython"), - ] - for candidate in candidates: - if os.path.isfile(os.path.join(candidate, "tvm_ffi_python_helpers.h")): - return candidate - raise RuntimeError("Cannot find python helper include path.") - - -def find_dlpack_include_path(): - """Find dlpack header files for C compilation.""" - install_include_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "include") - if os.path.isdir(os.path.join(install_include_path, "dlpack")): - return install_include_path - - source_include_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..", "..", "3rdparty", "dlpack", "include" - ) - if os.path.isdir(source_include_path): - return source_include_path - - raise RuntimeError("Cannot find include path.") - - -def find_cython_lib(): - """Find the path to tvm cython.""" - path_candidates = [ - os.path.dirname(os.path.realpath(__file__)), - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "build"), - ] - suffixes = "pyd" if sys.platform.startswith("win32") else "so" - for candidate in path_candidates: - for path in glob.glob(os.path.join(candidate, f"core*.{suffixes}")): - return os.path.abspath(path) - raise RuntimeError("Cannot find tvm cython path.") - - -def include_paths(): - """Find all include paths needed for FFI related compilation.""" - include_path = find_include_path() - python_helper_include_path = find_python_helper_include_path() - dlpack_include_path = find_dlpack_include_path() - result = [include_path, dlpack_include_path] - if python_helper_include_path != include_path: - result.append(python_helper_include_path) - return result diff --git a/ffi/python/tvm_ffi/module.py b/ffi/python/tvm_ffi/module.py deleted file mode 100644 index 56c2a9385517..000000000000 --- a/ffi/python/tvm_ffi/module.py +++ /dev/null @@ -1,275 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Module related objects and functions.""" -# pylint: disable=invalid-name - -from enum import IntEnum -from . import _ffi_api - -from . import core -from .registry import register_object - -__all__ = ["Module", "ModulePropertyMask", "system_lib", "load_module"] - - -class ModulePropertyMask(IntEnum): - """Runtime Module Property Mask.""" - - BINARY_SERIALIZABLE = 0b001 - RUNNABLE = 0b010 - COMPILATION_EXPORTABLE = 0b100 - - -@register_object("ffi.Module") -class Module(core.Object): - """Module container for dynamically loaded Module. - - Example - ------- - .. code-block:: python - - import tvm_ffi - - # load the module from a tvm-ffi shared library - mod : tvm_ffi.Module = tvm_ffi.load_module("path/to/library.so") - # you can use mod.func_name to call the exported function - mod.func_name(*args) - - See Also - -------- - :py:func:`tvm_ffi.load_module` - """ - - # constant for entry function name - entry_name = "main" - - @property - def kind(self): - """Get type key of the module.""" - return _ffi_api.ModuleGetKind(self) - - @property - def imports(self): - """Get imported modules - - Returns - ---------- - modules : list of Module - The module - """ - return self.imports_ - - def implements_function(self, name, query_imports=False): - """Returns True if the module has a definition for the global function with name. Note - that has_function(name) does not imply get_function(name) is non-null since the module - may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function - without further compilation. However, get_function(name) non null should always imply - has_function(name). - - Parameters - ---------- - name : str - The name of the function - - query_imports : bool - Whether to also query modules imported by this module. - - Returns - ------- - b : Bool - True if module (or one of its imports) has a definition for name. - """ - return _ffi_api.ModuleImplementsFunction(self, name, query_imports) - - def __getattr__(self, name): - """Accessor to allow getting functions as attributes.""" - try: - func = self.get_function(name) - self.__dict__[name] = func - return func - except AttributeError: - raise AttributeError(f"Module has no function '{name}'") - - def get_function(self, name, query_imports=False): - """Get function from the module. - - Parameters - ---------- - name : str - The name of the function - - query_imports : bool - Whether also query modules imported by this module. - - Returns - ------- - f : tvm_ffi.Function - The result function. - """ - func = _ffi_api.ModuleGetFunction(self, name, query_imports) - if func is None: - raise AttributeError(f"Module has no function '{name}'") - return func - - def import_module(self, module): - """Add module to the import list of current one. - - Parameters - ---------- - module : tvm.runtime.Module - The other module. - """ - _ffi_api.ModuleImportModule(self, module) - - def __getitem__(self, name): - if not isinstance(name, str): - raise ValueError("Can only take string as function name") - return self.get_function(name) - - def __call__(self, *args): - # pylint: disable=not-callable - return self.main(*args) - - def inspect_source(self, fmt=""): - """Get source code from module, if available. - - Parameters - ---------- - fmt : str, optional - The specified format. - - Returns - ------- - source : str - The result source code. - """ - return _ffi_api.ModuleInspectSource(self, fmt) - - def get_write_formats(self): - """Get the format of the module.""" - return _ffi_api.ModuleGetWriteFormats(self) - - def get_property_mask(self): - """Get the runtime module property mask. The mapping is stated in ModulePropertyMask. - - Returns - ------- - mask : int - Bitmask of runtime module property - """ - return _ffi_api.ModuleGetPropertyMask(self) - - def is_binary_serializable(self): - """Module 'binary serializable', save_to_bytes is supported. - - Returns - ------- - b : Bool - True if the module is binary serializable. - """ - return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 - - def is_runnable(self): - """Module 'runnable', get_function is supported. - - Returns - ------- - b : Bool - True if the module is runnable. - """ - return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 - - def is_compilation_exportable(self): - """Module 'compilation exportable', write_to_file is supported for object or source. - - Returns - ------- - b : Bool - True if the module is compilation exportable. - """ - return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0 - - def clear_imports(self): - """Remove all imports of the module.""" - _ffi_api.ModuleClearImports(self) - - def write_to_file(self, file_name, fmt=""): - """Write the current module to file. - - Parameters - ---------- - file_name : str - The name of the file. - fmt : str - The format of the file. - - See Also - -------- - runtime.Module.export_library : export the module to shared library. - """ - _ffi_api.ModuleWriteToFile(self, file_name, fmt) - - -def system_lib(symbol_prefix=""): - """Get system-wide library module singleton. - - System lib is a global module that contains self register functions in startup. - Unlike normal dso modules which need to be loaded explicitly. - It is useful in environments where dynamic loading api like dlopen is banned. - - The system lib is intended to be linked and loaded during the entire life-cyle of the program. - If you want dynamic loading features, use dso modules instead. - - Parameters - ---------- - symbol_prefix: Optional[str] - Optional symbol prefix that can be used for search. When we lookup a symbol - symbol_prefix + name will first be searched, then the name without symbol_prefix. - - Returns - ------- - module : runtime.Module - The system-wide library module. - """ - return _ffi_api.SystemLib(symbol_prefix) - - -def load_module(path): - """Load module from file. - - Parameters - ---------- - path : str - The path to the module file. - - Returns - ------- - module : :py:class:`tvm_ffi.Module` - The loaded module - - Examples - -------- - .. code-block:: python - - mod = tvm_ffi.load_module("path/to/module.so") - mod.func_name(*args) - - See Also - -------- - :py:class:`tvm_ffi.Module` - """ - return _ffi_api.ModuleLoadFromFile(path) diff --git a/ffi/python/tvm_ffi/registry.py b/ffi/python/tvm_ffi/registry.py deleted file mode 100644 index b43e0dc6bb6b..000000000000 --- a/ffi/python/tvm_ffi/registry.py +++ /dev/null @@ -1,226 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""FFI registry to register function and objects.""" -import sys -from . import core - -# whether we simplify skip unknown objects regtistration -_SKIP_UNKNOWN_OBJECTS = False - - -def register_object(type_key=None): - """register object type. - - Parameters - ---------- - type_key : str or cls - The type key of the node - - Examples - -------- - The following code registers MyObject - using type key "test.MyObject" - - .. code-block:: python - - @tvm_ffi.register_object("test.MyObject") - class MyObject(Object): - pass - """ - object_name = type_key if isinstance(type_key, str) else type_key.__name__ - - def register(cls): - """internal register function""" - type_index = core._object_type_key_to_index(object_name) - if type_index is None: - if _SKIP_UNKNOWN_OBJECTS: - return cls - raise ValueError("Cannot find object type index for %s" % object_name) - core._add_class_attrs_by_reflection(type_index, cls) - core._register_object_by_index(type_index, cls) - return cls - - if isinstance(type_key, str): - return register - - return register(type_key) - - -def register_global_func(func_name, f=None, override=False): - """Register global function - - Parameters - ---------- - func_name : str or function - The function name - - f : function, optional - The function to be registered. - - override: boolean optional - Whether override existing entry. - - Returns - ------- - fregister : function - Register function if f is not specified. - - Examples - -------- - .. code-block:: python - - import tvm_ffi - - # we can use decorator to register a function - @tvm_ffi.register_global_func("mytest.echo") - def echo(x): - return x - # After registering, we can get the function by its name - f = tvm_ffi.get_global_func("mytest.echo") - assert f(1) == 1 - - # we can also directly register a function - tvm_ffi.register_global_func("mytest.add_one", lambda x: x + 1) - f = tvm_ffi.get_global_func("mytest.add_one") - assert f(1) == 2 - - See Also - -------- - :py:func:`tvm_ffi.get_global_func` - :py:func:`tvm_ffi.remove_global_func` - """ - if callable(func_name): - f = func_name - func_name = f.__name__ - - if not isinstance(func_name, str): - raise ValueError("expect string function name") - - def register(myf): - """internal register function""" - return core._register_global_func(func_name, myf, override) - - if f: - return register(f) - return register - - -def get_global_func(name, allow_missing=False): - """Get a global function by name - - Parameters - ---------- - name : str - The name of the global function - - allow_missing : bool - Whether allow missing function or raise an error. - - Returns - ------- - func : Function - The function to be returned, None if function is missing. - - See Also - -------- - :py:func:`tvm_ffi.register_global_func` - """ - return core._get_global_func(name, allow_missing) - - -def list_global_func_names(): - """Get list of global functions registered. - - Returns - ------- - names : list - List of global functions names. - """ - name_functor = get_global_func("ffi.FunctionListGlobalNamesFunctor")() - num_names = name_functor(-1) - return [name_functor(i) for i in range(num_names)] - - -def remove_global_func(name): - """Remove a global function by name - - Parameters - ---------- - name : str - The name of the global function - """ - get_global_func("ffi.FunctionRemoveGlobal")(name) - - -def init_ffi_api(namespace, target_module_name=None): - """Initialize register ffi api functions into a given module - - Parameters - ---------- - namespace : str - The namespace of the source registry - - target_module_name : str - The target module name if different from namespace - - Examples - -------- - - A typical usage pattern is to create a _ffi_api.py file to register - the functions under a given module. The following - code populates all registered global functions - prefixed with ``mypackage.`` into the current module, - then we can call the function through ``_ffi_api.func_name(*args)`` - which will call into the registered global function "mypackage.func_name". - - .. code-block:: python - - # _ffi_api.py - import tvm_ffi - - tvm_ffi.init_ffi_api("mypackage", __name__) - """ - target_module_name = target_module_name if target_module_name else namespace - - if namespace.startswith("tvm."): - prefix = namespace[4:] - else: - prefix = namespace - - target_module = sys.modules[target_module_name] - - for name in list_global_func_names(): - if not name.startswith(prefix): - continue - - fname = name[len(prefix) + 1 :] - if fname.find(".") != -1: - continue - - f = get_global_func(name) - f.__name__ = fname - setattr(target_module, f.__name__, f) - - -__all__ = [ - "register_object", - "register_global_func", - "get_global_func", - "list_global_func_names", - "remove_global_func", - "init_ffi_api", -] diff --git a/ffi/python/tvm_ffi/serialization.py b/ffi/python/tvm_ffi/serialization.py deleted file mode 100644 index 25d9bcefb828..000000000000 --- a/ffi/python/tvm_ffi/serialization.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Serialization related utilities to enable some object can be pickled""" - -from typing import Optional, Any -from . import _ffi_api - - -def to_json_graph_str(obj: Any, metadata: Optional[dict] = None): - """ - Dump an object to a JSON graph string. - - The JSON graph string is a string representation of of the object - graph includes the reference information of same objects, which can - be used for serialization and debugging. - - Parameters - ---------- - obj : Any - The object to save. - - metadata : Optional[dict], optional - Extra metadata to save into the json graph string. - - Returns - ------- - json_str : str - The JSON graph string. - """ - return _ffi_api.ToJSONGraphString(obj, metadata) - - -def from_json_graph_str(json_str: str): - """ - Load an object from a JSON graph string. - - The JSON graph string is a string representation of of the object - graph that also includes the reference information. - - Parameters - ---------- - json_str : str - The JSON graph string to load. - - Returns - ------- - obj : Any - The loaded object. - """ - return _ffi_api.FromJSONGraphString(json_str) - - -__all__ = ["from_json_graph_str", "to_json_graph_str"] diff --git a/ffi/python/tvm_ffi/testing.py b/ffi/python/tvm_ffi/testing.py deleted file mode 100644 index 843a10c896a8..000000000000 --- a/ffi/python/tvm_ffi/testing.py +++ /dev/null @@ -1,63 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Testing utilities.""" - -from . import _ffi_api -from .core import Object -from .registry import register_object - - -@register_object("testing.TestObjectBase") -class TestObjectBase(Object): - """ - Test object base class. - """ - - -@register_object("testing.TestObjectDerived") -class TestObjectDerived(TestObjectBase): - """ - Test object derived class. - """ - - -def create_object(type_key: str, **kwargs) -> Object: - """ - Make an object by reflection. - - Parameters - ---------- - type_key : str - The type key of the object. - kwargs : dict - The keyword arguments to the object. - - Returns - ------- - obj : object - The created object. - - Note - ---- - This function is only used for testing purposes and should - not be used in other cases. - """ - args = [type_key] - for k, v in kwargs.items(): - args.append(k) - args.append(v) - return _ffi_api.MakeObjectFromPackedArgs(*args) diff --git a/ffi/python/tvm_ffi/utils/__init__.py b/ffi/python/tvm_ffi/utils/__init__.py deleted file mode 100644 index 543bd0f84100..000000000000 --- a/ffi/python/tvm_ffi/utils/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from .lockfile import FileLock diff --git a/ffi/python/tvm_ffi/utils/lockfile.py b/ffi/python/tvm_ffi/utils/lockfile.py deleted file mode 100644 index 3b3197e2d8e0..000000000000 --- a/ffi/python/tvm_ffi/utils/lockfile.py +++ /dev/null @@ -1,113 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import sys -import time - -# Platform-specific imports for file locking -if sys.platform == "win32": - import msvcrt -else: - import fcntl - - -class FileLock: - """ - A cross-platform file locking mechanism using Python's standard library. - This class implements an advisory lock, which must be respected by all - cooperating processes. - """ - - def __init__(self, lock_file_path): - self.lock_file_path = lock_file_path - self._file_descriptor = None - - def __enter__(self): - """ - Context manager protocol: acquire the lock upon entering the 'with' block. - This method will block indefinitely until the lock is acquired. - """ - self.blocking_acquire() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Context manager protocol: release the lock upon exiting the 'with' block. - """ - self.release() - return False # Propagate exceptions, if any - - def acquire(self): - """ - Acquires an exclusive, non-blocking lock on the file. - Returns True if the lock was acquired, False otherwise. - """ - try: - if sys.platform == "win32": - self._file_descriptor = os.open( - self.lock_file_path, os.O_RDWR | os.O_CREAT | os.O_BINARY - ) - msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1) - else: # Unix-like systems - self._file_descriptor = os.open(self.lock_file_path, os.O_WRONLY | os.O_CREAT) - fcntl.flock(self._file_descriptor, fcntl.LOCK_EX | fcntl.LOCK_NB) - return True - except (IOError, BlockingIOError): - if self._file_descriptor is not None: - os.close(self._file_descriptor) - self._file_descriptor = None - return False - except Exception as e: - if self._file_descriptor is not None: - os.close(self._file_descriptor) - self._file_descriptor = None - raise RuntimeError(f"An unexpected error occurred: {e}") - - def blocking_acquire(self, timeout=None, poll_interval=0.1): - """ - Waits until an exclusive lock can be acquired, with an optional timeout. - - Args: - timeout (float): The maximum time to wait for the lock in seconds. - A value of None means wait indefinitely. - poll_interval (float): The time to wait between lock attempts in seconds. - """ - start_time = time.time() - while True: - if self.acquire(): - return True - - # Check for timeout - if timeout is not None and (time.time() - start_time) > timeout: - raise TimeoutError( - f"Failed to acquire lock on '{self.lock_file_path}' after {timeout} seconds." - ) - - time.sleep(poll_interval) - - def release(self): - """ - Releases the lock and closes the file descriptor. - """ - if self._file_descriptor is not None: - if sys.platform == "win32": - msvcrt.locking(self._file_descriptor, msvcrt.LK_UNLCK, 1) - else: - fcntl.flock(self._file_descriptor, fcntl.LOCK_UN) - os.close(self._file_descriptor) - self._file_descriptor = None diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py deleted file mode 100644 index 2ab85bf03559..000000000000 --- a/ffi/scripts/benchmark_dlpack.py +++ /dev/null @@ -1,448 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This script is used to benchmark the API overhead of different -python FFI API calling overhead, through DLPack API. - -Specifically, we would like to understand the overall overhead -python/C++ API calls. The general goal is to understand the overall -space and get a sense of what are the possible operations. - -We pick function f(x, y, z) where x, y, z are length 1 tensors. -The benchmark is running in eager mode so we can see what is possible. -It is orthogonal to other optimizations. For example cudagraph can -eliminate these overheads completely. So the goal is to get a sense -of what is possible under eager mode. - -Summary of some takeaways: -- numpy.add roughly takes 0.36 us per call, which gives roughly what can - be done in python env. -- torch.add on gpu takes about 3.7us per call, giving us an idea of what - roughly we need to get to in eager mode. -- - -""" -import os -import torch -import numpy as np -import tvm_ffi -import time - - -def print_speed(name, speed): - print(f"{name:<60} {speed} sec/call") - - -def print_error(name, error): - print(f"{name:<60} {error}") - - -def baseline_torch_add(repeat): - """Run torch.add with one element""" - - def run_bench(device): - x = torch.arange(1, device=device) - y = torch.arange(1, device=device) - z = torch.arange(1, device=device) - - torch.add(x, y, out=z) - if device == "cuda": - torch.cuda.synchronize() - start = time.time() - for i in range(repeat): - torch.add(x, y, out=z) - # note we deliberately do not use torch.cuda.synchronize() - # because we want to see the overhead of the FFI call. - end = time.time() - print_speed(f"torch.add[{device}]", (end - start) / repeat) - - # rough take away: add on cuda roughly takes 3e-6 sec/call - run_bench("cpu") - run_bench("cuda") - - -def baseline_numpy_add(repeat): - """Run numpy.add with one element""" - x = np.arange(1) - y = np.arange(1) - z = np.arange(1) - - np.add(x, y, out=z) - start = time.time() - for i in range(repeat): - np.add(x, y, out=z) - end = time.time() - speed = (end - start) / repeat - print_speed("numpy.add", speed) - - -def baseline_cupy_add(repeat): - """Run cupy.add with one element""" - try: - import cupy - except ImportError: - # skip if cupy is not installed - return - x = cupy.arange(1) - y = cupy.arange(1) - z = cupy.arange(1) - - cupy.add(x, y, out=z) - start = time.time() - for i in range(repeat): - cupy.add(x, y, out=z) - end = time.time() - speed = (end - start) / repeat - print_speed("cupy.add", speed) - - -def tvm_ffi_nop(repeat): - """Overhead of tvm FFI python call via calling a NOP. - - testing.nop is defined in c++ and do nothing. - """ - nop = tvm_ffi.get_global_func("testing.nop") - x = tvm_ffi.from_dlpack(torch.arange(1)) - y = tvm_ffi.from_dlpack(torch.arange(1)) - z = tvm_ffi.from_dlpack(torch.arange(1)) - nop(x, y, z) - start = time.time() - for i in range(repeat): - nop(x, y, z) - end = time.time() - print_speed("tvm_ffi.nop", (end - start) / repeat) - - -def bench_ffi_nop_from_dlpack(name, x, y, z, repeat): - """run dlpack conversion + tvm_ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - nop = tvm_ffi.get_global_func("testing.nop") - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - - start = time.time() - for i in range(repeat): - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - end = time.time() - print_speed(name, (end - start) / repeat) - - -def tvm_ffi_nop_from_torch_dlpack(repeat): - """run dlpack conversion + tvm_ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - x = torch.arange(1) - y = torch.arange(1) - z = torch.arange(1) - bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(torch)", x, y, z, repeat) - - -def tvm_ffi_nop_from_numpy_dlpack(repeat): - """run dlpack conversion + tvm_ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - x = np.arange(1) - y = np.arange(1) - z = np.arange(1) - bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(numpy)", x, y, z, repeat) - - -def tvm_ffi_self_dlpack_nop(repeat): - """run dlpack conversion + tvm_ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - x = tvm_ffi.from_dlpack(torch.arange(1)) - y = tvm_ffi.from_dlpack(torch.arange(1)) - z = tvm_ffi.from_dlpack(torch.arange(1)) - bench_ffi_nop_from_dlpack("tvm_ffi.nop+from_dlpack(tvm)", x, y, z, repeat) - - -def bench_ffi_nop_from_dlpack(name, x, y, z, repeat): - """run dlpack conversion + tvm_ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - nop = tvm_ffi.get_global_func("testing.nop") - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - - start = time.time() - for i in range(repeat): - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - end = time.time() - print_speed(name, (end - start) / repeat) - - -def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat): - """ - Measures overhead of running dlpack for each args then invoke - but uses the legacy torch.utils.dlpack.to_dlpack API - - This helps to measure possible implementation overhead of torch. - """ - nop = tvm_ffi.get_global_func("testing.nop") - x = torch.arange(1) - y = torch.arange(1) - z = torch.arange(1) - - tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x)) - ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y)) - tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z)) - nop(tx, ty, tz) - - start = time.time() - for i in range(repeat): - tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x)) - ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y)) - tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z)) - nop(tx, ty, tz) - end = time.time() - speed = (end - start) / repeat - print_speed("tvm_ffi.nop+from_dlpack(torch.utils)", speed) - - -def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat): - """ - Measures overhead of running dlpack via auto convert by directly - take torch.Tensor as inputs. - """ - nop = tvm_ffi.get_global_func("testing.nop") - nop(x, y, z) - eps = 1e-6 - start = time.time() - for i in range(repeat): - nop(x, y, z) - end = time.time() - speed = (end - start) / repeat - print_speed(name, speed) - - -def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False): - """ - Measures overhead of running dlpack via auto convert by directly - take torch.Tensor as inputs. - """ - # use larger to ensure alignment req is met - x = torch.arange(1, device=device) - y = torch.arange(1, device=device) - z = torch.arange(1, device=device) - if stream: - with torch.cuda.stream(torch.cuda.Stream()): - bench_tvm_ffi_nop_autodlpack( - f"tvm_ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, repeat - ) - else: - bench_tvm_ffi_nop_autodlpack(f"tvm_ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat) - - -def tvm_ffi_nop_autodlpack_from_numpy(repeat): - """ - Measures overhead of running dlpack via auto convert by directly - take numpy.ndarray as inputs. - """ - # use larger to ensure alignment req is met - x = np.arange(256) - y = np.arange(256) - z = np.arange(256) - bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z, repeat) - - -def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device): - """ - Measures overhead of running dlpack via auto convert by directly - take test wrapper as inputs. This effectively measure DLPack exchange in tvm ffi. - """ - x = tvm_ffi.from_dlpack(torch.arange(1, device=device)) - y = tvm_ffi.from_dlpack(torch.arange(1, device=device)) - z = tvm_ffi.from_dlpack(torch.arange(1, device=device)) - x = tvm_ffi.core.DLTensorTestWrapper(x) - y = tvm_ffi.core.DLTensorTestWrapper(y) - z = tvm_ffi.core.DLTensorTestWrapper(z) - bench_tvm_ffi_nop_autodlpack( - f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, repeat - ) - - -def bench_to_dlpack(x, name, repeat): - x.__dlpack__() - start = time.time() - for i in range(repeat): - x.__dlpack__() - end = time.time() - speed = (end - start) / repeat - print_speed(name, speed) - - -def bench_to_dlpack_versioned(x, name, repeat, max_version=(1, 1)): - """ - Measures overhead of running dlpack with latest 1.1. - """ - try: - x.__dlpack__(max_version=max_version) - start = time.time() - for i in range(repeat): - x.__dlpack__(max_version=max_version) - end = time.time() - speed = (end - start) / repeat - print_speed(name, speed) - except Exception as e: - print_error(name, e) - - -def bench_torch_utils_to_dlpack(repeat): - """ - Measures overhead of running torch.utils.dlpack.to_dlpack - """ - x = torch.arange(1) - torch.utils.dlpack.to_dlpack(x) - start = time.time() - for i in range(repeat): - torch.utils.dlpack.to_dlpack(x) - end = time.time() - speed = (end - start) / repeat - print_speed("torch.utils.dlpack.to_dlpack", speed) - - -def torch_get_cuda_stream_native(device_id): - return torch.cuda.current_stream(device_id).cuda_stream - - -def load_torch_get_current_cuda_stream(): - """Create a faster get_current_cuda_stream for torch through cpp extension.""" - from torch.utils import cpp_extension - - source = """ - #include - - int64_t get_current_cuda_stream(int device_id) { - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); - // fast invariant, default stream is always 0 - if (stream.id() == 0) return 0; - // convert to cudaStream_t - return reinterpret_cast(static_cast(stream)); - } - """ - result = cpp_extension.load_inline( - name="get_current_cuda_stream", - cpp_sources=[source], - cuda_sources=[], - extra_cflags=["-O3"], - extra_include_paths=cpp_extension.include_paths("cuda"), - functions=["get_current_cuda_stream"], - ) - return result.get_current_cuda_stream - - -def bench_torch_get_current_stream(repeat, name, func): - """ - Measures overhead of running torch.cuda.current_stream - """ - x = torch.arange(1, device="cuda") - func(0) - start = time.time() - for i in range(repeat): - func(0) - end = time.time() - speed = (end - start) / repeat - print_speed(f"torch.cuda.current_stream[{name}]", speed) - - -def populate_object_table(num_classes): - nop = tvm_ffi.get_global_func("testing.nop") - dummy_instances = [type(f"DummyClass{i}", (object,), {})() for i in range(num_classes)] - for instance in dummy_instances: - nop(instance) - - -def main(): - repeat = 10000 - # measures impact of object dispatch table size - # takeaway so far is that there is no impact on the performance - num_classes = 0 - populate_object_table(num_classes) - print("-----------------------------") - print("Benchmark f(x, y, z) overhead") - print("-----------------------------") - baseline_numpy_add(repeat) - baseline_torch_add(repeat) - baseline_cupy_add(repeat) - tvm_ffi_nop_from_torch_dlpack(repeat) - tvm_ffi_nop_from_numpy_dlpack(repeat) - tvm_ffi_self_dlpack_nop(repeat) - tvm_ffi_nop_from_torch_utils_to_dlpack(repeat) - tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu") - tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda") - tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True) - - tvm_ffi_nop_autodlpack_from_numpy(repeat) - tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu") - tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda") - tvm_ffi_nop(repeat) - print("-------------------------------") - print("Benchmark x.__dlpack__ overhead") - print("-------------------------------") - bench_torch_utils_to_dlpack(repeat) - bench_to_dlpack(torch.arange(1), "torch.__dlpack__", repeat) - bench_to_dlpack(np.arange(1), "numpy.__dlpack__", repeat) - bench_to_dlpack(tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__", repeat) - print("---------------------------------------------------") - print("Benchmark x.__dlpack__(max_version=(1,1)) overhead") - print("---------------------------------------------------") - bench_to_dlpack_versioned(torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat) - bench_to_dlpack_versioned(np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat) - bench_to_dlpack_versioned( - tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat - ) - print("---------------------------------------------------") - print("Benchmark torch.get_cuda_stream[default stream]") - print("---------------------------------------------------") - bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream()) - bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) - print("---------------------------------------------------") - print("Benchmark torch.get_cuda_stream[non-default stream]") - print("---------------------------------------------------") - with torch.cuda.stream(torch.cuda.Stream()): - bench_torch_get_current_stream( - repeat, "cpp-extension", load_torch_get_current_cuda_stream() - ) - bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) - print("---------------------------------------------------") - print("Debug information") - print("---------------------------------------------------") - tvm_ffi.core._print_debug_info() - release_gil = tvm_ffi.get_global_func("testing.nop").release_gil - print(f"TVM_FFI_RELEASE_GIL_BY_DEFAULT={int(release_gil)}") - print("---------------------------------------------------") - - -if __name__ == "__main__": - main() diff --git a/ffi/scripts/run_tests.sh b/ffi/scripts/run_tests.sh deleted file mode 100755 index 27795cc74512..000000000000 --- a/ffi/scripts/run_tests.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -set -euxo pipefail - -BUILD_TYPE=RelWithDebugInfo - -rm -rf build/CMakeCache.txt - -cmake -G Ninja -S . -B build -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -cmake --build build --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests -GTEST_COLOR=1 ctest -V -C ${BUILD_TYPE} --test-dir build --output-on-failure diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc deleted file mode 100644 index 5cf692ac2a18..000000000000 --- a/ffi/src/ffi/container.cc +++ /dev/null @@ -1,88 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/container.cc - */ -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -// Favor struct outside function scope as MSVC may have bug for in fn scope struct. -class MapForwardIterFunctor { - public: - MapForwardIterFunctor(ffi::MapObj::iterator iter, ffi::MapObj::iterator end) - : iter_(iter), end_(end) {} - // 0 get current key - // 1 get current value - // 2 move to next: return true if success, false if end - Any operator()(int command) const { - if (command == 0) { - return (*iter_).first; - } else if (command == 1) { - return (*iter_).second; - } else { - ++iter_; - if (iter_ == end_) { - return false; - } - return true; - } - } - - private: - mutable ffi::MapObj::iterator iter_; - ffi::MapObj::iterator end_; -}; - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def_packed("ffi.Array", - [](ffi::PackedArgs args, Any* ret) { - *ret = Array(args.data(), args.data() + args.size()); - }) - .def("ffi.ArrayGetItem", [](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }) - .def("ffi.ArraySize", - [](const ffi::ArrayObj* n) -> int64_t { return static_cast(n->size()); }) - .def_packed("ffi.Map", - [](ffi::PackedArgs args, Any* ret) { - TVM_FFI_ICHECK_EQ(args.size() % 2, 0); - Map data; - for (int i = 0; i < args.size(); i += 2) { - data.Set(args[i], args[i + 1]); - } - *ret = data; - }) - .def("ffi.MapSize", - [](const ffi::MapObj* n) -> int64_t { return static_cast(n->size()); }) - .def("ffi.MapGetItem", [](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }) - .def("ffi.MapCount", - [](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }) - .def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> ffi::Function { - return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end())); - }); -} -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc deleted file mode 100644 index e119f7733044..000000000000 --- a/ffi/src/ffi/dtype.cc +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Get the custom type name for a given type code. - */ -inline String DLDataTypeCodeGetCustomTypeName(DLDataTypeCode type_code) { - static Function fget_custom_type_name = Function::GetGlobalRequired("dtype.get_custom_type_name"); - return fget_custom_type_name(static_cast(type_code)).cast(); -} - -/*! - * \brief Get the custom type name for a given type code. - * \param str The string to parse. - * \param scan The scan pointer. - * \return The custom type name. - */ -inline int ParseCustomDataTypeCode(const std::string_view& str, const char** scan) { - TVM_FFI_ICHECK(str.substr(0, 6) == "custom") << "Not a valid custom datatype string"; - auto tmp = str.data(); - TVM_FFI_ICHECK(str.data() == tmp); - *scan = str.data() + 6; - TVM_FFI_ICHECK(str.data() == tmp); - if (**scan != '[') - TVM_FFI_THROW(ValueError) << "expected opening brace after 'custom' type in" << str; - TVM_FFI_ICHECK(str.data() == tmp); - *scan += 1; - TVM_FFI_ICHECK(str.data() == tmp); - size_t custom_name_len = 0; - TVM_FFI_ICHECK(str.data() == tmp); - while (*scan + custom_name_len <= str.data() + str.length() && - *(*scan + custom_name_len) != ']') { - ++custom_name_len; - } - TVM_FFI_ICHECK(str.data() == tmp); - if (*(*scan + custom_name_len) != ']') { - TVM_FFI_THROW(ValueError) << "expected closing brace after 'custom' type in" << str; - } - TVM_FFI_ICHECK(str.data() == tmp); - *scan += custom_name_len + 1; - TVM_FFI_ICHECK(str.data() == tmp); - auto type_name = str.substr(7, custom_name_len); - TVM_FFI_ICHECK(str.data() == tmp); - static Function fget_custom_type_code = Function::GetGlobalRequired("dtype.get_custom_type_code"); - return fget_custom_type_code(std::string(type_name)).cast(); -} - -/* - * \brief Convert a DLDataTypeCode to a string. - * \param os The output stream. - * \param type_code The DLDataTypeCode to convert. - */ -inline void PrintDLDataTypeCodeAsStr(std::ostream& os, DLDataTypeCode type_code) { // NOLINT(*) - switch (static_cast(type_code)) { - case kDLInt: { - os << "int"; - break; - } - case kDLUInt: { - os << "uint"; - break; - } - case kDLFloat: { - os << "float"; - break; - } - case kDLOpaqueHandle: { - os << "handle"; - break; - } - case kDLBfloat: { - os << "bfloat"; - break; - } - case kDLFloat8_e3m4: { - os << "float8_e3m4"; - break; - } - case kDLFloat8_e4m3: { - os << "float8_e4m3"; - break; - } - case kDLFloat8_e4m3b11fnuz: { - os << "float8_e4m3b11fnuz"; - break; - } - case kDLFloat8_e4m3fn: { - os << "float8_e4m3fn"; - break; - } - case kDLFloat8_e4m3fnuz: { - os << "float8_e4m3fnuz"; - break; - } - case kDLFloat8_e5m2: { - os << "float8_e5m2"; - break; - } - case kDLFloat8_e5m2fnuz: { - os << "float8_e5m2fnuz"; - break; - } - case kDLFloat8_e8m0fnu: { - os << "float8_e8m0fnu"; - break; - } - case kDLFloat6_e2m3fn: { - os << "float6_e2m3fn"; - break; - } - case kDLFloat6_e3m2fn: { - os << "float6_e3m2fn"; - break; - } - case kDLFloat4_e2m1fn: { - os << "float4_e2m1fn"; - break; - } - default: { - if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { - os << "custom[" << details::DLDataTypeCodeGetCustomTypeName(type_code) << "]"; - } else { - TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" - << static_cast(type_code); - } - TVM_FFI_UNREACHABLE(); - } - } -} -} // namespace details - -/*! - * \brief Printer function for DLDataType. - * \param os The output stream. - * \param dtype The DLDataType to print. - * \return The output stream. - */ -inline std::string DLDataTypeToString_(DLDataType dtype) { // NOLINT(*) - if (dtype.bits == 1 && dtype.lanes == 1 && dtype.code == kDLUInt) { - return "bool"; - } - // specially handle void - if (dtype.code == kDLOpaqueHandle && dtype.lanes == 0 && dtype.bits == 0) { - return ""; - } - - std::ostringstream os; - if (dtype.code >= kDLExtCustomBegin) { - os << "custom[" - << details::DLDataTypeCodeGetCustomTypeName(static_cast(dtype.code)) << "]"; - } else { - os << details::DLDataTypeCodeAsCStr(static_cast(dtype.code)); - } - if (dtype.code == kDLOpaqueHandle) return os.str(); - int16_t lanes = static_cast(dtype.lanes); - if (dtype.code < kDLFloat8_e3m4) { - os << static_cast(dtype.bits); - } - if (lanes > 1) { - os << 'x' << lanes; - } else if (lanes < -1) { - os << "xvscalex" << -lanes; - } - return os.str(); -} - -/*! - * \brief Parse a string to a DLDataType. - * \param str The string to convert. - * \return The corresponding DLDataType. - */ -inline DLDataType StringViewToDLDataType_(std::string_view str) { - DLDataType dtype; - // handle void type - if (str.length() == 0 || str == "void") { - dtype.code = kDLOpaqueHandle; - dtype.bits = 0; - dtype.lanes = 0; - return dtype; - } - // set the default values; - dtype.bits = 32; - dtype.lanes = 1; - const char* scan; - - auto parse_float = [&](const std::string_view& str, int offset, int code, int bits) { - dtype.code = static_cast(code); - dtype.bits = static_cast(bits); - scan = str.data() + offset; - char* endpt = nullptr; - if (*scan == 'x') { - dtype.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); - scan = endpt; - } - if (scan != str.data() + str.length()) { - TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; - } - return dtype; - }; - - if (str.compare(0, 3, "int") == 0) { - dtype.code = kDLInt; - scan = str.data() + 3; - } else if (str.compare(0, 4, "uint") == 0) { - dtype.code = kDLUInt; - scan = str.data() + 4; - } else if (str.compare(0, 5, "float") == 0) { - if (str.compare(5, 2, "8_") == 0) { - if (str.compare(7, 4, "e3m4") == 0) { - return parse_float(str, 11, kDLFloat8_e3m4, 8); - } else if (str.compare(7, 4, "e4m3") == 0) { - if (str.compare(11, 7, "b11fnuz") == 0) { - return parse_float(str, 18, kDLFloat8_e4m3b11fnuz, 8); - } else if (str.compare(11, 2, "fn") == 0) { - if (str.compare(13, 2, "uz") == 0) { - return parse_float(str, 15, kDLFloat8_e4m3fnuz, 8); - } else { - return parse_float(str, 13, kDLFloat8_e4m3fn, 8); - } - } else { - return parse_float(str, 11, kDLFloat8_e4m3, 8); - } - } else if (str.compare(7, 8, "e5m2fnuz") == 0) { - return parse_float(str, 15, kDLFloat8_e5m2fnuz, 8); - } else if (str.compare(7, 4, "e5m2") == 0) { - return parse_float(str, 11, kDLFloat8_e5m2, 8); - } else if (str.compare(7, 7, "e8m0fnu") == 0) { - return parse_float(str, 14, kDLFloat8_e8m0fnu, 8); - } else { - TVM_FFI_THROW(ValueError) << "unknown float8 type `" << str << '`'; - TVM_FFI_UNREACHABLE(); - } - } else if (str.compare(5, 2, "6_") == 0) { - if (str.compare(7, 6, "e2m3fn") == 0) { - return parse_float(str, 13, kDLFloat6_e2m3fn, 6); - } else if (str.compare(7, 6, "e3m2fn") == 0) { - return parse_float(str, 13, kDLFloat6_e3m2fn, 6); - } else { - TVM_FFI_THROW(ValueError) << "unknown float6 type `" << str << '`'; - TVM_FFI_UNREACHABLE(); - } - } else if (str.compare(5, 2, "4_") == 0) { - // kFloat4_e2m1fn - if (str.compare(7, 6, "e2m1fn") == 0) { - return parse_float(str, 13, kDLFloat4_e2m1fn, 4); - } else { - TVM_FFI_THROW(ValueError) << "unknown float4 type `" << str << '`'; - TVM_FFI_UNREACHABLE(); - } - } else { - dtype.code = kDLFloat; - scan = str.data() + 5; - } - } else if (str.compare(0, 6, "handle") == 0) { - dtype.code = kDLOpaqueHandle; - dtype.bits = 64; // handle uses 64 bit by default. - scan = str.data() + 6; - } else if (str == "bool") { - dtype.code = kDLUInt; - dtype.bits = 1; - dtype.lanes = 1; - return dtype; - } else if (str.compare(0, 6, "bfloat") == 0) { - dtype.code = kDLBfloat; - dtype.bits = 16; - scan = str.data() + 6; - } else if (str.compare(0, 6, "custom") == 0) { - dtype.code = static_cast(details::ParseCustomDataTypeCode(str, &scan)); - } else { - scan = str.data(); - TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; - } - char* xdelim; // emulate sscanf("%ux%u", bits, lanes) - uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); - if (bits != 0) dtype.bits = bits; - int scalable_multiplier = 1; - if (strncmp(xdelim, "xvscale", 7) == 0) { - scalable_multiplier = -1; - xdelim += 7; - } - char* endpt = xdelim; - if (*xdelim == 'x') { - dtype.lanes = static_cast(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10)); - } - if (endpt != str.data() + str.length()) { - TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; - } - return dtype; -} - -} // namespace ffi -} // namespace tvm - -int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::StringViewToDLDataType_(std::string_view(str->data, str->size)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype)); - tvm::ffi::TypeTraits::MoveToAny(std::move(out_str), out); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/error.cc b/ffi/src/ffi/error.cc deleted file mode 100644 index ba8dbbfb5828..000000000000 --- a/ffi/src/ffi/error.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/error.cc - * \brief Error handling implementation - */ -#include -#include - -namespace tvm { -namespace ffi { - -class SafeCallContext { - public: - void SetRaised(TVMFFIObjectHandle error) { - last_error_ = - details::ObjectUnsafe::ObjectPtrFromUnowned(static_cast(error)); - } - - void SetRaisedByCstr(const char* kind, const char* message, const TVMFFIByteArray* traceback) { - Error error(kind, message, traceback); - last_error_ = details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(error)); - } - - void MoveFromRaised(TVMFFIObjectHandle* result) { - result[0] = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(last_error_)); - } - - static SafeCallContext* ThreadLocal() { - static thread_local SafeCallContext ctx; - return &ctx; - } - - private: - ObjectPtr last_error_; -}; - -} // namespace ffi -} // namespace tvm - -void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message) { - // NOTE: run traceback here to simplify the depth of tracekback - tvm::ffi::SafeCallContext::ThreadLocal()->SetRaisedByCstr( - kind, message, TVMFFITraceback(nullptr, 0, nullptr, 0)); -} - -void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) { - tvm::ffi::SafeCallContext::ThreadLocal()->SetRaised(error); -} - -void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) { - tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromRaised(result); -} - -TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message, - const TVMFFIByteArray* traceback) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - tvm::ffi::Error error(std::string(kind->data, kind->size), - std::string(message->data, message->size), - std::string(traceback->data, traceback->size)); - TVMFFIObjectHandle out = - tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(error)); - return out; - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorCreate); -} diff --git a/ffi/src/ffi/extra/buffer_stream.h b/ffi/src/ffi/extra/buffer_stream.h deleted file mode 100644 index f6f162676607..000000000000 --- a/ffi/src/ffi/extra/buffer_stream.h +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file buffer_stream.h - * \brief Internal minimal stream helper to read from a buffer. - */ -#ifndef TVM_FFI_EXTRA_BUFFER_STREAM_H_ -#define TVM_FFI_EXTRA_BUFFER_STREAM_H_ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Lightweight stream helper to read from a buffer. - */ -class BufferInStream { - public: - /*! - * \brief constructor - * \param p_buffer the head pointer of the memory region. - * \param buffer_size the size of the memorybuffer - */ - BufferInStream(const void* data, size_t size) - : data_(reinterpret_cast(data)), size_(size) {} - /*! - * \brief Reads raw from stream. - * \param ptr pointer to the data to be read - * \param size the size of the data to be read - * \return the number of bytes read - */ - size_t Read(void* ptr, size_t size) { - size_t nread = std::min(size_ - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, data_ + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - /*! - * \brief Reads arithmetic data from stream in endian-aware manner. - * \param data data to be read - * \tparam T the data type to be read - * \return whether the read was successful - */ - template >> - bool Read(T* data) { - bool ret = Read(static_cast(data), sizeof(T)) == sizeof(T); // NOLINT(*) - if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { - ByteSwap(&data, sizeof(T), 1); - } - return ret; - } - /*! - * \brief Reads an array of data from stream in endian-aware manner. - * \param data data to be read - * \param size the size of the data to be read - * \return whether the read was successful - */ - template >> - bool ReadArray(T* data, size_t size) { - bool ret = - this->Read(static_cast(data), sizeof(T) * size) == sizeof(T) * size; // NOLINT(*) - if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { - ByteSwap(data, sizeof(T), size); - } - return ret; - } - /*! - * \brief Reads a string from stream. - * \param data data to be read - * \return whether the read was successful - */ - bool Read(std::string* data) { - // use uint64_t to ensure platform independent size - uint64_t size = 0; - if (!this->Read(&size)) return false; - data->resize(size); - if (!this->Read(data->data(), size)) return false; - return true; - } - /*! - * \brief Reads a vector of data from stream in endian-aware manner. - * \param data data to be read - * \return whether the read was successful - */ - template >> - bool Read(std::vector* data) { - uint64_t size = 0; - if (!this->Read(&size)) return false; - data->resize(size); - return this->ReadArray(data->data(), size); - } - - private: - /*! \brief in memory buffer */ - const char* data_; - /*! \brief size of the buffer */ - size_t size_; - /*! \brief current pointer */ - size_t curr_ptr_{0}; -}; // class BytesInStream - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_BUFFER_STREAM_H_ diff --git a/ffi/src/ffi/extra/env_c_api.cc b/ffi/src/ffi/extra/env_c_api.cc deleted file mode 100644 index 121cc9a3ccde..000000000000 --- a/ffi/src/ffi/extra/env_c_api.cc +++ /dev/null @@ -1,148 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/env_c_api.cc - * \brief Environment C API implementation. - */ -#include -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Execution environment specific API registry. - * - * This registry stores C API function pointers about - * execution environment(e.g. python) specific API function that - * we need for specific low-level handling(e.g. signal checking). - * - * We only stores the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). Always consider use the Function FFI when possible - * in other cases. - */ -class EnvCAPIRegistry { - public: - /*! - * \brief Callback to check if signals have been sent to the process and - * if so invoke the registered signal handler in the frontend environment. - * - * When running FFI in another language (Python), the signal handler - * may not be immediately executed, but instead the signal is marked - * in the interpreter state (to ensure non-blocking of the signal handler). - * - * \return 0 if no error happens, -1 if error happens. - */ - typedef int (*F_PyErr_CheckSignals)(); - - /*! \brief Callback to increment/decrement the python ref count */ - typedef void (*F_Py_IncDefRef)(void*); - - /*! - * \brief PyErr_CheckSignal function - */ - F_PyErr_CheckSignals pyerr_check_signals = nullptr; - - /*! - \brief PyGILState_Ensure function - */ - void* (*py_gil_state_ensure)() = nullptr; - - /*! - \brief PyGILState_Release function - */ - void (*py_gil_state_release)(void*) = nullptr; - - static EnvCAPIRegistry* Global() { - static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); - return inst; - } - - // register environment(e.g. python) specific api functions - void Register(const String& symbol_name, void* fptr) { - if (symbol_name == "PyErr_CheckSignals") { - Update(symbol_name, &pyerr_check_signals, fptr); - } else if (symbol_name == "PyGILState_Ensure") { - Update(symbol_name, &py_gil_state_ensure, fptr); - } else if (symbol_name == "PyGILState_Release") { - Update(symbol_name, &py_gil_state_release, fptr); - } else { - TVM_FFI_THROW(ValueError) << "Unknown env API " + symbol_name; - } - } - - int EnvCheckSignals() { - // check python signal to see if there are exception raised - if (pyerr_check_signals != nullptr) { - // The C++ env comes without gil, so we need to grab gil here - WithGIL context(this); - if ((*pyerr_check_signals)() != 0) { - // The error will let FFI know that the frontend environment - // already set an error. - return -1; - } - } - return 0; - } - - private: - // update the internal API table - template - void Update(const String& symbol_name, FType* target, void* ptr) { - FType ptr_casted = reinterpret_cast(ptr); - target[0] = ptr_casted; - } - - struct WithGIL { - explicit WithGIL(EnvCAPIRegistry* self) : self(self) { - TVM_FFI_ICHECK(self->py_gil_state_ensure); - TVM_FFI_ICHECK(self->py_gil_state_release); - gil_state = self->py_gil_state_ensure(); - } - ~WithGIL() { - if (self && gil_state) { - self->py_gil_state_release(gil_state); - } - } - WithGIL(const WithGIL&) = delete; - WithGIL(WithGIL&&) = delete; - WithGIL& operator=(const WithGIL&) = delete; - WithGIL& operator=(WithGIL&&) = delete; - - EnvCAPIRegistry* self = nullptr; - void* gil_state = nullptr; - }; -}; -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvCheckSignals() { return tvm::ffi::EnvCAPIRegistry::Global()->EnvCheckSignals(); } - -/*! - * \brief Register a symbol into the from the surrounding env. - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String s_name(name); - tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/extra/env_context.cc b/ffi/src/ffi/extra/env_context.cc deleted file mode 100644 index 30f9270dabc7..000000000000 --- a/ffi/src/ffi/extra/env_context.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/env_context.cc - * - * \brief A minimalistic env context based on ffi values. - */ - -#include -#include - -#include - -namespace tvm { -namespace ffi { - -class EnvContext { - public: - void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - if (static_cast(device_type) >= stream_table_.size()) { - stream_table_.resize(device_type + 1); - } - if (static_cast(device_id) >= stream_table_[device_type].size()) { - stream_table_[device_type].resize(device_id + 1, nullptr); - } - if (out_original_stream != nullptr) { - *out_original_stream = stream_table_[device_type][device_id]; - } - stream_table_[device_type][device_id] = stream; - } - - TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) { - if (static_cast(device_type) < stream_table_.size() && - static_cast(device_id) < stream_table_[device_type].size()) { - return stream_table_[device_type][device_id]; - } - return nullptr; - } - - DLPackTensorAllocator GetDLPackTensorAllocator() { - if (dlpack_allocator_ != nullptr) { - return dlpack_allocator_; - } - return GlobalTensorAllocator(); - } - - void SetDLPackTensorAllocator(DLPackTensorAllocator allocator, int write_to_global_context, - DLPackTensorAllocator* opt_out_original_allocator) { - dlpack_allocator_ = allocator; - if (write_to_global_context != 0) { - GlobalTensorAllocator() = allocator; - } - if (opt_out_original_allocator != nullptr) { - *opt_out_original_allocator = dlpack_allocator_; - } - dlpack_allocator_ = allocator; - } - - static EnvContext* ThreadLocal() { - static thread_local EnvContext inst; - return &inst; - } - - private: - // use static function to avoid static initialization order issue - static DLPackTensorAllocator& GlobalTensorAllocator() { // NOLINT(*) - static DLPackTensorAllocator allocator = nullptr; - return allocator; - } - std::vector> stream_table_; - DLPackTensorAllocator dlpack_allocator_ = nullptr; -}; - -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::EnvContext::ThreadLocal()->SetStream(device_type, device_id, stream, - out_original_stream); - TVM_FFI_SAFE_CALL_END(); -} - -TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::EnvContext::ThreadLocal()->GetStream(device_type, device_id); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetStream); -} - -int TVMFFIEnvSetTensorAllocator(DLPackTensorAllocator allocator, int write_to_global_context, - DLPackTensorAllocator* opt_out_original_allocator) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::EnvContext::ThreadLocal()->SetDLPackTensorAllocator(allocator, write_to_global_context, - opt_out_original_allocator); - TVM_FFI_SAFE_CALL_END(); -} - -DLPackTensorAllocator TVMFFIEnvGetTensorAllocator() { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::EnvContext::ThreadLocal()->GetDLPackTensorAllocator(); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetTensorAllocator); -} diff --git a/ffi/src/ffi/extra/json_parser.cc b/ffi/src/ffi/extra/json_parser.cc deleted file mode 100644 index dddb782d448e..000000000000 --- a/ffi/src/ffi/extra/json_parser.cc +++ /dev/null @@ -1,731 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/json/parser.cc - * - * \brief A minimalistic JSON parser based on ffi values. - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -namespace json { - -/*! - * \brief Helper class to parse a JSON string. - * - * Keep leaf level string/number parse also in context. - */ -class JSONParserContext { - public: - JSONParserContext(const char* begin, const char* end) : begin_(begin), cur_(begin), end_(end) { - last_line_begin_ = cur_; - } - - /*! - * \brief Peek the current character. - * \return The current character, or -1 if the end of the string is reached. - */ - int Peek() const { - return (cur_ != end_ ? static_cast(*reinterpret_cast(cur_)) : -1); - } - - /*! - * \brief Skip the next char that we know is not a space - * - * \note Caller must explicitly call SkipSpaces first or use - * Peek already that confirms char is not any space char. - */ - void SkipNextAssumeNoSpace() { ++cur_; } - - /*! - * \brief Get the current position. - * \return The current position. - */ - const char* GetCurrentPos() const { return cur_; } - - /*! - * \brief Set the current position for better error message - * \param pos The new position. - * \note implementation can do it as no-op if needed - */ - void SetCurrentPosForBetterErrorMsg(const char* pos) { cur_ = pos; } - - /*! - * \brief Skip the space characters. - * \note This function does not check if the end of the string is reached. - */ - void SkipSpaces() { - while (cur_ != end_) { - if (!(*cur_ == ' ' || *cur_ == '\t' || *cur_ == '\n' || *cur_ == '\r')) { - break; - } - if (*cur_ == '\n') { - ++line_counter_; - last_line_begin_ = cur_ + 1; - } - ++cur_; - } - } - - /*! - * \brief Check if the next characters match the given string. - * \param str The string to match. - * \param len The length of the string. - * \return True if the next characters match the given string, false otherwise. - */ - bool MatchLiteral(const char* pattern, int len) { - const char* pend = pattern + len; - const char* ptr = pattern; - for (; ptr != pend && cur_ != end_; ++ptr, ++cur_) { - if (*ptr != *cur_) { - return false; - } - } - // we get to the end of the pattern and match is successful - return ptr == pend; - } - - /* - * \brief Parse the next strin starting with a double quote. - * \param out The output string. - * \return Whether the next string parsing is successful. - */ - bool NextString(json::Value* out) { - // NOTE: we keep string parsing logic here to allow some special - // optimizations for simple string that do not e - const char* start_pos = cur_; - TVM_FFI_ICHECK(*cur_ == '\"'); - // skip first double quote - ++cur_; - // the loop focuses on simple string without escape characters - for (; cur_ != end_; ++cur_) { - if (*cur_ == '\"') { - *out = String(start_pos + 1, cur_ - start_pos - 1); - ++cur_; - return true; - } - if (*cur_ < ' ' || *cur_ == '\\') { - // fallback to full string handling - return this->NextStringWithFullHandling(out, start_pos); - } - } - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorUnterminatedString(); - return false; - } - - /*! - * \brief Parse the next number. - * \param out The output number. - * \return Whether the next number parsing is successful. - */ - bool NextNumber(json::Value* out) { - const char* start_pos = cur_; - if (cur_ == end_) { - this->SetErrorExpectingValue(); - return false; - } - // JSON number grammar: - // - // number = [ minus ] int [ frac ] [ exp ] - // decimal-point = %x2E ; . - // digit1-9 = %x31-39 ; 1-9 - // e = %x65 / %x45 ; e E - // exp = e [ minus / plus ] 1*DIGIT - // frac = decimal-point 1*DIGIT - std::string temp_buffer; - bool maybe_int = true; - // parse [minus], cross check for Infinity/NaN/-Infinity - if (*cur_ == '-') { - temp_buffer.push_back('-'); - ++cur_; - if (cur_ != end_ && *cur_ == 'I') { - if (this->MatchLiteral("Infinity", 8)) { - *out = FastMathSafeNegInf(); - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } - } else if (*cur_ == 'I') { - if (this->MatchLiteral("Infinity", 8)) { - *out = FastMathSafePosInf(); - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } else if (*cur_ == 'N') { - if (this->MatchLiteral("NaN", 3)) { - *out = FastMathSafeNaN(); - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } - // read in all parts that are possibly part of a number - while (cur_ != end_) { - char next_char = *cur_; - if ((next_char >= '0' && next_char <= '9') || next_char == 'e' || next_char == 'E' || - next_char == '+' || next_char == '-' || next_char == '.') { - temp_buffer.push_back(next_char); - if (next_char == '.' || next_char == 'e' || next_char == 'E') { - maybe_int = false; - } - ++cur_; - } else { - break; - } - } - if (temp_buffer.empty()) { - this->SetErrorExpectingValue(); - return false; - } - // parse from temp_buffer_ - if (maybe_int) { - // now try to parse the number as int64 - char* end_ptr; - errno = 0; - intmax_t int_val = strtoimax(temp_buffer.data(), &end_ptr, 10); - if (errno == 0 && int_val >= std::numeric_limits::min() && - int_val <= std::numeric_limits::max() && - end_ptr == temp_buffer.data() + temp_buffer.size()) { - *out = static_cast(int_val); - return true; - } - } - { - // now try to parse number as double - char* end_ptr; - errno = 0; - double double_val = strtod(temp_buffer.data(), &end_ptr); - if (errno == 0 && end_ptr == temp_buffer.data() + temp_buffer.size()) { - *out = double_val; - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } - } - - /*! - * \brief Get the current line context. - * \return The current line context. - */ - String GetSyntaxErrorContext(std::string err_prefix) const { - int64_t column = static_cast(cur_ - last_line_begin_) + 1; - int64_t char_pos = static_cast(cur_ - begin_); - if (err_prefix.empty()) { - err_prefix = "Syntax error"; - } - err_prefix += ": line " + std::to_string(line_counter_) + " column " + std::to_string(column) + - " (char " + std::to_string(char_pos) + ")"; - return String(err_prefix); - } - - std::string FinalizeErrorMsg() { - if (error_msg_.empty()) { - SetErrorDefault(); - } - return std::string(error_msg_); - } - - void SetErrorDefault() { error_msg_ = GetSyntaxErrorContext("Syntax error near"); } - - void SetErrorExpectingValue() { error_msg_ = GetSyntaxErrorContext("Expecting value"); } - - void SetErrorInvalidControlCharacter() { - error_msg_ = GetSyntaxErrorContext("Invalid control character at"); - } - - void SetErrorUnterminatedString() { - error_msg_ = GetSyntaxErrorContext("Unterminated string starting at"); - } - - void SetErrorInvalidUnicodeEscape() { - error_msg_ = GetSyntaxErrorContext("Invalid \\uXXXX escape"); - } - - void SetErrorInvalidSurrogatePair() { - error_msg_ = GetSyntaxErrorContext("Invalid surrogate pair of \\uXXXX escapes"); - } - - void SetErrorInvalidEscape() { error_msg_ = GetSyntaxErrorContext("Invalid \\escape"); } - - void SetErrorExtraData() { error_msg_ = GetSyntaxErrorContext("Extra data"); } - - void SetErrorExpectingPropertyName() { - error_msg_ = GetSyntaxErrorContext("Expecting property name enclosed in double quotes"); - } - - void SetErrorExpectingColon() { error_msg_ = GetSyntaxErrorContext("Expecting \':\' delimiter"); } - - void SetErrorExpectingComma() { error_msg_ = GetSyntaxErrorContext("Expecting \',\' delimiter"); } - - private: - static double FastMathSafePosInf() { -#ifdef __FAST_MATH__ - union { - uint64_t from; - double to; - } u; - u.from = 0x7FF0000000000000ULL; // write "from", read "to" - return u.to; -#else - return std::numeric_limits::infinity(); -#endif - } - - static double FastMathSafeNegInf() { -#ifdef __FAST_MATH__ - union { - uint64_t from; - double to; - } u; - u.from = 0xFFF0000000000000ULL; // write "from", read "to" - return u.to; -#else - return -std::numeric_limits::infinity(); -#endif - } - - static double FastMathSafeNaN() { -#ifdef __FAST_MATH__ - union { - uint64_t from; - double to; - } u; - u.from = 0x7FF8000000000000ULL; // write "from", read "to" - return u.to; -#else - return std::numeric_limits::quiet_NaN(); -#endif - } - - // Full string parsing with escape and unicode handling - bool NextStringWithFullHandling(Any* out, const char* start_pos) { - // copy over the prefix that was already parsed - std::string out_str(start_pos + 1, cur_ - start_pos - 1); - while (cur_ != end_) { - if (*cur_ < ' ') { - this->SetErrorInvalidControlCharacter(); - return false; - } - if (*cur_ == '\"') { - *out = String(std::move(out_str)); - ++cur_; - return true; - } - if (*cur_ == '\\') { - ++cur_; - switch (*cur_) { - // handle escape characters per JSON spec(RFC 8259) -#define HANDLE_ESCAPE_CHAR(pattern, val) \ - case pattern: \ - ++cur_; \ - out_str.push_back(val); \ - break - HANDLE_ESCAPE_CHAR('\"', '\"'); - HANDLE_ESCAPE_CHAR('\\', '\\'); - HANDLE_ESCAPE_CHAR('/', '/'); - HANDLE_ESCAPE_CHAR('b', '\b'); - HANDLE_ESCAPE_CHAR('f', '\f'); - HANDLE_ESCAPE_CHAR('n', '\n'); - HANDLE_ESCAPE_CHAR('r', '\r'); - HANDLE_ESCAPE_CHAR('t', '\t'); -#undef HANDLE_ESCAPE_CHAR - case 'u': { - const char* escape_pos = cur_; - // handle unicode code point - ++cur_; - int32_t first_i16, code_point = 0; - if (!Parse4Hex(&first_i16)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidUnicodeEscape(); - return false; - } - // Check if the first i16 is a UTF-16 surrogate pair - // - // Surrogate pair encoding rule: - // U' = yyyyyyyyyyxxxxxxxxxx // U - 0x10000 - // W1 = 110110yyyyyyyyyy // 0xD800 + yyyyyyyyyy - // W2 = 110111xxxxxxxxxx // 0xDC00 + xxxxxxxxxx - // - // Range of W1 and W2: - // 0xD800 - 0xDBFF for W1 - // 0xDC00 - 0xDFFF for W2 - // both W1 and W2 fit into 0xD800 - 0xDFFF - // Detect if the first i16 fit into range of W1/W2 - if (first_i16 >= 0xD800 && first_i16 <= 0xDFFF) { - // we are in the surrogate pair range - if (first_i16 >= 0xDC00) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidSurrogatePair(); - // we need to return false instead because this range is for W2 - return false; - } - if (!this->MatchLiteral("\\u", 2)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidSurrogatePair(); - return false; - } - escape_pos = cur_; - // get the value of the W2 (second i16) - int32_t second_i16; - if (!Parse4Hex(&second_i16)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidUnicodeEscape(); - return false; - } - if (!(second_i16 >= 0xDC00 && second_i16 <= 0xDFFF)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidSurrogatePair(); - return false; - } - // recover the code point - code_point = ((first_i16 - 0xD800) << 10) + (second_i16 - 0xDC00) + 0x10000; - } else { - // not a surrogate case, just assign as code point - code_point = first_i16; - } - // now need to push back the string based on UTF-8 encoding - // UTF-8 encoding rule: four cases - // ------------------------------------------------------------ - // Pattern | code point range - // ------------------------------------------------------------ - // 0xxxxxxx | 0x0 - 0x7F - // 110xxxxx 10xxxxxx | 0x80 - 0x7FF - // 1110xxxx 10xxxxxx 10xxxxxx | 0x800 - 0xFFFF - // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx | 0x10000 - end - // ------------------------------------------------------------ - if (code_point < 0x80) { - out_str.push_back(code_point); - } else if (code_point < 0x800) { - // first byte: 110xxxxx (5 effective bits) - // second byte: 10xxxxxx (6 effecive bits) - // shift by 6 bits to get the first bytes - out_str.push_back(0xC0 | (code_point >> 6)); - // mask by 6 effective bits - out_str.push_back(0x80 | (code_point & 0x3F)); - } else if (code_point < 0x10000) { - // first byte: 1110xxxx (4 effective bits) - // second byte: 10xxxxxx (6 effecive bits) - // third byte: 10xxxxxx (6 effecive bits) - // shift by 12 bits to get the first bytes - out_str.push_back(0xE0 | (code_point >> 12)); - // shift by 6 bits to get the second bytes, mask by 6 effective bits - out_str.push_back(0x80 | ((code_point >> 6) & 0x3F)); - // mask by 6 effective bits - out_str.push_back(0x80 | (code_point & 0x3F)); - } else { - // first byte: 11110xxx (3 effective bits) - // second byte: 10xxxxxx (6 effecive bits) - // third byte: 10xxxxxx (6 effecive bits) - // fourth byte: 10xxxxxx (6 effecive bits) - // shift by 18 bits to get the first bytes - out_str.push_back(0xF0 | (code_point >> 18)); - // shift by 12 bits to get the second bytes, mask by 6 effective bits - out_str.push_back(0x80 | ((code_point >> 12) & 0x3F)); - // shift by 6 bits to get the third bytes, mask by 6 effective bits - out_str.push_back(0x80 | ((code_point >> 6) & 0x3F)); - // mask by 6 effective bits - out_str.push_back(0x80 | (code_point & 0x3F)); - } - break; - } - default: { - this->SetErrorInvalidEscape(); - return false; - } - } - } else { - out_str.push_back(*cur_); - ++cur_; - } - } - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorUnterminatedString(); - return false; - } - /*! - * \brief Parse the four hex digits of a unicode code point per json spec. - * \param out_i16 The output i16 number - * \return True if four hex digits are parsed successfully, false otherwise. - */ - bool Parse4Hex(int32_t* out_i16) { - int32_t result = 0; - for (int i = 0; i < 4; ++i, ++cur_) { - int hex_val = *reinterpret_cast(cur_); - if (hex_val >= '0' && hex_val <= '9') { - hex_val -= '0'; - } else if (hex_val >= 'a' && hex_val <= 'f') { - hex_val -= 'a' - 0xa; - } else if (hex_val >= 'A' && hex_val <= 'F') { - hex_val -= 'A' - 0xa; - } else { - return false; - } - result = result * 16 + hex_val; - } - *out_i16 = result; - return true; - } - - /*! \brief The beginning of the string */ - const char* begin_; - /*! \brief The current pointer */ - const char* cur_; - /*! \brief End of the string */ - const char* end_; - /*! \brief The beginning of the last line */ - const char* last_line_begin_; - /*! \brief The error message */ - std::string error_msg_; - /*! \brief The line counter */ - int64_t line_counter_{1}; -}; - -class JSONParser { - public: - static json::Value Parse(const String& json_str, String* error_msg) { - JSONParser parser(json_str); - json::Value result; - if (parser.ParseValue(&result) && parser.ParseTail()) { - if (error_msg != nullptr) { - *error_msg = String(""); - } - return result; - } - if (error_msg != nullptr) { - *error_msg = parser.ctx_.FinalizeErrorMsg(); - TVM_FFI_ICHECK(!error_msg->empty()); - } else { - TVM_FFI_THROW(ValueError) << parser.ctx_.FinalizeErrorMsg(); - } - // note that when we don't throw, error msg is set to indicate - // an error happens - return nullptr; - } - - private: - explicit JSONParser(String json_str) : ctx_(json_str.data(), json_str.data() + json_str.size()) {} - - bool ParseTail() { - ctx_.SkipSpaces(); - // there are extra data in the tail - if (ctx_.Peek() != -1) { - ctx_.SetErrorExtraData(); - return false; - } - return true; - } - - bool ParseValue(json::Value* out) { - ctx_.SkipSpaces(); - // record start pos for cases where we might need to reset - // current position for better error message - auto start_pos = ctx_.GetCurrentPos(); - // check if the end of the string is reached - switch (ctx_.Peek()) { - case -1: { - ctx_.SetErrorExpectingValue(); - return false; - } - case '{': { - return ParseObject(out); - } - case '[': { - return ParseArray(out); - } - case '\"': { - return ctx_.NextString(out); - } - case 't': { - ctx_.SkipNextAssumeNoSpace(); - if (ctx_.MatchLiteral("rue", 3)) { - *out = true; - return true; - } else { - ctx_.SetCurrentPosForBetterErrorMsg(start_pos); - ctx_.SetErrorExpectingValue(); - return false; - } - } - case 'f': { - ctx_.SkipNextAssumeNoSpace(); - if (ctx_.MatchLiteral("alse", 4)) { - *out = false; - return true; - } else { - ctx_.SetCurrentPosForBetterErrorMsg(start_pos); - ctx_.SetErrorExpectingValue(); - return false; - } - } - case 'n': { - ctx_.SkipNextAssumeNoSpace(); - if (ctx_.MatchLiteral("ull", 3)) { - *out = nullptr; - return true; - } else { - ctx_.SetCurrentPosForBetterErrorMsg(start_pos); - ctx_.SetErrorExpectingValue(); - return false; - } - } - default: { - return ctx_.NextNumber(out); - } - } - return false; - } - - bool ParseObject(json::Value* out) { - size_t stack_top = object_temp_stack_.size(); - json::Object result; - ctx_.SkipNextAssumeNoSpace(); - ctx_.SkipSpaces(); - int next_char = ctx_.Peek(); - if (next_char == -1) { - ctx_.SetErrorExpectingPropertyName(); - return false; - } - // empty object - if (next_char == '}') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Object(); - return true; - } - // non-empty object - while ((next_char = ctx_.Peek()) != -1) { - if (next_char != '\"') { - ctx_.SetErrorExpectingPropertyName(); - return false; - } - json::Value key; - if (!ctx_.NextString(&key)) return false; - ctx_.SkipSpaces(); - if (ctx_.Peek() != ':') { - ctx_.SetErrorExpectingColon(); - return false; - } - ctx_.SkipNextAssumeNoSpace(); - json::Value value; - if (!ParseValue(&value)) return false; - object_temp_stack_.emplace_back(key, value); - // result.Set(key, value); - ctx_.SkipSpaces(); - if (ctx_.Peek() == '}') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Object(object_temp_stack_.begin() + stack_top, object_temp_stack_.end()); - // recover the stack to original state - object_temp_stack_.resize(stack_top); - return true; - } else if (ctx_.Peek() == ',') { - ctx_.SkipNextAssumeNoSpace(); - // must skip space so next iteration do not have to do so - ctx_.SkipSpaces(); - } else { - ctx_.SetErrorExpectingComma(); - return false; - } - } - return false; - } - - bool ParseArray(json::Value* out) { - size_t stack_top = array_temp_stack_.size(); - ctx_.SkipNextAssumeNoSpace(); - ctx_.SkipSpaces(); - int next_char = ctx_.Peek(); - if (next_char == -1) { - ctx_.SetErrorExpectingValue(); - return false; - } - // empty array - if (next_char == ']') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Array(); - return true; - } - // non-empty array - while ((next_char = ctx_.Peek()) != -1) { - json::Value value; - // no need to skip space here because we already skipped space - // at the beginning or in previous iteration - if (!ParseValue(&value)) return false; - array_temp_stack_.emplace_back(std::move(value)); - ctx_.SkipSpaces(); - next_char = ctx_.Peek(); - if (next_char == ',') { - ctx_.SkipNextAssumeNoSpace(); - // must skip space so next iteration do not have to do so - ctx_.SkipSpaces(); - } else if (next_char == ']') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Array(array_temp_stack_.begin() + stack_top, array_temp_stack_.end()); - // recover the stack - array_temp_stack_.resize(stack_top); - return true; - } else { - ctx_.SetErrorExpectingComma(); - return false; - } - } - return false; - } - - JSONParserContext ctx_; - // Temp stack for intermediate values - // we first create a persistent stack to store the parsed values - // then create the final array/object object with the precise size - std::vector array_temp_stack_; - std::vector> object_temp_stack_; -}; - -json::Value Parse(const String& json_str, String* error_msg) { - return JSONParser::Parse(json_str, error_msg); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.json.Parse", - [](const String& json_str) { return json::Parse(json_str); }); -} - -} // namespace json -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/json_writer.cc b/ffi/src/ffi/extra/json_writer.cc deleted file mode 100644 index 1a4636d2ecd3..000000000000 --- a/ffi/src/ffi/extra/json_writer.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/json/writer.cc - * - * \brief A minimalistic JSON writer based on ffi values. - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -#define TVM_FFI_SNPRINTF _snprintf_s -#pragma warning(push) -#pragma warning(disable : 4244) -#pragma warning(disable : 4127) -#pragma warning(disable : 4702) -#else -#define TVM_FFI_SNPRINTF snprintf -#endif - -namespace tvm { -namespace ffi { -namespace json { - -class JSONWriter { - public: - static String Stringify(const json::Value& value, Optional indent) { - JSONWriter writer(indent.value_or(0)); - writer.WriteValue(value); - return String(std::move(writer.result_)); - } - - private: - explicit JSONWriter(int indent) : indent_(indent), out_iter_(result_) {} - - static bool FastMathSafeIsNaN(double x) { -#ifdef __FAST_MATH__ - // Bit-level NaN detection (IEEE 754 double) - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // NaN is encoded as all 1s in the exponent and non-zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - union { - double from; - uint64_t to; - } u; - u.from = x; // write "from", read "to" - uint64_t bits = u.to; - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - return (exponent == 0x7FF) && (mantissa != 0); -#else - // Safe to use std::isnan when fast-math is off - return std::isnan(x); -#endif - } - - static bool FastMathSafeIsInf(double x) { -#ifdef __FAST_MATH__ - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // Inf is encoded as all 1s in the exponent and zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - union { - double from; - uint64_t to; - } u; - u.from = x; // write "from", read "to" - uint64_t bits = u.to; - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - // inf is encoded as all 1s in the exponent and zero in the mantissa - return (exponent == 0x7FF) && (mantissa == 0); -#else - return std::isinf(x); -#endif - } - - void WriteValue(const json::Value& value) { - switch (value.type_index()) { - case TypeIndex::kTVMFFINone: { - WriteLiteral("null", 4); - break; - } - case TypeIndex::kTVMFFIBool: { - bool bool_value = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - if (bool_value) { - WriteLiteral("true", 4); - } else { - WriteLiteral("false", 5); - } - break; - } - case TypeIndex::kTVMFFIInt: { - WriteInt(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIFloat: { - WriteFloat(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFISmallStr: - case TypeIndex::kTVMFFIStr: { - WriteString(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIArray: { - WriteArray(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIMap: { - WriteObject(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - default: { - TVM_FFI_THROW(ValueError) << "Unsupported type: `" << value.GetTypeKey() << "`"; - TVM_FFI_UNREACHABLE(); - } - } - } - - void WriteLiteral(const char* literal, int size) { - for (int i = 0; i < size; ++i) { - *out_iter_++ = literal[i]; - } - } - - void WriteInt(int64_t value) { - // the biggest possible string representation of -INT64_MIN - char buffer[sizeof("-9223372036854775808") + 1]; - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "%" PRId64, value); - WriteLiteral(buffer, size); - } - - void WriteFloat(double value) { - // largest possible string representation of a double is around 24 chars plus - // one null terminator keep 32 to be safe - char buffer[32]; - if (FastMathSafeIsNaN(value)) { - WriteLiteral("NaN", 3); - } else if (FastMathSafeIsInf(value)) { - if (value < 0) { - WriteLiteral("-Infinity", 9); - } else { - WriteLiteral("Infinity", 8); - } - } else { - double int_part; - // if the value can be represented as integer - if (std::fabs(value) < (1ULL << 53) && std::modf(value, &int_part) == 0) { - // always print an extra .0 for integer so integer numbers are printed as floats - // this helps us to distinguish between integer and float, which is not necessary - // but helps to ensure roundtrip property of the parser/printer in terms of int/float types - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "%.1f", int_part); - WriteLiteral(buffer, size); - } else { - // Save 17 decimal digits to avoid loss during loading JSON - // this is the maximum precision that can be represented in a double - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "%.17g", value); - WriteLiteral(buffer, size); - } - } - } - - void WriteString(const String& value) { - *out_iter_++ = '"'; - const char* data = value.data(); - const size_t size = value.size(); - for (size_t i = 0; i < size; ++i) { - switch (data[i]) { -// handle escape characters per JSON spec(RFC 8259) -#define HANDLE_ESCAPE_CHAR(pattern, val) \ - case pattern: \ - WriteLiteral(val, std::char_traits::length(val)); \ - break - HANDLE_ESCAPE_CHAR('\"', "\\\""); - HANDLE_ESCAPE_CHAR('\\', "\\\\"); - HANDLE_ESCAPE_CHAR('/', "\\/"); - HANDLE_ESCAPE_CHAR('\b', "\\b"); - HANDLE_ESCAPE_CHAR('\f', "\\f"); - HANDLE_ESCAPE_CHAR('\n', "\\n"); - HANDLE_ESCAPE_CHAR('\r', "\\r"); - HANDLE_ESCAPE_CHAR('\t', "\\t"); -#undef HANDLE_ESCAPE_CHAR - default: { - uint8_t u8_val = static_cast(data[i]); - // this is a control character, print as \uXXXX - if (u8_val < 0x20 || u8_val == 0x7f) { - char buffer[8]; - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "\\u%04x", - static_cast(data[i]) & 0xff); - WriteLiteral(buffer, size); - } else { - *out_iter_++ = data[i]; - } - break; - } - } - } - *out_iter_++ = '"'; - } - - void WriteArray(const json::Array& value) { - *out_iter_++ = '['; - if (indent_ != 0) { - total_indent_ += indent_; - } - for (size_t i = 0; i < value.size(); ++i) { - if (i != 0) { - *out_iter_++ = ','; - } - if (indent_ != 0) { - WriteIndent(); - } - WriteValue(value[i]); - } - if (indent_ != 0) { - total_indent_ -= indent_; - WriteIndent(); - } - *out_iter_++ = ']'; - } - - void WriteObject(const json::Object& value) { - *out_iter_++ = '{'; - if (indent_ != 0) { - total_indent_ += indent_; - } - int counter = 0; - for (const auto& [key, value] : value) { - if (counter++ != 0) { - *out_iter_++ = ','; - } - if (indent_ != 0) { - WriteIndent(); - } - auto opt_key = key.as(); - if (!opt_key.has_value()) { - TVM_FFI_THROW(ValueError) << "Expect key to be string, got `" << key.GetTypeKey() << "`"; - } - WriteString(*opt_key); - *out_iter_++ = ':'; - if (indent_ != 0) { - *out_iter_++ = ' '; - } - WriteValue(value); - } - if (indent_ != 0) { - total_indent_ -= indent_; - WriteIndent(); - } - *out_iter_++ = '}'; - } - - // Write a newline and indent the current level - void WriteIndent() { - *out_iter_++ = '\n'; - for (int i = 0; i < total_indent_; ++i) { - *out_iter_++ = ' '; - } - } - - int indent_ = 0; - int total_indent_ = 0; - std::string result_; - std::back_insert_iterator out_iter_; -}; - -String Stringify(const json::Value& value, Optional indent) { - return JSONWriter::Stringify(value, indent); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.json.Stringify", Stringify); -} - -} // namespace json -} // namespace ffi -} // namespace tvm - -#undef TVM_FFI_SNPRINTF diff --git a/ffi/src/ffi/extra/library_module.cc b/ffi/src/ffi/extra/library_module.cc deleted file mode 100644 index 2864cdb5904a..000000000000 --- a/ffi/src/ffi/extra/library_module.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/library_module.cc - * - * \brief Library module implementation. - */ -#include -#include -#include - -#include "buffer_stream.h" -#include "module_internal.h" - -namespace tvm { -namespace ffi { - -class LibraryModuleObj final : public ModuleObj { - public: - explicit LibraryModuleObj(ObjectPtr lib) : lib_(lib) {} - - const char* kind() const final { return "library"; } - - /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return Module::kBinarySerializable | Module::kRunnable; }; - - Optional GetFunction(const String& name) final { - TVMFFISafeCallType faddr; - faddr = reinterpret_cast(lib_->GetSymbolWithSymbolPrefix(name)); - // ensure the function keeps the Library Module alive - Module self_strong_ref = GetRef(this); - if (faddr != nullptr) { - return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, - ffi::Any* rv) { - TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); - TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), - args.size(), reinterpret_cast(rv))); - }); - } - return std::nullopt; - } - - private: - ObjectPtr lib_; -}; - -Module LoadModuleFromBytes(const std::string& kind, const Bytes& bytes) { - std::string loader_key = "ffi.Module.load_from_bytes." + kind; - const auto floader = tvm::ffi::Function::GetGlobal(loader_key); - if (!floader.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Library binary was created using {" << kind - << "} but a loader of that name is not registered. " - << "Make sure to have runtime that registers " << loader_key; - } - return (*floader)(bytes).cast(); -} - -/*! - * \brief Process libary binary to recover binary-serialized modules - * \param library_bin The binary embedded in the library. - * \param opt_lib The library, can be nullptr in which case we expect to deserialize - * all binary-serialized modules - * \param library_ctx_addr the pointer to library module as ctx addr - * \return the root module - * - */ -Module ProcessLibraryBin(const char* library_bin, ObjectPtr opt_lib, - void** library_ctx_addr = nullptr) { - // Layout of the library binary: - // ... - // key can be: "_lib", or a module kind - // - "_lib" indicate this location places the library module - // - other keys are module kinds - // Import tree structure (CSR structure of child indices): - // = > > - TVM_FFI_ICHECK(library_bin != nullptr); - uint64_t nbytes = 0; - for (size_t i = 0; i < sizeof(nbytes); ++i) { - uint64_t c = library_bin[i]; - nbytes |= (c & 0xffUL) << (i * 8); - } - - BufferInStream stream(library_bin + sizeof(nbytes), static_cast(nbytes)); - std::vector import_tree_indptr; - std::vector import_tree_child_indices; - TVM_FFI_ICHECK(stream.Read(&import_tree_indptr)); - TVM_FFI_ICHECK(stream.Read(&import_tree_child_indices)); - size_t num_modules = import_tree_indptr.size() - 1; - std::vector modules; - modules.reserve(num_modules); - - for (uint64_t i = 0; i < num_modules; ++i) { - std::string kind; - TVM_FFI_ICHECK(stream.Read(&kind)); - // "_lib" serves as a placeholder in the module import tree to indicate where - // to place the DSOModule - if (kind == "_lib") { - TVM_FFI_ICHECK(opt_lib != nullptr) << "_lib is not allowed during module serialization"; - auto lib_mod_ptr = make_object(opt_lib); - if (library_ctx_addr) { - *library_ctx_addr = lib_mod_ptr.get(); - } - modules.emplace_back(Module(lib_mod_ptr)); - } else { - std::string module_bytes; - TVM_FFI_ICHECK(stream.Read(&module_bytes)); - Module m = LoadModuleFromBytes(kind, Bytes(module_bytes)); - modules.emplace_back(m); - } - } - for (size_t i = 0; i < modules.size(); ++i) { - for (size_t j = import_tree_indptr[i]; j < import_tree_indptr[i + 1]; ++j) { - Array* module_imports = ModuleObj::InternalUnsafe::GetImports(modules[i].operator->()); - auto child_index = import_tree_child_indices[j]; - TVM_FFI_ICHECK(child_index < modules.size()); - module_imports->emplace_back(modules[child_index]); - } - } - return modules[0]; -} - -// registry to store context symbols -class ContextSymbolRegistry { - public: - void InitContextSymbols(ObjectPtr lib) { - for (const auto& [name, symbol] : context_symbols_) { - if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name))) { - *symbol_addr = symbol; - } - } - } - - void VisitContextSymbols(const ffi::TypedFunction& callback) { - for (const auto& [name, symbol] : context_symbols_) { - callback(name, symbol); - } - } - - void Register(String name, void* symbol) { context_symbols_.emplace_back(name, symbol); } - - static ContextSymbolRegistry* Global() { - static ContextSymbolRegistry* inst = new ContextSymbolRegistry(); - return inst; - } - - private: - std::vector> context_symbols_; -}; - -void Module::VisitContextSymbols(const ffi::TypedFunction& callback) { - ContextSymbolRegistry::Global()->VisitContextSymbols(callback); -} - -Module CreateLibraryModule(ObjectPtr lib) { - const char* library_bin = - reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_bin)); - void** library_ctx_addr = - reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_ctx)); - - ContextSymbolRegistry::Global()->InitContextSymbols(lib); - if (library_bin != nullptr) { - // we have embedded binaries that needs to be deserialized - return ProcessLibraryBin(library_bin, lib, library_ctx_addr); - } else { - // Only have one single DSO Module - auto lib_mod_ptr = make_object(lib); - Module root_mod = Module(lib_mod_ptr); - if (library_ctx_addr) { - *library_ctx_addr = root_mod.operator->(); - } - return root_mod; - } -} - -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String s_name(name); - tvm::ffi::ContextSymbolRegistry::Global()->Register(s_name, symbol); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/extra/library_module_dynamic_lib.cc b/ffi/src/ffi/extra/library_module_dynamic_lib.cc deleted file mode 100644 index 34072aad5a8e..000000000000 --- a/ffi/src/ffi/extra/library_module_dynamic_lib.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file library_module_dynamic_lib.cc - * \brief Create library module to load from dynamic shared library. - */ -#include -#include -#include - -#include "module_internal.h" - -#if defined(_WIN32) -#include -#else -#include -#endif - -#if defined(__hexagon__) -extern "C" { -#include -} -#endif - -namespace tvm { -namespace ffi { - -class DSOLibrary final : public Library { - public: - explicit DSOLibrary(const String& name) { Load(name); } - ~DSOLibrary() { - if (lib_handle_) Unload(); - } - - void* GetSymbol(const String& name) final { return GetSymbol_(name.c_str()); } - - private: - // private system dependent implementation - void* GetSymbol_(const char* name); - void Load(const String& name); - void Unload(); - -#if defined(_WIN32) - //! \brief Windows library handle - HMODULE lib_handle_{nullptr}; -#else - // \brief Linux library handle - void* lib_handle_{nullptr}; -#endif -}; - -#if defined(_WIN32) - -void* DSOLibrary::GetSymbol_(const char* name) { - return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) -} - -void DSOLibrary::Load(const String& name) { - // use wstring version that is needed by LLVM. - std::wstring wname(name.data(), name.data() + name.size()); - lib_handle_ = LoadLibraryW(wname.c_str()); - TVM_FFI_ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; -} - -void DSOLibrary::Unload() { - FreeLibrary(lib_handle_); - lib_handle_ = nullptr; -} - -#else - -void DSOLibrary::Load(const String& name) { - lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - TVM_FFI_ICHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name << " " << dlerror(); -#if defined(__hexagon__) - int p; - int rc = dlinfo(lib_handle_, RTLD_DI_LOAD_ADDR, &p); - if (rc) - FARF(ERROR, "error getting model .so start address : %u", rc); - else - FARF(ALWAYS, "Model .so Start Address : %x", p); -#endif -} - -void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } - -void DSOLibrary::Unload() { - dlclose(lib_handle_); - lib_handle_ = nullptr; -} -#endif - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.Module.load_from_file.so", [](String library_path, String) { - return CreateLibraryModule(make_object(library_path)); - }); -} -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/library_module_system_lib.cc b/ffi/src/ffi/extra/library_module_system_lib.cc deleted file mode 100644 index 3a614738a04f..000000000000 --- a/ffi/src/ffi/extra/library_module_system_lib.cc +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file system_library.cc - * \brief Create library module that directly get symbol from the system lib. - */ -#include -#include -#include -#include -#include - -#include - -#include "module_internal.h" - -namespace tvm { -namespace ffi { - -class SystemLibSymbolRegistry { - public: - void RegisterSymbol(const std::string& name, void* ptr) { - auto it = symbol_table_.find(name); - if (it != symbol_table_.end() && ptr != (*it).second) { - std::cerr << "Warning:SystemLib symbol " << name << " get overriden to a different address " - << ptr << "->" << (*it).second << std::endl; - } - symbol_table_.Set(name, ptr); - } - - void* GetSymbol(const String& name) { - auto it = symbol_table_.find(name); - if (it != symbol_table_.end()) { - return (*it).second; - } else { - return nullptr; - } - } - - static SystemLibSymbolRegistry* Global() { - static SystemLibSymbolRegistry* inst = new SystemLibSymbolRegistry(); - return inst; - } - - private: - // Internal symbol table - Map symbol_table_; -}; - -class SystemLibrary final : public Library { - public: - explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {} - - void* GetSymbol(const String& name) final { - // The `name` might or might not already contain the symbol prefix. - // Therefore, we check both with and without the prefix. - String name_with_prefix = symbol_prefix_ + name; - void* symbol = reg_->GetSymbol(name_with_prefix); - if (symbol != nullptr) { - return symbol; - } - return reg_->GetSymbol(name); - } - - void* GetSymbolWithSymbolPrefix(const String& name) final { - // The `name` might or might not already contain the symbol prefix. - // Therefore, we check both with and without the prefix. - String name_with_prefix = symbol::tvm_ffi_symbol_prefix + symbol_prefix_ + name; - void* symbol = reg_->GetSymbol(name_with_prefix); - if (symbol != nullptr) { - return symbol; - } - name_with_prefix = symbol::tvm_ffi_symbol_prefix + name; - return reg_->GetSymbol(name_with_prefix); - } - - private: - SystemLibSymbolRegistry* reg_ = SystemLibSymbolRegistry::Global(); - String symbol_prefix_; -}; - -class SystemLibModuleRegistry { - public: - Module GetOrCreateModule(String symbol_prefix) { - std::lock_guard lock(mutex_); - auto it = lib_map_.find(symbol_prefix); - if (it != lib_map_.end()) { - return (*it).second; - } else { - Module mod = CreateLibraryModule(make_object(symbol_prefix)); - lib_map_.Set(symbol_prefix, mod); - return mod; - } - } - - static SystemLibModuleRegistry* Global() { - static SystemLibModuleRegistry* inst = new SystemLibModuleRegistry(); - return inst; - } - - private: - // Internal mutex - std::mutex mutex_; - // maps prefix to the library module - // we need to make sure each lib map have an unique - // copy through out the entire lifetime of the process - Map lib_map_; -}; - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("ffi.SystemLib", [](ffi::PackedArgs args, ffi::Any* rv) { - String symbol_prefix = ""; - if (args.size() != 0) { - symbol_prefix = args[0].cast(); - } - *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); - }); -} -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* ptr) { - tvm::ffi::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr); - return 0; -} diff --git a/ffi/src/ffi/extra/module.cc b/ffi/src/ffi/extra/module.cc deleted file mode 100644 index d2ebcd121dfc..000000000000 --- a/ffi/src/ffi/extra/module.cc +++ /dev/null @@ -1,157 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include -#include - -#include "module_internal.h" - -namespace tvm { -namespace ffi { - -Optional ModuleObj::GetFunction(const String& name, bool query_imports) { - if (auto opt_func = this->GetFunction(name)) { - return opt_func; - } - if (query_imports) { - for (const Any& import : imports_) { - if (auto opt_func = import.cast()->GetFunction(name, query_imports)) { - return *opt_func; - } - } - } - return std::nullopt; -} - -Optional ModuleObj::GetFunctionMetadata(const String& name, bool query_imports) { - if (auto opt_metadata = this->GetFunctionMetadata(name)) { - return opt_metadata; - } - if (query_imports) { - for (const Any& import : imports_) { - if (auto opt_metadata = import.cast()->GetFunctionMetadata(name, query_imports)) { - return *opt_metadata; - } - } - } - return std::nullopt; -} - -void ModuleObj::ImportModule(const Module& other) { - std::unordered_set visited{other.operator->()}; - std::vector stack{other.operator->()}; - while (!stack.empty()) { - const ModuleObj* n = stack.back(); - stack.pop_back(); - for (const Any& m : n->imports_) { - const ModuleObj* next = m.cast(); - if (visited.count(next)) continue; - visited.insert(next); - stack.push_back(next); - } - } - if (visited.count(this)) { - TVM_FFI_THROW(RuntimeError) << "Cyclic dependency detected during import"; - } - imports_.push_back(other); -} - -void ModuleObj::ClearImports() { imports_.clear(); } - -bool ModuleObj::ImplementsFunction(const String& name, bool query_imports) { - if (this->ImplementsFunction(name)) { - return true; - } - if (query_imports) { - for (const Any& import : imports_) { - if (import.cast()->ImplementsFunction(name, query_imports)) { - return true; - } - } - } - return false; -} - -Module Module::LoadFromFile(const String& file_name) { - String format = [&file_name]() -> String { - const char* data = file_name.data(); - for (size_t i = file_name.size(); i > 0; i--) { - if (data[i - 1] == '.') { - return String(data + i, file_name.size() - i); - } - } - TVM_FFI_THROW(RuntimeError) << "Failed to get file format from " << file_name; - TVM_FFI_UNREACHABLE(); - }(); - - if (format == "dll" || format == "dylib" || format == "dso") { - format = "so"; - } - String loader_name = "ffi.Module.load_from_file." + format; - const auto floader = tvm::ffi::Function::GetGlobal(loader_name); - if (!floader.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Loader for `." << format << "` files is not registered," - << " resolved to (" << loader_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - } - return (*floader)(file_name, format).cast(); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - ModuleObj::InternalUnsafe::RegisterReflection(); - - refl::GlobalDef() - .def("ffi.ModuleLoadFromFile", &Module::LoadFromFile) - .def_method("ffi.ModuleImplementsFunction", - [](Module mod, String name, bool query_imports) { - return mod->ImplementsFunction(name, query_imports); - }) - .def_method("ffi.ModuleGetFunctionMetadata", - [](Module mod, String name, bool query_imports) { - return mod->GetFunctionMetadata(name, query_imports); - }) - .def_method("ffi.ModuleGetFunction", - [](Module mod, String name, bool query_imports) { - return mod->GetFunction(name, query_imports); - }) - .def_method("ffi.ModuleGetPropertyMask", &ModuleObj::GetPropertyMask) - .def_method("ffi.ModuleInspectSource", &ModuleObj::InspectSource) - .def_method("ffi.ModuleGetKind", [](const Module& mod) -> String { return mod->kind(); }) - .def_method("ffi.ModuleGetWriteFormats", &ModuleObj::GetWriteFormats) - .def_method("ffi.ModuleWriteToFile", &ModuleObj::WriteToFile) - .def_method("ffi.ModuleImportModule", &ModuleObj::ImportModule) - .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports); -} -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, - TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::ModuleObj::InternalUnsafe::GetFunctionFromImports( - reinterpret_cast(library_ctx), func_name); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/extra/module_internal.h b/ffi/src/ffi/extra/module_internal.h deleted file mode 100644 index 86cb6b66c1f6..000000000000 --- a/ffi/src/ffi/extra/module_internal.h +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file library_module.h - * \brief Module that builds from a libary of symbols. - */ -#ifndef TVM_FFI_EXTRA_MODULE_INTERNAL_H_ -#define TVM_FFI_EXTRA_MODULE_INTERNAL_H_ - -#include -#include - -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Library is the common interface - * for storing data in the form of shared libaries. - * - * \sa src/ffi/extra/dso_library.cc - * \sa src/ffi/extra/system_library.cc - */ -class Library : public Object { - public: - // destructor. - virtual ~Library() {} - /*! - * \brief Get the symbol address for a given name. - * \param name The name of the symbol. - * \return The symbol. - */ - virtual void* GetSymbol(const String& name) = 0; - /*! - * \brief Get the symbol address for a given name with the tvm ffi symbol prefix. - * \param name The name of the symbol. - * \return The symbol. - * \note This function will be overloaded by systemlib implementation. - */ - virtual void* GetSymbolWithSymbolPrefix(const String& name) { - String name_with_prefix = symbol::tvm_ffi_symbol_prefix + name; - return GetSymbol(name_with_prefix); - } - // NOTE: we do not explicitly create an type index and type_key here for libary. - // This is because we do not need dynamic type downcasting and only need to use the refcounting -}; - -struct ModuleObj::InternalUnsafe { - static Array* GetImports(ModuleObj* module) { return &(module->imports_); } - - static void* GetFunctionFromImports(ModuleObj* module, const char* name) { - // backend implementation for TVMFFIEnvModLookupFromImports - static std::mutex mutex_; - std::lock_guard lock(mutex_); - String s_name(name); - auto it = module->import_lookup_cache_.find(s_name); - if (it != module->import_lookup_cache_.end()) { - return const_cast((*it).second.operator->()); - } - - auto opt_func = [&]() -> std::optional { - for (const Any& import : module->imports_) { - if (auto opt_func = import.cast()->GetFunction(s_name, true)) { - return *opt_func; - } - } - // try global at last - return tvm::ffi::Function::GetGlobal(s_name); - }(); - if (!opt_func.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Cannot find function " << name - << " in the imported modules or global registry."; - } - module->import_lookup_cache_.Set(s_name, *opt_func); - return const_cast((*opt_func).operator->()); - } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("imports_", &ModuleObj::imports_); - } -}; - -/*! - * \brief Create a library module from a given library. - * - * \param lib The library. - * - * \return The corresponding loaded module. - */ -Module CreateLibraryModule(ObjectPtr lib); - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_MODULE_INTERNAL_H_ diff --git a/ffi/src/ffi/extra/reflection_extra.cc b/ffi/src/ffi/extra/reflection_extra.cc deleted file mode 100644 index f92364370f17..000000000000 --- a/ffi/src/ffi/extra/reflection_extra.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/reflection_extra.cc - * - * \brief Extra reflection registrations. * - */ -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { - int32_t type_index; - if (auto opt_type_index = args[0].try_cast()) { - type_index = *opt_type_index; - } else { - String type_key = args[0].cast(); - TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - } - - TVM_FFI_ICHECK(args.size() % 2 == 1); - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support reflection creation"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - - std::vector keys; - std::vector keys_found; - - for (int i = 1; i < args.size(); i += 2) { - keys.push_back(args[i].cast()); - } - keys_found.resize(keys.size(), false); - - auto search_field = [&](const TVMFFIByteArray& field_name) { - for (size_t i = 0; i < keys.size(); ++i) { - if (keys_found[i]) continue; - if (keys[i].compare(field_name) == 0) { - return i; - } - } - return keys.size(); - }; - - auto update_fields = [&](const TVMFFITypeInfo* tinfo) { - for (int i = 0; i < tinfo->num_fields; ++i) { - const TVMFFIFieldInfo* field_info = tinfo->fields + i; - size_t arg_index = search_field(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (arg_index < keys.size()) { - AnyView field_value = args[arg_index * 2 + 2]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); - keys_found[arg_index] = true; - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; - } - } - }; - - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - update_fields(type_info->type_acenstors[i]); - } - update_fields(type_info); - - for (size_t i = 0; i < keys.size(); ++i) { - if (!keys_found[i]) { - TVM_FFI_THROW(TypeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have field `" << keys[i] << "`"; - } - } - *ret = ObjectRef(ptr); -} - -inline void AccessStepRegisterReflection() { - // register access step reflection here since it is only needed for bindings - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("kind", &AccessStepObj::kind) - .def_ro("key", &AccessStepObj::key); -} - -inline void AccessPathRegisterReflection() { - // register access path reflection here since it is only needed for bindings - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("parent", &AccessPathObj::parent) - .def_ro("step", &AccessPathObj::step) - .def_ro("depth", &AccessPathObj::depth) - .def_static("_root", &AccessPath::Root) - .def("_extend", &AccessPathObj::Extend) - .def("_attr", &AccessPathObj::Attr) - .def("_array_item", &AccessPathObj::ArrayItem) - .def("_map_item", &AccessPathObj::MapItem) - .def("_attr_missing", &AccessPathObj::AttrMissing) - .def("_array_item_missing", &AccessPathObj::ArrayItemMissing) - .def("_map_item_missing", &AccessPathObj::MapItemMissing) - .def("_is_prefix_of", &AccessPathObj::IsPrefixOf) - .def("_to_steps", &AccessPathObj::ToSteps) - .def("_path_equal", - [](const AccessPath& self, const AccessPath& other) { return self->PathEqual(other); }); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - AccessStepRegisterReflection(); - AccessPathRegisterReflection(); - refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", MakeObjectFromPackedArgs); -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/serialization.cc b/ffi/src/ffi/extra/serialization.cc deleted file mode 100644 index 14c784428ed5..000000000000 --- a/ffi/src/ffi/extra/serialization.cc +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/serialization.cc - * - * \brief Reflection-based serialization utilities. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -class ObjectGraphSerializer { - public: - static json::Value Serialize(const Any& value, Any metadata) { - ObjectGraphSerializer serializer; - json::Object result; - result.Set("root_index", serializer.GetOrCreateNodeIndex(value)); - result.Set("nodes", std::move(serializer.nodes_)); - if (metadata != nullptr) { - result.Set("metadata", metadata); - } - return result; - } - - private: - ObjectGraphSerializer() = default; - - int64_t GetOrCreateNodeIndex(const Any& value) { - // already mapped value, return the index - auto it = node_index_map_.find(value); - if (it != node_index_map_.end()) { - return (*it).second; - } - json::Object node; - switch (value.type_index()) { - case TypeIndex::kTVMFFINone: { - node.Set("type", ffi::StaticTypeKey::kTVMFFINone); - break; - } - case TypeIndex::kTVMFFIBool: { - node.Set("type", ffi::StaticTypeKey::kTVMFFIBool); - node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIInt: { - node.Set("type", ffi::StaticTypeKey::kTVMFFIInt); - node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIFloat: { - node.Set("type", ffi::StaticTypeKey::kTVMFFIFloat); - node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIDataType: { - DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIDataType); - node.Set("data", DLDataTypeToString(dtype)); - break; - } - case TypeIndex::kTVMFFIDevice: { - DLDevice device = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIDevice); - node.Set("data", json::Array{ - static_cast(device.device_type), - static_cast(device.device_id), - }); - break; - } - case TypeIndex::kTVMFFISmallStr: - case TypeIndex::kTVMFFIStr: { - String str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIStr); - node.Set("data", str); - break; - } - case TypeIndex::kTVMFFISmallBytes: - case TypeIndex::kTVMFFIBytes: { - Bytes bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIBytes); - node.Set("data", Base64Encode(bytes)); - break; - } - case TypeIndex::kTVMFFIArray: { - Array array = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIArray); - node.Set("data", CreateArrayData(array)); - break; - } - case TypeIndex::kTVMFFIMap: { - Map map = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIMap); - node.Set("data", CreateMapData(map)); - break; - } - case TypeIndex::kTVMFFIShape: { - ffi::Shape shape = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIShape); - node.Set("data", Array(shape->data, shape->data + shape->size)); - break; - } - default: { - if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) { - // serialize type key since type index is runtime dependent - node.Set("type", value.GetTypeKey()); - node.Set("data", CreateObjectData(value)); - } else { - TVM_FFI_THROW(RuntimeError) << "Cannot serialize type `" << value.GetTypeKey() << "`"; - TVM_FFI_UNREACHABLE(); - } - } - } - int64_t node_index = nodes_.size(); - nodes_.push_back(node); - node_index_map_.Set(value, node_index); - return node_index; - } - - json::Array CreateArrayData(const Array& value) { - json::Array data; - data.reserve(value.size()); - for (const Any& item : value) { - data.push_back(GetOrCreateNodeIndex(item)); - } - return data; - } - - json::Array CreateMapData(const Map& value) { - json::Array data; - data.reserve(value.size() * 2); - for (const auto& [key, value] : value) { - data.push_back(GetOrCreateNodeIndex(key)); - data.push_back(GetOrCreateNodeIndex(value)); - } - return data; - } - - // create the data for the object, if the type has a custom data to json function, - // use it. otherwise, we go over the fields and create the data. - json::Value CreateObjectData(const Any& value) { - static reflection::TypeAttrColumn data_to_json = reflection::TypeAttrColumn("__data_to_json__"); - if (data_to_json[value.type_index()] != nullptr) { - return data_to_json[value.type_index()].cast()(value); - } - // NOTE: invariant: lhs and rhs are already the same type - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index()); - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" - << String(type_info->type_key) - << "`, so ToJSONGraph is not supported for this type"; - } - const Object* obj = value.cast(); - json::Object data; - // go over the content and hash the fields - reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - // get the field value from both side - reflection::FieldGetter getter(field_info); - Any field_value = getter(obj); - int field_static_type_index = field_info->field_static_type_index; - String field_name(field_info->name); - // for static field index that are known, we can directly set the field value. - switch (field_static_type_index) { - case TypeIndex::kTVMFFINone: { - data.Set(field_name, nullptr); - break; - } - case TypeIndex::kTVMFFIBool: { - data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); - break; - } - case TypeIndex::kTVMFFIInt: { - data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); - break; - } - case TypeIndex::kTVMFFIFloat: { - data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); - break; - } - case TypeIndex::kTVMFFIDataType: { - DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value); - data.Set(field_name, DLDataTypeToString(dtype)); - break; - } - default: { - // for dynamic field index, we need need to put them onto nodes - int64_t node_index = GetOrCreateNodeIndex(field_value); - data.Set(field_name, node_index); - break; - } - } - }); - return data; - } - - // maps the original value to the index of the node in the nodes_ array - Map node_index_map_; - // records nodes that are serialized - json::Array nodes_; -}; - -json::Value ToJSONGraph(const Any& value, const Any& metadata) { - return ObjectGraphSerializer::Serialize(value, metadata); -} - -class ObjectGraphDeserializer { - public: - static Any Deserialize(const json::Value& value) { - ObjectGraphDeserializer deserializer(value); - return deserializer.GetOrDecodeNode(deserializer.root_index_); - } - - Any GetOrDecodeNode(int64_t node_index) { - // already decoded null index - if (node_index == decoded_null_index_) { - return Any(nullptr); - } - // already decoded - if (decoded_nodes_[node_index] != nullptr) { - return decoded_nodes_[node_index]; - } - // now decode the node - Any value = DecodeNode(nodes_[node_index].cast()); - decoded_nodes_[node_index] = value; - if (value == nullptr) { - decoded_null_index_ = node_index; - } - return value; - } - - private: - Any DecodeNode(const json::Object& node) { - String type_key = node["type"].cast(); - TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()}; - int32_t type_index; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); - - switch (type_index) { - case TypeIndex::kTVMFFINone: { - return nullptr; - } - case TypeIndex::kTVMFFIBool: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIInt: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIFloat: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIDataType: { - return StringToDLDataType(node["data"].cast()); - } - case TypeIndex::kTVMFFIDevice: { - Array data = node["data"].cast>(); - return DLDevice{static_cast(data[0]), data[1]}; - } - case TypeIndex::kTVMFFIStr: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIBytes: { - return Base64Decode(node["data"].cast()); - } - case TypeIndex::kTVMFFIMap: { - return DecodeMapData(node["data"].cast()); - } - case TypeIndex::kTVMFFIArray: { - return DecodeArrayData(node["data"].cast()); - } - case TypeIndex::kTVMFFIShape: { - Array data = node["data"].cast>(); - return ffi::Shape(data); - } - default: { - return DecodeObjectData(type_index, node["data"]); - } - } - } - - Array DecodeArrayData(const json::Array& data) { - Array array; - array.reserve(data.size()); - for (size_t i = 0; i < data.size(); i++) { - array.push_back(GetOrDecodeNode(data[i].cast())); - } - return array; - } - - Map DecodeMapData(const json::Array& data) { - Map map; - for (size_t i = 0; i < data.size(); i += 2) { - int64_t key_index = data[i].cast(); - int64_t value_index = data[i + 1].cast(); - map.Set(GetOrDecodeNode(key_index), GetOrDecodeNode(value_index)); - } - return map; - } - - Any DecodeObjectData(int32_t type_index, const json::Value& data) { - static reflection::TypeAttrColumn data_from_json = - reflection::TypeAttrColumn("__data_from_json__"); - if (data_from_json[type_index] != nullptr) { - return data_from_json[type_index].cast()(data); - } - // otherwise, we go over the fields and create the data. - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor" - << ", so ToJSONGraph is not supported for this type"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - - auto decode_field_value = [&](const TVMFFIFieldInfo* field_info, json::Value data) -> Any { - switch (field_info->field_static_type_index) { - case TypeIndex::kTVMFFINone: { - return nullptr; - } - case TypeIndex::kTVMFFIBool: { - return data.cast(); - } - case TypeIndex::kTVMFFIInt: { - return data.cast(); - } - case TypeIndex::kTVMFFIFloat: { - return data.cast(); - } - case TypeIndex::kTVMFFIDataType: { - return StringToDLDataType(data.cast()); - } - default: { - return GetOrDecodeNode(data.cast()); - } - } - }; - - json::Object data_object = data.cast(); - reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - String field_name(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (data_object.count(field_name) != 0) { - Any field_value = decode_field_value(field_info, data_object[field_name]); - field_info->setter(field_addr, reinterpret_cast(&field_value)); - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; - } - }); - return ObjectRef(ptr); - } - - explicit ObjectGraphDeserializer(json::Value serialized) { - if (!serialized.as()) { - TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected an object"; - } - json::Object encoded_object = serialized.cast(); - if (encoded_object.count("root_index") == 0 || !encoded_object["root_index"].as()) { - TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `root_index` integer field"; - } - if (encoded_object.count("nodes") == 0 || !encoded_object["nodes"].as()) { - TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `nodes` array field"; - } - root_index_ = encoded_object["root_index"].cast(); - nodes_ = encoded_object["nodes"].cast(); - decoded_nodes_.resize(nodes_.size(), Any(nullptr)); - } - // nodes - json::Array nodes_; - // root index - int64_t root_index_; - // null index if already created - int64_t decoded_null_index_{-1}; - // decoded nodes - std::vector decoded_nodes_; -}; - -Any FromJSONGraph(const json::Value& value) { return ObjectGraphDeserializer::Deserialize(value); } - -// string version of the api -Any FromJSONGraphString(const String& value) { return FromJSONGraph(json::Parse(value)); } - -String ToJSONGraphString(const Any& value, const Any& metadata) { - return json::Stringify(ToJSONGraph(value, metadata)); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("ffi.ToJSONGraph", ToJSONGraph) - .def("ffi.ToJSONGraphString", ToJSONGraphString) - .def("ffi.FromJSONGraph", FromJSONGraph) - .def("ffi.FromJSONGraphString", FromJSONGraphString); - refl::EnsureTypeAttrColumn("__data_to_json__"); - refl::EnsureTypeAttrColumn("__data_from_json__"); -} - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/structural_equal.cc b/ffi/src/ffi/extra/structural_equal.cc deleted file mode 100644 index ccedfcb7a8b1..000000000000 --- a/ffi/src/ffi/extra/structural_equal.cc +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/reflection/structural_equal.cc - * - * \brief Structural equal implementation. - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -/** - * \brief Internal Handler class for structural equal comparison. - */ -class StructEqualHandler { - public: - StructEqualHandler() = default; - - bool CompareAny(ffi::Any lhs, ffi::Any rhs) { - using ffi::details::AnyUnsafe; - const TVMFFIAny* lhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(lhs); - const TVMFFIAny* rhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(rhs); - if (lhs_data->type_index != rhs_data->type_index) { - // type_index mismatch, if index is not string, return false - if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index != kTVMFFISmallStr && - lhs_data->type_index != kTVMFFISmallBytes && lhs_data->type_index != kTVMFFIBytes) { - return false; - } - // small string and normal string comparison - if (lhs_data->type_index == kTVMFFIStr && rhs_data->type_index == kTVMFFISmallStr) { - const details::BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_str->data, rhs_data->v_bytes, lhs_str->size, - rhs_data->small_str_len); - } - if (lhs_data->type_index == kTVMFFISmallStr && rhs_data->type_index == kTVMFFIStr) { - const details::BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_data->v_bytes, rhs_str->data, lhs_data->small_str_len, - rhs_str->size); - } - if (lhs_data->type_index == kTVMFFIBytes && rhs_data->type_index == kTVMFFISmallBytes) { - const details::BytesObjBase* lhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_bytes->data, rhs_data->v_bytes, lhs_bytes->size, - rhs_data->small_str_len); - } - if (lhs_data->type_index == kTVMFFISmallBytes && rhs_data->type_index == kTVMFFIBytes) { - const details::BytesObjBase* rhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_data->v_bytes, rhs_bytes->data, lhs_data->small_str_len, - rhs_bytes->size); - } - return false; - } - - if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - // specially handle nan for float, as there can be multiple representations of nan - if (lhs_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(lhs_data->v_float64)) { - return std::isnan(rhs_data->v_float64); - } - // this is POD data, we can just compare the value - return lhs_data->zero_padding == rhs_data->zero_padding && - lhs_data->v_int64 == rhs_data->v_int64; - } - switch (lhs_data->type_index) { - case TypeIndex::kTVMFFIStr: - case TypeIndex::kTVMFFIBytes: { - // compare bytes - const details::BytesObjBase* lhs_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - const details::BytesObjBase* rhs_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); - } - case TypeIndex::kTVMFFIArray: { - return CompareArray(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck>(std::move(rhs))); - } - case TypeIndex::kTVMFFIMap: { - return CompareMap(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck>(std::move(rhs))); - } - case TypeIndex::kTVMFFIShape: { - return CompareShape(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); - } - case TypeIndex::kTVMFFITensor: { - return CompareTensor(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); - } - default: { - return CompareObject(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); - } - } - } - - bool CompareObject(ObjectRef lhs, ObjectRef rhs) { - // NOTE: invariant: lhs and rhs are already the same type - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index()); - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { - TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - - auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind; - if (structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) { - // use pointer comparison - return lhs.same_as(rhs); - } - if (structural_eq_hash_kind == kTVMFFISEqHashKindConstTreeNode) { - // fast path: constant tree node, pointer equality indicate equality and avoid content - // comparison if false, we should still run content comparison - if (lhs.same_as(rhs)) return true; - } - // check recorded mapping for DAG and fre var - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode || - structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - // if there is pre-recorded mapping, need to cross check the pointer equality after mapping - auto it = equal_map_lhs_.find(lhs); - if (it != equal_map_lhs_.end()) { - return it->second.same_as(rhs); - } - // if rhs is mapped but lhs is not, it means lhs is a free var, return false - if (equal_map_rhs_.count(rhs)) { - return false; - } - } - - static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__"); - - bool success = true; - if (custom_s_equal[type_info->type_index] == nullptr) { - // We recursively compare the fields the object - reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { - // skip fields that are marked as structural eq hash ignore - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) return false; - // get the field value from both side - reflection::FieldGetter getter(field_info); - Any lhs_value = getter(lhs); - Any rhs_value = getter(rhs); - // field is in def region, enable free var mapping - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - success = CompareAny(lhs_value, rhs_value); - std::swap(allow_free_var, map_free_vars_); - } else { - success = CompareAny(lhs_value, rhs_value); - } - if (!success) { - // record the first mismatching field if we sub-rountine compare failed - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(String(field_info->name))); - mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(String(field_info->name))); - } - // return true to indicate early stop - return true; - } else { - // return false to continue checking other fields - return false; - } - }); - } else { - // run custom equal function defined via __s_equal__ type attribute - if (s_equal_callback_ == nullptr) { - s_equal_callback_ = ffi::Function::FromTyped( - [this](AnyView lhs, AnyView rhs, bool def_region, AnyView field_name) { - // NOTE: we explicitly make field_name as AnyView to avoid copy overhead initially - // and only cast to string if mismatch happens - bool success = true; - if (def_region) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - success = CompareAny(lhs, rhs); - std::swap(allow_free_var, map_free_vars_); - } else { - success = CompareAny(lhs, rhs); - } - if (!success) { - if (mismatch_lhs_reverse_path_ != nullptr) { - String field_name_str = field_name.cast(); - mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(field_name_str)); - mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(field_name_str)); - } - } - return success; - }); - } - success = custom_s_equal[type_info->type_index] - .cast()(lhs, rhs, s_equal_callback_) - .cast(); - } - - if (success) { - if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - // we are in a free var case that is not yet mapped. - // in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be - // set - if (lhs.same_as(rhs) || map_free_vars_) { - // record the equality - equal_map_lhs_[lhs] = rhs; - equal_map_rhs_[rhs] = lhs; - return true; - } else { - return false; - } - } - // if we have a success mapping and in graph/var mode, record the equality mapping - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { - // record the equality - equal_map_lhs_[lhs] = rhs; - equal_map_rhs_[rhs] = lhs; - } - return true; - } else { - return false; - } - } - - bool CompareMap(Map lhs, Map rhs) { - if (lhs.size() != rhs.size()) { - // size mismatch, and there is no path tracing - // return false since we don't need informative error message - if (mismatch_lhs_reverse_path_ == nullptr) return false; - } - // compare key and value pair by pair - for (auto kv : lhs) { - Any rhs_key = this->MapLhsToRhs(kv.first); - auto it = rhs.find(rhs_key); - if (it == rhs.end()) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(rhs_key)); - } - return false; - } - // now recursively compare value - if (!CompareAny(kv.second, (*it).second)) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(rhs_key)); - } - return false; - } - } - // fast path, all contents equals to each other - if (lhs.size() == rhs.size()) return true; - // slow path, cross check every key from rhs in lhs to find the missing - // key for better error reporting - for (auto kv : rhs) { - Any lhs_key = this->MapRhsToLhs(kv.first); - auto it = lhs.find(lhs_key); - if (it == lhs.end()) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(lhs_key)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); - } - return false; - } - } - return false; - } - - bool CompareArray(ffi::Array lhs, ffi::Array rhs) { - if (lhs.size() != rhs.size()) { - // fast path, size mismatch, and there is no path tracing - // return false since we don't need informative error message - if (mismatch_lhs_reverse_path_ == nullptr) return false; - } - for (size_t i = 0; i < std::min(lhs.size(), rhs.size()); ++i) { - if (!CompareAny(lhs[i], rhs[i])) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i)); - } - return false; - } - } - if (lhs.size() == rhs.size()) return true; - if (mismatch_lhs_reverse_path_ != nullptr) { - if (lhs.size() > rhs.size()) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(rhs.size())); - mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::ArrayItemMissing(rhs.size())); - } else { - mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::ArrayItemMissing(lhs.size())); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(lhs.size())); - } - } - return false; - } - - bool CompareShape(Shape lhs, Shape rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - for (size_t i = 0; i < lhs.size(); ++i) { - if (lhs[i] != rhs[i]) { - return false; - } - } - return true; - } - - bool CompareTensor(Tensor lhs, Tensor rhs) { - if (lhs.same_as(rhs)) return true; - if (lhs->ndim != rhs->ndim) return false; - for (int i = 0; i < lhs->ndim; ++i) { - if (lhs->shape[i] != rhs->shape[i]) return false; - } - if (lhs->dtype != rhs->dtype) return false; - if (!skip_tensor_content_) { - TVM_FFI_ICHECK_EQ(lhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; - TVM_FFI_ICHECK_EQ(rhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; - TVM_FFI_ICHECK(lhs.IsContiguous()) << "Can only compare contiguous tensor"; - TVM_FFI_ICHECK(rhs.IsContiguous()) << "Can only compare contiguous tensor"; - size_t data_size = GetDataSize(*(lhs.operator->())); - return std::memcmp(lhs->data, rhs->data, data_size) == 0; - } else { - return true; - } - } - - Any MapLhsToRhs(Any lhs) const { - if (lhs.type_index() < TypeIndex::kTVMFFIStaticObjectBegin) { - return lhs; - } - ObjectRef lhs_obj = ffi::details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)); - auto it = equal_map_lhs_.find(lhs_obj); - if (it != equal_map_lhs_.end()) { - return it->second; - } - return lhs_obj; - } - - Any MapRhsToLhs(Any rhs) const { - if (rhs.type_index() < TypeIndex::kTVMFFIStaticObjectBegin) { - return rhs; - } - ObjectRef rhs_obj = ffi::details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs)); - auto it = equal_map_rhs_.find(rhs_obj); - if (it != equal_map_rhs_.end()) { - return it->second; - } - return rhs_obj; - } - // whether we map free variables that are not defined - bool map_free_vars_{false}; - // whether we compare tensor data - bool skip_tensor_content_{false}; - // the root lhs for result printing - std::vector* mismatch_lhs_reverse_path_ = nullptr; - std::vector* mismatch_rhs_reverse_path_ = nullptr; - // lazily initialize custom equal function - ffi::Function s_equal_callback_ = nullptr; - // map from lhs to rhs - std::unordered_map equal_map_lhs_; - // map from rhs to lhs - std::unordered_map equal_map_rhs_; -}; - -bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars, - bool skip_tensor_content) { - StructEqualHandler handler; - handler.map_free_vars_ = map_free_vars; - handler.skip_tensor_content_ = skip_tensor_content; - return handler.CompareAny(lhs, rhs); -} - -Optional StructuralEqual::GetFirstMismatch(const Any& lhs, - const Any& rhs, - bool map_free_vars, - bool skip_tensor_content) { - StructEqualHandler handler; - handler.map_free_vars_ = map_free_vars; - handler.skip_tensor_content_ = skip_tensor_content; - std::vector lhs_reverse_path; - std::vector rhs_reverse_path; - handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path; - handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path; - if (handler.CompareAny(lhs, rhs)) { - return std::nullopt; - } - using reflection::AccessPath; - reflection::AccessPath lhs_path = - AccessPath::FromSteps(lhs_reverse_path.rbegin(), lhs_reverse_path.rend()); - reflection::AccessPath rhs_path = - AccessPath::FromSteps(rhs_reverse_path.rbegin(), rhs_reverse_path.rend()); - return reflection::AccessPathPair(lhs_path, rhs_path); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.GetFirstStructuralMismatch", StructuralEqual::GetFirstMismatch); - // ensure the type attribute column is presented in the system even if it is empty. - refl::EnsureTypeAttrColumn("__s_equal__"); -} - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/structural_hash.cc b/ffi/src/ffi/extra/structural_hash.cc deleted file mode 100644 index f6463afa9cff..000000000000 --- a/ffi/src/ffi/extra/structural_hash.cc +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/reflection/structural_equal.cc - * - * \brief Structural equal implementation. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -/** - * \brief Internal Handler class for structural hash. - */ -class StructuralHashHandler { - public: - StructuralHashHandler() = default; - - uint64_t HashAny(ffi::Any src) { - using ffi::details::AnyUnsafe; - const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); - - if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - // specially handle nan for float, as there can be multiple representations of nan - // make sure they map to the same hash value - if (src_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(src_data->v_float64)) { - TVMFFIAny temp = *src_data; - temp.v_float64 = std::numeric_limits::quiet_NaN(); - return details::StableHashCombine(temp.type_index, temp.v_uint64); - } - if (src_data->type_index == TypeIndex::kTVMFFISmallStr) { - // for small string, we use the same type key hash as normal string - // so heap allocated string and on stack string will have the same hash - return details::StableHashCombine(TypeIndex::kTVMFFIStr, - details::StableHashSmallStrBytes(src_data)); - } - // this is POD data, we can just hash the value - return details::StableHashCombine(src_data->type_index, src_data->v_uint64); - } - - switch (src_data->type_index) { - case TypeIndex::kTVMFFIStr: - case TypeIndex::kTVMFFIBytes: { - // return same hash as AnyHash - const details::BytesObjBase* src_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(src); - return details::StableHashCombine(src_data->type_index, - details::StableHashBytes(src_str->data, src_str->size)); - } - case TypeIndex::kTVMFFIArray: { - return HashArray(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(src))); - } - case TypeIndex::kTVMFFIMap: { - return HashMap(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(src))); - } - case TypeIndex::kTVMFFIShape: { - return HashShape(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); - } - case TypeIndex::kTVMFFITensor: { - return HashTensor(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); - } - default: { - return HashObject(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); - } - } - } - - uint64_t HashObject(ObjectRef obj) { - // NOTE: invariant: lhs and rhs are already the same type - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { - TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - - auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind; - if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { - // Fallback to pointer hash - return std::hash()(obj.get()); - } - // return recored hash value if it is already computed - auto it = hash_memo_.find(obj); - if (it != hash_memo_.end()) { - return it->second; - } - - static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__"); - - // compute the hash value - uint64_t hash_value = obj->GetTypeKeyHash(); - if (custom_s_hash[type_info->type_index] == nullptr) { - // go over the content and hash the fields - reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - // skip fields that are marked as structural eq hash ignore - if (!(field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore)) { - // get the field value from both side - reflection::FieldGetter getter(field_info); - Any field_value = getter(obj); - // field is in def region, enable free var mapping - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); - std::swap(allow_free_var, map_free_vars_); - } else { - hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); - } - } - }); - } else { - if (s_hash_callback_ == nullptr) { - s_hash_callback_ = - ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, bool def_region) { - if (def_region) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - uint64_t hash_value = HashAny(val); - std::swap(allow_free_var, map_free_vars_); - return details::StableHashCombine(init_hash, hash_value); - } else { - return details::StableHashCombine(init_hash, HashAny(val)); - } - }); - } - hash_value = custom_s_hash[type_info->type_index] - .cast()(obj, hash_value, s_hash_callback_) - .cast(); - } - - if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - if (map_free_vars_) { - // use lexical order of free var and its type - hash_value = details::StableHashCombine(hash_value, free_var_counter_++); - } else { - // Fallback to pointer hash, we are not mapping free var. - hash_value = std::hash()(obj.get()); - } - } - // if it is a DAG node, also record the lexical order of graph counter - // this helps to distinguish DAG from trees. - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { - hash_value = details::StableHashCombine(hash_value, graph_node_counter_++); - } - // record the hash value for this object - hash_memo_[obj] = hash_value; - return hash_value; - } - - uint64_t HashArray(Array arr) { - uint64_t hash_value = details::StableHashCombine(arr->GetTypeKeyHash(), arr.size()); - for (size_t i = 0; i < arr.size(); ++i) { - hash_value = details::StableHashCombine(hash_value, HashAny(arr[i])); - } - return hash_value; - } - - // Find an order independent hash value for a given Any. - // Order independent hash value means the hash value will remain stable independent - // of the order we hash the content at the current context. - // This property is needed to support stable hash for map. - std::optional FindOrderIndependentHash(Any src) { - using ffi::details::AnyUnsafe; - const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); - - if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - if (src_data->type_index == TypeIndex::kTVMFFISmallStr) { - // for small string, we use the same type key hash as normal string - // so heap allocated string and on stack string will have the same hash - return details::StableHashCombine( - TypeIndex::kTVMFFIStr, - details::StableHashBytes(src_data->v_bytes, src_data->small_str_len)); - } - // this is POD data, we can just hash the value - return details::StableHashCombine(src_data->type_index, src_data->v_uint64); - } else { - if (src_data->type_index == TypeIndex::kTVMFFIStr || - src_data->type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* src_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(src); - // return same hash as AnyHash - return details::StableHashCombine(src_data->type_index, - details::StableHashBytes(src_str->data, src_str->size)); - } else { - // if the hash of the object is already computed, return it - auto it = hash_memo_.find(src.cast()); - if (it != hash_memo_.end()) { - return it->second; - } - return std::nullopt; - } - } - } - - uint64_t HashMap(Map map) { - // Compute a deterministic hash value for the map. - uint64_t hash_value = details::StableHashCombine(map->GetTypeKeyHash(), map.size()); - std::vector> items; - for (auto [key, value] : map) { - // if we cannot find order independent hash, we skip the key - if (auto hash_key = FindOrderIndependentHash(key)) { - items.emplace_back(*hash_key, value); - } - } - // sort the items by the hash key, so the hash value is deterministic - // and independent of the order of insertion - std::sort(items.begin(), items.end(), - [](const auto& a, const auto& b) { return a.first < b.first; }); - - for (size_t i = 0; i < items.size();) { - size_t k = i + 1; - for (; k < items.size() && items[k].first == items[i].first; ++k) { - } - // detect ties, which are rare, but we need to skip value hash during ties - // to make sure that the hash value is deterministic. - if (k == i + 1) { - // no ties, we just hash the key and value - hash_value = details::StableHashCombine(hash_value, items[i].first); - hash_value = details::StableHashCombine(hash_value, HashAny(items[i].second)); - } else { - // ties occur, we skip the value hash to make sure that the hash value is deterministic. - hash_value = details::StableHashCombine(hash_value, items[i].first); - } - i = k; - } - return hash_value; - } - - uint64_t HashShape(Shape shape) { - uint64_t hash_value = details::StableHashCombine(shape->GetTypeKeyHash(), shape.size()); - for (size_t i = 0; i < shape.size(); ++i) { - hash_value = details::StableHashCombine(hash_value, shape[i]); - } - return hash_value; - } - - uint64_t HashTensor(Tensor tensor) { - uint64_t hash_value = details::StableHashCombine(tensor->GetTypeKeyHash(), tensor->ndim); - for (int i = 0; i < tensor->ndim; ++i) { - hash_value = details::StableHashCombine(hash_value, tensor->shape[i]); - } - TVMFFIAny temp; - temp.v_uint64 = 0; - temp.v_dtype = tensor->dtype; - hash_value = details::StableHashCombine(hash_value, temp.v_int64); - - if (!skip_tensor_content_) { - TVM_FFI_ICHECK_EQ(tensor->device.device_type, kDLCPU) << "can only hash CPU tensor"; - TVM_FFI_ICHECK(tensor.IsContiguous()) << "Can only hash contiguous tensor"; - size_t data_size = GetDataSize(*(tensor.operator->())); - uint64_t data_hash = - details::StableHashBytes(static_cast(tensor->data), data_size); - hash_value = details::StableHashCombine(hash_value, data_hash); - } - return hash_value; - } - - bool map_free_vars_{false}; - bool skip_tensor_content_{false}; - // free var counter. - uint32_t free_var_counter_{0}; - // graph node counter. - uint32_t graph_node_counter_{0}; - // lazily initialize custom hash function - ffi::Function s_hash_callback_ = nullptr; - // map from lhs to rhs - std::unordered_map hash_memo_; -}; - -uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool skip_tensor_content) { - StructuralHashHandler handler; - handler.map_free_vars_ = map_free_vars; - handler.skip_tensor_content_ = skip_tensor_content; - return handler.HashAny(value); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.StructuralHash", StructuralHash::Hash); - refl::EnsureTypeAttrColumn("__s_hash__"); -} - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/testing.cc b/ffi/src/ffi/extra/testing.cc deleted file mode 100644 index 3d9501d8c460..000000000000 --- a/ffi/src/ffi/extra/testing.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -// This file is used for testing the FFI API. -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -// Step 1: Define the object class (stores the actual data) -class TestIntPairObj : public tvm::ffi::Object { - public: - int64_t a; - int64_t b; - - TestIntPairObj() = default; - TestIntPairObj(int64_t a, int64_t b) : a(a), b(b) {} - - // Required: declare type information - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestIntPair", TestIntPairObj, tvm::ffi::Object); -}; - -// Step 2: Define the reference wrapper (user-facing interface) -class TestIntPair : public tvm::ffi::ObjectRef { - public: - // Constructor - explicit TestIntPair(int64_t a, int64_t b) { - data_ = tvm::ffi::make_object(a, b); - } - - // Required: define object reference methods - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TestIntPair, tvm::ffi::ObjectRef, TestIntPairObj); -}; - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("a", &TestIntPairObj::a) - .def_ro("b", &TestIntPairObj::b) - .def_static("__create__", - [](int64_t a, int64_t b) -> TestIntPair { return TestIntPair(a, b); }); -} - -class TestObjectBase : public Object { - public: - int64_t v_i64; - double v_f64; - String v_str; - - int64_t AddI64(int64_t other) const { return v_i64 + other; } - - // declare as one slot, with float as overflow - static constexpr bool _type_mutable = true; - static constexpr uint32_t _type_child_slots = 1; - TVM_FFI_DECLARE_OBJECT_INFO("testing.TestObjectBase", TestObjectBase, Object); -}; - -class TestObjectDerived : public TestObjectBase { - public: - Map v_map; - Array v_array; - - // declare as one slot, with float as overflow - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestObjectDerived", TestObjectDerived, TestObjectBase); -}; - -TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) { - // keep name and no liner for testing traceback - throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0)); -} - -TVM_FFI_NO_INLINE void TestApply(PackedArgs args, Any* ret) { - // keep name and no liner for testing traceback - auto f = args[0].cast(); - f.CallPacked(args.Slice(1), ret); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - - refl::ObjectDef() - .def_rw("v_i64", &TestObjectBase::v_i64, refl::DefaultValue(10), "i64 field") - .def_ro("v_f64", &TestObjectBase::v_f64, refl::DefaultValue(10.0)) - .def_rw("v_str", &TestObjectBase::v_str, refl::DefaultValue("hello")) - .def("add_i64", &TestObjectBase::AddI64, "add_i64 method"); - - refl::ObjectDef() - .def_ro("v_map", &TestObjectDerived::v_map) - .def_ro("v_array", &TestObjectDerived::v_array); - - refl::GlobalDef() - .def("testing.test_raise_error", TestRaiseError) - .def_packed("testing.nop", [](PackedArgs args, Any* ret) {}) - .def_packed("testing.echo", [](PackedArgs args, Any* ret) { *ret = args[0]; }) - .def_packed("testing.apply", TestApply) - .def("testing.run_check_signal", - [](int nsec) { - for (int i = 0; i < nsec; ++i) { - if (TVMFFIEnvCheckSignals() != 0) { - throw ffi::EnvErrorAlreadySet(); - } - std::this_thread::sleep_for(std::chrono::seconds(1)); - } - std::cout << "Function finished without catching signal" << std::endl; - }) - .def("testing.object_use_count", [](const Object* obj) { return obj->use_count(); }); -} - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc deleted file mode 100644 index b1bee7ee506c..000000000000 --- a/ffi/src/ffi/function.cc +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/function.cc - * \brief Function call registry and safecall context - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Global function table. - * - - * \note We do not use mutex to guard updating of GlobalFunctionTable - * - * The assumption is that updating of GlobalFunctionTable will be done - * in the main thread during initialization or loading, or - * explicitly locked from the caller. - * - * Then the followup code will leverage the information - */ -class GlobalFunctionTable { - public: - // Note: this class is hidden from the public API, so we just - // use it as a private class as ObjectRef - class Entry : public Object, public TVMFFIMethodInfo { - public: - String name_data; - String doc_data; - String type_schema_data; - ffi::Function func_data; - - explicit Entry(const TVMFFIMethodInfo* method_info) { - // make copy of the metadata - name_data = String(method_info->name.data, method_info->name.size); - doc_data = String(method_info->doc.data, method_info->doc.size); - type_schema_data = String(method_info->type_schema.data, method_info->type_schema.size); - func_data = AnyView::CopyFromTVMFFIAny(method_info->method).cast(); - this->SyncMethodInfo(method_info->flags); - // no need to update method pointer as it would remain the same as func and we retained - } - explicit Entry(String name, ffi::Function func) : name_data(name), func_data(func) { - this->SyncMethodInfo(kTVMFFIFieldFlagBitMaskIsStaticMethod); - } - - private: - void SyncMethodInfo(int64_t flags) { - this->flags = flags; - this->name = TVMFFIByteArray{name_data.data(), name_data.size()}; - this->doc = TVMFFIByteArray{doc_data.data(), doc_data.size()}; - this->type_schema = TVMFFIByteArray{type_schema_data.data(), type_schema_data.size()}; - } - }; - - void Update(const String& name, Function func, bool can_override) { - if (table_.count(name)) { - if (!can_override) { - TVM_FFI_THROW(RuntimeError) << "Global Function `" << name << "` is already registered"; - } - } - table_.Set(name, ObjectRef(make_object(name, func))); - } - - void Update(const TVMFFIMethodInfo* method_info, bool can_override) { - String name(method_info->name.data, method_info->name.size); - if (table_.count(name)) { - if (!can_override) { - TVM_FFI_LOG_AND_THROW(RuntimeError) - << "Global Function `" << name << "` is already registered, possible causes:\n" - << "- Two GlobalDef().def registrations for the same function \n" - << "Please remove the duplicate registration."; - } - } - table_.Set(name, ObjectRef(make_object(method_info))); - } - - bool Remove(const String& name) { - auto it = table_.find(name); - if (it == table_.end()) return false; - table_.erase(name); - return true; - } - - const Entry* Get(const String& name) { - auto it = table_.find(name); - if (it == table_.end()) return nullptr; - const Object* obj = (*it).second.cast(); - return static_cast(obj); - } - - Array ListNames() const { - Array names; - names.reserve(table_.size()); - for (const auto& kv : table_) { - names.push_back(kv.first); - } - return names; - } - - static GlobalFunctionTable* Global() { - // We deliberately create a new instance via raw new - // This is because GlobalFunctionTable can contain callbacks into - // the host language (Python) and the resource can become invalid - // indeterministic order of destruction and forking. - // The resources will only be recycled during program exit. - static GlobalFunctionTable* inst = new GlobalFunctionTable(); - return inst; - } - - private: - Map table_; -}; -} // namespace ffi -} // namespace tvm - -int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self), - TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::Function func = tvm::ffi::Function::FromExternC(self, safe_call, deleter); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(func)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::Any result(*reinterpret_cast(any_view)); - *out = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(result)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, int override) { - using namespace tvm::ffi; - TVM_FFI_SAFE_CALL_BEGIN(); - String name_str(name->data, name->size); - GlobalFunctionTable::Global()->Update(name_str, GetRef(static_cast(f)), - override != 0); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo* method_info, int override) { - using namespace tvm::ffi; - TVM_FFI_SAFE_CALL_BEGIN(); - GlobalFunctionTable::Global()->Update(method_info, override != 0); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out) { - using namespace tvm::ffi; - TVM_FFI_SAFE_CALL_BEGIN(); - String name_str(name->data, name->size); - const GlobalFunctionTable::Entry* fp = GlobalFunctionTable::Global()->Get(name_str); - if (fp != nullptr) { - tvm::ffi::Function func(fp->func_data); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(func)); - } else { - *out = nullptr; - } - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result) { - using namespace tvm::ffi; -#ifdef _MSC_VER - // Avoid tail call optimization - // in MSVC many cases python symbols are hidden, so we need this function symbol - // to be in the call frame to reliably detect the ffi boundary - volatile int ret = reinterpret_cast(func)->safe_call(func, args, num_args, result); - return ret; -#else - // NOTE: this is a tail call - return reinterpret_cast(func)->safe_call(func, args, num_args, result); -#endif -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("ffi.FunctionRemoveGlobal", - [](const tvm::ffi::String& name) -> bool { - return tvm::ffi::GlobalFunctionTable::Global()->Remove(name); - }) - .def("ffi.FunctionListGlobalNamesFunctor", - []() { - // NOTE: we return functor instead of array - // so list global function names do not need to depend on array - // this is because list global function names usually is a core api that happens - // before array ffi functions are available. - tvm::ffi::Array names = - tvm::ffi::GlobalFunctionTable::Global()->ListNames(); - auto return_functor = [names](int64_t i) -> tvm::ffi::Any { - if (i < 0) { - return names.size(); - } else { - return names[i]; - } - }; - return tvm::ffi::Function::FromTyped(return_functor); - }) - .def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return val; }) - .def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return val; }); -} diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc deleted file mode 100644 index 292c8e913f1d..000000000000 --- a/ffi/src/ffi/object.cc +++ /dev/null @@ -1,513 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/object.cc - * \brief Registry to record dynamic types - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Global registry that manages - * - * \note We do not use mutex to guard updating of TypeTable - * - * The assumption is that updating of TypeTable will be done - * in the main thread during initialization or loading, or - * explicitly locked from the caller. - * - * Then the followup code will leverage the information - */ -class TypeTable { - public: - /*! \brief Type information */ - struct Entry : public TypeInfo { - /*! \brief stored type key */ - String type_key_data; - /*! \brief acenstor information */ - std::vector type_acenstors_data; - /*! \brief type fields informaton */ - std::vector type_fields_data; - /*! \brief type methods informaton */ - std::vector type_methods_data; - /*! \brief extra information */ - TVMFFITypeMetadata metadata_data; - // NOTE: the indices in [index, index + num_reserved_slots) are - // reserved for the child-class of this type. - /*! \brief Total number of slots reserved for the type and its children. */ - int32_t num_slots; - /*! \brief number of allocated child slots. */ - int32_t allocated_slots; - /*! \brief Whether child can overflow. */ - bool child_slots_can_overflow{true}; - - Entry(int32_t type_index, int32_t type_depth, String type_key, int32_t num_slots, - bool child_slots_can_overflow, const Entry* parent) { - // setup fields in the class - this->type_key_data = std::move(type_key); - this->num_slots = num_slots; - this->allocated_slots = 1; - this->child_slots_can_overflow = child_slots_can_overflow; - // set up type acenstors information - if (type_depth != 0) { - TVM_FFI_ICHECK_NOTNULL(parent); - TVM_FFI_ICHECK_EQ(type_depth, parent->type_depth + 1); - type_acenstors_data.resize(type_depth); - // copy over parent's type information - for (int32_t i = 0; i < parent->type_depth; ++i) { - type_acenstors_data[i] = parent->type_acenstors[i]; - } - // set last type information to be parent - type_acenstors_data[parent->type_depth] = parent; - } - // initialize type info: no change to type_key and type_acenstors fields - // after this line - this->type_index = type_index; - this->type_depth = type_depth; - this->type_key = TVMFFIByteArray{this->type_key_data.data(), this->type_key_data.length()}; - this->type_key_hash = std::hash()(this->type_key_data); - this->type_acenstors = type_acenstors_data.data(); - // initialize the reflection information - this->num_fields = 0; - this->num_methods = 0; - this->fields = nullptr; - this->methods = nullptr; - this->metadata = nullptr; - } - }; - - struct TypeAttrColumnData : public TVMFFITypeAttrColumn { - std::vector data_; - }; - - int32_t GetOrAllocTypeIndex(String type_key, int32_t static_type_index, int32_t type_depth, - int32_t num_child_slots, bool child_slots_can_overflow, - int32_t parent_type_index) { - auto it = type_key2index_.find(type_key); - if (it != type_key2index_.end()) { - return type_table_[(*it).second]->type_index; - } - - // get parent's entry - Entry* parent = [&]() -> Entry* { - if (parent_type_index < 0) return nullptr; - // try to allocate from parent's type table. - TVM_FFI_ICHECK_LT(parent_type_index, type_table_.size()) - << " type_key=" << type_key << ", static_index=" << static_type_index; - return type_table_[parent_type_index].get(); - }(); - - // get allocated index - int32_t allocated_tindex = [&]() { - // Step 0: static allocation - if (static_type_index >= 0) { - TVM_FFI_ICHECK_LT(static_type_index, type_table_.size()); - TVM_FFI_ICHECK(type_table_[static_type_index] == nullptr) - << "Conflicting static index " << static_type_index << " between " - << ToStringView(type_table_[static_type_index]->type_key) << " and " << type_key; - return static_type_index; - } - TVM_FFI_ICHECK_NOTNULL(parent); - int num_slots = num_child_slots + 1; - if (parent->allocated_slots + num_slots <= parent->num_slots) { - // allocate the slot from parent's reserved pool - int32_t allocated_tindex = parent->type_index + parent->allocated_slots; - // update parent's state - parent->allocated_slots += num_slots; - return allocated_tindex; - } - // Step 2: allocate from overflow - TVM_FFI_ICHECK(parent->child_slots_can_overflow) - << "Reach maximum number of sub-classes for " << ToStringView(parent->type_key); - // allocate new entries. - int32_t allocated_tindex = type_counter_; - type_counter_ += num_slots; - TVM_FFI_ICHECK_LE(type_table_.size(), type_counter_); - type_table_.reserve(type_counter_); - // resize type table - while (static_cast(type_table_.size()) < type_counter_) { - type_table_.emplace_back(nullptr); - } - return allocated_tindex; - }(); - - // if parent cannot overflow, then this class cannot. - if (parent != nullptr && !(parent->child_slots_can_overflow)) { - child_slots_can_overflow = false; - } - // total number of slots include the type itself. - - if (parent != nullptr) { - TVM_FFI_ICHECK_GT(allocated_tindex, parent->type_index); - } - - type_table_[allocated_tindex] = - std::make_unique(allocated_tindex, type_depth, type_key, num_child_slots + 1, - child_slots_can_overflow, parent); - // update the key2index mapping. - type_key2index_.Set(type_key, allocated_tindex); - return allocated_tindex; - } - - int32_t TypeKeyToIndex(const TVMFFIByteArray* type_key) { - String type_key_str(type_key->data, type_key->size); - auto it = type_key2index_.find(type_key_str); - TVM_FFI_ICHECK(it != type_key2index_.end()) << "Cannot find type `" << type_key_str << "`"; - return static_cast((*it).second); - } - - Entry* GetTypeEntry(int32_t type_index) { - Entry* entry = nullptr; - if (type_index >= 0 && static_cast(type_index) < type_table_.size()) { - entry = type_table_[type_index].get(); - } - TVM_FFI_ICHECK(entry != nullptr) << "Cannot find type info for type_index=" << type_index; - return entry; - } - - void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { - Entry* entry = GetTypeEntry(type_index); - TVMFFIFieldInfo field_data = *info; - field_data.name = this->CopyString(info->name); - field_data.doc = this->CopyString(info->doc); - field_data.type_schema = this->CopyString(info->type_schema); - if (info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_data.default_value = - this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value)).CopyToTVMFFIAny(); - } else { - field_data.default_value = AnyView(nullptr).CopyToTVMFFIAny(); - } - entry->type_fields_data.push_back(field_data); - // refresh ptr as the data can change - entry->fields = entry->type_fields_data.data(); - entry->num_fields = static_cast(entry->type_fields_data.size()); - } - - void RegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info) { - Entry* entry = GetTypeEntry(type_index); - TVMFFIMethodInfo method_data = *info; - method_data.name = this->CopyString(info->name); - method_data.doc = this->CopyString(info->doc); - method_data.type_schema = this->CopyString(info->type_schema); - method_data.method = this->CopyAny(AnyView::CopyFromTVMFFIAny(info->method)).CopyToTVMFFIAny(); - entry->type_methods_data.push_back(method_data); - entry->methods = entry->type_methods_data.data(); - entry->num_methods = static_cast(entry->type_methods_data.size()); - } - - void RegisterTypeMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) { - Entry* entry = GetTypeEntry(type_index); - if (entry->metadata != nullptr) { - TVM_FFI_LOG_AND_THROW(RuntimeError) - << "Overriding " << ToStringView(entry->type_key) << ", possible causes:\n" - << "- two ObjectDef() calls for the same T \n" - << "- when we forget to assign _type_key to ObjectRef that inherits from T\n" - << "- another type with the same key is already registered\n" - << "Cross check the reflection registration."; - } - entry->metadata_data = *metadata; - entry->metadata_data.doc = this->CopyString(metadata->doc); - entry->metadata = &(entry->metadata_data); - } - - void RegisterTypeAttr(int32_t type_index, const TVMFFIByteArray* name, const TVMFFIAny* value) { - AnyView value_view = AnyView::CopyFromTVMFFIAny(*value); - String name_str(*name); - size_t column_index = 0; - auto it = type_attr_name_to_column_index_.find(name_str); - if (it == type_attr_name_to_column_index_.end()) { - column_index = type_attr_columns_.size(); - type_attr_columns_.emplace_back(std::make_unique()); - type_attr_name_to_column_index_.Set(name_str, column_index); - } else { - column_index = (*it).second; - } - TypeAttrColumnData* column = type_attr_columns_[column_index].get(); - if (column->data_.size() < static_cast(type_index + 1)) { - column->data_.resize(type_index + 1, Any(nullptr)); - column->data = reinterpret_cast(column->data_.data()); - column->size = column->data_.size(); - } - if (type_index == kTVMFFINone) return; - if (column->data_[type_index] != nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type attribute `" << name_str << "` is already set for type `" - << TypeIndexToTypeKey(type_index) << "`"; - } - column->data_[type_index] = value_view; - } - const TVMFFITypeAttrColumn* GetTypeAttrColumn(const TVMFFIByteArray* name) { - String name_str(*name); - auto it = type_attr_name_to_column_index_.find(name_str); - if (it == type_attr_name_to_column_index_.end()) return nullptr; - return type_attr_columns_[(*it).second].get(); - } - - void Dump(int min_children_count) { - std::vector num_children(type_table_.size(), 0); - // expected child slots compute the expected slots - // based on the current child slot setting - std::vector expected_child_slots(type_table_.size(), 0); - // reverse accumulation so we can get total counts in a bottom-up manner. - for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) { - const Entry* ptr = it->get(); - if (ptr != nullptr && ptr->type_depth != 0) { - int parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index; - num_children[parent_index] += num_children[ptr->type_index] + 1; - if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) { - expected_child_slots[ptr->type_index] = ptr->num_slots - 1; - } - expected_child_slots[parent_index] += expected_child_slots[ptr->type_index] + 1; - } - } - - for (const auto& ptr : type_table_) { - if (ptr != nullptr && num_children[ptr->type_index] >= min_children_count) { - std::cerr << '[' << ptr->type_index << "]\t" << ToStringView(ptr->type_key); - if (ptr->type_depth != 0) { - int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index; - std::cerr << "\tparent=" << ToStringView(type_table_[parent_index]->type_key); - } else { - std::cerr << "\tparent=root"; - } - std::cerr << "\tnum_child_slots=" << ptr->num_slots - 1 - << "\tnum_children=" << num_children[ptr->type_index] - << "\texpected_child_slots=" << expected_child_slots[ptr->type_index] - << std::endl; - } - } - } - - static TypeTable* Global() { - static TypeTable inst; - return &inst; - } - - private: - TypeTable() { - type_table_.reserve(TypeIndex::kTVMFFIDynObjectBegin); - for (int32_t i = 0; i < TypeIndex::kTVMFFIDynObjectBegin; ++i) { - type_table_.emplace_back(nullptr); - } - // initialize the entry for object - this->GetOrAllocTypeIndex(String(Object::_type_key), Object::_type_index, Object::_type_depth, - Object::_type_child_slots, Object::_type_child_slots_can_overflow, - -1); - TVMFFITypeMetadata info; - info.total_size = sizeof(Object); - info.creator = nullptr; - info.doc = TVMFFIByteArray{nullptr, 0}; - RegisterTypeMetadata(Object::_type_index, &info); - // reserve the static types - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFINone, TypeIndex::kTVMFFINone); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIInt, TypeIndex::kTVMFFIInt); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIFloat, TypeIndex::kTVMFFIFloat); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIBool, TypeIndex::kTVMFFIBool); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIRawStr, TypeIndex::kTVMFFIRawStr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIOpaquePtr, TypeIndex::kTVMFFIOpaquePtr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDataType, TypeIndex::kTVMFFIDataType); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDevice, TypeIndex::kTVMFFIDevice); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIByteArrayPtr, TypeIndex::kTVMFFIByteArrayPtr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef, - TypeIndex::kTVMFFIObjectRValueRef); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallStr, TypeIndex::kTVMFFISmallStr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallBytes, TypeIndex::kTVMFFISmallBytes); - // no need to reserve for object types as they will be registered - } - - void ReserveBuiltinTypeIndex(const char* type_key, int32_t static_type_index) { - this->GetOrAllocTypeIndex(String(type_key), static_type_index, 0, 0, false, -1); - } - - static ObjectPtr MakeInplaceString(const char* data, size_t length) { - ObjectPtr p = - make_inplace_array_object(length + 1); - static_assert(alignof(details::StringObj) % alignof(char) == 0); - static_assert(sizeof(details::StringObj) % alignof(char) == 0); - char* dest_data = reinterpret_cast(p.get()) + sizeof(details::StringObj); - p->data = dest_data; - p->size = length; - std::memcpy(dest_data, data, length); - dest_data[length] = '\0'; - return p; - } - - TVMFFIByteArray CopyString(TVMFFIByteArray str) { - if (str.size == 0) { - return TVMFFIByteArray{nullptr, 0}; - } - // use explicit object creation to ensure the space pointer to not move - auto str_obj = MakeInplaceString(str.data, str.size); - TVMFFIByteArray c_val{str_obj->data, str_obj->size}; - any_pool_.emplace_back(ObjectRef(std::move(str_obj))); - return c_val; - } - - AnyView CopyAny(Any val) { - AnyView view = AnyView(val); - any_pool_.emplace_back(std::move(val)); - return view; - } - - int64_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin}; - std::vector> type_table_; - Map type_key2index_; - std::vector any_pool_; - // type attribute columns - std::vector> type_attr_columns_; - Map type_attr_name_to_column_index_; -}; - -/** - * \brief Opaque implementation - */ -class OpaqueObjectImpl : public Object, public TVMFFIOpaqueObjectCell { - public: - OpaqueObjectImpl(void* handle, void (*deleter)(void* handle)) : deleter_(deleter) { - this->handle = handle; - } - - void SetTypeIndex(int32_t type_index) { - details::ObjectUnsafe::GetHeader(this)->type_index = type_index; - } - - ~OpaqueObjectImpl() { - if (deleter_ != nullptr) { - deleter_(handle); - } - } - - private: - void (*deleter_)(void* handle); -}; - -} // namespace ffi -} // namespace tvm - -int TVMFFIObjectDecRef(TVMFFIObjectHandle handle) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIObjectIncRef(TVMFFIObjectHandle handle) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(handle); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, void (*deleter)(void* handle), - TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - if (type_index != kTVMFFIOpaquePyObject) { - TVM_FFI_THROW(RuntimeError) << "Only kTVMFFIOpaquePyObject is supported for now"; - } - // create initial opaque object - tvm::ffi::ObjectPtr p = - tvm::ffi::make_object(handle, deleter); - // need to set the type index after creation, because the set to RuntimeTypeIndex() - // happens after the constructor is called - p->SetTypeIndex(type_index); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(p)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) { - TVM_FFI_SAFE_CALL_BEGIN(); - out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeMethod(type_index, info); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeMetadata(type_index, metadata); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* name, - const TVMFFIAny* value) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeAttr(type_index, name, value); - TVM_FFI_SAFE_CALL_END(); -} - -const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* name) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::TypeTable::Global()->GetTypeAttrColumn(name); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeAttrColumn); -} - -int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t static_type_index, - int32_t type_depth, int32_t num_child_slots, - int32_t child_slots_can_overflow, int32_t parent_type_index) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - tvm::ffi::String s_type_key(type_key->data, type_key->size); - return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex( - s_type_key, static_type_index, type_depth, num_child_slots, child_slots_can_overflow, - parent_type_index); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFITypeGetOrAllocIndex); -} - -const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::TypeTable::Global()->GetTypeEntry(type_index); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo); -} - -// string APIs, we blend into object.cc to keep things simple -int TVMFFIStringFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - // must set to none first - out->type_index = kTVMFFINone; - tvm::ffi::TypeTraits::MoveToAny(tvm::ffi::String(input->data, input->size), - out); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - // must set to none first - out->type_index = kTVMFFINone; - tvm::ffi::TypeTraits::MoveToAny(tvm::ffi::Bytes(input->data, input->size), out); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/tensor.cc b/ffi/src/ffi/tensor.cc deleted file mode 100644 index d40828012fb1..000000000000 --- a/ffi/src/ffi/tensor.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/tensor.cc - * \brief Tensor C API implementation - */ -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("ffi.Shape", [](ffi::PackedArgs args, Any* ret) { - int64_t* mutable_data; - ObjectPtr shape = details::MakeEmptyShape(args.size(), &mutable_data); - for (int i = 0; i < args.size(); ++i) { - if (auto opt_int = args[i].try_cast()) { - mutable_data[i] = *opt_int; - } else { - TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; - } - } - *ret = details::ObjectUnsafe::ObjectRefFromObjectPtr(shape); - }); -} - -} // namespace ffi -} // namespace tvm - -int TVMFFITensorFromDLPack(DLManagedTensor* from, int32_t min_alignment, int32_t require_contiguous, - TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::Tensor tensor = - tvm::ffi::Tensor::FromDLPack(from, static_cast(min_alignment), require_contiguous); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(tensor)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* from, int32_t min_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::Tensor tensor = tvm::ffi::Tensor::FromDLPackVersioned( - from, static_cast(min_alignment), require_contiguous); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(tensor)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITensorToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned( - static_cast(from)) - ->ToDLPack(); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle from, DLManagedTensorVersioned** out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned( - static_cast(from)) - ->ToDLPackVersioned(); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/traceback.cc b/ffi/src/ffi/traceback.cc deleted file mode 100644 index 57638d704e3b..000000000000 --- a/ffi/src/ffi/traceback.cc +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file traceback.cc - * \brief Traceback implementation on non-windows platforms - * \note We use the term "traceback" to be consistent with python naming convention. - */ -#ifndef _MSC_VER - -#include "./traceback.h" - -#include -#include - -#if TVM_FFI_USE_LIBBACKTRACE - -#include -#include - -#include -#include -#include -#include - -#if TVM_FFI_BACKTRACE_ON_SEGFAULT -#include -#endif - -namespace tvm { -namespace ffi { -namespace { -void BacktraceCreateErrorCallback(void*, const char* msg, int) { - std::cerr << "Could not initialize backtrace state: " << msg << std::endl; -} - -backtrace_state* BacktraceCreate() { - return backtrace_create_state(nullptr, 1, BacktraceCreateErrorCallback, nullptr); -} - -static backtrace_state* _bt_state = BacktraceCreate(); - -std::string DemangleName(std::string name) { - int status = 0; - size_t length = name.size(); - char* demangled_name = abi::__cxa_demangle(name.c_str(), nullptr, &length, &status); - if (demangled_name && status == 0 && length > 0) { - name = demangled_name; - } - if (demangled_name) { - std::free(demangled_name); - } - return name; -} - -void BacktraceErrorCallback(void*, const char*, int) { - // do nothing -} - -void BacktraceSyminfoCallback(void* data, uintptr_t pc, const char* symname, uintptr_t, uintptr_t) { - auto str = reinterpret_cast(data); - - if (symname != nullptr) { - *str = DemangleName(symname); - } else { - std::ostringstream s; - s << "0x" << std::setfill('0') << std::setw(sizeof(uintptr_t) * 2) << std::hex << pc; - *str = s.str(); - } -} - -int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int lineno, - const char* symbol) { - auto stack_trace = reinterpret_cast(data); - std::string symbol_str = ""; - if (symbol) { - symbol_str = DemangleName(symbol); - } else { - // see if syminfo gives anything - backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, BacktraceErrorCallback, &symbol_str); - } - symbol = symbol_str.data(); - if (stack_trace->ExceedTracebackLimit()) { - return 1; - } - if (stack_trace->stop_at_boundary && DetectFFIBoundary(filename, symbol)) { - return 1; - } - // skip extra frames - if (stack_trace->skip_frame_count > 0) { - stack_trace->skip_frame_count--; - return 0; - } - if (ShouldExcludeFrame(filename, symbol)) { - return 0; - } - stack_trace->Append(filename, symbol, lineno); - return 0; -} -} // namespace -} // namespace ffi -} // namespace tvm - -const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func, - int cross_ffi_boundary) { - // We collapse the traceback into a single function - // to simplify the traceback detection handling (since we need to detect TVMFFITraceback) - static thread_local std::string traceback_str; - static thread_local TVMFFIByteArray traceback_array; - // pass in current line as here so last line of traceback is always accurate - tvm::ffi::TracebackStorage traceback; - traceback.stop_at_boundary = cross_ffi_boundary == 0; - if (filename != nullptr && func != nullptr) { - // need to skip TVMFFITraceback and the caller function - // which is already included in filename and func - traceback.skip_frame_count = 2; - if (!tvm::ffi::ShouldExcludeFrame(filename, func)) { - traceback.Append(filename, func, lineno); - } - } - // libbacktrace eats memory if run on multiple threads at the same time, so we guard against it - if (tvm::ffi::_bt_state != nullptr) { - static std::mutex m; - std::lock_guard lock(m); - backtrace_full(tvm::ffi::_bt_state, 0, tvm::ffi::BacktraceFullCallback, - tvm::ffi::BacktraceErrorCallback, &traceback); - } - traceback_str = traceback.GetTraceback(); - traceback_array.data = traceback_str.data(); - traceback_array.size = traceback_str.size(); - return &traceback_array; -} - -#if TVM_FFI_BACKTRACE_ON_SEGFAULT -void TVMFFISegFaultHandler(int sig) { - // Technically we shouldn't do any allocation in a signal handler, but - // Backtrace may allocate. What's the worst it could do? We're already - // crashing. - const TVMFFIByteArray* traceback = TVMFFITraceback(nullptr, 0, nullptr, 1); - std::cerr << "!!!!!!! Segfault encountered !!!!!!!\n" - << std::string(traceback->data, traceback->size) << std::endl; - // Re-raise signal with default handler - struct sigaction act; - std::memset(&act, 0, sizeof(struct sigaction)); - act.sa_flags = SA_RESETHAND; - act.sa_handler = SIG_DFL; - sigaction(sig, &act, nullptr); - raise(sig); -} - -__attribute__((constructor)) void TVMFFIInstallSignalHandler(void) { - // this may override already installed signal handlers - std::signal(SIGSEGV, TVMFFISegFaultHandler); -} -#endif // TVM_FFI_BACKTRACE_ON_SEGFAULT -#else -// fallback implementation simply print out the last trace -const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func, - int cross_ffi_boundary) { - static thread_local std::string traceback_str; - static thread_local TVMFFIByteArray traceback_array; - std::ostringstream traceback_stream; - if (filename != nullptr && func != nullptr) { - // python style backtrace - traceback_stream << " File \"" << filename << "\", line " << lineno << ", in " << func << '\n'; - } - traceback_str = traceback_stream.str(); - traceback_array.data = traceback_str.data(); - traceback_array.size = traceback_str.size(); - return &traceback_array; -} -#endif // TVM_FFI_USE_LIBBACKTRACE -#endif // _MSC_VER diff --git a/ffi/src/ffi/traceback.h b/ffi/src/ffi/traceback.h deleted file mode 100644 index 710414490367..000000000000 --- a/ffi/src/ffi/traceback.h +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file traceback.h - * \brief Common headers for traceback. - * \note We use the term "traceback" to be consistent with python naming convention. - */ -#ifndef TVM_FFI_TRACEBACK_H_ -#define TVM_FFI_TRACEBACK_H_ - -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4996) // std::getenv is unsafe -#endif - -inline int32_t GetTracebackLimit() { - if (const char* env = std::getenv("TVM_TRACEBACK_LIMIT")) { - return std::stoi(env); - } - return 512; -} - -#ifdef _MSC_VER -#pragma warning(pop) -#endif - -/*! - * \brief List frame patterns that should be excluded as they contain less information - */ -inline bool ShouldExcludeFrame(const char* filename, const char* symbol) { - if (symbol != nullptr) { - if (strncmp(symbol, "tvm::ffi::Function", 18) == 0) { - return true; - } - if (strncmp(symbol, "tvm::ffi::details::", 19) == 0) { - return true; - } - if (strncmp(symbol, "TVMFFITraceback", 15) == 0) { - return true; - } - if (strncmp(symbol, "TVMFFIErrorSetRaisedFromCStr", 28) == 0) { - return true; - } - // C++ stdlib frames - if (strncmp(symbol, "__libc_", 7) == 0) { - return true; - } - // libffi.so stack frames. These may also show up as numeric - // addresses with no symbol name. This could be improved in the - // future by using dladdr() to check whether an address is contained - // in libffi.so - if (strncmp(symbol, "ffi_call_", 9) == 0) { - return true; - } - } - if (filename) { - // Stack frames for TVM FFI - if (strstr(filename, "include/tvm/ffi/error.h") != nullptr) { - return true; - } - if (strstr(filename, "include/tvm/ffi/function_details.h") != nullptr) { - return true; - } - if (strstr(filename, "include/tvm/ffi/function.h") != nullptr) { - return true; - } - if (strstr(filename, "include/tvm/ffi/any.h") != nullptr) { - return true; - } - // C++ stdlib frames - if (strstr(filename, "include/c++/") != nullptr) { - return true; - } - } - return false; -} - -/** - * \brief List frames that should stop the traceback. - * \param filename The filename of the frame. - * \param symbol The symbol name of the frame. - * \return true if the frame should stop the traceback. - * \note We stop traceback at the FFI boundary. - */ -inline bool DetectFFIBoundary(const char* filename, const char* symbol) { - if (symbol != nullptr) { - if (strncmp(symbol, "TVMFFIFunctionCall", 18) == 0) { - return true; - } - // python ABI functions - if (strncmp(symbol, "slot_tp_call", 12) == 0) { - return true; - } - if (strncmp(symbol, "object_is_not_callable", 11) == 0) { - return true; - } - // Python interpreter stack frames - // we stop traceback at the Python interpreter stack frames - // since these frame will be handled from by the python side. - if (strncmp(symbol, "_Py", 3) == 0 || strncmp(symbol, "PyObject", 8) == 0) { - return true; - } - } - return false; -} - -/*! - * \brief storage to store traceback - */ -struct TracebackStorage { - std::vector lines; - /*! \brief Maximum size of the traceback. */ - size_t max_frame_size = GetTracebackLimit(); - /*! \brief Number of frames to skip. */ - size_t skip_frame_count = 0; - /*! \brief Whether to stop at the ffi boundary. */ - bool stop_at_boundary = true; - - void Append(const char* filename, const char* func, int lineno) { - // skip frames with empty filename - if (filename == nullptr) { - if (func != nullptr) { - if (strncmp(func, "0x0", 3) == 0) { - return; - } - if (strncmp(func, "", 9) == 0) { - return; - } - filename = ""; - } else { - return; - } - } - std::ostringstream trackeback_stream; - trackeback_stream << " File \"" << filename << "\""; - trackeback_stream << ", line " << lineno; - trackeback_stream << ", in " << func << '\n'; - lines.push_back(trackeback_stream.str()); - } - - bool ExceedTracebackLimit() const { return lines.size() >= max_frame_size; } - - // get traceback in the order of most recent call last - std::string GetTraceback() const { - std::string traceback; - for (auto it = lines.rbegin(); it != lines.rend(); ++it) { - traceback.insert(traceback.end(), it->begin(), it->end()); - } - return traceback; - } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_TRACEBACK_H_ diff --git a/ffi/src/ffi/traceback_win.cc b/ffi/src/ffi/traceback_win.cc deleted file mode 100644 index ae7d85dc6720..000000000000 --- a/ffi/src/ffi/traceback_win.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file traceback_win.cc - * \brief Traceback implementation on windows platform - * \note We use the term "traceback" to be consistent with python naming convention. - */ -#ifdef _MSC_VER - -// clang-format off -#include -#include // NOLINT(*) -// clang-format on - -#include -#include - -#include -#include - -#include "./traceback.h" - -const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func, - int cross_ffi_boundary) { - static thread_local std::string traceback_str; - static thread_local TVMFFIByteArray traceback_array; - - // pass in current line as here so last line of traceback is always accurate - tvm::ffi::TracebackStorage traceback; - traceback.stop_at_boundary = cross_ffi_boundary == 0; - if (filename != nullptr && func != nullptr) { - // need to skip TVMFFITraceback and the caller function - // which is already included in filename and func - traceback.skip_frame_count = 2; - traceback.Append(filename, func, lineno); - } - - HANDLE process = GetCurrentProcess(); - HANDLE thread = GetCurrentThread(); - - SymSetOptions(SYMOPT_LOAD_LINES | SYMOPT_UNDNAME); - SymInitialize(process, NULL, TRUE); - CONTEXT context = {}; - RtlCaptureContext(&context); - - STACKFRAME64 stack = {}; - DWORD machine_type; - -#if defined(_M_X64) - machine_type = IMAGE_FILE_MACHINE_AMD64; - stack.AddrPC.Offset = context.Rip; - stack.AddrFrame.Offset = context.Rbp; - stack.AddrStack.Offset = context.Rsp; -#elif defined(_M_IX86) - machine_type = IMAGE_FILE_MACHINE_I386; - stack.AddrPC.Offset = context.Eip; - stack.AddrFrame.Offset = context.Ebp; - stack.AddrStack.Offset = context.Esp; -#else -#error "Platform not supported!" -#endif - - stack.AddrPC.Mode = AddrModeFlat; - stack.AddrFrame.Mode = AddrModeFlat; - stack.AddrStack.Mode = AddrModeFlat; - - while (!traceback.ExceedTracebackLimit()) { - if (!StackWalk64(machine_type, process, thread, &stack, &context, nullptr, - SymFunctionTableAccess64, SymGetModuleBase64, nullptr)) { - break; - } - - if (stack.AddrPC.Offset == 0) { - break; - } - const char* filename = nullptr; - const char* symbol = ""; - int lineno = 0; - // Get file and line number - IMAGEHLP_LINE64 line_info; - ZeroMemory(&line_info, sizeof(IMAGEHLP_LINE64)); - line_info.SizeOfStruct = sizeof(IMAGEHLP_LINE64); - DWORD displacement32 = 0; - - if (SymGetLineFromAddr64(process, stack.AddrPC.Offset, &displacement32, &line_info)) { - filename = line_info.FileName; - lineno = line_info.LineNumber; - } - // allocate symbol info that aligns to the SYMBOL_INFO - // we use u64 here to be safe - size_t total_symbol_bytes = sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR); - size_t total_u64_words = (total_symbol_bytes + 7) / 8; - static_assert(8 % alignof(SYMBOL_INFO) == 0); - std::vector symbol_buffer(total_u64_words, 0); - if (filename != nullptr) { - // only run symbol translation if we have the file name - // this is because SymFromAddr can return wrong symbol which becomes even more - // confusing when pdb file do not exist - PSYMBOL_INFO symbol_info = reinterpret_cast(symbol_buffer.data()); - symbol_info->SizeOfStruct = sizeof(SYMBOL_INFO); - symbol_info->MaxNameLen = MAX_SYM_NAME; - DWORD64 displacement = 0; - if (SymFromAddr(process, stack.AddrPC.Offset, &displacement, symbol_info)) { - symbol = symbol_info->Name; - } - } - if (traceback.stop_at_boundary && tvm::ffi::DetectFFIBoundary(filename, symbol)) { - break; - } - // skip extra frames - if (traceback.skip_frame_count > 0) { - traceback.skip_frame_count--; - continue; - } - if (tvm::ffi::ShouldExcludeFrame(filename, symbol)) { - continue; - } - traceback.Append(filename, symbol, lineno); - } - SymCleanup(process); - traceback_str = traceback.GetTraceback(); - traceback_array.data = traceback_str.data(); - traceback_array.size = traceback_str.size(); - return &traceback_array; -} -#endif // _MSC_VER diff --git a/ffi/tests/cpp/CMakeLists.txt b/ffi/tests/cpp/CMakeLists.txt deleted file mode 100644 index c807fad21674..000000000000 --- a/ffi/tests/cpp/CMakeLists.txt +++ /dev/null @@ -1,33 +0,0 @@ -file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc") -file(GLOB _test_extra_sources "${CMAKE_CURRENT_SOURCE_DIR}/extra/test*.cc") - -if (TVM_FFI_USE_EXTRA_CXX_API) - list(APPEND _test_sources ${_test_extra_sources}) -endif() - -add_executable( - tvm_ffi_tests - EXCLUDE_FROM_ALL - ${_test_sources} -) - -set_target_properties( - tvm_ffi_tests PROPERTIES - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - CXX_EXTENSIONS OFF - ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" -) - -tvm_ffi_add_cxx_warning(tvm_ffi_tests) -add_sanitizer_address(tvm_ffi_tests) -tvm_ffi_add_apple_dsymutil(tvm_ffi_tests) -tvm_ffi_add_msvc_flags(tvm_ffi_tests) -target_link_libraries(tvm_ffi_tests PRIVATE tvm_ffi_shared) -tvm_ffi_add_googletest(tvm_ffi_tests) - -if (MSVC) - target_link_options(tvm_ffi_tests PRIVATE /DEBUG) -endif() diff --git a/ffi/tests/cpp/extra/test_json_parser.cc b/ffi/tests/cpp/extra/test_json_parser.cc deleted file mode 100644 index a1cc2800094f..000000000000 --- a/ffi/tests/cpp/extra/test_json_parser.cc +++ /dev/null @@ -1,394 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include - -namespace { - -using namespace tvm::ffi; - -inline bool FastMathSafeIsNaN(double x) { -#ifdef __FAST_MATH__ - // Bit-level NaN detection (IEEE 754 double) - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // NaN is encoded as all 1s in the exponent and non-zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - uint64_t bits = *reinterpret_cast(&x); - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - return (exponent == 0x7FF) && (mantissa != 0); -#else - // Safe to use std::isnan when fast-math is off - return std::isnan(x); -#endif -} - -inline bool FastMathSafeIsInf(double x) { -#ifdef __FAST_MATH__ - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // Inf is encoded as all 1s in the exponent and zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - uint64_t bits = *reinterpret_cast(&x); - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - // inf is encoded as all 1s in the exponent and zero in the mantissa - return (exponent == 0x7FF) && (mantissa == 0); -#else - return std::isinf(x); -#endif -} - -TEST(JSONParser, BoolNull) { - // boolean value - EXPECT_EQ(json::Parse("true").cast(), true); - EXPECT_EQ(json::Parse("false").cast(), false); - EXPECT_EQ(json::Parse("null"), nullptr); -} - -TEST(JSONParser, WrongBoolNull) { - String error_msg; - EXPECT_EQ(json::Parse("nul", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("fals", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("\n\nfx", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 3 column 1 (char 2)"); - EXPECT_EQ(json::Parse("fx", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("n1", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("t1", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("f1", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); -} - -TEST(JSONParser, Number) { - // number - EXPECT_EQ(json::Parse("123").cast(), 123); - EXPECT_EQ(json::Parse("-124").cast(), -124); - EXPECT_EQ(json::Parse("123.456").cast(), 123.456); - // parsing scientific notation - EXPECT_EQ(json::Parse("1.456e12").cast(), 1.456e12); - // NaN - EXPECT_EQ(FastMathSafeIsNaN(json::Parse("NaN").cast()), true); - // Infinity - EXPECT_EQ(FastMathSafeIsInf(json::Parse("Infinity").cast()), true); - // -Infinity - EXPECT_EQ(FastMathSafeIsInf(-json::Parse("-Infinity").cast()), true); - - // Test zero variants - EXPECT_EQ(json::Parse("0").cast(), 0); - EXPECT_EQ(json::Parse("-0").cast(), -0.0); - EXPECT_EQ(json::Parse("0.0").cast(), 0.0); - - // Test very large numbers - EXPECT_EQ(json::Parse("9223372036854775807").cast(), - std::numeric_limits::max()); - EXPECT_EQ(json::Parse("-9223372036854775808").cast(), - std::numeric_limits::min()); - - // Test very small decimals - EXPECT_EQ(json::Parse("1e-10").cast(), 1e-10); - EXPECT_EQ(json::Parse("-1e-10").cast(), -1e-10); - - // Test scientific notation edge cases - EXPECT_EQ(json::Parse("1E+10").cast(), 1E+10); - EXPECT_EQ(json::Parse("1e+10").cast(), 1e+10); - EXPECT_EQ(json::Parse("1E-10").cast(), 1E-10); - EXPECT_EQ(json::Parse("123.456E+10").cast(), 123.456E+10); -} - -TEST(JSONParser, WrongNumber) { - String error_msg; - EXPECT_EQ(json::Parse("123.456.789", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - - // Test invalid number formats - EXPECT_EQ(json::Parse("123e", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("123e+", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("123E-", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); -} - -TEST(JSONParser, String) { - EXPECT_EQ(json::Parse("\"hello\"").cast(), "hello"); - EXPECT_EQ(json::Parse("\n\t \"hello\"\n\r").cast(), "hello"); - EXPECT_EQ(json::Parse("\"hello\\nworld\"").cast(), "hello\nworld"); - EXPECT_EQ(json::Parse("\"\"").cast(), ""); - // test escape characters - EXPECT_EQ(json::Parse("\"\\ta\\n\\/\\f\\\"\\\\\"").cast(), "\ta\n/\f\"\\"); - // test unicode code point - EXPECT_EQ(json::Parse("\"\\u0041\"").cast(), "A"); - // test unicode surrogate pair - EXPECT_EQ(json::Parse("\"\\uD83D\\uDE04hello\"").cast(), u8"\U0001F604hello"); -} - -TEST(JSONParser, WrongString) { - String error_msg; - EXPECT_EQ(json::Parse("\"hello", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Unterminated string starting at: line 1 column 1 (char 0)"); - - EXPECT_EQ(json::Parse("\"hello\x01\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid control character at: line 1 column 7 (char 6)"); - - EXPECT_EQ(json::Parse("\"hello\\uxx\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid \\uXXXX escape: line 1 column 8 (char 7)"); - - EXPECT_EQ(json::Parse("\"hello\\uDC00\\uDE04\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid surrogate pair of \\uXXXX escapes: line 1 column 8 (char 7)"); - - EXPECT_EQ(json::Parse("\"hello\\uD800\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid surrogate pair of \\uXXXX escapes: line 1 column 8 (char 7)"); - - EXPECT_EQ(json::Parse("\"hello\\uD800\\uxx\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid \\uXXXX escape: line 1 column 15 (char 14)"); - - EXPECT_EQ(json::Parse("\"hello\\a\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid \\escape: line 1 column 8 (char 7)"); -} - -TEST(JSONParser, Array) { - EXPECT_TRUE(StructuralEqual()(json::Parse("[]"), json::Array{})); - - EXPECT_TRUE(StructuralEqual()(json::Parse("[1, 2,\n\t\"a\"]"), json::Array{1, 2, "a"})); -} - -TEST(JSONParser, WrongArray) { - String error_msg; - - EXPECT_EQ(json::Parse("]", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - - EXPECT_EQ(json::Parse("[1,]", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 4 (char 3)"); - - EXPECT_EQ(json::Parse("[", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 2 (char 1)"); - - EXPECT_EQ(json::Parse("[1a", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ',' delimiter: line 1 column 3 (char 2)"); - - EXPECT_EQ(json::Parse("[1,2,3", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ',' delimiter: line 1 column 7 (char 6)"); - - EXPECT_EQ(json::Parse("[1] a", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 6 (char 5)"); -} - -TEST(JSONParser, Object) { - EXPECT_TRUE(StructuralEqual()(json::Parse("{}"), json::Object{})); - - EXPECT_TRUE(StructuralEqual()(json::Parse("{\"a\": 1, \n\"b\": \t\"c\"} "), - json::Object{{"a", 1}, {"b", "c"}})); -} - -TEST(JSONParser, ObjectOrderPreserving) { - auto obj = json::Parse("{\"c\": 1, \"a\": 2, \"b\": 3} "); - json::Array keys; - for (auto& [key, value] : obj.cast()) { - keys.push_back(key); - } - EXPECT_TRUE(StructuralEqual()(keys, json::Array{"c", "a", "b"})); -} - -TEST(JSONParser, WrongObject) { - String error_msg; - EXPECT_EQ(json::Parse("{\"a\":", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 6 (char 5)"); - - EXPECT_EQ(json::Parse("{", &error_msg), nullptr); - EXPECT_EQ(error_msg, - "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)"); - - // Test incomplete structures - EXPECT_EQ(json::Parse("{\"incomplete\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ':' delimiter: line 1 column 14 (char 13)"); -} - -TEST(JSONParser, NestedObject) { - EXPECT_TRUE( - StructuralEqual()(json::Parse("{\"a\": \t{\"b\": 1}, \n\"c\": [1, 2, 3]}"), - json::Object{{"a", json::Object{{"b", 1}}}, {"c", json::Array{1, 2, 3}}})); - - EXPECT_TRUE(StructuralEqual()( - json::Parse("{\"a\": \t{\"b\": 1}, \n\"c\": [1, null, Infinity]}"), - json::Object{{"a", json::Object{{"b", 1}}}, - {"c", json::Array{1, nullptr, std::numeric_limits::infinity()}}})); - - EXPECT_TRUE(StructuralEqual()( - json::Parse("[{}, {\"a\": [1.1, 1000000]}]"), - json::Array{json::Object{}, json::Object{{"a", json::Array{1.1, 1000000}}}})); -} - -TEST(JSONParser, WrongNestedObject) { - String error_msg; - EXPECT_EQ(json::Parse("{\"a\":\n\n[1]", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ',' delimiter: line 3 column 4 (char 10)"); - - EXPECT_EQ(json::Parse("{\"a\":\n\n[abc]}", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 3 column 2 (char 8)"); -} - -// edge cases -TEST(JSONParser, WhitespaceHandling) { - // Test various whitespace characters - EXPECT_EQ(json::Parse(" \t\n\r true \t\n\r ").cast(), true); - EXPECT_EQ(json::Parse("\n\n\n123\n\n\n").cast(), 123); - EXPECT_EQ(json::Parse(" \"hello world\" ").cast(), "hello world"); - - // Test whitespace in arrays and objects - EXPECT_TRUE(StructuralEqual()(json::Parse(" [ 1 , 2 , 3 ] "), json::Array{1, 2, 3})); - - EXPECT_TRUE(StructuralEqual()(json::Parse(" { \"a\" : 1 , \"b\" : 2 } "), - json::Object{{"a", 1}, {"b", 2}})); -} - -TEST(JSONParser, WrongEmptyAndMinimalInputs) { - String error_msg; - // Test empty string - EXPECT_EQ(json::Parse("", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - - // Test only whitespace - EXPECT_EQ(json::Parse(" \t\n ", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 2 column 5 (char 9)"); -} - -TEST(JSONParser, UnicodeEdgeCases) { - // Test various unicode characters - EXPECT_EQ(json::Parse("\"\\u0000\"").cast(), std::string("\0", 1)); - // replace using \U to avoid encoding issues - EXPECT_EQ(json::Parse("\"\\u00FF\"").cast(), u8"\U000000FF"); - EXPECT_EQ(json::Parse("\"\\u4E2D\\u6587\"").cast(), u8"\U00004E2D\U00006587"); - - // Test multiple surrogate pairs - EXPECT_EQ(json::Parse("\"\\uD83D\\uDE00\\uD83D\\uDE01\"").cast(), - u8"\U0001F600\U0001F601"); -} - -TEST(JSONParser, LargeInputs) { - // Test large array - std::string large_array = "["; - for (int i = 0; i < 1000; ++i) { - if (i > 0) large_array += ","; - large_array += std::to_string(i); - } - large_array += "]"; - - auto result = json::Parse(large_array); - EXPECT_TRUE(result != nullptr); - EXPECT_EQ(result.cast().size(), 1000); - - // Test large object - std::string large_object = "{"; - for (int i = 0; i < 500; ++i) { - if (i > 0) large_object += ","; - large_object += "\"key" + std::to_string(i) + "\":" + std::to_string(i); - } - large_object += "}"; - - result = json::Parse(large_object); - EXPECT_TRUE(result != nullptr); - EXPECT_EQ(result.cast().size(), 500); -} - -TEST(JSONParser, MixedDataTypes) { - // Test complex nested structure with all data types - std::string complex_json = R"({ - "null_value": null, - "boolean_true": true, - "boolean_false": false, - "integer": 42, - "negative_integer": -42, - "float": 3.14159, - "scientific": 1.23e-4, - "string": "hello world", - "unicode_string": "Hello \u4e16\u754c \ud83c\udf0d", - "empty_string": "", - "empty_array": [], - "empty_object": {}, - "number_array": [1, 2, 3, 4, 5], - "mixed_array": [1, "two", true, null, 3.14], - "nested_object": { - "level1": { - "level2": { - "data": [1, 2, {"nested_array": [true, false]}] - } - } - } - })"; - - auto result = json::Parse(complex_json); - - // Create expected structure for comparison - json::Object expected{ - {"null_value", nullptr}, - {"boolean_true", true}, - {"boolean_false", false}, - {"integer", 42}, - {"negative_integer", -42}, - {"float", 3.14159}, - {"scientific", 1.23e-4}, - {"string", "hello world"}, - {"unicode_string", u8"Hello \U00004E16\U0000754C \U0001F30D"}, - {"empty_string", ""}, - {"empty_array", json::Array{}}, - {"empty_object", json::Object{}}, - {"number_array", json::Array{1, 2, 3, 4, 5}}, - {"mixed_array", json::Array{1, "two", true, nullptr, 3.14}}, - {"nested_object", - json::Object{ - {"level1", - json::Object{ - {"level2", - json::Object{ - {"data", - json::Array{1, 2, - json::Object{{"nested_array", json::Array{true, false}}}}}}}}}}}}; - - EXPECT_TRUE(StructuralEqual()(result, expected)); -} - -TEST(JSONParser, WrongExtraData) { - String error_msg; - - EXPECT_EQ(json::Parse("truee", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 5 (char 4)"); - - EXPECT_EQ(json::Parse("true false", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 6 (char 5)"); - - EXPECT_EQ(json::Parse("123 456", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 5 (char 4)"); - - EXPECT_EQ(json::Parse("\"hello\" \"world\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 9 (char 8)"); - - EXPECT_EQ(json::Parse("{} []", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 4 (char 3)"); -} -} // namespace diff --git a/ffi/tests/cpp/extra/test_json_writer.cc b/ffi/tests/cpp/extra/test_json_writer.cc deleted file mode 100644 index ae6172c2e53b..000000000000 --- a/ffi/tests/cpp/extra/test_json_writer.cc +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include - -namespace { - -using namespace tvm::ffi; - -TEST(JSONWriter, BoolNull) { - // boolean value - EXPECT_EQ(json::Stringify(json::Value(true)), "true"); - EXPECT_EQ(json::Stringify(json::Value(false)), "false"); - EXPECT_EQ(json::Stringify(json::Value(nullptr)), "null"); -} - -TEST(JSONWriter, Integer) { - // positive integer - EXPECT_EQ(json::Stringify(json::Value(42)), "42"); - // negative integer - EXPECT_EQ(json::Stringify(json::Value(-123)), "-123"); - // zero - EXPECT_EQ(json::Stringify(json::Value(0)), "0"); - // large positive integer - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::max())), - "9223372036854775807"); - // large negative integer - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::min())), - "-9223372036854775808"); -} - -TEST(JSONWriter, Float) { - // regular float - EXPECT_EQ(json::Stringify(json::Value(2.5)), "2.5"); - // integer-like float (should have .0 suffix) - EXPECT_EQ(json::Stringify(json::Value(5.0)), "5.0"); - EXPECT_EQ(json::Stringify(json::Value(-10.0)), "-10.0"); - // zero float - EXPECT_EQ(json::Stringify(json::Value(0.0)), "0.0"); - // scientific notation for very small numbers - EXPECT_EQ(json::Stringify(json::Value(-7.89e-15)), "-7.89e-15"); - // short scientific notation (shorter than fixed-point) - EXPECT_EQ(json::Stringify(json::Value(2e-8)), "2e-08"); - // NaN - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::quiet_NaN())), "NaN"); - // positive infinity - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::infinity())), "Infinity"); - // negative infinity - EXPECT_EQ(json::Stringify(json::Value(-std::numeric_limits::infinity())), "-Infinity"); -} - -TEST(JSONWriter, String) { - // simple string - EXPECT_EQ(json::Stringify(json::Value(String("hello"))), "\"hello\""); - // empty string - EXPECT_EQ(json::Stringify(json::Value(String(""))), "\"\""); - // string with escaped characters - EXPECT_EQ(json::Stringify(json::Value(String("\"quoted\""))), "\"\\\"quoted\\\"\""); - EXPECT_EQ(json::Stringify(json::Value(String("backslash\\"))), "\"backslash\\\\\""); - EXPECT_EQ(json::Stringify(json::Value(String("forward/slash"))), "\"forward\\/slash\""); - EXPECT_EQ(json::Stringify(json::Value(String("line\nbreak"))), "\"line\\nbreak\""); - EXPECT_EQ(json::Stringify(json::Value(String("tab\there"))), "\"tab\\there\""); - EXPECT_EQ(json::Stringify(json::Value(String("carriage\rreturn"))), "\"carriage\\rreturn\""); - // string with control character - EXPECT_EQ(json::Stringify(json::Value(String(std::string("\x01", 1) + "control"))), - "\"\\u0001control\""); -} - -TEST(JSONWriter, Array) { - // empty array - json::Array empty_array; - EXPECT_EQ(json::Stringify(empty_array), "[]"); - - // single element array - json::Array single_array{42}; - EXPECT_EQ(json::Stringify(single_array), "[42]"); - - // multiple elements array - json::Array multi_array{1, "hello", true}; - EXPECT_EQ(json::Stringify(multi_array), "[1,\"hello\",true]"); - - // nested array - json::Array nested_array{json::Array{1, 2}, 3}; - EXPECT_EQ(json::Stringify(nested_array), "[[1,2],3]"); -} - -TEST(JSONWriter, Object) { - // empty object - json::Object empty_object; - EXPECT_EQ(json::Stringify(empty_object), "{}"); - - // single key-value pair - json::Object single_object{{String("key"), String("value")}}; - EXPECT_EQ(json::Stringify(single_object), "{\"key\":\"value\"}"); - - // multiple key-value pairs - insertion order preservation - json::Object multi_object{{"name", "Alice"}, {"age", 30}, {"active", true}, {"score", 95.5}}; - EXPECT_EQ(json::Stringify(multi_object), - "{\"name\":\"Alice\",\"age\":30,\"active\":true,\"score\":95.5}"); -} - -TEST(JSONWriter, InsertionOrderPreservation) { - // test that objects preserve insertion order - json::Object ordered_object{ - {"zebra", "last"}, {"alpha", "first"}, {"beta", "middle"}, {"gamma", 123}, {"delta", true}}; - EXPECT_EQ( - json::Stringify(ordered_object), - "{\"zebra\":\"last\",\"alpha\":\"first\",\"beta\":\"middle\",\"gamma\":123,\"delta\":true}"); - - // test with indentation to verify order is preserved - std::string ordered_indented = json::Stringify(ordered_object, 2); - EXPECT_EQ(ordered_indented, String(R"({ - "zebra": "last", - "alpha": "first", - "beta": "middle", - "gamma": 123, - "delta": true -})")); - - // test nested objects also preserve order - json::Object nested_ordered{ - {"outer1", - json::Object{{"inner_z", "z_value"}, {"inner_a", "a_value"}, {"inner_m", "m_value"}}}, - {"outer2", json::Object{{"third", 3}, {"first", 1}, {"second", 2}}}}; - std::string nested_ordered_indented = json::Stringify(nested_ordered, 2); - EXPECT_EQ(nested_ordered_indented, String(R"({ - "outer1": { - "inner_z": "z_value", - "inner_a": "a_value", - "inner_m": "m_value" - }, - "outer2": { - "third": 3, - "first": 1, - "second": 2 - } -})")); -} - -TEST(JSONWriter, NestedStructures) { - // object containing array - json::Object obj_with_array{{String("numbers"), json::Array{1, 2, 3}}}; - EXPECT_EQ(json::Stringify(obj_with_array), "{\"numbers\":[1,2,3]}"); - - // array containing object - json::Array arr_with_obj{json::Object{{String("key"), String("value")}}}; - EXPECT_EQ(json::Stringify(arr_with_obj), "[{\"key\":\"value\"}]"); - - // deeply nested structure - json::Object nested_obj{ - {String("nested"), json::Array{json::Object{{String("deep"), String("value")}}}}}; - EXPECT_EQ(json::Stringify(nested_obj), "{\"nested\":[{\"deep\":\"value\"}]}"); -} - -TEST(JSONWriter, Indentation) { - // test with indentation - json::Array arr{1, 2}; - std::string indented = json::Stringify(arr, 2); - EXPECT_EQ(indented, String(R"([ - 1, - 2 -])")); - - // object with indentation - json::Object obj{{"key", "value"}}; - std::string indented_obj = json::Stringify(obj, 2); - EXPECT_EQ(indented_obj, String(R"({ - "key": "value" -})")); - - // complex nested structure with multiple data types - // keep double as .5 so output is deterministic as they exactly rounds to power of 2 - json::Object complex_nested{ - {"name", "test"}, - {"count", 42}, - {"price", 3.5}, - {"active", true}, - {"metadata", nullptr}, - {"numbers", json::Array{1, 2, 3}}, - {"config", json::Object{{"enabled", false}, - {"timeout", 30.5}, - {"tags", json::Array{"production", "critical", nullptr}}}}, - {"matrix", json::Array{json::Array{1, 2}, json::Array{3.5, 4.5}, json::Array{"a", "b"}}}}; - std::string complex_indented = json::Stringify(complex_nested, 2); - EXPECT_EQ(complex_indented, String(R"({ - "name": "test", - "count": 42, - "price": 3.5, - "active": true, - "metadata": null, - "numbers": [ - 1, - 2, - 3 - ], - "config": { - "enabled": false, - "timeout": 30.5, - "tags": [ - "production", - "critical", - null - ] - }, - "matrix": [ - [ - 1, - 2 - ], - [ - 3.5, - 4.5 - ], - [ - "a", - "b" - ] - ] -})")); -} -} // namespace diff --git a/ffi/tests/cpp/extra/test_serialization.cc b/ffi/tests/cpp/extra/test_serialization.cc deleted file mode 100644 index 9d18e6a03e2d..000000000000 --- a/ffi/tests/cpp/extra/test_serialization.cc +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Serialization, BoolNull) { - json::Object expected_null = - json::Object{{"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "None"}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(nullptr), expected_null)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_null), nullptr)); - - json::Object expected_true = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(true), expected_true)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_true), true)); - - json::Object expected_false = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", false}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(false), expected_false)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_false), false)); -} - -TEST(Serialization, IntegerTypes) { - // Test positive integer - json::Object expected_int = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "int"}, {"data", 42}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(static_cast(42)), expected_int)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int), static_cast(42))); -} - -TEST(Serialization, FloatTypes) { - // Test positive float - json::Object expected_float = - json::Object{{"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "float"}, {"data", 3.14159}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(3.14159), expected_float)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float), 3.14159)); -} - -TEST(Serialization, StringTypes) { - // Test short string - json::Object expected_short = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("hello")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello")), expected_short)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_short), String("hello"))); - - // Test long string - std::string long_str(1000, 'x'); - json::Object expected_long = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String(long_str)}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String(long_str)), expected_long)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_long), String(long_str))); - - // Test string with special characters - json::Object expected_special = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, - {"data", String("hello\nworld\t\"quotes\"")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello\nworld\t\"quotes\"")), expected_special)); - EXPECT_TRUE( - StructuralEqual()(FromJSONGraph(expected_special), String("hello\nworld\t\"quotes\""))); -} - -TEST(Serialization, Bytes) { - // Test empty bytes - Bytes empty_bytes; - json::Object expected_empty = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", ""}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_bytes), expected_empty)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_bytes)); - - // Test bytes with that encoded as base64 - Bytes bytes_content = Bytes("abcd"); - json::Object expected_encoded = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "YWJjZA=="}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_content), expected_encoded)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded), bytes_content)); - - // Test bytes with that encoded as base64, that contains control characters via utf-8 - char bytes_v2_content[] = {0x01, 0x02, 0x03, 0x04, 0x01, 0x0b}; - Bytes bytes_v2 = Bytes(bytes_v2_content, sizeof(bytes_v2_content)); - json::Object expected_encoded_v2 = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "AQIDBAEL"}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_v2), expected_encoded_v2)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded_v2), bytes_v2)); -} - -TEST(Serialization, DataTypes) { - // Test int32 dtype - DLDataType int32_dtype; - int32_dtype.code = kDLInt; - int32_dtype.bits = 32; - int32_dtype.lanes = 1; - - json::Object expected_int32 = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("int32")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(int32_dtype), expected_int32)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int32), int32_dtype)); - - // Test float64 dtype - DLDataType float64_dtype; - float64_dtype.code = kDLFloat; - float64_dtype.bits = 64; - float64_dtype.lanes = 1; - - json::Object expected_float64 = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float64")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(float64_dtype), expected_float64)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float64), float64_dtype)); - - // Test vector dtype - DLDataType vector_dtype; - vector_dtype.code = kDLFloat; - vector_dtype.bits = 32; - vector_dtype.lanes = 4; - - json::Object expected_vector = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float32x4")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(vector_dtype), expected_vector)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_vector), vector_dtype)); -} - -TEST(Serialization, DeviceTypes) { - // Test CPU device - DLDevice cpu_device; - cpu_device.device_type = kDLCPU; - cpu_device.device_id = 0; - - json::Object expected_cpu = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "Device"}, - {"data", json::Array{static_cast(kDLCPU), - static_cast(0)}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(cpu_device), expected_cpu)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_cpu), cpu_device)); - - // Test GPU device - DLDevice gpu_device; - gpu_device.device_type = kDLCUDA; - gpu_device.device_id = 1; - - json::Object expected_gpu = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{ - {"type", "Device"}, {"data", json::Array{static_cast(kDLCUDA), 1}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(gpu_device), expected_gpu)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_gpu), gpu_device)); -} - -TEST(Serialization, Arrays) { - // Test empty array - Array empty_array; - json::Object expected_empty = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_array), expected_empty)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_array)); - - // Test single element array - Array single_array; - single_array.push_back(Any(42)); - json::Object expected_single = - json::Object{{"root_index", 1}, - {"nodes", json::Array{ - json::Object{{"type", "int"}, {"data", static_cast(42)}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{0}}}, - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_array), expected_single)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_array)); - - // Test duplicated element array - Array duplicated_array; - duplicated_array.push_back(42); - duplicated_array.push_back(42); - json::Object expected_duplicated = - json::Object{{"root_index", 1}, - {"nodes", json::Array{ - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 0}}}, - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_array), expected_duplicated)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_array)); - // Test mixed element array, note that 42 and "hello" are duplicated and will - // be indexed as 0 and 1 - Array mixed_array; - mixed_array.push_back(42); - mixed_array.push_back(String("hello")); - mixed_array.push_back(true); - mixed_array.push_back(nullptr); - mixed_array.push_back(42); - mixed_array.push_back(String("hello")); - json::Object expected_mixed = json::Object{ - {"root_index", 4}, - {"nodes", json::Array{ - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.String"}, {"data", String("hello")}}, - json::Object{{"type", "bool"}, {"data", true}}, - json::Object{{"type", "None"}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 1, 2, 3, 0, 1}}}, - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(mixed_array), expected_mixed)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_mixed), mixed_array)); -} - -TEST(Serialization, Maps) { - // Test empty map - Map empty_map; - json::Object expected_empty = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Map"}, {"data", json::Array{}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_map), expected_empty)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_map)); - - // Test single element map - Map single_map{{"key", 42}}; - json::Object expected_single = json::Object{ - {"root_index", 2}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("key")}}, - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_map), expected_single)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_map)); - - // Test duplicated element map - Map duplicated_map{{"b", 42}, {"a", 42}}; - json::Object expected_duplicated = json::Object{ - {"root_index", 3}, - {"nodes", json::Array{ - json::Object{{"type", "ffi.String"}, {"data", "b"}}, - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.String"}, {"data", "a"}}, - json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1, 2, 1}}}, - - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_map), expected_duplicated)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_map)); -} - -TEST(Serialization, Shapes) { - Shape empty_shape; - - json::Object expected_empty_shape = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_shape), expected_empty_shape)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty_shape), empty_shape)); - - Shape shape({1, 2, 3}); - json::Object expected_shape = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{1, 2, 3}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(shape), expected_shape)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shape), shape)); -} - -TEST(Serialization, TestObjectVar) { - TVar x = TVar("x"); - json::Object expected_x = json::Object{ - {"root_index", 1}, - {"nodes", - json::Array{json::Object{{"type", "ffi.String"}, {"data", "x"}}, - json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(x), expected_x)); - EXPECT_TRUE(StructuralEqual::Equal(FromJSONGraph(expected_x), x, /*map_free_vars=*/true)); -} - -TEST(Serialization, TestObjectIntCustomToJSON) { - TInt value = TInt(42); - json::Object expected_i = json::Object{ - {"root_index", 0}, - {"nodes", - json::Array{json::Object{{"type", "test.Int"}, {"data", json::Object{{"value", 42}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value), expected_i)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_i), value)); -} - -TEST(Serialization, TestObjectFunc) { - TVar x = TVar("x"); - // comment fields are ignored - TFunc fa = TFunc({x}, {x, x}, String("comment a")); - - json::Object expected_fa = json::Object{ - {"root_index", 5}, - {"nodes", - json::Array{ - json::Object{{"type", "ffi.String"}, {"data", "x"}}, // string "x" - json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}, // var x - json::Object{{"type", "ffi.Array"}, {"data", json::Array{1}}}, // array [x] - json::Object{{"type", "ffi.Array"}, {"data", json::Array{1, 1}}}, // array [x, x] - json::Object{{"type", "ffi.String"}, {"data", "comment a"}}, // "comment a" - json::Object{{"type", "test.Func"}, - {"data", json::Object{{"params", 2}, {"body", 3}, {"comment", 4}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fa), expected_fa)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fa), fa)); - - TFunc fb = TFunc({}, {}, std::nullopt); - json::Object expected_fb = json::Object{ - {"root_index", 3}, - {"nodes", - json::Array{ - json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, - json::Object{{"type", "None"}}, - json::Object{{"type", "test.Func"}, - {"data", json::Object{{"params", 0}, {"body", 1}, {"comment", 2}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fb), expected_fb)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fb), fb)); -} - -TEST(Serialization, AttachMetadata) { - bool value = true; - json::Object metadata{{"version", "1.0"}}; - json::Object expected = - json::Object{{"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}, - {"metadata", metadata}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value, metadata), expected)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected), value)); -} - -TEST(Serialization, ShuffleNodeOrder) { - // the FromJSONGraph is agnostic to the node order - // so we can shuffle the node order as it reads nodes lazily - Map duplicated_map{{"b", 42}, {"a", 42}}; - json::Object expected_shuffled = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{ - json::Object{{"type", "ffi.Map"}, {"data", json::Array{2, 3, 1, 3}}}, - json::Object{{"type", "ffi.String"}, {"data", "a"}}, - json::Object{{"type", "ffi.String"}, {"data", "b"}}, - json::Object{{"type", "int"}, {"data", 42}}, - }}}; - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shuffled), duplicated_map)); -} - -} // namespace diff --git a/ffi/tests/cpp/extra/test_structural_equal_hash.cc b/ffi/tests/cpp/extra/test_structural_equal_hash.cc deleted file mode 100644 index a05c50cc2617..000000000000 --- a/ffi/tests/cpp/extra/test_structural_equal_hash.cc +++ /dev/null @@ -1,178 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; -namespace refl = tvm::ffi::reflection; - -TEST(StructuralEqualHash, Array) { - Array a = {1, 2, 3}; - Array b = {1, 2, 3}; - EXPECT_TRUE(StructuralEqual()(a, b)); - EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); - - Array c = {1, 3}; - EXPECT_FALSE(StructuralEqual()(a, c)); - EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - - // first directly interepret diff, - EXPECT_TRUE(diff_a_c.has_value()); - auto lhs_steps = (*diff_a_c).get<0>()->ToSteps(); - auto rhs_steps = (*diff_a_c).get<1>()->ToSteps(); - EXPECT_EQ(lhs_steps[0]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ(rhs_steps[0]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ(lhs_steps[0]->key.cast(), 1); - EXPECT_EQ(rhs_steps[0]->key.cast(), 1); - EXPECT_EQ(lhs_steps.size(), 1); - EXPECT_EQ(rhs_steps.size(), 1); - - // use structural equal for checking in future parts - // given we have done some basic checks above by directly interepret diff, - Array d = {1, 2}; - auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); - auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath::FromSteps({ - refl::AccessStep::ArrayItem(2), - }), - refl::AccessPath::FromSteps({ - refl::AccessStep::ArrayItemMissing(2), - })); - // then use structural equal to check it - EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d)); -} - -TEST(StructuralEqualHash, Map) { - // same map but different insertion order - Map a = {{"a", 1}, {"b", 2}, {"c", 3}}; - Map b = {{"b", 2}, {"c", 3}, {"a", 1}}; - EXPECT_TRUE(StructuralEqual()(a, b)); - EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); - - Map c = {{"a", 1}, {"b", 2}, {"c", 4}}; - EXPECT_FALSE(StructuralEqual()(a, c)); - EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - - auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath::Root()->MapItem("c"), - refl::AccessPath::Root()->MapItem("c")); - EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); -} - -TEST(StructuralEqualHash, NestedMapArray) { - Map> a = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; - Map> b = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; - EXPECT_TRUE(StructuralEqual()(a, b)); - EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); - - Map> c = {{"a", {1, 2, 3}}, {"b", {4, "world", 6}}}; - EXPECT_FALSE(StructuralEqual()(a, c)); - EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - - auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - auto expected_diff_a_c = - refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b")->ArrayItem(1), - refl::AccessPath::Root()->MapItem("b")->ArrayItem(1)); - EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); - - Map> d = {{"a", {1, 2, 3}}}; - auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); - auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b"), - refl::AccessPath::Root()->MapItemMissing("b")); - EXPECT_TRUE(diff_a_d.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d)); - - auto diff_d_a = StructuralEqual::GetFirstMismatch(d, a); - auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath::Root()->MapItemMissing("b"), - refl::AccessPath::Root()->MapItem("b")); -} - -TEST(StructuralEqualHash, FreeVar) { - TVar a = TVar("a"); - TVar b = TVar("b"); - EXPECT_TRUE(StructuralEqual::Equal(a, b, /*map_free_vars=*/true)); - EXPECT_FALSE(StructuralEqual::Equal(a, b)); - - EXPECT_NE(StructuralHash()(a), StructuralHash()(b)); - EXPECT_EQ(StructuralHash::Hash(a, /*map_free_vars=*/true), - StructuralHash::Hash(b, /*map_free_vars=*/true)); -} - -TEST(StructuralEqualHash, FuncDefAndIgnoreField) { - TVar x = TVar("x"); - TVar y = TVar("y"); - // comment fields are ignored - TFunc fa = TFunc({x}, {TInt(1), x}, String("comment a")); - TFunc fb = TFunc({y}, {TInt(1), y}, String("comment b")); - - TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, String("comment c")); - - EXPECT_TRUE(StructuralEqual()(fa, fb)); - EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); - - EXPECT_FALSE(StructuralEqual()(fa, fc)); - auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); - auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath::FromSteps({ - refl::AccessStep::Attr("body"), - refl::AccessStep::ArrayItem(1), - }), - refl::AccessPath::FromSteps({ - refl::AccessStep::Attr("body"), - refl::AccessStep::ArrayItem(1), - })); - EXPECT_TRUE(diff_fa_fc.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); -} - -TEST(StructuralEqualHash, CustomTreeNode) { - TVar x = TVar("x"); - TVar y = TVar("y"); - // comment fields are ignored - TCustomFunc fa = TCustomFunc({x}, {TInt(1), x}, "comment a"); - TCustomFunc fb = TCustomFunc({y}, {TInt(1), y}, "comment b"); - - TCustomFunc fc = TCustomFunc({x}, {TInt(1), TInt(2)}, "comment c"); - - EXPECT_TRUE(StructuralEqual()(fa, fb)); - EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); - - EXPECT_FALSE(StructuralEqual()(fa, fc)); - auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); - auto expected_diff_fa_fc = - refl::AccessPathPair(refl::AccessPath::Root()->Attr("body")->ArrayItem(1), - refl::AccessPath::Root()->Attr("body")->ArrayItem(1)); - EXPECT_TRUE(diff_fa_fc.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); -} - -} // namespace diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc deleted file mode 100644 index d1f56e1a93d9..000000000000 --- a/ffi/tests/cpp/test_any.cc +++ /dev/null @@ -1,415 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Any, Int) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `int`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - AnyView view1 = 1; - EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); - - auto int_v1 = view1.cast(); - EXPECT_EQ(int_v1, 1); - - int64_t v1 = 2; - view0 = v1; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2); -} - -TEST(Any, Enum) { - enum class ENum : int { - A = 1, - B = 2, - }; - - AnyView view0; - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - AnyView view1 = ENum::A; - EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); - - ENum v1 = view1.cast(); - EXPECT_EQ(v1, ENum::A); -} - -TEST(Any, bool) { - AnyView view0; - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `bool`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - AnyView view1 = true; - EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); - EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); - - auto int_v1 = view1.cast(); - EXPECT_EQ(int_v1, 1); - - bool v1 = false; - view0 = v1; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); - EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 0); -} - -TEST(Any, nullptrcmp) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - EXPECT_TRUE(view0 == nullptr); - EXPECT_FALSE(view0 != nullptr); - - view0 = 1; - EXPECT_TRUE(view0 != nullptr); - EXPECT_FALSE(view0 == nullptr); - - Any any0 = view0; - EXPECT_TRUE(any0 != nullptr); - EXPECT_FALSE(any0 == nullptr); - - any0 = nullptr; - EXPECT_TRUE(any0 == nullptr); - EXPECT_FALSE(any0 != nullptr); -} - -TEST(Any, Float) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `float`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - AnyView view1_int = 1; - auto float_v1 = view1_int.cast(); - EXPECT_EQ(float_v1, 1); - - AnyView view2 = 2.2; - EXPECT_EQ(view2.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); - EXPECT_EQ(view2.CopyToTVMFFIAny().v_float64, 2.2); - - float v1 = 2; - view0 = v1; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); - EXPECT_EQ(view0.CopyToTVMFFIAny().v_float64, 2); -} - -TEST(Any, Device) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `Device`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - DLDevice device{kDLCUDA, 1}; - - AnyView view1_device = device; - auto dtype_v1 = view1_device.cast(); - EXPECT_EQ(dtype_v1.device_type, kDLCUDA); - EXPECT_EQ(dtype_v1.device_id, 1); - - Any any2 = DLDevice{kDLCPU, 0}; - TVMFFIAny ffi_v2 = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(any2)); - EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDevice); - EXPECT_EQ(ffi_v2.v_device.device_type, kDLCPU); - EXPECT_EQ(ffi_v2.v_device.device_id, 0); -} - -TEST(Any, DLTensor) { - AnyView view0; - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `DLTensor*`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - DLTensor dltensor; - - AnyView view1_dl = &dltensor; - auto dl_v1 = view1_dl.cast(); - EXPECT_EQ(dl_v1, &dltensor); -} - -TEST(Any, Object) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - // int object is not nullable - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - TInt v1(11); - EXPECT_EQ(v1.use_count(), 1); - // view won't increase refcount - AnyView view1 = v1; - EXPECT_EQ(v1.use_count(), 1); - // any will trigger ref count increase - Any any1 = v1; - EXPECT_EQ(v1.use_count(), 2); - // copy to another view - AnyView view2 = any1; - EXPECT_EQ(v1.use_count(), 2); - - // convert to weak raw object ptr - const TIntObj* v1_ptr = view2.cast(); - EXPECT_EQ(v1.use_count(), 2); - EXPECT_EQ(v1_ptr->value, 11); - Any any2 = v1_ptr; - EXPECT_EQ(v1.use_count(), 3); - EXPECT_TRUE(any2.as().has_value()); - - any2 = const_cast(v1_ptr); - EXPECT_TRUE(any2.as().has_value()); - - // convert to raw opaque ptr - void* raw_v1_ptr = const_cast(v1_ptr); - any2 = raw_v1_ptr; - EXPECT_TRUE(any2.as().value() == v1_ptr); - - // convert to ObjectRef - { - auto v1_obj_ref = view2.cast(); - EXPECT_EQ(v1.use_count(), 3); - any2 = v1_obj_ref; - EXPECT_EQ(v1.use_count(), 4); - EXPECT_TRUE(any2.as().has_value()); - any2.reset(); - } - - // convert that triggers error - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view1.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - std::cout << what; - EXPECT_NE(what.find("Cannot convert from type `test.Int` to `test.Float`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - // Try to convert to number - auto number0 = any1.cast(); - EXPECT_EQ(v1.use_count(), 3); - EXPECT_TRUE(number0.as()); - EXPECT_EQ(number0.as()->value, 11); - EXPECT_TRUE(!any1.as().has_value()); - - auto int1 = view2.cast(); - EXPECT_EQ(v1.use_count(), 4); - any1.reset(); - EXPECT_EQ(v1.use_count(), 3); -} - -TEST(Any, ObjectRefWithFallbackTraits) { - // Test case for TPrimExpr fallback from Any - Any any1 = TPrimExpr("float32", 3.14); - auto v0 = any1.cast(); - EXPECT_EQ(v0->value, 3.14); - EXPECT_EQ(v0->dtype, "float32"); - - any1 = true; - auto v1 = any1.cast(); - EXPECT_EQ(v1->value, 1); - EXPECT_EQ(v1->dtype, "bool"); - - any1 = int64_t(42); - auto v2 = any1.cast(); - EXPECT_EQ(v2->value, 42); - EXPECT_EQ(v2->dtype, "int64"); - - any1 = 2.718; - auto v3 = any1.cast(); - EXPECT_EQ(v3->value, 2.718); - EXPECT_EQ(v3->dtype, "float32"); - - // Test case for TPrimExpr fallback from AnyView - TPrimExpr texpr1("float32", 3.14); - AnyView view1 = texpr1; - auto v4 = view1.cast(); - EXPECT_EQ(v4->value, 3.14); - EXPECT_EQ(v4->dtype, "float32"); - - view1 = true; - auto v5 = view1.cast(); - EXPECT_EQ(v5->value, 1); - EXPECT_EQ(v5->dtype, "bool"); - - view1 = int64_t(42); - auto v6 = view1.cast(); - EXPECT_EQ(v6->value, 42); - EXPECT_EQ(v6->dtype, "int64"); - - view1 = 2.718; - auto v7 = view1.cast(); - EXPECT_EQ(v7->value, 2.718); - EXPECT_EQ(v7->dtype, "float32"); - - // Test case for TPrimExpr fallback from Any with String - any1 = std::string("test_string"); - auto v8 = any1.cast(); - EXPECT_EQ(v8->dtype, "test_string"); - EXPECT_EQ(v8->value, 0); - - // Test case for TPrimExpr fallback from AnyView with String - view1 = "test_string"; - auto v9 = view1.cast(); - EXPECT_EQ(v9->dtype, "test_string"); - EXPECT_EQ(v9->value, 0); -} - -TEST(Any, CastVsAs) { - AnyView view0 = 1; - // as only runs strict check - auto opt_v0 = view0.as(); - EXPECT_TRUE(opt_v0.has_value()); - EXPECT_EQ(opt_v0.value(), 1); - - auto opt_v1 = view0.as(); - EXPECT_TRUE(!opt_v1.has_value()); - auto opt_v2 = view0.as(); - EXPECT_TRUE(!opt_v2.has_value()); - - // try_cast will try run the conversion. - auto opt_v3 = view0.try_cast(); - EXPECT_TRUE(opt_v3.has_value()); - EXPECT_EQ(opt_v3.value(), 1); - auto opt_v4 = view0.try_cast(); - EXPECT_TRUE(opt_v4.has_value()); - EXPECT_EQ(opt_v4.value(), 1); - - Any any1 = true; - auto opt_v5 = any1.as(); - EXPECT_TRUE(opt_v5.has_value()); - EXPECT_EQ(opt_v5.value(), 1); - - auto opt_v6 = any1.try_cast(); - EXPECT_TRUE(opt_v6.has_value()); - EXPECT_EQ(opt_v6.value(), 1); - - auto opt_v7 = any1.try_cast(); - EXPECT_TRUE(opt_v7.has_value()); -} - -TEST(Any, ObjectMove) { - Any any1 = TPrimExpr("float32", 3.14); - auto v0 = std::move(any1).cast(); - EXPECT_EQ(v0->value, 3.14); - EXPECT_EQ(v0.use_count(), 1); - EXPECT_TRUE(any1 == nullptr); -} - -TEST(Any, AnyEqualHash) { - // small string - Any a = "a1"; - // on heap allocated string - Any b = String(std::string("a1")); - EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr); - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_TRUE(AnyEqual()(a, b)); - EXPECT_EQ(AnyHash()(a), AnyHash()(b)); - - Any c = Bytes("a1", 2); - Any d = Bytes(std::string("a1")); - EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFISmallBytes); - EXPECT_EQ(d.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_TRUE(AnyEqual()(c, d)); - EXPECT_EQ(AnyHash()(c), AnyHash()(d)); -} - -} // namespace diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc deleted file mode 100644 index 321af7ae16ac..000000000000 --- a/ffi/tests/cpp/test_array.cc +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Array, Basic) { - Array arr = {TInt(11), TInt(12)}; - TInt v1 = arr[0]; - EXPECT_EQ(v1->value, 11); - EXPECT_EQ(v1.use_count(), 2); - EXPECT_EQ(arr[1]->value, 12); -} - -TEST(Array, COWSet) { - Array arr = {TInt(11), TInt(12)}; - Array arr2 = arr; - EXPECT_EQ(arr.use_count(), 2); - arr.Set(1, TInt(13)); - EXPECT_EQ(arr.use_count(), 1); - EXPECT_EQ(arr[1]->value, 13); - EXPECT_EQ(arr2[1]->value, 12); -} - -TEST(Array, MutateInPlaceForUniqueReference) { - TInt x(1); - Array arr{x, x}; - EXPECT_TRUE(arr.unique()); - auto* before = arr.get(); - - arr.MutateByApply([](TInt) { return TInt(2); }); - auto* after = arr.get(); - EXPECT_EQ(before, after); -} - -TEST(Array, CopyWhenMutatingNonUniqueReference) { - TInt x(1); - Array arr{x, x}; - Array arr2 = arr; - - EXPECT_TRUE(!arr.unique()); - auto* before = arr.get(); - - arr.MutateByApply([](TInt) { return TInt(2); }); - auto* after = arr.get(); - EXPECT_NE(before, after); -} - -TEST(Array, Map) { - // Basic functionality - TInt x(1), y(1); - Array var_arr{x, y}; - Array expr_arr = - var_arr.Map([](TInt var) -> TNumber { return TFloat(static_cast(var->value + 1)); }); - - EXPECT_NE(var_arr.get(), expr_arr.get()); - EXPECT_TRUE(expr_arr[0]->IsInstance()); - EXPECT_TRUE(expr_arr[1]->IsInstance()); -} - -TEST(Array, Iterator) { - Array array{1, 2, 3}; - std::vector vector(array.begin(), array.end()); - EXPECT_EQ(vector[1], 2); -} - -TEST(Array, PushPop) { - Array a; - std::vector b; - for (int i = 0; i < 10; ++i) { - a.push_back(i); - b.push_back(i); - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), b.size()); - int n = static_cast(a.size()); - for (int j = 0; j < n; ++j) { - ASSERT_EQ(a[j], b[j]); - } - } - for (int i = 9; i >= 0; --i) { - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), b.size()); - a.pop_back(); - b.pop_back(); - int n = static_cast(a.size()); - for (int j = 0; j < n; ++j) { - ASSERT_EQ(a[j], b[j]); - } - } - ASSERT_EQ(a.empty(), true); -} - -TEST(Array, ResizeReserveClear) { - for (size_t n = 0; n < 10; ++n) { - Array a; - Array b; - a.resize(n); - b.reserve(n); - ASSERT_EQ(a.size(), n); - ASSERT_GE(a.capacity(), n); - a.clear(); - b.clear(); - ASSERT_EQ(a.size(), 0); - ASSERT_EQ(b.size(), 0); - } -} - -TEST(Array, InsertErase) { - Array a; - std::vector b; - for (int n = 1; n <= 10; ++n) { - a.insert(a.end(), n); - b.insert(b.end(), n); - for (int pos = 0; pos <= n; ++pos) { - a.insert(a.begin() + pos, pos); - b.insert(b.begin() + pos, pos); - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n + 1); - ASSERT_EQ(b.size(), n + 1); - for (int k = 0; k <= n; ++k) { - ASSERT_EQ(a[k], b[k]); - } - a.erase(a.begin() + pos); - b.erase(b.begin() + pos); - } - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n); - } -} - -TEST(Array, InsertEraseRange) { - Array range_a{-1, -2, -3, -4}; - std::vector range_b{-1, -2, -3, -4}; - Array a; - std::vector b; - - static_assert(std::is_same_v); - for (size_t n = 1; n <= 10; ++n) { - a.insert(a.end(), static_cast(n)); - b.insert(b.end(), static_cast(n)); - for (size_t pos = 0; pos <= n; ++pos) { - a.insert(a.begin() + pos, range_a.begin(), range_a.end()); - b.insert(b.begin() + pos, range_b.begin(), range_b.end()); - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n + range_a.size()); - ASSERT_EQ(b.size(), n + range_b.size()); - size_t m = n + range_a.size(); - for (size_t k = 0; k < m; ++k) { - ASSERT_EQ(a[k], b[k]); - } - a.erase(a.begin() + pos, a.begin() + pos + range_a.size()); - b.erase(b.begin() + pos, b.begin() + pos + range_b.size()); - } - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n); - } -} - -TEST(Array, FuncArrayAnyArg) { - Function fadd_one = Function::FromTyped([](Array a) -> Any { return a[0].cast() + 1; }); - EXPECT_EQ(fadd_one(Array{1}).cast(), 2); -} - -TEST(Array, MapUniquePropogation) { - // Basic functionality - Array var_arr{TInt(1), TInt(2)}; - var_arr.MutateByApply([](TInt x) -> TInt { - EXPECT_TRUE(x.unique()); - return x; - }); -} - -TEST(Array, AnyImplicitConversion) { - Array arr0_mixed = {11.1, 1}; - EXPECT_EQ(arr0_mixed[1].cast(), 1); - - AnyView view0 = arr0_mixed; - auto arr0_float = view0.cast>(); - // they are not the same because arr_mixed - // stores arr_mixed[1] as int but we need to convert to float - EXPECT_TRUE(!arr0_float.same_as(arr0_mixed)); - EXPECT_EQ(arr0_float[1], 1.0); - - Any any1 = arr0_float; - // if storage check passes, the same array get returned - auto arr1_float = any1.cast>(); - EXPECT_TRUE(arr1_float.same_as(arr0_float)); - // total count equals 3 include any1 - EXPECT_EQ(arr1_float.use_count(), 3); - - // convert to Array do not need any conversion - auto arr1_mixed = any1.cast>(); - EXPECT_TRUE(arr1_mixed.same_as(arr1_float)); - EXPECT_EQ(arr1_float.use_count(), 4); -} - -TEST(Array, AnyConvertCheck) { - Array arr = {11.1, 1}; - EXPECT_EQ(arr[1].cast(), 1); - - AnyView view0 = arr; - auto arr1 = view0.cast>(); - EXPECT_EQ(arr1[0], 11.1); - EXPECT_EQ(arr1[1], 1.0); - - Any any1 = arr; - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast>(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `Array[index 0: float]` to `Array`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - Array> arr_nested = {{}, {TInt(1), TFloat(2)}}; - any1 = arr_nested; - auto arr1_nested = any1.cast>>(); - EXPECT_EQ(arr1_nested.use_count(), 3); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast>>(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("`Array[index 1: Array[index 0: test.Int]]` to `Array>`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Array, Upcast) { - Array a0 = {1, 2, 3}; - Array a1 = a0; - EXPECT_EQ(a1[0].cast(), 1); - EXPECT_EQ(a1[1].cast(), 2); - EXPECT_EQ(a1[2].cast(), 3); - - Array> a2 = {a0}; - Array> a3 = a2; - Array> a4 = a2; - - static_assert(details::type_contains_v, Array>); - static_assert(details::type_contains_v>); -} - -} // namespace diff --git a/ffi/tests/cpp/test_c_ffi_abi.cc b/ffi/tests/cpp/test_c_ffi_abi.cc deleted file mode 100644 index e6c6116edd8c..000000000000 --- a/ffi/tests/cpp/test_c_ffi_abi.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -namespace { - -TEST(ABIHeaderAlignment, Default) { - TVMFFIObject value; - value.type_index = 10; - EXPECT_EQ(reinterpret_cast(&value)->type_index, 10); - static_assert(sizeof(TVMFFIObject) == 24); -} - -} // namespace diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc deleted file mode 100644 index 79fc9d7c2da1..000000000000 --- a/ffi/tests/cpp/test_dtype.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -TEST(DType, StringConversion) { - DLDataType dtype = DLDataType{kDLFloat, 32, 1}; - EXPECT_EQ(DLDataTypeToString(dtype), "float32"); - EXPECT_EQ(StringToDLDataType("float32"), dtype); - - dtype = DLDataType{kDLInt, 16, 2}; - EXPECT_EQ(DLDataTypeToString(dtype), "int16x2"); - EXPECT_EQ(StringToDLDataType("int16x2"), dtype); - - dtype = DLDataType{kDLOpaqueHandle, 0, 0}; - EXPECT_EQ(DLDataTypeToString(dtype), ""); - EXPECT_EQ(StringToDLDataType("void"), dtype); - - // test bfloat with lanes - dtype = DLDataType{kDLBfloat, 16, 2}; - EXPECT_EQ(DLDataTypeToString(dtype), "bfloat16x2"); - EXPECT_EQ(StringToDLDataType("bfloat16x2"), dtype); - - // test float8 - dtype = DLDataType{kDLFloat8_e4m3fn, 8, 2}; - EXPECT_EQ(DLDataTypeToString(dtype), "float8_e4m3fnx2"); - EXPECT_EQ(StringToDLDataType("float8_e4m3fnx2"), dtype); -} - -TEST(DType, StringConversionAllDLPackTypes) { - std::vector> test_cases = { - {DLDataType{kDLFloat, 32, 1}, "float32"}, - {DLDataType{kDLInt, 16, 1}, "int16"}, - {DLDataType{kDLUInt, 16, 1}, "uint16"}, - {DLDataType{kDLBfloat, 16, 1}, "bfloat16"}, - {DLDataType{kDLFloat8_e3m4, 8, 1}, "float8_e3m4"}, - {DLDataType{kDLFloat8_e4m3, 8, 1}, "float8_e4m3"}, - {DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}, "float8_e4m3b11fnuz"}, - {DLDataType{kDLFloat8_e4m3fn, 8, 1}, "float8_e4m3fn"}, - {DLDataType{kDLFloat8_e4m3fnuz, 8, 1}, "float8_e4m3fnuz"}, - {DLDataType{kDLFloat8_e5m2, 8, 1}, "float8_e5m2"}, - {DLDataType{kDLFloat8_e5m2fnuz, 8, 1}, "float8_e5m2fnuz"}, - {DLDataType{kDLFloat8_e8m0fnu, 8, 1}, "float8_e8m0fnu"}, - {DLDataType{kDLFloat6_e2m3fn, 6, 1}, "float6_e2m3fn"}, - {DLDataType{kDLFloat6_e3m2fn, 6, 1}, "float6_e3m2fn"}, - {DLDataType{kDLFloat4_e2m1fn, 4, 1}, "float4_e2m1fn"}, - }; - - for (const auto& [dtype, str] : test_cases) { - EXPECT_EQ(DLDataTypeToString(dtype), str); - EXPECT_EQ(StringToDLDataType(str), dtype); - } -} - -TEST(DataType, AnyConversion) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `DataType`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - DLDataType dtype{kDLFloat, 32, 1}; - - AnyView view1_dtype = dtype; - auto dtype_v1 = view1_dtype.cast(); - EXPECT_EQ(dtype_v1.code, kDLFloat); - EXPECT_EQ(dtype_v1.bits, 32); - EXPECT_EQ(dtype_v1.lanes, 1); - - Any any2 = DLDataType{kDLInt, 16, 2}; - TVMFFIAny ffi_v2 = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(any2)); - EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDataType); - EXPECT_EQ(ffi_v2.v_dtype.code, kDLInt); - EXPECT_EQ(ffi_v2.v_dtype.bits, 16); - EXPECT_EQ(ffi_v2.v_dtype.lanes, 2); -} - -// String can be automatically converted to DLDataType -TEST(DataType, AnyConversionWithString) { - AnyView view0 = "float32"; - - Optional opt_v0 = view0.try_cast(); - DLDataType dtype_v0 = opt_v0.value(); - EXPECT_EQ(dtype_v0.code, kDLFloat); - EXPECT_EQ(dtype_v0.bits, 32); - EXPECT_EQ(dtype_v0.lanes, 1); - - Any any = String("bfloat16x2"); - Optional opt_v1 = any.try_cast(); - EXPECT_EQ(opt_v1.value().code, kDLBfloat); - EXPECT_EQ(opt_v1.value().bits, 16); - EXPECT_EQ(opt_v1.value().lanes, 2); -} -} // namespace diff --git a/ffi/tests/cpp/test_error.cc b/ffi/tests/cpp/test_error.cc deleted file mode 100644 index 9938603a47ba..000000000000 --- a/ffi/tests/cpp/test_error.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -void ThrowRuntimeError() { TVM_FFI_THROW(RuntimeError) << "test0"; } - -TEST(Error, Traceback) { - EXPECT_THROW( - { - try { - ThrowRuntimeError(); - } catch (const Error& error) { - EXPECT_EQ(error.message(), "test0"); - EXPECT_EQ(error.kind(), "RuntimeError"); - std::string what = error.what(); - EXPECT_NE(what.find("line"), std::string::npos); - EXPECT_NE(what.find("ThrowRuntimeError"), std::string::npos); - EXPECT_NE(what.find("RuntimeError: test0"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(CheckError, Traceback) { - EXPECT_THROW( - { - try { - TVM_FFI_ICHECK_GT(2, 3); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "InternalError"); - std::string what = error.what(); - EXPECT_NE(what.find("line"), std::string::npos); - EXPECT_NE(what.find("2 > 3"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Error, AnyConvert) { - Any any = Error("TypeError", "here", "test0"); - Optional opt_err = any.as(); - EXPECT_EQ(opt_err.value().kind(), "TypeError"); - EXPECT_EQ(opt_err.value().message(), "here"); -} -} // namespace diff --git a/ffi/tests/cpp/test_example.cc b/ffi/tests/cpp/test_example.cc deleted file mode 100644 index ee450bcf4063..000000000000 --- a/ffi/tests/cpp/test_example.cc +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// test-cases used in example code -namespace { - -void ExampleAny() { - namespace ffi = tvm::ffi; - // Create an Any from various types - ffi::Any int_value = 42; - ffi::Any float_value = 3.14; - ffi::Any string_value = "hello world"; - - // AnyView provides a lightweight view without ownership - ffi::AnyView view = int_value; - // we can cast Any/AnyView to a specific type - int extracted = view.cast(); - EXPECT_EQ(extracted, 42); - - // If we are not sure about the type - // we can use as to get an optional value - std::optional maybe_int = view.as(); - if (maybe_int.has_value()) { - EXPECT_EQ(maybe_int.value(), 42); - } - // Try cast is another version that will try to run the type - // conversion even if the type does not exactly match - std::optional maybe_int_try = view.try_cast(); - if (maybe_int_try.has_value()) { - EXPECT_EQ(maybe_int_try.value(), 42); - } -} - -TEST(Example, Any) { ExampleAny(); } - -void ExampleFunctionFromPacked() { - namespace ffi = tvm::ffi; - // Create a function from a typed lambda - ffi::Function fadd1 = - ffi::Function::FromPacked([](const ffi::AnyView* args, int32_t num_args, ffi::Any* rv) { - TVM_FFI_ICHECK_EQ(num_args, 1); - int a = args[0].cast(); - *rv = a + 1; - }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); -} - -void ExampleFunctionFromTyped() { - namespace ffi = tvm::ffi; - // Create a function from a typed lambda - ffi::Function fadd1 = ffi::Function::FromTyped([](const int a) -> int { return a + 1; }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); -} - -void ExampleFunctionPassFunction() { - namespace ffi = tvm::ffi; - // Create a function from a typed lambda - ffi::Function fapply = ffi::Function::FromTyped( - [](const ffi::Function f, ffi::Any param) { return f(param.cast()); }); - ffi::Function fadd1 = ffi::Function::FromTyped( // - [](const int a) -> int { return a + 1; }); - int b = fapply(fadd1, 2).cast(); - EXPECT_EQ(b, 3); -} - -void ExamplegGlobalFunctionRegistry() { - namespace ffi = tvm::ffi; - ffi::reflection::GlobalDef().def("xyz.add1", [](const int a) -> int { return a + 1; }); - ffi::Function fadd1 = ffi::Function::GetGlobalRequired("xyz.add1"); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); -} - -void FuncThrowError() { - namespace ffi = tvm::ffi; - TVM_FFI_THROW(TypeError) << "test0"; -} - -void ExampleErrorHandling() { - namespace ffi = tvm::ffi; - try { - FuncThrowError(); - } catch (const ffi::Error& e) { - EXPECT_EQ(e.kind(), "TypeError"); - EXPECT_EQ(e.message(), "test0"); - std::cout << e.traceback() << std::endl; - } -} - -TEST(Example, Function) { - ExampleFunctionFromPacked(); - ExampleFunctionFromTyped(); - ExampleFunctionPassFunction(); - ExamplegGlobalFunctionRegistry(); - ExampleErrorHandling(); -} - -struct CPUNDAlloc { - void AllocData(DLTensor* tensor) { tensor->data = malloc(tvm::ffi::GetDataSize(*tensor)); } - void FreeData(DLTensor* tensor) { free(tensor->data); } -}; - -void ExampleTensor() { - namespace ffi = tvm::ffi; - ffi::Shape shape = {1, 2, 3}; - DLDataType dtype = {kDLFloat, 32, 1}; - DLDevice device = {kDLCPU, 0}; - ffi::Tensor tensor = ffi::Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); -} - -void ExampleTensorDLPack() { - namespace ffi = tvm::ffi; - ffi::Shape shape = {1, 2, 3}; - DLDataType dtype = {kDLFloat, 32, 1}; - DLDevice device = {kDLCPU, 0}; - ffi::Tensor tensor = ffi::Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); - // convert to DLManagedTensorVersioned - DLManagedTensorVersioned* dlpack = tensor.ToDLPackVersioned(); - // load back from DLManagedTensorVersioned - ffi::Tensor tensor2 = ffi::Tensor::FromDLPackVersioned(dlpack); -} - -TEST(Example, Tensor) { - ExampleTensor(); - ExampleTensorDLPack(); -} - -void ExampleString() { - namespace ffi = tvm::ffi; - ffi::String str = "hello world"; - EXPECT_EQ(str.size(), 11); - std::string std_str = str; - EXPECT_EQ(std_str, "hello world"); -} - -TEST(Example, String) { ExampleString(); } - -void ExampleArray() { - namespace ffi = tvm::ffi; - ffi::Array numbers = {1, 2, 3}; - EXPECT_EQ(numbers.size(), 3); - EXPECT_EQ(numbers[0], 1); - - ffi::Function head = ffi::Function::FromTyped([](const ffi::Array a) { return a[0]; }); - EXPECT_EQ(head(numbers).cast(), 1); - - try { - // throw an error because 2.2 is not int - head(ffi::Array({1, 2.2})); - } catch (const ffi::Error& e) { - EXPECT_EQ(e.kind(), "TypeError"); - } -} - -void ExampleTuple() { - namespace ffi = tvm::ffi; - ffi::Tuple tup(42, "hello", true); - - EXPECT_EQ(tup.get<0>(), 42); - EXPECT_EQ(tup.get<1>(), "hello"); - EXPECT_EQ(tup.get<2>(), true); -} - -TEST(Example, Array) { - ExampleArray(); - ExampleTuple(); -} - -void ExampleMap() { - namespace ffi = tvm::ffi; - - ffi::Map map0 = {{"Alice", 100}, {"Bob", 95}}; - - EXPECT_EQ(map0.size(), 2); - EXPECT_EQ(map0.at("Alice"), 100); - EXPECT_EQ(map0.count("Alice"), 1); -} - -TEST(Example, Map) { ExampleMap(); } - -void ExampleOptional() { - namespace ffi = tvm::ffi; - ffi::Optional opt0 = 100; - EXPECT_EQ(opt0.has_value(), true); - EXPECT_EQ(opt0.value(), 100); - - ffi::Optional opt1; - EXPECT_EQ(opt1.has_value(), false); - EXPECT_EQ(opt1.value_or("default"), "default"); -} - -TEST(Example, Optional) { ExampleOptional(); } - -void ExampleVariant() { - namespace ffi = tvm::ffi; - ffi::Variant var0 = 100; - EXPECT_EQ(var0.get(), 100); - - var0 = ffi::String("hello"); - std::optional maybe_str = var0.as(); - EXPECT_EQ(maybe_str.value(), "hello"); - - std::optional maybe_int2 = var0.as(); - EXPECT_EQ(maybe_int2.has_value(), false); -} - -TEST(Example, Variant) { ExampleVariant(); } - -// Step 1: Define the object class (stores the actual data) -class MyIntPairObj : public tvm::ffi::Object { - public: - int64_t a; - int64_t b; - - MyIntPairObj() = default; - MyIntPairObj(int64_t a, int64_t b) : a(a), b(b) {} - - // Required: declare type information - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("example.MyIntPair", MyIntPairObj, tvm::ffi::Object); -}; - -// Step 2: Define the reference wrapper (user-facing interface) -class MyIntPair : public tvm::ffi::ObjectRef { - public: - // Constructor - explicit MyIntPair(int64_t a, int64_t b) { data_ = tvm::ffi::make_object(a, b); } - - // Required: define object reference methods - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MyIntPair, tvm::ffi::ObjectRef, MyIntPairObj); -}; - -void ExampleObjectPtr() { - namespace ffi = tvm::ffi; - ffi::ObjectPtr obj = ffi::make_object(100, 200); - EXPECT_EQ(obj->a, 100); - EXPECT_EQ(obj->b, 200); -} - -void ExampleObjectRef() { - namespace ffi = tvm::ffi; - MyIntPair pair(100, 200); - EXPECT_EQ(pair->a, 100); - EXPECT_EQ(pair->b, 200); -} - -void ExampleObjectRefAny() { - namespace ffi = tvm::ffi; - MyIntPair pair(100, 200); - ffi::Any any = pair; - MyIntPair pair2 = any.cast(); - EXPECT_EQ(pair2->a, 100); - EXPECT_EQ(pair2->b, 200); -} - -TEST(Example, ObjectPtr) { - ExampleObjectPtr(); - ExampleObjectRef(); - ExampleObjectRefAny(); -} - -} // namespace diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc deleted file mode 100644 index c3c484f33317..000000000000 --- a/ffi/tests/cpp/test_function.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Func, FromPacked) { - Function fadd1 = Function::FromPacked([](const AnyView* args, int32_t num_args, Any* rv) { - EXPECT_EQ(num_args, 1); - int32_t a = args[0].cast(); - *rv = a + 1; - }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - Function fadd2 = Function::FromPacked([](const AnyView* args, int32_t num_args, Any* rv) { - EXPECT_EQ(num_args, 1); - auto a = args[0].cast(); - EXPECT_EQ(a.use_count(), 2); - *rv = a->value + 1; - }); - EXPECT_EQ(fadd2(TInt(12)).cast(), 13); -} - -TEST(Func, PackedArgs) { - Function fadd1 = Function::FromPacked([](PackedArgs args, Any* rv) { - EXPECT_EQ(args.size(), 1); - int32_t a = args[0].cast(); - *rv = a + 1; - }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - Function fadd2 = Function::FromPacked([](PackedArgs args, Any* rv) { - EXPECT_EQ(args.size(), 1); - TInt a = args[0].cast(); - EXPECT_EQ(a.use_count(), 2); - *rv = a->value + 1; - }); - EXPECT_EQ(fadd2(TInt(12)).cast(), 13); - - TInt v(12); - AnyView data[3]; - PackedArgs::Fill(data, 3, 1, v); - EXPECT_EQ(data[0].cast(), 3); - EXPECT_EQ(data[1].cast(), 1); - EXPECT_EQ(data[2].cast()->value, 12); -} - -TEST(Func, FromTyped) { - // try decution - Function fadd1 = Function::FromTyped([](const int32_t& a) -> int { return a + 1; }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(1.1); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: int) -> int`. " - "Expected `int` but got `float`"); - throw; - } - }, - ::tvm::ffi::Error); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched number of arguments when calling: `(0: int) -> int`. " - "Expected 1 but got 0 arguments"); - throw; - } - }, - ::tvm::ffi::Error); - - // try decution - Function fpass_and_return = Function::FromTyped( - [](TInt x, int value, AnyView z) -> Function { - EXPECT_EQ(x.use_count(), 2); - EXPECT_EQ(x->value, value); - if (auto opt = z.as()) { - EXPECT_EQ(value, *opt); - } - return Function::FromTyped([value](int x) -> int { return x + value; }); - }, - "fpass_and_return"); - TInt a(11); - auto fret = fpass_and_return(std::move(a), 11, 11).cast(); - EXPECT_EQ(fret(12).cast(), 23); - - EXPECT_THROW( - { - try { - fpass_and_return(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched number of arguments when calling: " - "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> ffi.Function`. " - "Expected 3 but got 0 arguments"); - throw; - } - }, - ::tvm::ffi::Error); - - Function fconcact = - Function::FromTyped([](const String& a, const String& b) -> String { return a + b; }); - EXPECT_EQ(fconcact("abc", "def").cast(), "abcdef"); -} - -TEST(Func, PassReturnAny) { - Function fadd_one = Function::FromTyped([](Any a) -> Any { return a.cast() + 1; }); - EXPECT_EQ(fadd_one(1).cast(), 2); -} - -TEST(Func, Global) { - Function::SetGlobal("testing.add1", - Function::FromTyped([](const int32_t& a) -> int { return a + 1; })); - auto fadd1 = Function::GetGlobalRequired("testing.add1"); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - auto fnot_exist = Function::GetGlobal("testing.not_existing_func"); - EXPECT_TRUE(!fnot_exist); - - auto fname_functor = - Function::GetGlobal("ffi.FunctionListGlobalNamesFunctor").value()().cast(); - Array names; - int len = fname_functor(-1).cast(); - for (int i = 0; i < len; ++i) { - names.push_back(fname_functor(i).cast()); - } - EXPECT_TRUE(std::find(names.begin(), names.end(), "testing.add1") != names.end()); -} - -TEST(Func, TypedFunction) { - TypedFunction fadd1 = [](int a) -> int { return a + 1; }; - EXPECT_EQ(fadd1(1), 2); - - TypedFunction fadd2([](int a) -> int { return a + 2; }); - EXPECT_EQ(fadd2(1), 3); - EXPECT_EQ(fadd2.packed()(1).cast(), 3); - - TypedFunction fcheck_int; - EXPECT_TRUE(fcheck_int == nullptr); - fcheck_int = [](int a) -> void { EXPECT_EQ(a, 1); }; - fcheck_int(1); -} - -TEST(Func, TypedFunctionAsAny) { - TypedFunction fadd1 = [](int a) -> int { return a + 1; }; - Any fany(std::move(fadd1)); - EXPECT_TRUE(fadd1 == nullptr); - auto fadd1_dup = fany.cast>(); - EXPECT_EQ(fadd1_dup(1), 2); -} - -TEST(Func, TypedFunctionAsAnyView) { - TypedFunction fadd2 = [](int a) -> int { return a + 2; }; - AnyView fview(fadd2); - auto fadd2_dup = fview.cast>(); - EXPECT_EQ(fadd2_dup(1), 3); -} - -TEST(Func, ObjectRefWithFallbackTraits) { - // test cases to test automatic type conversion via ObjectRefWithFallbackTraits - // through TPrimExpr - Function freturn_primexpr = Function::FromTyped([](TPrimExpr a) -> TPrimExpr { return a; }); - - auto result_int = freturn_primexpr(1).cast(); - EXPECT_EQ(result_int->dtype, "int64"); - EXPECT_EQ(result_int->value, 1); - - // Test case for float - auto result_float = freturn_primexpr(2.5).cast(); - EXPECT_EQ(result_float->dtype, "float32"); - EXPECT_EQ(result_float->value, 2.5); - - // Test case for bool - auto result_bool = freturn_primexpr(true).cast(); - EXPECT_EQ(result_bool->dtype, "bool"); - EXPECT_EQ(result_bool->value, 1); - - // Test case for string - auto result_string = freturn_primexpr("test_string").cast(); - EXPECT_EQ(result_string->dtype, "test_string"); - EXPECT_EQ(result_string->value, 0); - - EXPECT_THROW( - { - try { - freturn_primexpr(TInt(1)); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ( - error.message(), - "Mismatched type on argument #0 when calling: `(0: test.PrimExpr) -> test.PrimExpr`. " - "Expected `test.PrimExpr` but got `test.Int`"); - throw; - } - }, - ::tvm::ffi::Error); -} - -} // namespace diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc deleted file mode 100644 index 98d8427c23a1..000000000000 --- a/ffi/tests/cpp/test_map.cc +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Map, Basic) { - Map map0; - TInt k0(0); - map0.Set(k0, 1); - - EXPECT_EQ(map0.size(), 1); - - map0.Set(k0, 2); - EXPECT_EQ(map0.size(), 1); - - auto it = map0.find(k0); - EXPECT_TRUE(it != map0.end()); - EXPECT_EQ((*it).second, 2); -} - -TEST(Map, PODKey) { - Map map0; - - // int as key - map0.Set(1, 2); - // float key is different - map0.Set(1.1, 3); - EXPECT_EQ(map0.size(), 2); - - auto it = map0.find(1.1); - EXPECT_TRUE(it != map0.end()); - EXPECT_EQ((*it).second.cast(), 3); -} - -TEST(Map, Object) { - TInt x(1); - TInt z(100); - TInt zz(1000); - Map dict{{x, z}, {z, zz}}; - EXPECT_EQ(dict.size(), 2); - EXPECT_TRUE(dict[x].same_as(z)); - EXPECT_TRUE(dict.count(z)); - EXPECT_TRUE(!dict.count(zz)); -} - -TEST(Map, Str) { - TInt x(1); - TInt z(100); - Map dict{{"x", z}, {"z", z}}; - EXPECT_EQ(dict.size(), 2); - EXPECT_TRUE(dict["x"].same_as(z)); -} - -TEST(Map, Mutate) { - TInt x(1); - TInt z(100); - TInt zz(1000); - Map dict{{x, z}, {z, zz}}; - - EXPECT_TRUE(dict[x].same_as(z)); - dict.Set(x, zz); - auto dict2 = dict; - EXPECT_EQ(dict2.count(z), 1); - dict.Set(zz, x); - EXPECT_EQ(dict2.count(zz), 0); - EXPECT_EQ(dict.count(zz), 1); - - auto it = dict.find(zz); - EXPECT_TRUE(it != dict.end() && (*it).second.same_as(x)); - - it = dict2.find(zz); - EXPECT_TRUE(it == dict2.end()); -} - -TEST(Map, Clear) { - TInt x(1); - TInt z(100); - Map dict{{x, z}, {z, z}}; - EXPECT_EQ(dict.size(), 2); - dict.clear(); - EXPECT_EQ(dict.size(), 0); -} - -TEST(Map, Insert) { - auto check = [](const Map& result, - std::unordered_map expected) { - EXPECT_EQ(result.size(), expected.size()); - for (const auto& kv : result) { - EXPECT_TRUE(expected.count(kv.first)); - EXPECT_EQ(expected[kv.first], kv.second); - expected.erase(kv.first); - } - }; - Map result; - std::unordered_map expected; - char key = 'a'; - int64_t val = 1; - for (int i = 0; i < 26; ++i, ++key, ++val) { - std::string s(1, key); - result.Set(s, val); - expected[s] = val; - check(result, expected); - } -} - -TEST(Map, Erase) { - auto check = [](const Map& result, - std::unordered_map expected) { - EXPECT_EQ(result.size(), expected.size()); - for (const auto& kv : result) { - EXPECT_TRUE(expected.count(kv.first)); - EXPECT_EQ(expected[kv.first], kv.second); - expected.erase(kv.first); - } - }; - Map map{{"a", 1}, {"b", 2}, {"c", 3}, {"d", 4}, {"e", 5}}; - std::unordered_map stl; - std::transform(map.begin(), map.end(), std::inserter(stl, stl.begin()), - [](auto&& p) { return std::make_pair(p.first, p.second); }); - for (char c = 'a'; c <= 'e'; ++c) { - Map result = map; - std::unordered_map expected(stl); - std::string key(1, c); - result.erase(key); - expected.erase(key); - check(result, expected); - } -} - -TEST(Map, AnyImplicitConversion) { - Map map0; - map0.Set(1, 2); - map0.Set(2, 3.1); - EXPECT_EQ(map0.size(), 2); - - // check will trigger copy - AnyView view0 = map0; - auto map1 = view0.cast>(); - EXPECT_TRUE(!map1.same_as(map0)); - EXPECT_EQ(map1[1], 2); - EXPECT_EQ(map1[2], 3.1); - EXPECT_EQ(map1.use_count(), 1); - - auto map2 = view0.cast>(); - EXPECT_TRUE(map2.same_as(map0)); - EXPECT_EQ(map2.use_count(), 2); - - auto map3 = view0.cast>(); - EXPECT_TRUE(!map3.same_as(map0)); - EXPECT_EQ(map3.use_count(), 1); - - Map map4{{"yes", 1.1}, {"no", 2.2}}; - Any any1 = map4; - - auto map5 = any1.cast>(); - EXPECT_TRUE(map5.same_as(map4)); - EXPECT_EQ(map5.use_count(), 3); - - auto map6 = any1.cast>(); - EXPECT_TRUE(map6.same_as(map4)); - EXPECT_EQ(map6.use_count(), 4); - - EXPECT_EQ(map6["yes"].cast(), 1.1); - EXPECT_EQ(map6["no"].cast(), 2.2); - - auto map7 = any1.cast>(); - EXPECT_TRUE(map7.same_as(map4)); - EXPECT_EQ(map7.use_count(), 5); - - auto map8 = any1.cast>(); - EXPECT_TRUE(!map8.same_as(map4)); - EXPECT_EQ(map8.use_count(), 1); - EXPECT_EQ(map8["yes"]->value, 1.1); - EXPECT_EQ(map8["no"]->value, 2.2); -} - -TEST(Map, AnyConvertCheck) { - Map map = {{11, 1.1}}; - EXPECT_EQ(map[11].cast(), 1.1); - - AnyView view0 = map; - auto arr1 = view0.cast>(); - EXPECT_EQ(arr1[11], 1.1); - - Any any1 = map; - using WrongMap = Map; - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE( - what.find( - "Cannot convert from type `Map[K, some value is float]` to `Map`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - using WrongMap2 = Map; - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `Map[some key is int, V]` to " - "`Map`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Map, FunctionGetItem) { - Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, - "map_get_item"); - Map map{{"x", 1}, {"y", 2}}; - Any k("x"); - Any v = f(map, k); - EXPECT_EQ(v.cast(), 1); -} - -TEST(Map, Upcast) { - Map m0 = {{1, 2}, {3, 4}}; - Map m1 = m0; - EXPECT_EQ(m1[1].cast(), 2); - EXPECT_EQ(m1[3].cast(), 4); - static_assert(details::type_contains_v, Map>); - - Map> m2 = {{"x", {1}}, {"y", {2}}}; - Map> m3 = m2; -} - -template -void PrintMap(const Map& m0) { - std::cout << "{"; - for (auto it = m0.begin(); it != m0.end(); ++it) { - if (it != m0.begin()) { - std::cout << ", "; - } - std::cout << (*it).first << ": " << (*it).second; - } - std::cout << "}" << std::endl; -} - -TEST(Map, MapInsertOrder) { - // test that map preserves the insertion order - auto get_reverse_order = [](size_t size) { - std::vector reverse_order; - for (int i = static_cast(size); i != 0; --i) { - reverse_order.push_back(i - 1); - } - return reverse_order; - }; - - auto check_map = [&](Map m0, size_t size, const std::vector& order) { - auto lhs = m0.begin(); - auto rhs = order.begin(); - while (lhs != m0.end()) { - TVM_FFI_ICHECK_EQ((*lhs).first, "hello" + std::to_string(*rhs)); - TVM_FFI_ICHECK_EQ((*lhs).second, *rhs); - ++lhs; - ++rhs; - } - lhs = m0.end(); - rhs = order.begin() + size; - do { - --lhs; - --rhs; - TVM_FFI_ICHECK_EQ((*lhs).first, "hello" + std::to_string(*rhs)); - TVM_FFI_ICHECK_EQ((*lhs).second, *rhs); - } while (lhs != m0.begin()); - }; - - auto check_order = [&](std::vector order) { - Map m0; - for (size_t i = 0; i < order.size(); ++i) { - m0.Set("hello" + std::to_string(order[i]), order[i]); - check_map(m0, i + 1, order); - } - check_map(m0, order.size(), order); - // erase a few items - m0.erase("hello" + std::to_string(order[0])); - auto item0 = order[0]; - order.erase(order.begin()); - check_map(m0, order.size(), order); - // erase the middle part - if (order.size() > 1) { - m0.erase("hello" + std::to_string(order[1])); - order.erase(order.begin() + 1); - check_map(m0, order.size(), order); - } - // erase the end - m0.erase("hello" + std::to_string(order.back())); - auto item2 = order.back(); - order.erase(order.end() - 1); - check_map(m0, order.size(), order); - EXPECT_NE(m0.size(), 0); - // put back some items - order.push_back(item2); - m0.Set("hello" + std::to_string(item2), item2); - check_map(m0, order.size(), order); - order.push_back(item0); - m0.Set("hello" + std::to_string(item0), item0); - check_map(m0, order.size(), order); - }; - // test with 17 items: DenseMapObj - check_order(get_reverse_order(17)); - // test with 4 items: SmallMapObj - check_order(get_reverse_order(4)); -} - -TEST(Map, EmptyIter) { - Map m0; - EXPECT_EQ(m0.begin(), m0.end()); - // create a big map and then erase to keep a dense map empty - for (int i = 0; i < 10; ++i) { - m0.Set("hello" + std::to_string(i), i); - } - for (int i = 0; i < 10; ++i) { - m0.erase("hello" + std::to_string(i)); - } - EXPECT_EQ(m0.size(), 0); - // now m0 is dense map with all empty slots - EXPECT_EQ(m0.begin(), m0.end()); -} - -TEST(Map, DuplicatedKeysInit) { - std::vector> data = {{"a", 1}, {"a", 2}, {"a", 3}}; - Map map(data.begin(), data.end()); - EXPECT_EQ(map.size(), 1); - EXPECT_EQ(map["a"], 3); -} -} // namespace diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc deleted file mode 100644 index ec5c54c4d77a..000000000000 --- a/ffi/tests/cpp/test_object.cc +++ /dev/null @@ -1,258 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Object, RefCounter) { - ObjectPtr a = make_object(11); - ObjectPtr b = a; - - EXPECT_EQ(a->value, 11); - - EXPECT_EQ(a.use_count(), 2); - ObjectPtr aa = make_object(*a); - EXPECT_EQ(aa.use_count(), 1); - EXPECT_EQ(aa->value, 11); - - b.reset(); - EXPECT_EQ(a.use_count(), 1); - EXPECT_TRUE(b == nullptr); - EXPECT_EQ(b.use_count(), 0); - - ObjectPtr c = std::move(a); - EXPECT_EQ(c.use_count(), 1); - EXPECT_TRUE(a == nullptr); - - EXPECT_EQ(c->value, 11); -} - -TEST(Object, TypeInfo) { - const TypeInfo* info = TVMFFIGetTypeInfo(TIntObj::RuntimeTypeIndex()); - EXPECT_TRUE(info != nullptr); - EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex()); - EXPECT_EQ(info->type_depth, 2); - EXPECT_EQ(info->type_acenstors[0]->type_index, Object::_type_index); - EXPECT_EQ(info->type_acenstors[1]->type_index, TNumberObj::_type_index); - EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin); -} - -TEST(Object, InstanceCheck) { - ObjectPtr a = make_object(11); - ObjectPtr b = make_object(11); - - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(!a->IsInstance()); - - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(b->IsInstance()); - EXPECT_TRUE(!b->IsInstance()); - EXPECT_TRUE(b->IsInstance()); -} - -TEST(ObjectRef, as) { - ObjectRef a = TInt(10); - ObjectRef b = TFloat(20); - // nullable object - ObjectRef c(nullptr); - - EXPECT_TRUE(a.as() != nullptr); - EXPECT_TRUE(a.as() == nullptr); - EXPECT_TRUE(a.as() != nullptr); - - EXPECT_TRUE(b.as() == nullptr); - EXPECT_TRUE(b.as() != nullptr); - EXPECT_TRUE(b.as() != nullptr); - - EXPECT_TRUE(c.as() == nullptr); - EXPECT_TRUE(c.as() == nullptr); - EXPECT_TRUE(c.as() == nullptr); - - EXPECT_EQ(a.as()->value, 10); - EXPECT_EQ(b.as()->value, 20); -} - -TEST(ObjectRef, UnsafeInit) { - ObjectRef a(UnsafeInit{}); - EXPECT_TRUE(a.get() == nullptr); - - TInt b(UnsafeInit{}); - EXPECT_TRUE(b.get() == nullptr); -} - -TEST(Object, CAPIAccessor) { - ObjectRef a = TInt(10); - TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(a); - int32_t type_index = TVMFFIObjectGetTypeIndex(obj); - EXPECT_EQ(type_index, TIntObj::RuntimeTypeIndex()); -} - -TEST(Object, WeakObjectPtr) { - // Test basic construction from ObjectPtr - ObjectPtr strong_ptr = make_object(42); - WeakObjectPtr weak_ptr(strong_ptr); - - EXPECT_EQ(strong_ptr.use_count(), 1); - EXPECT_FALSE(weak_ptr.expired()); - EXPECT_EQ(weak_ptr.use_count(), 1); - - // Test lock() when object is still alive - ObjectPtr locked_ptr = weak_ptr.lock(); - EXPECT_TRUE(locked_ptr != nullptr); - EXPECT_EQ(locked_ptr->value, 42); - EXPECT_EQ(strong_ptr.use_count(), 2); - EXPECT_EQ(weak_ptr.use_count(), 2); - - // Test lock() when object is expired - strong_ptr.reset(); - locked_ptr.reset(); - EXPECT_TRUE(weak_ptr.expired()); - EXPECT_EQ(weak_ptr.use_count(), 0); - - ObjectPtr expired_lock = weak_ptr.lock(); - EXPECT_TRUE(expired_lock == nullptr); -} - -TEST(Object, WeakObjectPtrAssignment) { - // Test copy construction - ObjectPtr new_strong = make_object(100); - WeakObjectPtr weak1(new_strong); - WeakObjectPtr weak2(weak1); - - EXPECT_EQ(new_strong.use_count(), 1); - EXPECT_FALSE(weak1.expired()); - EXPECT_FALSE(weak2.expired()); - EXPECT_EQ(weak1.use_count(), 1); - EXPECT_EQ(weak2.use_count(), 1); - - // Test move construction - WeakObjectPtr weak3(std::move(weak1)); - EXPECT_TRUE(weak1.expired()); // weak1 should be moved from - EXPECT_FALSE(weak3.expired()); - EXPECT_EQ(weak3.use_count(), 1); - - // Test assignment - WeakObjectPtr weak4; - weak4 = weak2; - EXPECT_FALSE(weak2.expired()); - EXPECT_FALSE(weak4.expired()); - EXPECT_EQ(weak2.use_count(), 1); - EXPECT_EQ(weak4.use_count(), 1); - - // Test move assignment - WeakObjectPtr weak5; - weak5 = std::move(weak2); - EXPECT_TRUE(weak2.expired()); // weak2 should be moved from - EXPECT_FALSE(weak5.expired()); - EXPECT_EQ(weak5.use_count(), 1); - - // Test reset() - weak3.reset(); - EXPECT_TRUE(weak3.expired()); - EXPECT_EQ(weak3.use_count(), 0); - - // Test swap() - ObjectPtr strong_a = make_object(200); - ObjectPtr strong_b = make_object(300); - WeakObjectPtr weak_a(strong_a); - WeakObjectPtr weak_b(strong_b); - - weak_a.swap(weak_b); - EXPECT_EQ(weak_a.lock()->value, 300); - EXPECT_EQ(weak_b.lock()->value, 200); - - // Test construction from nullptr - WeakObjectPtr null_weak(nullptr); - EXPECT_TRUE(null_weak.expired()); - EXPECT_EQ(null_weak.use_count(), 0); - EXPECT_TRUE(null_weak.lock() == nullptr); - - // Test inheritance compatibility - ObjectPtr number_ptr = make_object(500); - WeakObjectPtr number_weak(number_ptr); - - EXPECT_FALSE(number_weak.expired()); - EXPECT_EQ(number_weak.use_count(), 1); - - // Test that weak references don't prevent object deletion - ObjectPtr temp_strong = make_object(999); - WeakObjectPtr temp_weak(temp_strong); - - EXPECT_FALSE(temp_weak.expired()); - temp_strong.reset(); - EXPECT_TRUE(temp_weak.expired()); - EXPECT_TRUE(temp_weak.lock() == nullptr); - - // Test multiple weak references - ObjectPtr multi_strong = make_object(777); - WeakObjectPtr multi_weak1(multi_strong); - WeakObjectPtr multi_weak2(multi_strong); - WeakObjectPtr multi_weak3(multi_strong); - - EXPECT_EQ(multi_strong.use_count(), 1); - EXPECT_FALSE(multi_weak1.expired()); - EXPECT_FALSE(multi_weak2.expired()); - EXPECT_FALSE(multi_weak3.expired()); - - // All weak references should be able to lock - ObjectPtr lock1 = multi_weak1.lock(); - ObjectPtr lock2 = multi_weak2.lock(); - ObjectPtr lock3 = multi_weak3.lock(); - - EXPECT_EQ(multi_strong.use_count(), 4); - EXPECT_EQ(lock1->value, 777); - EXPECT_EQ(lock2->value, 777); - EXPECT_EQ(lock3->value, 777); -} - -TEST(Object, OpaqueObject) { - thread_local int deleter_trigger_counter = 0; - struct DummyOpaqueObject { - int value; - DummyOpaqueObject(int value) : value(value) {} - - static void Deleter(void* handle) { - deleter_trigger_counter++; - delete static_cast(handle); - } - }; - TVMFFIObjectHandle handle = nullptr; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIObjectCreateOpaque(new DummyOpaqueObject(10), kTVMFFIOpaquePyObject, - DummyOpaqueObject::Deleter, &handle)); - ObjectPtr a = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - EXPECT_EQ(a->type_index(), kTVMFFIOpaquePyObject); - EXPECT_EQ(static_cast(TVMFFIOpaqueObjectGetCellPtr(a.get())->handle)->value, - 10); - EXPECT_EQ(a.use_count(), 1); - EXPECT_EQ(deleter_trigger_counter, 0); - a.reset(); - EXPECT_EQ(deleter_trigger_counter, 1); -} - -} // namespace diff --git a/ffi/tests/cpp/test_optional.cc b/ffi/tests/cpp/test_optional.cc deleted file mode 100644 index eb114df8a3fa..000000000000 --- a/ffi/tests/cpp/test_optional.cc +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Optional, TInt) { - Optional x; - Optional y = TInt(11); - static_assert(sizeof(Optional) == sizeof(ObjectRef)); - - EXPECT_TRUE(!x.has_value()); - EXPECT_EQ(x.value_or(TInt(12))->value, 12); - - EXPECT_TRUE(y.has_value()); - EXPECT_EQ(y.value_or(TInt(12))->value, 11); - - Any z_any = std::move(y); - EXPECT_TRUE(z_any != nullptr); - EXPECT_EQ((z_any.cast())->value, 11); - EXPECT_TRUE(!y.has_value()); - - // move from any to optional - auto y2 = std::move(z_any).cast>(); - EXPECT_EQ(y2.use_count(), 1); - EXPECT_TRUE(y2.has_value()); - EXPECT_EQ(y2.value_or(TInt(12))->value, 11); -} - -TEST(Optional, double) { - Optional x; - Optional y = 11.0; - static_assert(sizeof(Optional) > sizeof(ObjectRef)); - - EXPECT_TRUE(!x.has_value()); - EXPECT_EQ(x.value_or(12), 12); - EXPECT_TRUE(x != 12); - - EXPECT_TRUE(y.has_value()); - EXPECT_EQ(y.value_or(12), 11); - EXPECT_TRUE(y == 11); - EXPECT_TRUE(y != 12); -} - -TEST(Optional, AnyConvert_int) { - Optional opt_v0 = 1; - EXPECT_EQ(opt_v0.value(), 1); - EXPECT_TRUE(opt_v0.has_value()); - - AnyView view0 = opt_v0; - EXPECT_EQ(view0.cast(), 1); - - Any any1; - auto opt_v1 = std::move(any1).cast>(); - EXPECT_TRUE(!opt_v1.has_value()); - Optional opt_v2 = 11; - Any any2 = std::move(opt_v2); - EXPECT_EQ(any2.cast(), 11); -} - -TEST(Optional, AnyConvert_Array) { - AnyView view0; - Array> arr_nested = {{}, {TInt(1), TFloat(2)}}; - view0 = arr_nested; - - auto opt_arr = view0.cast>>>(); - EXPECT_EQ(arr_nested.use_count(), 2); - - auto arr1 = view0.cast>>>(); - EXPECT_EQ(arr_nested.use_count(), 3); - EXPECT_EQ(arr1.value()[1][1].as()->value, 2); - - Any any1; - auto arr2 = any1.cast>>>(); - EXPECT_TRUE(!arr2.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = view0.cast>>>(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - std::cout << what << std::endl; - EXPECT_NE(what.find("to `Optional>>`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Optional, OptionalOfOptional) { - // testcase of optional - Optional> opt_opt_int; - EXPECT_TRUE(!opt_opt_int.has_value()); - - Optional> opt_opt_int2 = Optional(std::nullopt); - EXPECT_TRUE(opt_opt_int2.has_value()); - EXPECT_TRUE(!opt_opt_int2.value().has_value()); - - // Optional> - Optional> opt_opt_tint; - EXPECT_TRUE(!opt_opt_tint.has_value()); - - Optional> opt_opt_tint2 = Optional(std::nullopt); - EXPECT_TRUE(opt_opt_tint2.has_value()); - EXPECT_TRUE(!opt_opt_tint2.value().has_value()); - opt_opt_tint2 = std::nullopt; - EXPECT_TRUE(!opt_opt_tint2.has_value()); - - Optional> opt_opt_tint3 = Optional(TInt(42)); - EXPECT_TRUE(opt_opt_tint3.has_value()); - EXPECT_TRUE(opt_opt_tint3.value().has_value()); - EXPECT_EQ(opt_opt_tint3.value().value()->value, 42); -} - -TEST(Optional, ValueMove) { - Optional y = TInt(11); - TInt x = std::move(y).value(); - EXPECT_TRUE(!y.has_value()); - EXPECT_EQ(x->value, 11); - - Optional opt_tint = TInt(21); - EXPECT_TRUE(opt_tint.has_value()); - EXPECT_EQ((*opt_tint)->value, 21); - - TInt moved_tint = *std::move(opt_tint); - EXPECT_EQ(moved_tint->value, 21); - EXPECT_TRUE(!opt_tint.has_value()); -} - -TEST(Optional, OptionalInArray) { - // This pattern plus iteration may cause memory leak - // this is because arr[0] returns a temporary object - // and further call arr[0].value() may return a reference to - // the temporary object - Array>> arr = {Array({TInt(0), TInt(1)})}; - int counter = 0; - - for (const auto& x : arr[0].value()) { - EXPECT_EQ(x->value, counter++); - } - - Any any = arr; - auto opt_arr = any.cast>>>(); - EXPECT_EQ(opt_arr[0].value()[0]->value, 0); -} - -TEST(Optional, String) { - Optional opt_str; - EXPECT_TRUE(!opt_str.has_value()); - EXPECT_EQ(opt_str.value_or("default"), "default"); - EXPECT_TRUE(opt_str != "default"); - EXPECT_TRUE(opt_str != String("default")); - EXPECT_TRUE(opt_str == std::nullopt); - - opt_str = "hello"; - EXPECT_TRUE(opt_str.has_value()); - EXPECT_EQ(opt_str.value(), "hello"); - EXPECT_TRUE(opt_str == "hello"); - EXPECT_TRUE(opt_str == String("hello")); - EXPECT_TRUE(opt_str != std::nullopt); - static_assert(sizeof(Optional) == sizeof(String)); -} - -TEST(Optional, Bytes) { - Optional opt_bytes; - EXPECT_TRUE(!opt_bytes.has_value()); - EXPECT_EQ(opt_bytes.value_or(std::string("default")), "default"); - - opt_bytes = std::string("hello"); - EXPECT_TRUE(opt_bytes.has_value()); - EXPECT_EQ(opt_bytes.value().operator std::string(), "hello"); - EXPECT_TRUE(opt_bytes != std::nullopt); - static_assert(sizeof(Optional) == sizeof(Bytes)); -} -} // namespace diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc deleted file mode 100644 index c9aa500aeb41..000000000000 --- a/ffi/tests/cpp/test_reflection.cc +++ /dev/null @@ -1,269 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -struct TestObjA : public Object { - int64_t x; - int64_t y; - - static constexpr bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO("test.TestObjA", TestObjA, Object); -}; - -struct TestObjADerived : public TestObjA { - int64_t z; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.TestObjADerived", TestObjADerived, TestObjA); -}; - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - - TIntObj::RegisterReflection(); - TFloatObj::RegisterReflection(); - TPrimExprObj::RegisterReflection(); - TVarObj::RegisterReflection(); - TFuncObj::RegisterReflection(); - TCustomFuncObj::RegisterReflection(); - - refl::ObjectDef().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y); - refl::ObjectDef().def_ro("z", &TestObjADerived::z); -} - -TEST(Reflection, GetFieldByteOffset) { - EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::x), sizeof(TVMFFIObject)); - EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::y), 8 + sizeof(TVMFFIObject)); - EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject)); -} - -TEST(Reflection, FieldGetter) { - ObjectRef a = TInt(10); - reflection::FieldGetter getter("test.Int", "value"); - EXPECT_EQ(getter(a).cast(), 10); - - ObjectRef b = TFloat(10.0); - reflection::FieldGetter getter_float("test.Float", "value"); - EXPECT_EQ(getter_float(b).cast(), 10.0); -} - -TEST(Reflection, FieldSetter) { - ObjectRef a = TFloat(10.0); - reflection::FieldSetter setter("test.Float", "value"); - setter(a, 20.0); - EXPECT_EQ(a.as()->value, 20.0); -} - -TEST(Reflection, FieldInfo) { - const TVMFFIFieldInfo* info_int = reflection::GetFieldInfo("test.Int", "value"); - EXPECT_FALSE(info_int->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_FALSE(info_int->flags & kTVMFFIFieldFlagBitMaskWritable); - EXPECT_EQ(Bytes(info_int->doc).operator std::string(), ""); - - const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", "value"); - EXPECT_EQ(info_float->default_value.v_float64, 10.0); - EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_FALSE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable); - EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value field"); - - const TVMFFIFieldInfo* info_prim_expr_dtype = reflection::GetFieldInfo("test.PrimExpr", "dtype"); - AnyView default_value = AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value); - EXPECT_EQ(default_value.cast(), "float"); - EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable); - EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype field"); -} - -TEST(Reflection, MethodInfo) { - const TVMFFIMethodInfo* info_int_static_add = reflection::GetMethodInfo("test.Int", "static_add"); - EXPECT_TRUE(info_int_static_add->flags & kTVMFFIFieldFlagBitMaskIsStaticMethod); - EXPECT_EQ(Bytes(info_int_static_add->doc).operator std::string(), "static add method"); - - const TVMFFIMethodInfo* info_float_add = reflection::GetMethodInfo("test.Float", "add"); - EXPECT_FALSE(info_float_add->flags & kTVMFFIFieldFlagBitMaskIsStaticMethod); - EXPECT_EQ(Bytes(info_float_add->doc).operator std::string(), "add method"); - - const TVMFFIMethodInfo* info_float_sub = reflection::GetMethodInfo("test.Float", "sub"); - EXPECT_FALSE(info_float_sub->flags & kTVMFFIFieldFlagBitMaskIsStaticMethod); - EXPECT_EQ(Bytes(info_float_sub->doc).operator std::string(), ""); -} - -TEST(Reflection, CallMethod) { - Function static_int_add = reflection::GetMethod("test.Int", "static_add"); - EXPECT_EQ(static_int_add(TInt(1), TInt(2)).cast()->value, 3); - - Function float_add = reflection::GetMethod("test.Float", "add"); - EXPECT_EQ(float_add(TFloat(1), 2.0).cast(), 3.0); - - Function float_sub = reflection::GetMethod("test.Float", "sub"); - EXPECT_EQ(float_sub(TFloat(1), 2.0).cast(), -1.0); - - Function prim_expr_sub = reflection::GetMethod("test.PrimExpr", "sub"); - EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast(), -1.0); -} - -TEST(Reflection, ForEachFieldInfo) { - const TypeInfo* info = TVMFFIGetTypeInfo(TestObjADerived::RuntimeTypeIndex()); - Map field_name_to_offset; - reflection::ForEachFieldInfo(info, [&](const TVMFFIFieldInfo* field_info) { - field_name_to_offset.Set(String(field_info->name), field_info->offset); - }); - EXPECT_EQ(field_name_to_offset["x"], sizeof(TVMFFIObject)); - EXPECT_EQ(field_name_to_offset["y"], 8 + sizeof(TVMFFIObject)); - EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject)); -} - -TEST(Reflection, TypeAttrColumn) { - reflection::TypeAttrColumn size_attr("test.size"); - EXPECT_EQ(size_attr[TIntObj::_type_index].cast(), sizeof(TIntObj)); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("testing.Int_GetValue", &TIntObj::GetValue); -} - -TEST(Reflection, FuncRegister) { - Function fget_value = Function::GetGlobalRequired("testing.Int_GetValue"); - TInt a(12); - EXPECT_EQ(fget_value(a).cast(), 12); -} - -TEST(Reflection, ObjectCreator) { - namespace refl = tvm::ffi::reflection; - refl::ObjectCreator creator("test.Int"); - EXPECT_EQ(creator(Map({{"value", 1}})).cast()->value, 1); -} - -TEST(Reflection, AccessPath) { - namespace refl = tvm::ffi::reflection; - - // Test basic path construction and ToSteps() - refl::AccessPath path = refl::AccessPath::Root()->Attr("body")->ArrayItem(1); - auto steps = path->ToSteps(); - EXPECT_EQ(steps.size(), 2); - EXPECT_EQ(steps[0]->kind, refl::AccessKind::kAttr); - EXPECT_EQ(steps[1]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ(steps[0]->key.cast(), "body"); - EXPECT_EQ(steps[1]->key.cast(), 1); - - // Test PathEqual with identical paths - refl::AccessPath path2 = refl::AccessPath::Root()->Attr("body")->ArrayItem(1); - EXPECT_TRUE(path->PathEqual(path2)); - EXPECT_TRUE(path->IsPrefixOf(path2)); - - // Test PathEqual with different paths - refl::AccessPath path3 = refl::AccessPath::Root()->Attr("body")->ArrayItem(2); - EXPECT_FALSE(path->PathEqual(path3)); - EXPECT_FALSE(path->IsPrefixOf(path3)); - - // Test prefix relationship - path4 extends path, so path should be prefix of path4 - refl::AccessPath path4 = refl::AccessPath::Root()->Attr("body")->ArrayItem(1)->Attr("body"); - EXPECT_FALSE(path->PathEqual(path4)); // Not equal (different lengths) - EXPECT_TRUE(path->IsPrefixOf(path4)); // But path is a prefix of path4 - - // Test completely different paths - refl::AccessPath path5 = refl::AccessPath::Root()->ArrayItem(0)->ArrayItem(1)->Attr("body"); - EXPECT_FALSE(path->PathEqual(path5)); - EXPECT_FALSE(path->IsPrefixOf(path5)); - - // Test Root path - refl::AccessPath root = refl::AccessPath::Root(); - auto root_steps = root->ToSteps(); - EXPECT_EQ(root_steps.size(), 0); - EXPECT_EQ(root->depth, 0); - EXPECT_TRUE(root->IsPrefixOf(path)); - EXPECT_TRUE(root->IsPrefixOf(root)); - EXPECT_TRUE(root->PathEqual(refl::AccessPath::Root())); - - // Test depth calculations - EXPECT_EQ(path->depth, 2); - EXPECT_EQ(path4->depth, 3); - EXPECT_EQ(root->depth, 0); - - // Test MapItem access - refl::AccessPath map_path = refl::AccessPath::Root()->Attr("data")->MapItem("key1"); - auto map_steps = map_path->ToSteps(); - EXPECT_EQ(map_steps.size(), 2); - EXPECT_EQ(map_steps[0]->kind, refl::AccessKind::kAttr); - EXPECT_EQ(map_steps[1]->kind, refl::AccessKind::kMapItem); - EXPECT_EQ(map_steps[0]->key.cast(), "data"); - EXPECT_EQ(map_steps[1]->key.cast(), "key1"); - - // Test MapItemMissing access - refl::AccessPath map_missing_path = refl::AccessPath::Root()->MapItemMissing(42); - auto map_missing_steps = map_missing_path->ToSteps(); - EXPECT_EQ(map_missing_steps.size(), 1); - EXPECT_EQ(map_missing_steps[0]->kind, refl::AccessKind::kMapItemMissing); - EXPECT_EQ(map_missing_steps[0]->key.cast(), 42); - - // Test ArrayItemMissing access - refl::AccessPath array_missing_path = refl::AccessPath::Root()->ArrayItemMissing(5); - auto array_missing_steps = array_missing_path->ToSteps(); - EXPECT_EQ(array_missing_steps.size(), 1); - EXPECT_EQ(array_missing_steps[0]->kind, refl::AccessKind::kArrayItemMissing); - EXPECT_EQ(array_missing_steps[0]->key.cast(), 5); - - // Test FromSteps static method - round trip conversion - auto original_steps = path->ToSteps(); - refl::AccessPath reconstructed = refl::AccessPath::FromSteps(original_steps); - EXPECT_TRUE(path->PathEqual(reconstructed)); - EXPECT_EQ(path->depth, reconstructed->depth); - - // Test complex prefix relationships - refl::AccessPath short_path = refl::AccessPath::Root()->Attr("x"); - refl::AccessPath medium_path = refl::AccessPath::Root()->Attr("x")->ArrayItem(0); - refl::AccessPath long_path = refl::AccessPath::Root()->Attr("x")->ArrayItem(0)->MapItem("z"); - - EXPECT_TRUE(short_path->IsPrefixOf(medium_path)); - EXPECT_TRUE(short_path->IsPrefixOf(long_path)); - EXPECT_TRUE(medium_path->IsPrefixOf(long_path)); - EXPECT_FALSE(medium_path->IsPrefixOf(short_path)); - EXPECT_FALSE(long_path->IsPrefixOf(medium_path)); - EXPECT_FALSE(long_path->IsPrefixOf(short_path)); - - // Test non-prefix relationships - refl::AccessPath branch1 = refl::AccessPath::Root()->Attr("x")->ArrayItem(0); - refl::AccessPath branch2 = refl::AccessPath::Root()->Attr("x")->ArrayItem(1); - EXPECT_FALSE(branch1->IsPrefixOf(branch2)); - EXPECT_FALSE(branch2->IsPrefixOf(branch1)); - EXPECT_FALSE(branch1->PathEqual(branch2)); - - // Test GetParent functionality - auto parent = path4->GetParent(); - EXPECT_TRUE(parent.has_value()); - EXPECT_TRUE(parent.value()->PathEqual(path)); - - auto root_parent = root->GetParent(); - EXPECT_FALSE(root_parent.has_value()); -} -} // namespace diff --git a/ffi/tests/cpp/test_rvalue_ref.cc b/ffi/tests/cpp/test_rvalue_ref.cc deleted file mode 100644 index dd211a34dc60..000000000000 --- a/ffi/tests/cpp/test_rvalue_ref.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(RValueRef, Basic) { - auto append = - Function::FromTyped([](RValueRef> ref, int val, bool is_unique) -> Array { - Array arr = *std::move(ref); - EXPECT_EQ(arr.unique(), is_unique); - arr.push_back(val); - return arr; - }); - auto a = append(RValueRef(Array({1, 2})), 3, true).cast>(); - EXPECT_EQ(a.size(), 3); - a = append(RValueRef(std::move(a)), 4, true).cast>(); - EXPECT_EQ(a.size(), 4); - // pass in lvalue instead, the append still will succeed but array will not be unique - a = append(a, 5, false).cast>(); - EXPECT_EQ(a.size(), 5); -} - -TEST(RValueRef, ParamChecking) { - // try decution - Function fadd1 = Function::FromTyped([](TInt a) -> int64_t { return a->value + 1; }); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(RValueRef(TInt(1))); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: test.Int) -> int`. " - "Expected `test.Int` but got `ObjectRValueRef`"); - throw; - } - }, - ::tvm::ffi::Error); - - Function fadd2 = Function::FromTyped([](RValueRef> a) -> int { - Array arr = *std::move(a); - return arr[0] + 1; - }); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd2(RValueRef(Array({1, 2.2}))); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ( - error.message(), - "Mismatched type on argument #0 when calling: `(0: RValueRef>) -> int`. " - "Expected `RValueRef>` but got `RValueRef`"); - throw; - } - }, - ::tvm::ffi::Error); - // triggered a rvalue based conversion - Function func3 = Function::FromTyped([](RValueRef a) -> String { - TPrimExpr expr = *std::move(a); - return expr->dtype; - }); - // EXPECT_EQ(func3(RValueRef(String("int32"))).cast(), "int32"); - // triggered a lvalue based conversion - // EXPECT_EQ(func3(String("int32")).cast(), "int32"); -} -} // namespace diff --git a/ffi/tests/cpp/test_shape.cc b/ffi/tests/cpp/test_shape.cc deleted file mode 100644 index 0ccba7820ad7..000000000000 --- a/ffi/tests/cpp/test_shape.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -TEST(Shape, Basic) { - Shape shape = Shape({1, 2, 3}); - EXPECT_EQ(shape.size(), 3); - EXPECT_EQ(shape[0], 1); - EXPECT_EQ(shape[1], 2); - EXPECT_EQ(shape[2], 3); - - Shape shape2 = Shape(Array({4, 5, 6, 7})); - EXPECT_EQ(shape2.size(), 4); - EXPECT_EQ(shape2[0], 4); - EXPECT_EQ(shape2[1], 5); - EXPECT_EQ(shape2[2], 6); - EXPECT_EQ(shape2[3], 7); - - std::vector vec = {8, 9, 10}; - Shape shape3 = Shape(std::move(vec)); - EXPECT_EQ(shape3.size(), 3); - EXPECT_EQ(shape3[0], 8); - EXPECT_EQ(shape3[1], 9); - EXPECT_EQ(shape3[2], 10); - EXPECT_EQ(shape3.Product(), 8 * 9 * 10); - - Shape shape4 = Shape(); - EXPECT_EQ(shape4.size(), 0); - EXPECT_EQ(shape4.Product(), 1); -} - -TEST(Shape, AnyConvert) { - Shape shape0 = Shape({1, 2, 3}); - Any any0 = shape0; - - auto shape1 = any0.cast(); - EXPECT_EQ(shape1.size(), 3); - EXPECT_EQ(shape1[0], 1); - EXPECT_EQ(shape1[1], 2); - EXPECT_EQ(shape1[2], 3); - - Array arr({1, 2}); - AnyView any_view0 = arr; - auto shape2 = any_view0.cast(); - EXPECT_EQ(shape2.size(), 2); - EXPECT_EQ(shape2[0], 1); - EXPECT_EQ(shape2[1], 2); -} - -} // namespace diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc deleted file mode 100644 index 8522aa93a3b9..000000000000 --- a/ffi/tests/cpp/test_string.cc +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -TEST(String, MoveFromStd) { - using namespace std; - string source = "this is a string"; - string expect = source; - String s(std::move(source)); - string copy = (string)s; - EXPECT_EQ(copy, expect); - EXPECT_EQ(source.size(), 0); -} - -TEST(String, CopyFromStd) { - using namespace std; - string source = "this is a string"; - string expect = source; - String s{source}; - string copy = (string)s; - EXPECT_EQ(copy, expect); - EXPECT_EQ(source.size(), expect.size()); -} - -TEST(String, Assignment) { - using namespace std; - String s{string{"hello"}}; - s = string{"world"}; - EXPECT_EQ(s == "world", true); - string s2{"world2"}; - s = std::move(s2); - EXPECT_EQ(s == "world2", true); - - Any r; - r = String("hello"); - EXPECT_EQ(r != nullptr, true); -} - -TEST(String, empty) { - using namespace std; - String s{"hello"}; - EXPECT_EQ(s.empty(), false); - s = std::string(""); - EXPECT_EQ(s.empty(), true); -} - -TEST(String, Comparisons) { - using namespace std; - string source = "a string"; - string mismatch = "a string but longer"; - String s{"a string"}; - String m{mismatch}; - - EXPECT_EQ("a str" >= s, false); - EXPECT_EQ(s == source, true); - EXPECT_EQ(s == mismatch, false); - EXPECT_EQ(s == source.data(), true); - EXPECT_EQ(s == mismatch.data(), false); - - EXPECT_EQ(s < m, source < mismatch); - EXPECT_EQ(s > m, source > mismatch); - EXPECT_EQ(s <= m, source <= mismatch); - EXPECT_EQ(s >= m, source >= mismatch); - EXPECT_EQ(s == m, source == mismatch); - EXPECT_EQ(s != m, source != mismatch); - - EXPECT_EQ(m < s, mismatch < source); - EXPECT_EQ(m > s, mismatch > source); - EXPECT_EQ(m <= s, mismatch <= source); - EXPECT_EQ(m >= s, mismatch >= source); - EXPECT_EQ(m == s, mismatch == source); - EXPECT_EQ(m != s, mismatch != source); -} - -TEST(String, Compare) { - // string compare const char* - String s{"hello"}; - EXPECT_EQ(s.compare("hello"), 0); - EXPECT_EQ(s.compare(String("hello")), 0); - - EXPECT_EQ(s.compare("hallo"), 1); - EXPECT_EQ(s.compare(String("hallo")), 1); - EXPECT_EQ(s.compare("hfllo"), -1); - EXPECT_EQ(s.compare(String("hfllo")), -1); - // s is longer - EXPECT_EQ(s.compare("hell"), 1); - EXPECT_EQ(s.compare(String("hell")), 1); - // s is shorter - EXPECT_EQ(s.compare("hello world"), -1); - EXPECT_EQ(s.compare(String("helloworld")), -1); -} - -// Check '\0' handling -TEST(String, null_byte_handling) { - using namespace std; - // Ensure string still compares equal if it contains '\0'. - string v1 = "hello world"; - size_t v1_size = v1.size(); - v1[5] = '\0'; - EXPECT_EQ(v1[5], '\0'); - EXPECT_EQ(v1.size(), v1_size); - String str_v1{v1}; - EXPECT_EQ(str_v1.compare(v1), 0); - EXPECT_EQ(str_v1.size(), v1_size); - - // Ensure bytes after '\0' are taken into account for mismatches. - string v2 = "aaa one"; - string v3 = "aaa two"; - v2[3] = '\0'; - v3[3] = '\0'; - String str_v2{v2}; - String str_v3{v3}; - EXPECT_EQ(str_v2.compare(str_v3), -1); - EXPECT_EQ(str_v2.size(), 7); - // strcmp won't be able to detect the mismatch - EXPECT_EQ(strcmp(v2.data(), v3.data()), 0); - // string::compare can handle \0 since it knows size - EXPECT_LT(v2.compare(v3), 0); - - // If there is mismatch before '\0', should still handle it. - string v4 = "acc one"; - string v5 = "abb two"; - v4[3] = '\0'; - v5[3] = '\0'; - String str_v4{v4}; - String str_v5{v5}; - EXPECT_GT(str_v4.compare(str_v5), 0); - EXPECT_EQ(str_v4.size(), 7); - // strcmp is able to detect the mismatch - EXPECT_GT(strcmp(v4.data(), v5.data()), 0); - // string::compare can handle \0 since it knows size - EXPECT_GT(v4.compare(v5), 0); -} - -TEST(String, compare_same_memory_region_different_size) { - using namespace std; - string source = "a string"; - String str_source{source}; - char* memory = const_cast(str_source.data()); - EXPECT_EQ(str_source.compare(memory), 0); - // This changes the string size - memory[2] = '\0'; - // memory is logically shorter now - EXPECT_GT(str_source.compare(memory), 0); -} - -TEST(String, compare) { - using namespace std; - constexpr auto mismatch1_cstr = "a string but longer"; - string source = "a string"; - string mismatch1 = mismatch1_cstr; - string mismatch2 = "a strin"; - string mismatch3 = "a b"; - string mismatch4 = "a t"; - String str_source{source}; - String str_mismatch1{mismatch1_cstr}; - String str_mismatch2{mismatch2}; - String str_mismatch3{mismatch3}; - String str_mismatch4{mismatch4}; - - // compare with string - EXPECT_EQ(str_source.compare(source), 0); - EXPECT_TRUE(str_source == source); - EXPECT_TRUE(source == str_source); - EXPECT_TRUE(str_source <= source); - EXPECT_TRUE(source <= str_source); - EXPECT_TRUE(str_source >= source); - EXPECT_TRUE(source >= str_source); - EXPECT_LT(str_source.compare(mismatch1), 0); - EXPECT_TRUE(str_source < mismatch1); - EXPECT_TRUE(mismatch1 != str_source); - EXPECT_GT(str_source.compare(mismatch2), 0); - EXPECT_TRUE(str_source > mismatch2); - EXPECT_TRUE(mismatch2 < str_source); - EXPECT_GT(str_source.compare(mismatch3), 0); - EXPECT_TRUE(str_source > mismatch3); - EXPECT_LT(str_source.compare(mismatch4), 0); - EXPECT_TRUE(str_source < mismatch4); - EXPECT_TRUE(mismatch4 > str_source); - - // compare with char* - EXPECT_EQ(str_source.compare(source.data()), 0); - EXPECT_TRUE(str_source == source.data()); - EXPECT_TRUE(source.data() == str_source); - EXPECT_TRUE(str_source <= source.data()); - EXPECT_TRUE(source <= str_source.data()); - EXPECT_TRUE(str_source >= source.data()); - EXPECT_TRUE(source >= str_source.data()); - EXPECT_LT(str_source.compare(mismatch1.data()), 0); - EXPECT_TRUE(str_source < mismatch1.data()); - EXPECT_TRUE(str_source != mismatch1.data()); - EXPECT_TRUE(mismatch1.data() != str_source); - EXPECT_GT(str_source.compare(mismatch2.data()), 0); - EXPECT_TRUE(str_source > mismatch2.data()); - EXPECT_TRUE(mismatch2.data() < str_source); - EXPECT_GT(str_source.compare(mismatch3.data()), 0); - EXPECT_TRUE(str_source > mismatch3.data()); - EXPECT_LT(str_source.compare(mismatch4.data()), 0); - EXPECT_TRUE(str_source < mismatch4.data()); - EXPECT_TRUE(mismatch4.data() > str_source); - - // compare with String - EXPECT_LT(str_source.compare(str_mismatch1), 0); - EXPECT_TRUE(str_source < str_mismatch1); - EXPECT_GT(str_source.compare(str_mismatch2), 0); - EXPECT_TRUE(str_source > str_mismatch2); - EXPECT_GT(str_source.compare(str_mismatch3), 0); - EXPECT_TRUE(str_source > str_mismatch3); - EXPECT_LT(str_source.compare(str_mismatch4), 0); - EXPECT_TRUE(str_source < str_mismatch4); -} - -TEST(String, c_str) { - using namespace std; - string source = "this is a string"; - string mismatch = "mismatch"; - String s{source}; - - EXPECT_EQ(std::strcmp(s.c_str(), source.data()), 0); - EXPECT_NE(std::strcmp(s.c_str(), mismatch.data()), 0); -} - -TEST(String, hash) { - using namespace std; - string source = "this is a string"; - String s{source}; - std::hash()(s); - - std::unordered_map map; - String k1{string{"k1"}}; - string v1{"v1"}; - String k2{string{"k2"}}; - string v2{"v2"}; - map[k1] = v1; - map[k2] = v2; - - EXPECT_EQ(map[k1], v1); - EXPECT_EQ(map[k2], v2); -} - -TEST(String, Cast) { - using namespace std; - string source = "this is a string"; - String s{source}; - Any r = s; - String s2 = r.cast(); -} - -TEST(String, Concat) { - String s1("hello"); - String s2("world"); - std::string s3("world"); - String res1 = s1 + s2; - String res2 = s1 + s3; - String res3 = s3 + s1; - String res4 = s1 + "world"; - String res5 = "world" + s1; - - EXPECT_EQ(res1.compare("helloworld"), 0); - EXPECT_EQ(res2.compare("helloworld"), 0); - EXPECT_EQ(res3.compare("worldhello"), 0); - EXPECT_EQ(res4.compare("helloworld"), 0); - EXPECT_EQ(res5.compare("worldhello"), 0); - - String storage_scope; - String res = "The input storage scope \"" + storage_scope + "\" is invalid."; - EXPECT_EQ(res.compare("The input storage scope \"\" is invalid."), 0); -} - -TEST(String, Any) { - // test anyview promotion to any - AnyView view = "hello"; - EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIRawStr); - - Any b = view; - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallStr); - EXPECT_EQ(b.as().value(), "hello"); - EXPECT_TRUE(b.as().has_value()); - EXPECT_EQ(b.try_cast().value(), "hello"); - - std::string s_world = "world"; - view = s_world; - EXPECT_EQ(view.try_cast().value(), "world"); - - String s{"hello"}; - Any a = s; - EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr); - EXPECT_EQ(a.as().value(), "hello"); - EXPECT_EQ(a.try_cast().value(), "hello"); - - Any c = "long string very long"; - EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(c.as().value(), "long string very long"); - EXPECT_EQ(c.try_cast().value(), "long string very long"); -} - -TEST(String, Bytes) { - Bytes b0; - EXPECT_EQ(b0.size(), 0); - EXPECT_EQ(b0.operator std::string(), ""); - - // explicitly test zero element - std::string s = {'\0', 'a', 'b', 'c'}; - Bytes b = s; - EXPECT_EQ(b.size(), 4); - EXPECT_EQ(b.operator std::string(), s); - - TVMFFIByteArray arr{s.data(), static_cast(s.size())}; - Bytes b2 = arr; - EXPECT_EQ(b2.size(), 4); - EXPECT_EQ(b2.operator std::string(), s); -} - -TEST(String, BytesAny) { - std::string s = {'\0', 'a', 'b', 'c'}; - TVMFFIByteArray arr{s.data(), static_cast(s.size())}; - - AnyView view = &arr; - EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIByteArrayPtr); - EXPECT_EQ(view.try_cast().value().operator std::string(), s); - - Any b = view; - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallBytes); - - EXPECT_EQ(b.try_cast().value().operator std::string(), s); - EXPECT_EQ(b.cast(), s); - - std::string s2 = "hello long long long string"; - s2[0] = '\0'; - Any b2 = Bytes(s2); - EXPECT_EQ(b2.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(b2.try_cast().value(), s2); - EXPECT_EQ(b2.cast(), s2); -} - -TEST(String, StdString) { - std::string s1 = "test_string"; - AnyView view1 = s1; - EXPECT_EQ(view1.type_index(), TypeIndex::kTVMFFIRawStr); - EXPECT_EQ(view1.try_cast().value(), s1); - - TVMFFIByteArray arr1{s1.data(), static_cast(s1.size())}; - AnyView view2 = &arr1; - EXPECT_EQ(view2.type_index(), TypeIndex::kTVMFFIByteArrayPtr); - EXPECT_EQ(view2.try_cast().value(), s1); - - Bytes bytes1 = s1; - AnyView view3 = bytes1; - EXPECT_EQ(view3.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(view3.try_cast().value(), s1); - - String string1 = s1; - AnyView view4 = string1; - EXPECT_EQ(view4.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(view4.try_cast().value(), s1); - - // Test with Any - Any any1 = s1; - EXPECT_EQ(any1.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(any1.try_cast().value(), s1); - - Any any2 = &arr1; - EXPECT_EQ(any2.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(any2.try_cast().value(), s1); - - Any any3 = bytes1; - EXPECT_EQ(any3.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(any3.try_cast().value(), s1); - - Any any4 = string1; - EXPECT_EQ(any4.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(any4.try_cast().value(), s1); -} - -TEST(String, CAPIAccessor) { - using namespace std; - String s{"hello"}; - TVMFFIByteArray arr{s.data(), s.size()}; - EXPECT_EQ(arr.size, 5); - EXPECT_EQ(std::string(arr.data, arr.size), "hello"); -} - -TEST(String, BytesHash) { - std::vector data1(10); - std::vector data2(11); - for (size_t i = 0; i < data1.size(); ++i) { - data1[i] = i; - } - char* data1_ptr = reinterpret_cast(data1.data()); - char* data2_ptr = reinterpret_cast(data2.data()) + 1; - std::memcpy(data2_ptr, data1.data(), data1.size() * sizeof(int64_t)); - // has of aligned and unaligned data should be the same - uint64_t hash1 = details::StableHashBytes(data1_ptr, data1.size() * sizeof(int64_t)); - uint64_t hash2 = details::StableHashBytes(data2_ptr, data1.size() * sizeof(int64_t)); - EXPECT_EQ(hash1, hash2); -} - -TEST(String, StdHash) { - String s1 = "a"; - String s2(std::string("a")); - EXPECT_EQ(std::hash()(s1), std::hash()(s2)); - - Bytes s3("a", 1); - Bytes s4(std::string("a")); - EXPECT_EQ(std::hash()(s3), std::hash()(s4)); -} - -} // namespace diff --git a/ffi/tests/cpp/test_tensor.cc b/ffi/tests/cpp/test_tensor.cc deleted file mode 100644 index 7c696a3429c1..000000000000 --- a/ffi/tests/cpp/test_tensor.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -namespace { - -using namespace tvm::ffi; - -struct CPUNDAlloc { - void AllocData(DLTensor* tensor) { tensor->data = malloc(GetDataSize(*tensor)); } - void FreeData(DLTensor* tensor) { free(tensor->data); } -}; - -inline Tensor Empty(Shape shape, DLDataType dtype, DLDevice device) { - return Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); -} - -int TestDLPackTensorAllocator(DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, - void (*SetError)(void* error_ctx, const char* kind, - const char* message)) { - Shape shape(prototype->shape, prototype->shape + prototype->ndim); - Tensor nd = Empty(shape, prototype->dtype, prototype->device); - *out = nd.ToDLPackVersioned(); - return 0; -} - -int TestDLPackTensorAllocatorError(DLTensor* prototype, DLManagedTensorVersioned** out, - void* error_ctx, - void (*SetError)(void* error_ctx, const char* kind, - const char* message)) { - SetError(error_ctx, "RuntimeError", "TestDLPackTensorAllocatorError"); - return -1; -} - -TEST(Tensor, Basic) { - Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); - Shape shape = nd.shape(); - Shape strides = nd.strides(); - EXPECT_EQ(shape.size(), 3); - EXPECT_EQ(shape[0], 1); - EXPECT_EQ(shape[1], 2); - EXPECT_EQ(shape[2], 3); - EXPECT_EQ(strides.size(), 3); - EXPECT_EQ(strides[0], 6); - EXPECT_EQ(strides[1], 3); - EXPECT_EQ(strides[2], 1); - EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1})); - for (int64_t i = 0; i < shape.Product(); ++i) { - reinterpret_cast(nd->data)[i] = static_cast(i); - } - - Any any0 = nd; - Tensor nd2 = any0.as().value(); - EXPECT_EQ(nd2.shape(), shape); - EXPECT_EQ(nd2.strides(), strides); - EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1})); - for (int64_t i = 0; i < shape.Product(); ++i) { - EXPECT_EQ(reinterpret_cast(nd2->data)[i], i); - } - - EXPECT_EQ(nd.IsContiguous(), true); - EXPECT_EQ(nd2.use_count(), 3); -} - -TEST(Tensor, DLPack) { - Tensor tensor = Empty({1, 2, 3}, DLDataType({kDLInt, 16, 1}), DLDevice({kDLCPU, 0})); - DLManagedTensor* dlpack = tensor.ToDLPack(); - EXPECT_EQ(dlpack->dl_tensor.ndim, 3); - EXPECT_EQ(dlpack->dl_tensor.shape[0], 1); - EXPECT_EQ(dlpack->dl_tensor.shape[1], 2); - EXPECT_EQ(dlpack->dl_tensor.shape[2], 3); - EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLInt); - EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 16); - EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); - EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); - EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); - EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides[0], 6); - EXPECT_EQ(dlpack->dl_tensor.strides[1], 3); - EXPECT_EQ(dlpack->dl_tensor.strides[2], 1); - EXPECT_EQ(tensor.use_count(), 2); - { - Tensor tensor2 = Tensor::FromDLPack(dlpack); - EXPECT_EQ(tensor2.use_count(), 1); - EXPECT_EQ(tensor2->data, tensor->data); - EXPECT_EQ(tensor.use_count(), 2); - EXPECT_EQ(tensor2.use_count(), 1); - } - EXPECT_EQ(tensor.use_count(), 1); -} - -TEST(Tensor, DLPackVersioned) { - DLDataType dtype = DLDataType({kDLFloat4_e2m1fn, 4, 1}); - EXPECT_EQ(GetDataSize(2, dtype), 2 * 4 / 8); - Tensor tensor = Empty({2}, dtype, DLDevice({kDLCPU, 0})); - DLManagedTensorVersioned* dlpack = tensor.ToDLPackVersioned(); - EXPECT_EQ(dlpack->version.major, DLPACK_MAJOR_VERSION); - EXPECT_EQ(dlpack->version.minor, DLPACK_MINOR_VERSION); - EXPECT_EQ(dlpack->dl_tensor.ndim, 1); - EXPECT_EQ(dlpack->dl_tensor.shape[0], 2); - EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLFloat4_e2m1fn); - EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 4); - EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); - EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); - EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); - EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides[0], 1); - - EXPECT_EQ(tensor.use_count(), 2); - { - Tensor tensor2 = Tensor::FromDLPackVersioned(dlpack); - EXPECT_EQ(tensor2.use_count(), 1); - EXPECT_EQ(tensor2->data, tensor->data); - EXPECT_EQ(tensor.use_count(), 2); - EXPECT_EQ(tensor2.use_count(), 1); - } - EXPECT_EQ(tensor.use_count(), 1); -} - -TEST(Tensor, DLPackAlloc) { - // Test successful allocation - Tensor tensor = Tensor::FromDLPackAlloc(TestDLPackTensorAllocator, {1, 2, 3}, - DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); - EXPECT_EQ(tensor.use_count(), 1); - EXPECT_EQ(tensor.shape().size(), 3); - EXPECT_EQ(tensor.shape()[0], 1); - EXPECT_EQ(tensor.shape()[1], 2); - EXPECT_EQ(tensor.shape()[2], 3); - EXPECT_EQ(tensor.dtype().code, kDLFloat); - EXPECT_EQ(tensor.dtype().bits, 32); - EXPECT_EQ(tensor.dtype().lanes, 1); - EXPECT_EQ(tensor->device.device_type, kDLCPU); - EXPECT_EQ(tensor->device.device_id, 0); - EXPECT_NE(tensor->data, nullptr); -} - -TEST(Tensor, DLPackAllocError) { - // Test error handling in DLPackAlloc - EXPECT_THROW( - { - Tensor::FromDLPackAlloc(TestDLPackTensorAllocatorError, {1, 2, 3}, - DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); - }, - tvm::ffi::Error); -} - -} // namespace diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc deleted file mode 100644 index 5735e86eca4d..000000000000 --- a/ffi/tests/cpp/test_tuple.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Tuple, Basic) { - Tuple tuple0(1, 2.0f); - EXPECT_EQ(tuple0.get<0>(), 1); - EXPECT_EQ(tuple0.get<1>(), 2.0f); - - Tuple tuple1 = tuple0; - EXPECT_EQ(tuple0.use_count(), 2); - - // test copy on write - tuple1.Set<0>(3); - EXPECT_EQ(tuple0.get<0>(), 1); - EXPECT_EQ(tuple1.get<0>(), 3); - - EXPECT_EQ(tuple0.use_count(), 1); - EXPECT_EQ(tuple1.use_count(), 1); - - // copy on write not triggered because - // tuple1 is unique. - tuple1.Set<1>(4); - EXPECT_EQ(tuple1.get<1>(), 4.0f); - EXPECT_EQ(tuple1.use_count(), 1); - - // default state - Tuple tuple2; - EXPECT_EQ(tuple2.use_count(), 1); - tuple2.Set<0>(1); - tuple2.Set<1>(2.0f); - EXPECT_EQ(tuple2.get<0>(), 1); - EXPECT_EQ(tuple2.get<1>(), 2.0f); - - // tuple of object and primitive - Tuple tuple3(1, 2); - EXPECT_EQ(tuple3.get<0>()->value, 1); - EXPECT_EQ(tuple3.get<1>(), 2); - tuple3.Set<0>(4); - EXPECT_EQ(tuple3.get<0>()->value, 4); -} - -TEST(Tuple, AnyConvert) { - Tuple tuple0(1, 2); - AnyView view0 = tuple0; - Array arr0 = view0.as>().value(); - EXPECT_EQ(arr0.size(), 2); - EXPECT_EQ(arr0[0].as().value(), 1); - EXPECT_EQ(arr0[1].as().value()->value, 2); - - // directly reuse the underlying storage. - auto tuple1 = view0.cast>(); - EXPECT_TRUE(tuple0.same_as(tuple1)); - - Any any0 = view0; - // trigger a copy due to implict conversion - auto tuple2 = any0.cast>(); - EXPECT_TRUE(!tuple0.same_as(tuple2)); - EXPECT_EQ(tuple2.get<0>()->value, 1); - EXPECT_EQ(tuple2.get<1>()->value, 2); -} - -TEST(Tuple, FromTyped) { - // try decution - Function fadd1 = Function::FromTyped([](const Tuple& a) -> int { - return a.get<0>() + static_cast(a.get<1>()->value); - }); - int b = fadd1(Tuple(1, 2)).cast(); - EXPECT_EQ(b, 3); - - int c = fadd1(Array({1, 2})).cast(); - EXPECT_EQ(c, 3); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(Array({1.1, 2})); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: Tuple) -> int`. " - "Expected `Tuple` but got `Array[index 0: float]`"); - throw; - } - }, - ::tvm::ffi::Error); - - EXPECT_THROW( - { - try { - fadd1(Array({1.1})); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: Tuple) -> int`. " - "Expected `Tuple` but got `Array[size=1]`"); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Tuple, Upcast) { - Tuple t0(1, 2.0f); - Tuple t1 = t0; - EXPECT_EQ(t1.get<0>().cast(), 1); - EXPECT_EQ(t1.get<1>().cast(), 2.0f); - static_assert(details::type_contains_v, Tuple>); - static_assert(details::type_contains_v, Tuple>); - static_assert(details::type_contains_v, Tuple>); -} - -TEST(Tuple, ArrayIterForwarding) { - Tuple t0(1, 2); - Tuple t1(3, 4); - Array> arr0 = {t0, t1}; - std::vector> vec0 = {t0}; - vec0.insert(vec0.end(), arr0.begin(), arr0.end()); - EXPECT_EQ(vec0.size(), 3); - EXPECT_EQ(vec0[0].get<0>()->value, 1); - EXPECT_EQ(vec0[0].get<1>()->value, 2); - EXPECT_EQ(vec0[1].get<0>()->value, 1); - EXPECT_EQ(vec0[1].get<1>()->value, 2); - EXPECT_EQ(vec0[2].get<0>()->value, 3); - EXPECT_EQ(vec0[2].get<1>()->value, 4); -} - -TEST(Tuple, ArrayIterForwardSingleElem) { - Tuple t0(1); - Tuple t1(2); - Array> arr0 = {t0, t1}; - std::vector> vec0 = {t0}; - vec0.insert(vec0.end(), arr0.begin(), arr0.end()); - EXPECT_EQ(vec0.size(), 3); - EXPECT_EQ(vec0[0].get<0>()->value, 1); - EXPECT_EQ(vec0[1].get<0>()->value, 1); - EXPECT_EQ(vec0[2].get<0>()->value, 2); -} - -} // namespace diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc deleted file mode 100644 index 639e6ee671dd..000000000000 --- a/ffi/tests/cpp/test_variant.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Variant, Basic) { - Variant v1 = 1; - EXPECT_EQ(v1.get(), 1); - - Variant v2 = 2.0f; - EXPECT_EQ(v2.get(), 2.0f); - v2 = v1; - EXPECT_EQ(v2.get(), 1); -} - -TEST(Variant, AnyConvert) { - Variant v = 1; - AnyView view0 = v; - EXPECT_EQ(view0.as().value(), 1); - - // implicit convert to variant - Any any0 = 1; - auto v1 = any0.cast>>(); - EXPECT_EQ(v1.get()->value, 1); - - // move from any to variant - Variant v2 = TInt(1); - Any any1 = std::move(v2); - auto v3 = std::move(any1).cast>(); - auto v4 = std::move(v3).get(); - EXPECT_EQ(v4->value, 1); - EXPECT_EQ(v4.use_count(), 1); -} - -TEST(Variant, ObjectPtrHashEqual) { - TInt x = TInt(1); - TFloat y = TFloat(1.0f); - - Variant v0 = x; - Variant v1 = y; - Variant v2 = v1; - - EXPECT_EQ(ObjectPtrHash()(v0), ObjectPtrHash()(x)); - EXPECT_TRUE(!ObjectPtrEqual()(v0, v1)); - EXPECT_TRUE(!ObjectPtrEqual()(v0, v2)); -} - -TEST(Variant, FromTyped) { - // try decution - Function fadd1 = Function::FromTyped([](const Variant& a) -> int64_t { - if (auto opt_int = a.as()) { - return opt_int.value() + 1; - } else { - return a.get()->value + 1; - } - }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(1.1); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ( - error.message(), - "Mismatched type on argument #0 when calling: `(0: Variant) -> int`. " - "Expected `Variant` but got `float`"); - throw; - } - }, - ::tvm::ffi::Error); - - Function fadd2 = Function::FromTyped([](const Array>& a) -> int64_t { - if (auto opt_int = a[0].as()) { - return opt_int.value() + 1; - } else { - return a[0].get()->value + 1; - } - }); - int c = fadd2(Array({1, 2})).cast(); - EXPECT_EQ(c, 2); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd2(Array({1, 1.1})); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: Array>) -> int`. " - "Expected `Array>` but got `Array[index 1: float]`"); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Variant, Upcast) { - Array a0 = {1, 2, 3}; - static_assert(details::type_contains_v>, Array>); - Array> a1 = a0; - EXPECT_EQ(a1[0].get(), 1); -} - -TEST(Variant, AllObjectRef) { - Variant> v0 = TInt(1); - EXPECT_EQ(v0.get()->value, 1); - static_assert(std::is_base_of_v); - Any any0 = v0; - EXPECT_EQ(any0.cast()->value, 1); - auto v2 = any0.cast>>(); - EXPECT_TRUE(v0.same_as(v2)); - // assignment operator - v0 = Array({TInt(2), TInt(3)}); - EXPECT_EQ(v0.get>().size(), 2); - EXPECT_EQ(v0.get>()[0]->value, 2); - EXPECT_EQ(v0.get>()[1]->value, 3); - EXPECT_EQ(sizeof(v0), sizeof(ObjectRef)); -} - -TEST(Variant, PODSameAs) { - Variant v0 = 1; - Variant v1 = 1; - EXPECT_TRUE(v0.same_as(v1)); - String s = String("hello long str"); - v0 = s; - v1 = s; - EXPECT_TRUE(v0.same_as(v1)); - v1 = String("hello long str"); - EXPECT_TRUE(!v0.same_as(v1)); -} -} // namespace diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h deleted file mode 100644 index 933ba996b0ae..000000000000 --- a/ffi/tests/cpp/testing_object.h +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_FFI_TESTING_OBJECT_H_ -#define TVM_FFI_TESTING_OBJECT_H_ - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace testing { - -// We deliberately pad extra -// in the header to test cases -// where the object subclass address -// do not align with the base object address -// not handling properly will cause buffer overflow -class BasePad { - public: - int64_t extra[4]; -}; - -class TNumberObj : public BasePad, public Object { - public: - // declare as one slot, with float as overflow - static constexpr uint32_t _type_child_slots = 1; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO("test.Number", TNumberObj, Object); -}; - -class TNumber : public ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TNumber, ObjectRef, TNumberObj); -}; - -class TIntObj : public TNumberObj { - public: - int64_t value; - - TIntObj(int64_t value) : value(value) {} - explicit TIntObj(UnsafeInit) {} - - int64_t GetValue() const { return value; } - - inline static void RegisterReflection(); - - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.Int", TIntObj, TNumberObj); -}; - -class TInt : public TNumber { - public: - explicit TInt(int64_t value) { data_ = make_object(value); } - - static TInt StaticAdd(TInt lhs, TInt rhs) { return TInt(lhs->value + rhs->value); } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TInt, TNumber, TIntObj); -}; - -inline void TIntObj::RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("value", &TIntObj::value) - .def_static("static_add", &TInt::StaticAdd, "static add method"); - // define extra type attributes - refl::TypeAttrDef() - .def("test.GetValue", &TIntObj::GetValue) - .attr("test.size", sizeof(TIntObj)); - // custom json serialization - refl::TypeAttrDef() - .def("__data_to_json__", - [](const TIntObj* self) -> Map { - return Map{{"value", self->value}}; - }) - .def("__data_from_json__", [](Map json_obj) -> TInt { - return TInt(json_obj["value"].cast()); - }); -} - -class TFloatObj : public TNumberObj { - public: - double value; - - TFloatObj(double value) : value(value) {} - - double Add(double other) const { return value + other; } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) - .def("sub", - [](const TFloatObj* self, double other) -> double { return self->value - other; }) - .def("add", &TFloatObj::Add, "add method"); - } - - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.Float", TFloatObj, TNumberObj); -}; - -class TFloat : public TNumber { - public: - explicit TFloat(double value) { data_ = make_object(value); } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TFloat, TNumber, TFloatObj); -}; - -class TPrimExprObj : public Object { - public: - std::string dtype; - double value; - - TPrimExprObj(std::string dtype, double value) : dtype(dtype), value(value) {} - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_rw("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) - .def_ro("value", &TPrimExprObj::value, "value field", refl::DefaultValue(0)) - .def("sub", [](TPrimExprObj* self, double other) -> double { - // this is ok because TPrimExprObj is declared asmutable - return self->value - other; - }); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.PrimExpr", TPrimExprObj, Object); -}; - -class TPrimExpr : public ObjectRef { - public: - explicit TPrimExpr(std::string dtype, double value) { - data_ = make_object(dtype, value); - } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TPrimExpr, ObjectRef, TPrimExprObj); -}; - -class TVarObj : public Object { - public: - std::string name; - - TVarObj(std::string name) : name(name) {} - // need unsafe init constructor for json serialization - explicit TVarObj(UnsafeInit) {} - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("name", &TVarObj::name, - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.Var", TVarObj, Object); -}; - -class TVar : public ObjectRef { - public: - explicit TVar(std::string name) { data_ = make_object(name); } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVar, ObjectRef, TVarObj); -}; - -class TFuncObj : public Object { - public: - Array params; - Array body; - Optional comment; - - // need unsafe init constructor or default constructor for json serialization - explicit TFuncObj(UnsafeInit) {} - TFuncObj(Array params, Array body, Optional comment) - : params(params), body(body), comment(comment) {} - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("params", &TFuncObj::params, refl::AttachFieldFlag::SEqHashDef()) - .def_ro("body", &TFuncObj::body) - .def_ro("comment", &TFuncObj::comment, refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.Func", TFuncObj, Object); -}; - -class TFunc : public ObjectRef { - public: - explicit TFunc(Array params, Array body, Optional comment) { - data_ = make_object(params, body, comment); - } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TFunc, ObjectRef, TFuncObj); -}; - -class TCustomFuncObj : public Object { - public: - Array params; - Array body; - String comment; - - TCustomFuncObj(Array params, Array body, String comment) - : params(params), body(body), comment(comment) {} - - bool SEqual(const TCustomFuncObj* other, - ffi::TypedFunction cmp) const { - if (!cmp(params, other->params, true, "params")) { - return false; - } - if (!cmp(body, other->body, false, "body")) { - return false; - } - return true; - } - - uint64_t SHash(uint64_t init_hash, - ffi::TypedFunction hash) const { - uint64_t hash_value = init_hash; - hash_value = hash(params, hash_value, true); - hash_value = hash(body, hash_value, false); - return hash_value; - } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("params", &TCustomFuncObj::params) - .def_ro("body", &TCustomFuncObj::body) - .def_ro("comment", &TCustomFuncObj::comment); - refl::TypeAttrDef() - .def("__s_equal__", &TCustomFuncObj::SEqual) - .def("__s_hash__", &TCustomFuncObj::SHash); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.CustomFunc", TCustomFuncObj, Object); -}; - -class TCustomFunc : public ObjectRef { - public: - explicit TCustomFunc(Array params, Array body, String comment) { - data_ = make_object(params, body, comment); - } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TCustomFunc, ObjectRef, TCustomFuncObj); -}; - -} // namespace testing - -template <> -inline constexpr bool use_default_type_traits_v = true; - -template <> -struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(StrictBool value) { - return testing::TPrimExpr("bool", static_cast(value)); - } - - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(int64_t value) { - return testing::TPrimExpr("int64", static_cast(value)); - } - - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(double value) { - return testing::TPrimExpr("float32", static_cast(value)); - } - // hack into the dtype to store string - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(String value) { - return testing::TPrimExpr(value, 0); - } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_TESTING_OBJECT_H_ diff --git a/ffi/tests/python/test_access_path.py b/ffi/tests/python/test_access_path.py deleted file mode 100644 index 7d9e7af55f5f..000000000000 --- a/ffi/tests/python/test_access_path.py +++ /dev/null @@ -1,133 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -from tvm_ffi.access_path import AccessPath, AccessKind - - -def test_root_path(): - root = AccessPath.root() - assert isinstance(root, AccessPath) - steps = root.to_steps() - assert len(steps) == 0 - assert root == AccessPath.root() - - -def test_path_attr(): - path = AccessPath.root().attr("foo") - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.ATTR - assert steps[0].key == "foo" - assert path.parent == AccessPath.root() - - -def test_path_array_item(): - path = AccessPath.root().array_item(2) - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.ARRAY_ITEM - assert steps[0].key == 2 - assert path.parent == AccessPath.root() - - -def test_path_missing_array_element(): - path = AccessPath.root().array_item_missing(2) - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.ARRAY_ITEM_MISSING - assert steps[0].key == 2 - assert path.parent == AccessPath.root() - - -def test_path_map_item(): - path = AccessPath.root().map_item("foo") - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.MAP_ITEM - assert steps[0].key == "foo" - assert path.parent == AccessPath.root() - - -def test_path_missing_map_item(): - path = AccessPath.root().map_item_missing("foo") - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.MAP_ITEM_MISSING - assert steps[0].key == "foo" - assert path.parent == AccessPath.root() - - -def test_path_is_prefix_of(): - # Root is prefix of root - assert AccessPath.root().is_prefix_of(AccessPath.root()) - - # Root is prefix of any path - assert AccessPath.root().is_prefix_of(AccessPath.root().attr("foo")) - - # Non-root is not prefix of root - assert not AccessPath.root().attr("foo").is_prefix_of(AccessPath.root()) - - # Path is prefix of itself - assert AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo")) - - # Different attrs are not prefixes of each other - assert not AccessPath.root().attr("bar").is_prefix_of(AccessPath.root().attr("foo")) - - # Shorter path is prefix of longer path with same start - assert AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo").array_item(2)) - - # Longer path is not prefix of shorter path - assert ( - not AccessPath.root().attr("foo").array_item(2).is_prefix_of(AccessPath.root().attr("foo")) - ) - - # Different paths are not prefixes - assert ( - not AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("bar").array_item(2)) - ) - - -def test_path_equal(): - # Root equals root - assert AccessPath.root() == AccessPath.root() - - # Root does not equal non-root paths - assert not (AccessPath.root() == AccessPath.root().attr("foo")) - - # Non-root does not equal root - assert not (AccessPath.root().attr("foo") == AccessPath.root()) - - # Path equals itself - assert AccessPath.root().attr("foo") == AccessPath.root().attr("foo") - - # Different attrs are not equal - assert not (AccessPath.root().attr("bar") == AccessPath.root().attr("foo")) - - # Shorter path does not equal longer path - assert not (AccessPath.root().attr("foo") == AccessPath.root().attr("foo").array_item(2)) - - # Longer path does not equal shorter path - assert not (AccessPath.root().attr("foo").array_item(2) == AccessPath.root().attr("foo")) - - # Different paths are not equal - assert not (AccessPath.root().attr("foo") == AccessPath.root().attr("bar").array_item(2)) diff --git a/ffi/tests/python/test_container.py b/ffi/tests/python/test_container.py deleted file mode 100644 index 9f2fb09df216..000000000000 --- a/ffi/tests/python/test_container.py +++ /dev/null @@ -1,124 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest -import pickle -import tvm_ffi - - -def test_array(): - a = tvm_ffi.convert([1, 2, 3]) - assert isinstance(a, tvm_ffi.Array) - assert len(a) == 3 - assert a[-1] == 3 - a_slice = a[-3:-1] - assert (a_slice[0], a_slice[1]) == (1, 2) - - -def test_bad_constructor_init_state(): - """Test when error is raised before __init_handle_by_constructor - - This case we need the FFI binding to gracefully handle both repr - and dealloc by ensuring the chandle is initialized and there is - proper repr code - """ - with pytest.raises(TypeError): - tvm_ffi.Array(1) - - with pytest.raises(AttributeError): - tvm_ffi.Map(1) - - -def test_array_of_array_map(): - a = tvm_ffi.convert([[1, 2, 3], {"A": 5, "B": 6}]) - assert isinstance(a, tvm_ffi.Array) - assert len(a) == 2 - assert isinstance(a[0], tvm_ffi.Array) - assert isinstance(a[1], tvm_ffi.Map) - assert tuple(a[0]) == (1, 2, 3) - assert a[1]["A"] == 5 - assert a[1]["B"] == 6 - - -def test_int_map(): - amap = tvm_ffi.convert({3: 2, 4: 3}) - assert 3 in amap - assert len(amap) == 2 - dd = dict(amap.items()) - assert 3 in dd - assert 4 in dd - assert 5 not in amap - assert tuple(amap.items()) == ((3, 2), (4, 3)) - assert tuple(amap.keys()) == (3, 4) - assert tuple(amap.values()) == (2, 3) - - -def test_array_map_of_opaque_object(): - class MyObject: - def __init__(self, value): - self.value = value - - a = tvm_ffi.convert([MyObject("hello"), MyObject(1)]) - assert isinstance(a, tvm_ffi.Array) - assert len(a) == 2 - assert isinstance(a[0], MyObject) - assert a[0].value == "hello" - assert isinstance(a[1], MyObject) - assert a[1].value == 1 - - y = tvm_ffi.convert({"a": MyObject(1), "b": MyObject("hello")}) - assert isinstance(y, tvm_ffi.Map) - assert len(y) == 2 - assert isinstance(y["a"], MyObject) - assert y["a"].value == 1 - assert isinstance(y["b"], MyObject) - assert y["b"].value == "hello" - - -def test_str_map(): - data = [] - for i in reversed(range(10)): - data.append((f"a{i}", i)) - amap = tvm_ffi.convert({k: v for k, v in data}) - assert tuple(amap.items()) == tuple(data) - for k, v in data: - assert k in amap - assert amap[k] == v - assert amap.get(k) == v - - assert tuple(k for k in amap) == tuple(k for k, _ in data) - - -def test_key_not_found(): - amap = tvm_ffi.convert({3: 2, 4: 3}) - with pytest.raises(KeyError): - amap[5] - - -def test_repr(): - a = tvm_ffi.convert([1, 2, 3]) - assert str(a) == "[1, 2, 3]" - amap = tvm_ffi.convert({3: 2, 4: 3}) - assert str(amap) == "{3: 2, 4: 3}" - - smap = tvm_ffi.convert({"a": 1, "b": 2}) - assert str(smap) == "{'a': 1, 'b': 2}" - - -def test_serialization(): - a = tvm_ffi.convert([1, 2, 3]) - b = pickle.loads(pickle.dumps(a)) - assert str(b) == "[1, 2, 3]" diff --git a/ffi/tests/python/test_device.py b/ffi/tests/python/test_device.py deleted file mode 100644 index 849f45b8f97d..000000000000 --- a/ffi/tests/python/test_device.py +++ /dev/null @@ -1,94 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import pickle -from tvm_ffi import Device, DLDeviceType -import tvm_ffi - - -def test_device(): - device = tvm_ffi.Device("cuda", 0) - assert device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCUDA - assert device.index == 0 - assert str(device) == "cuda:0" - assert device.__repr__() == "device(type='cuda', index=0)" - - -def test_device_from_str(): - device = tvm_ffi.device("ext_dev:0") - assert device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLExtDev - assert device.index == 0 - assert str(device) == "ext_dev:0" - assert device.__repr__() == "device(type='ext_dev', index=0)" - - -@pytest.mark.parametrize( - "dev_str, expected_device_type, expect_device_id", - [ - ("cpu", DLDeviceType.kDLCPU, 0), - ("cuda", DLDeviceType.kDLCUDA, 0), - ("cuda:0", DLDeviceType.kDLCUDA, 0), - ("cuda:3", DLDeviceType.kDLCUDA, 3), - ("metal:2", DLDeviceType.kDLMetal, 2), - ], -) -def test_device(dev_str, expected_device_type, expect_device_id): - dev = tvm_ffi.device(dev_str) - assert dev.dlpack_device_type() == expected_device_type - assert dev.index == expect_device_id - - -@pytest.mark.parametrize( - "dev_type, dev_id, expected_device_type, expect_device_id", - [ - ("cpu", 0, DLDeviceType.kDLCPU, 0), - ("cuda", 0, DLDeviceType.kDLCUDA, 0), - (DLDeviceType.kDLCUDA, 0, DLDeviceType.kDLCUDA, 0), - ("cuda", 3, DLDeviceType.kDLCUDA, 3), - (DLDeviceType.kDLMetal, 2, DLDeviceType.kDLMetal, 2), - ], -) -def test_device_with_dev_id(dev_type, dev_id, expected_device_type, expect_device_id): - dev = tvm_ffi.device(dev_type, dev_id) - assert dev.dlpack_device_type() == expected_device_type - assert dev.index == expect_device_id - - -@pytest.mark.parametrize( - "dev_type, dev_id", - [ - ("cpu:0:0", None), - ("cpu:?", None), - ("cpu:", None), - ], -) -def test_deive_type_error(dev_type, dev_id): - with pytest.raises(ValueError): - dev = tvm_ffi.device(dev_type, dev_id) - - -def test_deive_id_error(): - with pytest.raises(TypeError): - dev = tvm_ffi.device("cpu", "?") - - -def test_device_pickle(): - device = tvm_ffi.device("cuda", 0) - device_pickled = pickle.loads(pickle.dumps(device)) - assert device_pickled.dlpack_device_type() == device.dlpack_device_type() - assert device_pickled.index == device.index diff --git a/ffi/tests/python/test_dtype.py b/ffi/tests/python/test_dtype.py deleted file mode 100644 index 7d09d3def98c..000000000000 --- a/ffi/tests/python/test_dtype.py +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import pickle -import numpy as np -import tvm_ffi - - -def test_dtype(): - float32 = tvm_ffi.dtype("float32") - assert float32.__repr__() == "dtype('float32')" - assert type(float32) == tvm_ffi.dtype - x = np.array([1, 2, 3], dtype=float32) - assert x.dtype == float32 - - -@pytest.mark.parametrize( - "dtype_str, expected_size", - [ - ("float32", 4), - ("float32x4", 16), - ("float8_e5m2x4", 4), - ("float6_e2m3fnx4", 3), - ("float4_e2m1fnx4", 2), - ("uint8", 1), - ("bool", 1), - ], -) -def test_dtype_itemsize(dtype_str, expected_size): - dtype = tvm_ffi.dtype(dtype_str) - assert dtype.itemsize == expected_size - - -@pytest.mark.parametrize("dtype_str", ["int32xvscalex4"]) -def test_dtype_itemmize_error(dtype_str): - with pytest.raises(ValueError): - tvm_ffi.dtype(dtype_str).itemsize - - -@pytest.mark.parametrize( - "dtype_str", - [ - "float32", - "float32x4", - "float8_e5m2x4", - "float6_e2m3fnx4", - "float4_e2m1fnx4", - "uint8", - "bool", - ], -) -def test_dtype_pickle(dtype_str): - dtype = tvm_ffi.dtype(dtype_str) - dtype_pickled = pickle.loads(pickle.dumps(dtype)) - assert dtype_pickled.type_code == dtype.type_code - assert dtype_pickled.bits == dtype.bits - assert dtype_pickled.lanes == dtype.lanes - - -@pytest.mark.parametrize("dtype_str", ["float32", "bool"]) -def test_dtype_with_lanes(dtype_str): - dtype = tvm_ffi.dtype(dtype_str) - dtype_with_lanes = dtype.with_lanes(4) - assert dtype_with_lanes.type_code == dtype.type_code - assert dtype_with_lanes.bits == dtype.bits - assert dtype_with_lanes.lanes == 4 - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/ffi/tests/python/test_error.py b/ffi/tests/python/test_error.py deleted file mode 100644 index ad6da64c0f19..000000000000 --- a/ffi/tests/python/test_error.py +++ /dev/null @@ -1,113 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import platform -import tvm_ffi - - -def test_parse_traceback(): - traceback = """ - File "test.py", line 1, in - File "test.py", line 3, in run_test - """ - parsed = tvm_ffi.error._parse_traceback(traceback) - assert len(parsed) == 2 - assert parsed[0] == ("test.py", 1, "") - assert parsed[1] == ("test.py", 3, "run_test") - - -def test_error_from_cxx(): - test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error") - - try: - test_raise_error("ValueError", "error XYZ") - except ValueError as e: - assert e.__tvm_ffi_error__.kind == "ValueError" - assert e.__tvm_ffi_error__.message == "error XYZ" - assert e.__tvm_ffi_error__.traceback.find("TestRaiseError") != -1 - - fapply = tvm_ffi.convert(lambda f, *args: f(*args)) - - with pytest.raises(TypeError): - fapply(test_raise_error, "TypeError", "error XYZ") - - # wrong number of arguments - with pytest.raises(TypeError): - tvm_ffi.convert(lambda x: x)() - - -def test_error_from_nested_pyfunc(): - fapply = tvm_ffi.convert(lambda f, *args: f(*args)) - cxx_test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error") - cxx_test_apply = tvm_ffi.get_global_func("testing.apply") - - record_object = [] - - def raise_error(): - try: - fapply(cxx_test_raise_error, "ValueError", "error XYZ") - except ValueError as e: - assert e.__tvm_ffi_error__.kind == "ValueError" - assert e.__tvm_ffi_error__.message == "error XYZ" - assert e.__tvm_ffi_error__.traceback.find("TestRaiseError") != -1 - record_object.append(e.__tvm_ffi_error__) - raise e - - try: - cxx_test_apply(raise_error) - except ValueError as e: - traceback = e.__tvm_ffi_error__.traceback - assert e.__tvm_ffi_error__.same_as(record_object[0]) - assert traceback.count("TestRaiseError") == 1 - # The following lines may fail if debug symbols are missing - try: - assert traceback.count("TestApply") == 1 - assert traceback.count("") == 1 - pos_cxx_raise = traceback.find("TestRaiseError") - pos_cxx_apply = traceback.find("TestApply") - pos_lambda = traceback.find("") - assert pos_cxx_raise > pos_lambda - assert pos_lambda > pos_cxx_apply - except Exception as e: - pytest.xfail("May fail if debug symbols are missing") - - -def test_error_traceback_update(): - fecho = tvm_ffi.get_global_func("testing.echo") - - def raise_error(): - raise ValueError("error XYZ") - - try: - raise_error() - except ValueError as e: - ffi_error = tvm_ffi.convert(e) - assert ffi_error.traceback.find("raise_error") != -1 - - def raise_cxx_error(): - cxx_test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error") - cxx_test_raise_error("ValueError", "error XYZ") - - try: - raise_cxx_error() - except ValueError as e: - assert e.__tvm_ffi_error__.traceback.find("raise_cxx_error") == -1 - ffi_error1 = tvm_ffi.convert(e) - ffi_error2 = fecho(e) - assert ffi_error1.traceback.find("raise_cxx_error") != -1 - assert ffi_error2.traceback.find("raise_cxx_error") != -1 diff --git a/ffi/tests/python/test_examples.py b/ffi/tests/python/test_examples.py deleted file mode 100644 index f8a94636a284..000000000000 --- a/ffi/tests/python/test_examples.py +++ /dev/null @@ -1,47 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# testcases appearing in example docstrings -import tvm_ffi - - -def test_register_global_func(): - # we can use decorator to register a function - @tvm_ffi.register_global_func("example.echo") - def echo(x): - return x - - # After registering, we can get the function by its name - f = tvm_ffi.get_global_func("example.echo") - assert f(1) == 1 - # we can also directly register a function - tvm_ffi.register_global_func("example.add_one", lambda x: x + 1) - f = tvm_ffi.get_global_func("example.add_one") - assert f(1) == 2 - - -def test_array(): - a = tvm_ffi.convert([1, 2, 3]) - assert isinstance(a, tvm_ffi.Array) - assert len(a) == 3 - - -def test_map(): - amap = tvm_ffi.convert({"a": 1, "b": 2}) - assert isinstance(amap, tvm_ffi.Map) - assert len(amap) == 2 - assert amap["a"] == 1 - assert amap["b"] == 2 diff --git a/ffi/tests/python/test_function.py b/ffi/tests/python/test_function.py deleted file mode 100644 index b5a1da4f7d1d..000000000000 --- a/ffi/tests/python/test_function.py +++ /dev/null @@ -1,221 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import gc -import ctypes -import sys -import numpy as np -import tvm_ffi - - -def test_echo(): - fecho = tvm_ffi.get_global_func("testing.echo") - assert isinstance(fecho, tvm_ffi.Function) - # test each type - assert fecho(None) is None - - # test bool - bool_result = fecho(True) - assert isinstance(bool_result, bool) - assert bool_result is True - bool_result = fecho(False) - assert isinstance(bool_result, bool) - assert bool_result is False - - # test int/float - assert fecho(1) == 1 - assert fecho(1.2) == 1.2 - - # test str - str_result = fecho("hello") - assert isinstance(str_result, str) - assert str_result == "hello" - - # test bytes - bytes_result = fecho(b"abc") - assert isinstance(bytes_result, bytes) - assert bytes_result == b"abc" - - # test dtype - dtype_result = fecho(tvm_ffi.dtype("float32")) - assert isinstance(dtype_result, tvm_ffi.dtype) - assert dtype_result == tvm_ffi.dtype("float32") - - # test device - device_result = fecho(tvm_ffi.device("cuda:1")) - assert isinstance(device_result, tvm_ffi.Device) - assert device_result.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCUDA - assert device_result.index == 1 - assert str(device_result) == "cuda:1" - assert device_result.__repr__() == "device(type='cuda', index=1)" - - # test c_void_p - c_void_p_result = fecho(ctypes.c_void_p(0x12345678)) - assert isinstance(c_void_p_result, ctypes.c_void_p) - assert c_void_p_result.value == 0x12345678 - - # test function: aka object - fadd = tvm_ffi.convert(lambda a, b: a + b) - fadd1 = fecho(fadd) - assert fadd1(1, 2) == 3 - assert fadd1.same_as(fadd) - - def check_tensor(): - np_data = np.arange(10, dtype="int32") - if not hasattr(np_data, "__dlpack__"): - return - # test Tensor - x = tvm_ffi.from_dlpack(np_data) - assert isinstance(x, tvm_ffi.Tensor) - tensor_result = fecho(x) - assert isinstance(tensor_result, tvm_ffi.Tensor) - assert tensor_result.shape == (10,) - assert tensor_result.dtype == tvm_ffi.dtype("int32") - assert tensor_result.device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCPU - assert tensor_result.device.index == 0 - - check_tensor() - - -def test_return_raw_str_bytes(): - assert tvm_ffi.convert(lambda: "hello")() == "hello" - assert tvm_ffi.convert(lambda: b"hello")() == b"hello" - assert tvm_ffi.convert(lambda: bytearray(b"hello"))() == b"hello" - - -def test_string_bytes_passing(): - fecho = tvm_ffi.get_global_func("testing.echo") - use_count = tvm_ffi.get_global_func("testing.object_use_count") - # small string - assert fecho("hello") == "hello" - # large string - x = "hello" * 100 - y = fecho(x) - assert y == x - assert y.__tvm_ffi_object__ is not None - use_count(y) == 1 - # small bytes - assert fecho(b"hello") == b"hello" - # large bytes - x = b"hello" * 100 - y = fecho(x) - assert y == x - assert y.__tvm_ffi_object__ is not None - fecho(y) == 1 - - -def test_nested_container_passing(): - # test and make sure our ref counting is correct - fecho = tvm_ffi.get_global_func("testing.echo") - use_count = tvm_ffi.get_global_func("testing.object_use_count") - obj = tvm_ffi.convert((1, 2, 3)) - assert use_count(obj) == 1 - y = fecho([obj, {"a": 1, "b": obj}]) - assert use_count(y) == 1 - assert use_count(obj) == 3 - assert use_count(y[1]) == 2 - - -def test_pyfunc_convert(): - def add(a, b): - return a + b - - fadd = tvm_ffi.convert(add) - assert isinstance(fadd, tvm_ffi.Function) - assert fadd(1, 2) == 3 - - def fapply(f, *args): - return f(*args) - - fapply = tvm_ffi.convert(fapply) - assert fapply(add, 1, 3.3) == 4.3 - - -def test_global_func(): - @tvm_ffi.register_global_func("mytest.echo") - def echo(x): - return x - - f = tvm_ffi.get_global_func("mytest.echo") - assert f.same_as(echo) - assert f(1) == 1 - - assert "mytest.echo" in tvm_ffi.registry.list_global_func_names() - - tvm_ffi.registry.remove_global_func("mytest.echo") - assert "mytest.echo" not in tvm_ffi.registry.list_global_func_names() - assert tvm_ffi.get_global_func("mytest.echo", allow_missing=True) is None - - -def test_rvalue_ref(): - use_count = tvm_ffi.get_global_func("testing.object_use_count") - - def callback(x, expected_count): - # The use count of TVM FFI objects is decremented as part of - # `ObjectRef.__del__`, which runs when the Python object is - # destructed. However, Python object destruction is not - # deterministic, and even CPython's reference-counting is - # considered an implementation detail. Therefore, to ensure - # correct results from this test, `gc.collect()` must be - # explicitly called. - gc.collect() - assert expected_count == use_count(x) - return x._move() - - f = tvm_ffi.convert(callback) - - def check0(): - x = tvm_ffi.convert([1, 2]) - assert use_count(x) == 1 - f(x, 2) - y = f(x._move(), 1) - assert x.__ctypes_handle__().value == None - - def check1(): - x = tvm_ffi.convert([1, 2]) - assert use_count(x) == 1 - y = f(x, 2) - z = f(x._move(), 2) - assert x.__ctypes_handle__().value == None - assert y.__ctypes_handle__().value is not None - - check0() - check1() - - -def test_echo_with_opaque_object(): - class MyObject: - def __init__(self, value): - self.value = value - - fecho = tvm_ffi.get_global_func("testing.echo") - x = MyObject("hello") - assert sys.getrefcount(x) == 2 - y = fecho(x) - assert isinstance(y, MyObject) - assert y is x - assert sys.getrefcount(x) == 3 - - def py_callback(z): - """python callback with opaque object""" - assert z is x - return z - - fcallback = tvm_ffi.convert(py_callback) - z = fcallback(x) - assert z is x - assert sys.getrefcount(x) == 4 diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py deleted file mode 100644 index 0277803730dc..000000000000 --- a/ffi/tests/python/test_load_inline.py +++ /dev/null @@ -1,324 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import numpy -import sys - -try: - import torch -except ImportError: - torch = None - -import tvm_ffi.cpp -from tvm_ffi.module import Module - - -@pytest.mark.xfail(sys.platform.startswith("win"), reason="needs to robustify windows support") -def test_load_inline_cpp(): - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cpp_sources=r""" - void add_one_cpu(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } - } - """, - functions=["add_one_cpu"], - ) - - x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) - y = numpy.empty_like(x) - mod.add_one_cpu(x, y) - numpy.testing.assert_equal(x + 1, y) - - -@pytest.mark.xfail(sys.platform.startswith("win"), reason="needs to robustify windows support") -def test_load_inline_cpp_with_docstrings(): - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cpp_sources=r""" - void add_one_cpu(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } - } - """, - functions={"add_one_cpu": "add two float32 1D tensors element-wise"}, - ) - - x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) - y = numpy.empty_like(x) - mod.add_one_cpu(x, y) - numpy.testing.assert_equal(x + 1, y) - - -@pytest.mark.xfail(sys.platform.startswith("win"), reason="needs to robustify windows support") -def test_load_inline_cpp_multiple_sources(): - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cpp_sources=[ - r""" - void add_one_cpu(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } - } - """, - r""" - void add_two_cpu(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 2; - } - } - """, - ], - functions=["add_one_cpu", "add_two_cpu"], - ) - - x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) - y = numpy.empty_like(x) - mod.add_one_cpu(x, y) - numpy.testing.assert_equal(x + 1, y) - - -@pytest.mark.xfail(sys.platform.startswith("win"), reason="needs to robustify windows support") -def test_load_inline_cpp_build_dir(): - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cpp_sources=r""" - void add_one_cpu(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } - } - """, - functions=["add_one_cpu"], - build_directory="./build_add_one", - ) - - x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) - y = numpy.empty_like(x) - mod.add_one_cpu(x, y) - numpy.testing.assert_equal(x + 1, y) - - -@pytest.mark.skipif( - torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" -) -def test_load_inline_cuda(): - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cuda_sources=r""" - __global__ void AddOneKernel(float* x, float* y, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - y[idx] = x[idx] + 1; - } - } - - void add_one_cuda(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - - int64_t n = x->shape[0]; - int64_t nthread_per_block = 256; - int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; - // Obtain the current stream from the environment - // it will be set to torch.cuda.current_stream() when calling the function - // with torch.Tensors - cudaStream_t stream = static_cast( - TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); - // launch the kernel - AddOneKernel<<>>(static_cast(x->data), - static_cast(y->data), n); - } - """, - functions=["add_one_cuda"], - ) - - if torch is not None: - x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") - y_cuda = torch.empty_like(x_cuda) - mod.add_one_cuda(x_cuda, y_cuda) - torch.testing.assert_close(x_cuda + 1, y_cuda) - - -@pytest.mark.skipif( - torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" -) -def test_load_inline_cuda_with_env_tensor_allocator(): - if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"): - pytest.skip("Torch does not support __c_dlpack_tensor_allocator__") - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cuda_sources=r""" - #include - #include - #include - - __global__ void AddOneKernel(float* x, float* y, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - y[idx] = x[idx] + 1; - } - } - namespace ffi = tvm::ffi; - - ffi::Tensor return_add_one(ffi::Map> kwargs) { - ffi::Tensor x = kwargs["x"].get<0>(); - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - // allocate a new tensor with the env tensor allocator - // it will be redirected to torch.empty when calling the function - ffi::Tensor y = ffi::Tensor::FromDLPackAlloc( - TVMFFIEnvGetTensorAllocator(), ffi::Shape({x->shape[0]}), f32_dtype, x->device); - int64_t n = x->shape[0]; - int64_t nthread_per_block = 256; - int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; - // Obtain the current stream from the environment - // it will be set to torch.cuda.current_stream() when calling the function - // with torch.Tensors - cudaStream_t stream = static_cast( - TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); - // launch the kernel - AddOneKernel<<>>(static_cast(x->data), - static_cast(y->data), n); - return y; - } - """, - functions=["return_add_one"], - ) - - if torch is not None: - x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") - # test support for nested container passing - y_cuda = mod.return_add_one({"x": [x_cuda]}) - assert isinstance(y_cuda, torch.Tensor) - assert y_cuda.shape == (5,) - assert y_cuda.dtype == torch.float32 - torch.testing.assert_close(x_cuda + 1, y_cuda) - assert y_cuda.is_cuda - - -@pytest.mark.skipif( - torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" -) -def test_load_inline_both(): - mod: Module = tvm_ffi.cpp.load_inline( - name="hello", - cpp_sources=r""" - void add_one_cpu(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - for (int i = 0; i < x->shape[0]; ++i) { - static_cast(y->data)[i] = static_cast(x->data)[i] + 1; - } - } - - void add_one_cuda(DLTensor* x, DLTensor* y); - """, - cuda_sources=r""" - __global__ void AddOneKernel(float* x, float* y, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - y[idx] = x[idx] + 1; - } - } - - void add_one_cuda(DLTensor* x, DLTensor* y) { - // implementation of a library function - TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; - DLDataType f32_dtype{kDLFloat, 32, 1}; - TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; - TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; - TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; - TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; - - int64_t n = x->shape[0]; - int64_t nthread_per_block = 256; - int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; - // Obtain the current stream from the environment - // it will be set to torch.cuda.current_stream() when calling the function - // with torch.Tensors - cudaStream_t stream = static_cast( - TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); - // launch the kernel - AddOneKernel<<>>(static_cast(x->data), - static_cast(y->data), n); - } - """, - functions=["add_one_cpu", "add_one_cuda"], - ) - - x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32) - y = numpy.empty_like(x) - mod.add_one_cpu(x, y) - numpy.testing.assert_equal(x + 1, y) - - x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") - y_cuda = torch.empty_like(x_cuda) - mod.add_one_cuda(x_cuda, y_cuda) - torch.testing.assert_close(x_cuda + 1, y_cuda) diff --git a/ffi/tests/python/test_object.py b/ffi/tests/python/test_object.py deleted file mode 100644 index 1b07de8e9d69..000000000000 --- a/ffi/tests/python/test_object.py +++ /dev/null @@ -1,91 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest -import sys - -import tvm_ffi - - -def test_make_object(): - # with default values - obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase") - assert obj0.v_i64 == 10 - assert obj0.v_f64 == 10.0 - assert obj0.v_str == "hello" - - -def test_method(): - obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12) - assert obj0.add_i64(1) == 13 - assert type(obj0).add_i64.__doc__ == "add_i64 method" - assert type(obj0).v_i64.__doc__ == "i64 field" - - -def test_setter(): - # test setter - obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, v_str="hello") - assert obj0.v_i64 == 10 - obj0.v_i64 = 11 - assert obj0.v_i64 == 11 - obj0.v_str = "world" - assert obj0.v_str == "world" - - with pytest.raises(TypeError): - obj0.v_str = 1 - - with pytest.raises(TypeError): - obj0.v_i64 = "hello" - - -def test_derived_object(): - with pytest.raises(TypeError): - obj0 = tvm_ffi.testing.create_object("testing.TestObjectDerived") - - v_map = tvm_ffi.convert({"a": 1}) - v_array = tvm_ffi.convert([1, 2, 3]) - - obj0 = tvm_ffi.testing.create_object( - "testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array - ) - assert obj0.v_map.same_as(v_map) - assert obj0.v_array.same_as(v_array) - assert obj0.v_i64 == 20 - assert obj0.v_f64 == 10.0 - assert obj0.v_str == "hello" - - obj0.v_i64 = 21 - assert obj0.v_i64 == 21 - - -class MyObject: - def __init__(self, value): - self.value = value - - -def test_opaque_object(): - obj0 = MyObject("hello") - assert sys.getrefcount(obj0) == 2 - obj0_converted = tvm_ffi.convert(obj0) - assert sys.getrefcount(obj0) == 3 - assert isinstance(obj0_converted, tvm_ffi.core.OpaquePyObject) - obj0_cpy = obj0_converted.pyobject() - assert obj0_cpy is obj0 - assert sys.getrefcount(obj0) == 4 - obj0_converted = None - assert sys.getrefcount(obj0) == 3 - obj0_cpy = None - assert sys.getrefcount(obj0) == 2 diff --git a/ffi/tests/python/test_string.py b/ffi/tests/python/test_string.py deleted file mode 100644 index feaa9584d2fc..000000000000 --- a/ffi/tests/python/test_string.py +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pickle -import tvm_ffi - - -def test_string(): - fecho = tvm_ffi.get_global_func("testing.echo") - s = tvm_ffi.core.String("hello") - s2 = fecho(s) - assert s2 == "hello" - s3 = tvm_ffi.convert("hello") - assert isinstance(s3, str) - - x = "hello long string" - assert fecho(x) == x - - s4 = pickle.loads(pickle.dumps(s)) - assert s4 == "hello" - - -def test_bytes(): - fecho = tvm_ffi.get_global_func("testing.echo") - b = tvm_ffi.core.Bytes(b"hello") - assert isinstance(b, tvm_ffi.core.Bytes) - b2 = fecho(b) - assert b2 == b"hello" - - b3 = tvm_ffi.convert(b"hello") - assert isinstance(b3, tvm_ffi.core.Bytes) - assert isinstance(b3, bytes) - - b4 = tvm_ffi.convert(bytearray(b"hello")) - assert isinstance(b4, tvm_ffi.core.Bytes) - assert isinstance(b4, bytes) - - b5 = pickle.loads(pickle.dumps(b)) - assert b5 == b"hello" - assert isinstance(b5, tvm_ffi.core.Bytes) diff --git a/ffi/tests/python/test_tensor.py b/ffi/tests/python/test_tensor.py deleted file mode 100644 index 5c7051279815..000000000000 --- a/ffi/tests/python/test_tensor.py +++ /dev/null @@ -1,68 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest - -try: - import torch -except ImportError: - torch = None - -import tvm_ffi -import numpy as np - - -def test_tensor_attributes(): - data = np.zeros((10, 8, 4, 2), dtype="int16") - if not hasattr(data, "__dlpack__"): - return - x = tvm_ffi.from_dlpack(data) - assert isinstance(x, tvm_ffi.Tensor) - assert x.shape == (10, 8, 4, 2) - assert x.dtype == tvm_ffi.dtype("int16") - assert x.device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCPU - assert x.device.index == 0 - x2 = np.from_dlpack(x) - np.testing.assert_equal(x2, data) - - -def test_shape_object(): - shape = tvm_ffi.Shape((10, 8, 4, 2)) - assert isinstance(shape, tvm_ffi.Shape) - assert shape == (10, 8, 4, 2) - - fecho = tvm_ffi.convert(lambda x: x) - shape2 = fecho(shape) - assert shape2.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__) - assert isinstance(shape2, tvm_ffi.Shape) - assert isinstance(shape2, tuple) - - shape3 = tvm_ffi.convert(shape) - assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__) - assert isinstance(shape3, tvm_ffi.Shape) - - -@pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not enabled") -def test_tensor_auto_dlpack(): - x = torch.arange(128) - fecho = tvm_ffi.get_global_func("testing.echo") - y = fecho(x) - assert isinstance(y, torch.Tensor) - assert y.data_ptr() == x.data_ptr() - assert y.dtype == x.dtype - assert y.shape == x.shape - assert y.device == x.device - np.testing.assert_equal(y.numpy(), x.numpy()) diff --git a/jvm/native/linux-x86_64/pom.xml b/jvm/native/linux-x86_64/pom.xml index c21a3d2ae5af..0bf5d88b76fe 100644 --- a/jvm/native/linux-x86_64/pom.xml +++ b/jvm/native/linux-x86_64/pom.xml @@ -118,7 +118,7 @@ under the License. -I../../../include - -I../../../ffi/include + -I../../../3rdparty/tvm-ffi/include -I${JAVA_HOME}/include -I${JAVA_HOME}/include/linux ${cflags} diff --git a/jvm/native/osx-x86_64/pom.xml b/jvm/native/osx-x86_64/pom.xml index e2bd0fd7ae9d..de468519b828 100644 --- a/jvm/native/osx-x86_64/pom.xml +++ b/jvm/native/osx-x86_64/pom.xml @@ -119,7 +119,7 @@ under the License. -I../../../include - -I../../../ffi/include + -I../../../3rdparty/tvm-ffi/include -I${JAVA_HOME}/include -I${JAVA_HOME}/include/darwin ${cflags} diff --git a/pyproject.toml b/pyproject.toml index 43be53b8cb6e..475e183ffcba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,7 +142,7 @@ sdist.include = [ "/CMakeLists.txt", "/pyproject.toml", "/cmake/**/*", - "/3rdparty/**/*", + "/ */*", # Source code "/src/**/*.cc", diff --git a/python/tvm/libinfo.py b/python/tvm/libinfo.py index b8fa9ec91aff..4dbae65ebbf3 100644 --- a/python/tvm/libinfo.py +++ b/python/tvm/libinfo.py @@ -231,9 +231,12 @@ def find_include_path(name=None, search_path=None, optional=False): dmlc_include_path = [] else: tvm_include_path = [os.path.join(p, "include") for p in header_path] - tvm_ffi_include_path = [os.path.join(p, "ffi", "include") for p in header_path] + tvm_ffi_include_path = [ + os.path.join(p, "3rdparty", "tvm-ffi", "include") for p in header_path + ] dlpack_include_path = [ - os.path.join(p, "ffi", "3rdparty", "dlpack", "include") for p in header_path + os.path.join(p, "3rdparty", "tvm-ffi", "3rdparty", "dlpack", "include") + for p in header_path ] dmlc_include_path = [ os.path.join(p, "3rdparty", "dmlc-core", "include") for p in header_path diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index e7248b0f4b27..b35f6e0d220c 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -310,8 +310,8 @@ def get_includes(tvm_pkg: Optional[List[str]] = None) -> List[Path]: results = [ tvm_home / "include", tvm_home / "3rdparty/dmlc-core/include", - tvm_home / "ffi/include", - tvm_home / "ffi/3rdparty/dlpack/include", + tvm_home / "3rdparty/tvm-ffi/include", + tvm_home / "3rdparty/tvm-ffi/3rdparty/dlpack/include", ] if tvm_pkg: for relative in tvm_pkg: @@ -387,12 +387,14 @@ def compile(self, output_path: Path) -> None: options=self.compile_options, cc=self.compiler, cwd=temp_dir, - ccache_env={ - "CCACHE_COMPILERCHECK": "content", - "CCACHE_NOHASHDIR": "1", - } - if shutil.which("ccache") - else None, + ccache_env=( + { + "CCACHE_COMPILERCHECK": "content", + "CCACHE_NOHASHDIR": "1", + } + if shutil.which("ccache") + else None + ), ) shutil.move(str(object_path), str(output_path)) diff --git a/tests/lint/cpplint.sh b/tests/lint/cpplint.sh index e49c6801ade7..84065e17b01d 100755 --- a/tests/lint/cpplint.sh +++ b/tests/lint/cpplint.sh @@ -19,7 +19,6 @@ set -e echo "Running 2 cpplints..." -python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp ffi/include ffi/src python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp \ include src \ examples/extension/src examples/graph_executor/src \ diff --git a/tests/scripts/task_python_adreno.sh b/tests/scripts/task_python_adreno.sh index acf585c0acba..1714a3c06358 100755 --- a/tests/scripts/task_python_adreno.sh +++ b/tests/scripts/task_python_adreno.sh @@ -58,7 +58,7 @@ trap "{ kill ${TRACKER_PID}; kill ${DEVICE_PID}; cleanup; }" 0 # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f # setup tvm-ffi into python folder -python3 -m pip install --target=python -v ./ffi +python3 -m pip install --target=python -v ./3rdparty/tvm-ffi/ exit 0 diff --git a/tests/scripts/task_python_arm_compute_library.sh b/tests/scripts/task_python_arm_compute_library.sh index 7593e0134416..b67724308fce 100755 --- a/tests/scripts/task_python_arm_compute_library.sh +++ b/tests/scripts/task_python_arm_compute_library.sh @@ -24,4 +24,4 @@ source tests/scripts/setup-pytest-env.sh find . -type f -path "*.pyc" | xargs rm -f # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index df4e12504320..bb1fd2d95b8d 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -48,7 +48,7 @@ sphinx_precheck() { echo "PreCheck sphinx doc generation WARNINGS.." # setup tvm-ffi into python folder - python3 -m pip install -v --target=python ./ffi + python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ pushd docs make clean @@ -127,7 +127,7 @@ find . -type f -path "*.log" | xargs rm -f find . -type f -path "*.pyc" | xargs rm -f # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ cd docs diff --git a/tests/scripts/task_python_hexagon.sh b/tests/scripts/task_python_hexagon.sh index edef1016b061..6d91759805b7 100755 --- a/tests/scripts/task_python_hexagon.sh +++ b/tests/scripts/task_python_hexagon.sh @@ -28,7 +28,7 @@ fi source tests/scripts/setup-pytest-env.sh # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # disable hexagon tests for now exit 0 diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index b8a14d81e7f1..a1a0068ac972 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -34,4 +34,4 @@ fi find . -type f -path "*.pyc" | xargs rm -f # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ diff --git a/tests/scripts/task_python_nightly.sh b/tests/scripts/task_python_nightly.sh index 4ad12baed77c..af1b6ec3d212 100755 --- a/tests/scripts/task_python_nightly.sh +++ b/tests/scripts/task_python_nightly.sh @@ -21,7 +21,7 @@ set -euxo pipefail source tests/scripts/setup-pytest-env.sh # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index 60cb7269f5dc..569ad9b2de4b 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -24,7 +24,7 @@ source tests/scripts/setup-pytest-env.sh find . -type f -path "*.pyc" | xargs rm -f # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # NOTE: also set by task_python_unittest_gpuonly.sh. if [ -z "${TVM_UNITTEST_TESTSUITE_NAME:-}" ]; then diff --git a/tests/scripts/task_web_wasm.sh b/tests/scripts/task_web_wasm.sh index 46c8eaa8b221..c43215549788 100755 --- a/tests/scripts/task_web_wasm.sh +++ b/tests/scripts/task_web_wasm.sh @@ -21,7 +21,7 @@ set -euxo pipefail export PYTHONPATH=`pwd`/python # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ rm -rf .emscripten_cache cd web diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh index 99ef50fb5ccb..c25cc6ec6597 100755 --- a/tests/scripts/unity/task_python_relax.sh +++ b/tests/scripts/unity/task_python_relax.sh @@ -26,7 +26,7 @@ export TVM_BIND_THREADS=0 export TVM_NUM_THREADS=2 # setup tvm-ffi into python folder -python3 -m pip install -v --target=python ./ffi +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # Run Relax tests TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax diff --git a/web/Makefile b/web/Makefile index e9d1375fc76c..9f8a7e94b42f 100644 --- a/web/Makefile +++ b/web/Makefile @@ -18,8 +18,8 @@ TVM_ROOT=$(realpath $(shell dirname $(firstword $(MAKEFILE_LIST))))/../ INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ - -I$(TVM_ROOT)/ffi/include\ - -I$(TVM_ROOT)/ffi/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include\ + -I$(TVM_ROOT)/3rdparty/tvm-ffi/include\ + -I$(TVM_ROOT)/3rdparty/tvm-ffi/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include\ -I$(TVM_ROOT)/3rdparty/compiler-rt -I$(TVM_ROOT)/3rdparty/picojson .PHONY: clean all rmtypedep preparetest diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index b7a1bd83e9eb..35f3a4dc4d1e 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -48,17 +48,17 @@ #include "src/runtime/tensor.cc" #include "src/runtime/workspace_pool.cc" // relax setup -#include "ffi/src/ffi/container.cc" -#include "ffi/src/ffi/dtype.cc" -#include "ffi/src/ffi/error.cc" -#include "ffi/src/ffi/extra/library_module.cc" -#include "ffi/src/ffi/extra/library_module_system_lib.cc" -#include "ffi/src/ffi/extra/module.cc" -#include "ffi/src/ffi/extra/testing.cc" -#include "ffi/src/ffi/function.cc" -#include "ffi/src/ffi/object.cc" -#include "ffi/src/ffi/tensor.cc" -#include "ffi/src/ffi/traceback.cc" +#include "3rdparty/tvm-ffi/src/ffi/container.cc" +#include "3rdparty/tvm-ffi/src/ffi/dtype.cc" +#include "3rdparty/tvm-ffi/src/ffi/error.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/testing.cc" +#include "3rdparty/tvm-ffi/src/ffi/function.cc" +#include "3rdparty/tvm-ffi/src/ffi/object.cc" +#include "3rdparty/tvm-ffi/src/ffi/tensor.cc" +#include "3rdparty/tvm-ffi/src/ffi/traceback.cc" #include "src/runtime/memory/memory_manager.cc" #include "src/runtime/nvtx.cc" #include "src/runtime/vm/attn_backend.cc"