From daaa153445c8341d088d7a73fb5504ef9d7181b3 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 17 Apr 2024 09:05:27 +0800 Subject: [PATCH 001/382] [FFI] Initial Code Commit --- .gitmodules | 3 + ffi/3rdparty/dlpack | 1 + ffi/CMakeLists.txt | 78 ++++ ffi/cmake/Utils/AddGoogleTest.cmake | 42 ++ ffi/cmake/Utils/CxxWarning.cmake | 13 + ffi/cmake/Utils/Sanitizer.cmake | 18 + ffi/include | 1 + ffi/scripts/run_tests.sh | 10 + ffi/src | 1 + ffi/tests/cpp/CMakeLists.txt | 25 ++ ffi/tests/cpp/test_any.cc | 635 ++++++++++++++++++++++++++++ ffi/tests/cpp/test_any_view.cc | 630 +++++++++++++++++++++++++++ ffi/tests/cpp/test_dict.cc | 183 ++++++++ ffi/tests/cpp/test_func.cc | 215 ++++++++++ ffi/tests/cpp/test_list.cc | 461 ++++++++++++++++++++ ffi/tests/cpp/test_ref.cc | 293 +++++++++++++ ffi/tests/cpp/test_str.cc | 47 ++ ffi/tests/cpp/test_type_dyn.cc | 71 ++++ ffi/tests/cpp/test_type_static.cc | 119 ++++++ include/tvm/ffi/core/c_ffi_abi.h | 166 ++++++++ include/tvm/ffi/core/core.h | 421 ++++++++++++++++++ include/tvm/ffi/core/traits.h | 613 +++++++++++++++++++++++++++ include/tvm/ffi/core/utils.h | 281 ++++++++++++ include/tvm/ffi/ext/dict.h | 515 ++++++++++++++++++++++ include/tvm/ffi/ext/error.h | 89 ++++ include/tvm/ffi/ext/func.h | 390 +++++++++++++++++ include/tvm/ffi/ext/list.h | 457 ++++++++++++++++++++ include/tvm/ffi/ext/str.h | 245 +++++++++++ include/tvm/ffi/ffi.hpp | 9 + src/ffi/registry.cc | 144 +++++++ 30 files changed, 6176 insertions(+) create mode 160000 ffi/3rdparty/dlpack create mode 100644 ffi/CMakeLists.txt create mode 100644 ffi/cmake/Utils/AddGoogleTest.cmake create mode 100644 ffi/cmake/Utils/CxxWarning.cmake create mode 100644 ffi/cmake/Utils/Sanitizer.cmake create mode 120000 ffi/include create mode 100755 ffi/scripts/run_tests.sh create mode 120000 ffi/src create mode 100644 ffi/tests/cpp/CMakeLists.txt create mode 100644 ffi/tests/cpp/test_any.cc create mode 100644 ffi/tests/cpp/test_any_view.cc create mode 100644 ffi/tests/cpp/test_dict.cc create mode 100644 ffi/tests/cpp/test_func.cc create mode 100644 ffi/tests/cpp/test_list.cc create mode 100644 ffi/tests/cpp/test_ref.cc create mode 100644 ffi/tests/cpp/test_str.cc create mode 100644 ffi/tests/cpp/test_type_dyn.cc create mode 100644 ffi/tests/cpp/test_type_static.cc create mode 100644 include/tvm/ffi/core/c_ffi_abi.h create mode 100644 include/tvm/ffi/core/core.h create mode 100644 include/tvm/ffi/core/traits.h create mode 100644 include/tvm/ffi/core/utils.h create mode 100644 include/tvm/ffi/ext/dict.h create mode 100644 include/tvm/ffi/ext/error.h create mode 100644 include/tvm/ffi/ext/func.h create mode 100644 include/tvm/ffi/ext/list.h create mode 100644 include/tvm/ffi/ext/str.h create mode 100644 include/tvm/ffi/ffi.hpp create mode 100644 src/ffi/registry.cc diff --git a/.gitmodules b/.gitmodules index a1187967f77f..e8a48d99c2a2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -31,3 +31,6 @@ [submodule "3rdparty/zlib"] path = 3rdparty/zlib url = https://github.com/madler/zlib.git +[submodule "ffi/3rdparty/dlpack"] + path = ffi/3rdparty/dlpack + url = https://github.com/dmlc/dlpack.git diff --git a/ffi/3rdparty/dlpack b/ffi/3rdparty/dlpack new file mode 160000 index 000000000000..bbd2f4d32427 --- /dev/null +++ b/ffi/3rdparty/dlpack @@ -0,0 +1 @@ +Subproject commit bbd2f4d32427e548797929af08cfe2a9cbb3cf12 diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt new file mode 100644 index 000000000000..8710a0c377fd --- /dev/null +++ b/ffi/CMakeLists.txt @@ -0,0 +1,78 @@ +cmake_minimum_required(VERSION 3.14) + +project( + tvm_ffi + VERSION 1.0 + DESCRIPTION "TVM's FFI system" + LANGUAGES CXX +) + +option(TVM_FFI_ALLOW_DYN_TYPE "Support for objects with non-static type indices. When turned on, targets linked against `tvm_ffi` will allow objects that comes with non-pre-defined type indices, so that the object hierarchy could expand without limitation. This will require the downstream targets to link against target `tvm_ffi_registry` to be effective." OFF) +option(TVM_FFI_BUILD_TESTS "Adding test targets." OFF) + +include(cmake/Utils/CxxWarning.cmake) +include(cmake/Utils/Sanitizer.cmake) + +########## Target: `dlpack_header` ########## + +add_library(dlpack_header INTERFACE) +target_include_directories(dlpack_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include") + +########## Target: `tvm_ffi_registry_{objs|static|shared}` ########## + +add_library( + tvm_ffi_registry_objs + EXCLUDE_FROM_ALL + OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/registry.cc" +) +set_target_properties( + tvm_ffi_registry_objs PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + CXX_EXTENSIONS OFF + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON +) +add_cxx_warning(tvm_ffi_registry_objs) +target_link_libraries(tvm_ffi_registry_objs PRIVATE dlpack_header) +target_include_directories(tvm_ffi_registry_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include") +target_compile_definitions(tvm_ffi_registry_objs PRIVATE TVM_FFI_ALLOW_DYN_TYPE=1) +if (MSVC) + target_compile_definitions(tvm_ffi_registry_objs PRIVATE TVM_FFI_EXPORTS) +endif() + +add_library(tvm_ffi_registry_static EXCLUDE_FROM_ALL STATIC $) +add_library(tvm_ffi_registry_shared EXCLUDE_FROM_ALL SHARED $) +set_target_properties(tvm_ffi_registry_shared tvm_ffi_registry_static tvm_ffi_registry_objs + PROPERTIES + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) + +########## Target: `tvm_ffi` ########## + +add_library(tvm_ffi INTERFACE) +target_link_libraries(tvm_ffi INTERFACE dlpack_header) +target_compile_features(tvm_ffi INTERFACE cxx_std_17) +target_include_directories(tvm_ffi INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/include") + +if (TVM_FFI_ALLOW_DYN_TYPE) + message(STATUS "Setting C++ macro TVM_FFI_ALLOW_DYN_TYPE - 1") + target_compile_definitions(tvm_ffi INTERFACE TVM_FFI_ALLOW_DYN_TYPE=1) +else() + message(STATUS "Setting C++ macro TVM_FFI_ALLOW_DYN_TYPE - 0") + target_compile_definitions(tvm_ffi INTERFACE TVM_FFI_ALLOW_DYN_TYPE=0) +endif() + +########## Adding tests ########## + +if (${PROJECT_NAME} STREQUAL ${CMAKE_PROJECT_NAME}) + if (TVM_FFI_BUILD_TESTS) + enable_testing() + message(STATUS "Enable Testing") + include(cmake/Utils/AddGoogleTest.cmake) + add_subdirectory(tests/cpp/) + endif() +endif () diff --git a/ffi/cmake/Utils/AddGoogleTest.cmake b/ffi/cmake/Utils/AddGoogleTest.cmake new file mode 100644 index 000000000000..aec5c82b031e --- /dev/null +++ b/ffi/cmake/Utils/AddGoogleTest.cmake @@ -0,0 +1,42 @@ +include(FetchContent) +set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE) +set(BUILD_GMOCK ON CACHE BOOL "" FORCE) +set(BUILD_GTEST ON CACHE BOOL "" FORCE) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +FetchContent_GetProperties(googletest) +if (NOT googletest_POPULATED) + FetchContent_Populate(googletest) + message(STATUS "Found googletest_SOURCE_DIR - ${googletest_SOURCE_DIR}") + message(STATUS "Found googletest_BINARY_DIR - ${googletest_BINARY_DIR}") + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR}) + include(GoogleTest) + set_target_properties(gtest PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + set_target_properties(gtest_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + set_target_properties(gmock PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + set_target_properties(gmock_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) + mark_as_advanced( + BUILD_GMOCK BUILD_GTEST BUILD_SHARED_LIBS + gmock_build_tests gtest_build_samples gtest_build_tests + gtest_disable_pthreads gtest_force_shared_crt gtest_hide_internal_symbols + ) +endif() + +macro(add_googletest target_name) + add_test( + NAME ${target_name} + COMMAND ${target_name} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) + target_link_libraries(${target_name} PRIVATE gtest_main) + gtest_discover_tests(${target_name} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + DISCOVERY_MODE PRE_TEST + PROPERTIES + VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + ) + set_target_properties(${target_name} PROPERTIES FOLDER tests) +endmacro() diff --git a/ffi/cmake/Utils/CxxWarning.cmake b/ffi/cmake/Utils/CxxWarning.cmake new file mode 100644 index 000000000000..50ee5b616da1 --- /dev/null +++ b/ffi/cmake/Utils/CxxWarning.cmake @@ -0,0 +1,13 @@ +function(add_cxx_warning target_name) + # GNU, Clang, or AppleClang + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") + target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic") + return() + endif() + # MSVC + if(MSVC) + target_compile_options(${target_name} PRIVATE "/W4" "/WX") + return() + endif() + message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") +endfunction() diff --git a/ffi/cmake/Utils/Sanitizer.cmake b/ffi/cmake/Utils/Sanitizer.cmake new file mode 100644 index 000000000000..c1facc1999a5 --- /dev/null +++ b/ffi/cmake/Utils/Sanitizer.cmake @@ -0,0 +1,18 @@ +function(add_sanitizer_address target_name) + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") + include(CheckCXXCompilerFlag) + set (_saved_CRF ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "-fsanitize=address") + check_cxx_source_compiles("int main() { return 0; }" COMPILER_SUPPORTS_ASAN) + set (CMAKE_REQUIRED_FLAGS ${_saved_CRF}) + get_target_property(_saved_type ${target_name} TYPE) + if (${_saved_type} STREQUAL "INTERFACE_LIBRARY") + set(_saved_type INTERFACE) + else() + set(_saved_type PRIVATE) + endif() + target_link_options(${target_name} ${_saved_type} "-fsanitize=address") + target_compile_options(${target_name} ${_saved_type} "-fsanitize=address" "-fno-omit-frame-pointer" "-g") + return() + endif() +endfunction() diff --git a/ffi/include b/ffi/include new file mode 120000 index 000000000000..f5030fe88998 --- /dev/null +++ b/ffi/include @@ -0,0 +1 @@ +../include \ No newline at end of file diff --git a/ffi/scripts/run_tests.sh b/ffi/scripts/run_tests.sh new file mode 100755 index 000000000000..704abaeab9a2 --- /dev/null +++ b/ffi/scripts/run_tests.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -euxo pipefail + +HEADER_ONLY=OFF +BUILD_TYPE=RelWithDebInfo + +rm -rf build/CMakeFiles build/CMakeCache.txt +cmake -G Ninja -S . -B build -DTVM_FFI_ALLOW_DYN_TYPE=${HEADER_ONLY} -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache +cmake --build build --parallel 16 --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests +GTEST_COLOR=1 ctest -V -C ${BUILD_TYPE} --test-dir build --output-on-failure diff --git a/ffi/src b/ffi/src new file mode 120000 index 000000000000..5cd551cf2693 --- /dev/null +++ b/ffi/src @@ -0,0 +1 @@ +../src \ No newline at end of file diff --git a/ffi/tests/cpp/CMakeLists.txt b/ffi/tests/cpp/CMakeLists.txt new file mode 100644 index 000000000000..539115f24d0f --- /dev/null +++ b/ffi/tests/cpp/CMakeLists.txt @@ -0,0 +1,25 @@ +file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc") +add_executable( + tvm_ffi_tests + EXCLUDE_FROM_ALL + ${_test_sources} +) +set_target_properties( + tvm_ffi_tests PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS OFF + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) +add_cxx_warning(tvm_ffi_tests) +add_sanitizer_address(tvm_ffi_tests) +target_link_libraries(tvm_ffi_tests PRIVATE tvm_ffi) +if (TVM_FFI_ALLOW_DYN_TYPE) + add_sanitizer_address(tvm_ffi_registry_shared) + target_link_libraries(tvm_ffi_tests PRIVATE tvm_ffi_registry_shared) +endif() + +add_googletest(tvm_ffi_tests) diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc new file mode 100644 index 000000000000..a57829bd0a97 --- /dev/null +++ b/ffi/tests/cpp/test_any.cc @@ -0,0 +1,635 @@ +#include +#include + +namespace { +using namespace tvm::ffi; + +template +void TestAnyConstructor(Checker check, TVMFFITypeIndex expected_type_index, const SrcType &source) { + Any v(source); + EXPECT_EQ(v.type_index, static_cast(expected_type_index)); + EXPECT_EQ(v.ref_cnt, 0); + check(&v, source); +}; + +int64_t FuncCall(int64_t x) { return x + 1; } + +std::vector AnyArrayFactory() { + return std::vector{ + Any(nullptr), + Any(1), + Any(2.5), + Any(reinterpret_cast(FuncCall)), + Any(DLDevice{kDLCPU, 0}), + Any(DLDataType{kDLInt, 32, 1}), + Any("Hello (raw str)"), + Any(Ref::New()), + Any(Ref::New(FuncCall)), + Any(std::string("World (std::string)")), + Any(Ref::New("Hello World (Ref)")), + }; +} + +template +void TestAnyStringify(const SrcType &source, TVMFFITypeIndex expected_type_index, + const std::string &expected) { + Any v(source); + EXPECT_EQ(v.type_index, static_cast(expected_type_index)); + EXPECT_EQ(v.str()->c_str(), expected); +} + +template +void TestAnyStringifyChecker(const SrcType &source, TVMFFITypeIndex expected_type_index, + Checker check) { + Any v(source); + EXPECT_EQ(v.type_index, static_cast(expected_type_index)); + check(v); +} + +void CheckAnyRefCnt(const TVMFFIAny *v) { + if (v->type_index >= static_cast(TVMFFITypeIndex::kTVMFFIStaticObjectBegin)) { + EXPECT_EQ(v->v_obj->ref_cnt, 1); + } +} + +TEST(Any_Constructor_0_Default, Default) { + Any v; + EXPECT_EQ(v.type_index, 0); + EXPECT_EQ(v.ref_cnt, 0); + EXPECT_EQ(v.v_int64, 0); +} + +TEST(Any_Constructor_1_Any, Copy) { + Any v1(1); + Any v2(v1); + EXPECT_EQ(v1.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v1.v_int64, 1); + EXPECT_EQ(v2.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v2.v_int64, 1); +} + +TEST(Any_Constructor_1_Any, Move) { + Any v1(1); + Any v2(std::move(v1)); + EXPECT_EQ(v1.type_index, static_cast(TVMFFITypeIndex::kTVMFFINone)); + EXPECT_EQ(v1.v_int64, 0); + EXPECT_EQ(v2.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v2.v_int64, 1); +} + +TEST(Any_Constructor_2_AnyView, Copy) { + AnyView v1(1); + Any v2(v1); + EXPECT_EQ(v1.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v1.v_int64, 1); + EXPECT_EQ(v2.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v2.v_int64, 1); +} + +TEST(Any_Constructor_2_AnyView, Move) { + AnyView v1(1); + Any v2(std::move(v1)); + EXPECT_EQ(v1.type_index, static_cast(TVMFFITypeIndex::kTVMFFINone)); + EXPECT_EQ(v1.v_int64, 0); + EXPECT_EQ(v2.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v2.v_int64, 1); +} + +TEST(Any_Constructor_3_Ref, Copy) { + Ref obj = Ref::New(); + Any v(obj); + const TVMFFIAny *v_obj = v.v_obj; + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIObject)); + EXPECT_EQ(v.ref_cnt, 0); + EXPECT_EQ(v_obj, static_cast(obj.get())); + EXPECT_EQ(v_obj->ref_cnt, 2); +} + +TEST(Any_Constructor_3_Ref, Move) { + Ref obj = Ref::New(); + Any v(std::move(obj)); + const TVMFFIAny *v_obj = v.v_obj; + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIObject)); + EXPECT_EQ(v.ref_cnt, 0); + EXPECT_EQ(v_obj->ref_cnt, 1); + EXPECT_EQ(obj.get(), nullptr); +} + +TEST(Any_Constructor_4_TypeTraits, Integer) { + auto check = [](TVMFFIAny *v, int64_t source) -> void { EXPECT_EQ(v->v_int64, source); }; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(1)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(2)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(3)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(4)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(1)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(2)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(3)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(4)); +} + +TEST(Any_Constructor_4_TypeTraits, Float) { + auto check = [](TVMFFIAny *v, double source) -> void { EXPECT_EQ(v->v_float64, source); }; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIFloat, static_cast(3)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIFloat, static_cast(4)); +} + +TEST(Any_Constructor_4_TypeTraits, Ptr) { + int p = 4; + auto check = [](TVMFFIAny *v, void *source) -> void { EXPECT_EQ(v->v_ptr, source); }; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFINone, static_cast(nullptr)); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIPtr, static_cast(&p)); +} + +TEST(Any_Constructor_4_TypeTraits, Device) { + auto check = [](TVMFFIAny *v, const DLDevice &source) -> void { + EXPECT_EQ(v->v_device.device_type, source.device_type); + EXPECT_EQ(v->v_device.device_id, source.device_id); + }; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIDevice, DLDevice{kDLCPU, 0}); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIDevice, DLDevice{kDLCUDA, 1}); +} + +TEST(Any_Constructor_4_TypeTraits, DataType) { + auto check = [](TVMFFIAny *v, const DLDataType &source) -> void { + EXPECT_EQ(v->v_dtype.code, source.code); + EXPECT_EQ(v->v_dtype.bits, source.bits); + EXPECT_EQ(v->v_dtype.lanes, source.lanes); + }; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIDataType, DLDataType{kDLInt, 32, 1}); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIDataType, DLDataType{kDLUInt, 0, 0}); +} + +TEST(Any_Constructor_4_TypeTraits, RawStr) { + auto check = [](TVMFFIAny *v, const char *source) -> void { + Str *str = static_cast(v->v_ptr); + EXPECT_STREQ(str->c_str(), source); + }; + const char *empty = ""; + const char *hello = "hello"; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIStr, empty); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIStr, hello); +} + +TEST(Any_Constructor_4_TypeTraits, CharArray) { + auto check = [](TVMFFIAny *v, const char *source) -> void { + Str *str = static_cast(v->v_ptr); + EXPECT_STREQ(str->c_str(), source); + }; + const char empty[] = ""; + const char hello[] = "hello"; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIStr, empty); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIStr, hello); +} + +TEST(Any_Constructor_4_TypeTraits, StdString) { + auto check = [](TVMFFIAny *v, const std::string &source) -> void { + Str *str = static_cast(v->v_ptr); + EXPECT_EQ(str->c_str(), source); + }; + std::string empty = ""; + std::string hello = "hello"; + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIStr, hello); + TestAnyConstructor(check, TVMFFITypeIndex::kTVMFFIStr, empty); +} + +TEST(Any_Constructor_5_Object_Ptr, Object) { + Ref obj = Ref::New(); + TVMFFIAny *ptr = reinterpret_cast(obj.get()); + { + Any v(obj.get()); + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIObject)); + EXPECT_EQ(v.v_obj->ref_cnt, 2); + EXPECT_EQ(v.v_obj, ptr); + } + EXPECT_EQ(ptr->ref_cnt, 1); +} + +TEST(Any_Constructor_5_Object_Ptr, Func) { + Ref func = Ref::New(FuncCall); + TVMFFIAny *ptr = reinterpret_cast(func.get()); + { + Any v(func.get()); + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIFunc)); + EXPECT_EQ(v.v_obj->ref_cnt, 2); + EXPECT_EQ(v.v_obj, ptr); + } + EXPECT_EQ(ptr->ref_cnt, 1); +} + +TEST(Any_Constructor_5_Object_Ptr, Str) { + Ref str = Ref::New("hello"); + TVMFFIAny *ptr = reinterpret_cast(str.get()); + { + Any v(str.get()); + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_EQ(v.v_obj->ref_cnt, 2); + EXPECT_EQ(v.v_obj, ptr); + } + EXPECT_EQ(ptr->ref_cnt, 1); +} + +TEST(Any_Converter_0_TypeTraits, Integer) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> int64_t { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIInt)) { + EXPECT_EQ(convert(), 1); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `int`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_0_TypeTraits, Float) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> double { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIInt)) { + EXPECT_EQ(convert(), 1.0); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIFloat)) { + EXPECT_EQ(convert(), 2.5); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `float`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_0_TypeTraits, Ptr) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> void * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + EXPECT_EQ(convert(), nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIPtr)) { + EXPECT_EQ(convert(), reinterpret_cast(&FuncCall)); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIRawStr)) { + EXPECT_EQ(convert(), v.v_str); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `Ptr`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_0_TypeTraits, Device) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> DLDevice { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIDevice)) { + EXPECT_EQ(convert().device_type, kDLCPU); + EXPECT_EQ(convert().device_id, 0); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `Device`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_0_TypeTraits, DataType) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> DLDataType { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIDataType)) { + EXPECT_EQ(convert().code, kDLInt); + EXPECT_EQ(convert().bits, 32); + EXPECT_EQ(convert().lanes, 1); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `dtype`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_0_TypeTraits, RawStr) { + std::vector vs = AnyArrayFactory(); + int counter = 0; + for (const Any &v : vs) { + auto convert = [&]() -> const char * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIStr)) { + ++counter; + EXPECT_LE(counter, 3); + if (counter == 1) { + EXPECT_STREQ(convert(), "Hello (raw str)"); + } else if (counter == 2) { + EXPECT_STREQ(convert(), "World (std::string)"); + } else if (counter == 3) { + EXPECT_STREQ(convert(), "Hello World (Ref)"); + } + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `const char *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_1_AnyView, Any) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> AnyView { return v; }; + { + AnyView ret = convert(); + EXPECT_EQ(ret.type_index, v.type_index); + EXPECT_EQ(ret.ref_cnt, 0); + EXPECT_EQ(v.ref_cnt, 0); + EXPECT_EQ(ret.v_obj, v.v_obj); + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_2_Ref, Object) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> Ref { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), nullptr); + } else if (v.type_index >= static_cast(TVMFFITypeIndex::kTVMFFIStaticObjectBegin)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), static_cast(v.v_obj)); + EXPECT_EQ(v.v_obj->ref_cnt, 2); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Object`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_2_Ref, Func) { + std::vector views = AnyArrayFactory(); + for (const Any &v : views) { + auto convert = [&]() -> Ref { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIFunc)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), static_cast(v.v_obj)); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Func`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_2_Ref, Str) { + int counter = 0; + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> Ref { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIStr)) { + Ref ret = convert(); + ++counter; + EXPECT_LE(counter, 3); + if (counter == 1) { + EXPECT_STREQ(ret->c_str(), "Hello (raw str)"); + } else if (counter == 2) { + EXPECT_STREQ(ret->c_str(), "World (std::string)"); + } else { + EXPECT_STREQ(ret->c_str(), "Hello World (Ref)"); + } + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Str`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_3_Object_Ptr, Object) { + std::vector vs = AnyArrayFactory(); + for (const Any &v : vs) { + auto convert = [&]() -> Object * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Object *ret = convert(); + EXPECT_EQ(ret, nullptr); + } else if (v.type_index >= static_cast(TVMFFITypeIndex::kTVMFFIStaticObjectBegin)) { + Object *ret = convert(); + EXPECT_EQ(ret, static_cast(v.v_obj)); + EXPECT_EQ(v.v_obj->ref_cnt, 1); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Object *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_3_Object_Ptr, Func) { + std::vector views = AnyArrayFactory(); + for (const Any &v : views) { + auto convert = [&]() -> Func * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Func *ret = convert(); + EXPECT_EQ(ret, nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIFunc)) { + Func *ret = convert(); + EXPECT_EQ(ret, reinterpret_cast(v.v_ptr)); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Func *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Converter_3_Object_Ptr, Str) { + std::vector vs = AnyArrayFactory(); + int counter = 0; + for (const Any &v : vs) { + auto convert = [&]() -> Str * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Str *ret = convert(); + EXPECT_EQ(ret, nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIStr)) { + Str *ret = convert(); + ++counter; + EXPECT_LE(counter, 3); + if (counter == 1) { + EXPECT_STREQ(ret->c_str(), "Hello (raw str)"); + } else if (counter == 2) { + EXPECT_STREQ(ret->c_str(), "World (std::string)"); + } else { + EXPECT_STREQ(ret->c_str(), "Hello World (Ref)"); + } + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Str *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyRefCnt(&v); + } +} + +TEST(Any_Stringify, Integer) { + TestAnyStringify(-13, TVMFFITypeIndex::kTVMFFIInt, "-13"); + TestAnyStringify(-5, TVMFFITypeIndex::kTVMFFIInt, "-5"); + TestAnyStringify(0, TVMFFITypeIndex::kTVMFFIInt, "0"); + TestAnyStringify(1, TVMFFITypeIndex::kTVMFFIInt, "1"); +} + +TEST(Any_Stringify, Float) { + auto check = [](const Any &v) -> void { + std::string str = v.str()->c_str(); + double f_str = std::stod(str); + double f_src = v.v_float64; + EXPECT_NEAR(f_src, f_str, 1e-5); + }; + TestAnyStringifyChecker(float(-3.14), TVMFFITypeIndex::kTVMFFIFloat, check); + TestAnyStringifyChecker(0.0, TVMFFITypeIndex::kTVMFFIFloat, check); +} + +TEST(Any_Stringify, Ptr) { + auto check = [](const Any &v) -> void { + std::string str = v.str()->c_str(); + EXPECT_GT(str.size(), 2); + }; + TestAnyStringify(nullptr, TVMFFITypeIndex::kTVMFFINone, "None"); + TestAnyStringifyChecker(reinterpret_cast(FuncCall), TVMFFITypeIndex::kTVMFFIPtr, + check); +} + +TEST(Any_Stringify, Device) { + TestAnyStringify(DLDevice{kDLCPU, 0}, TVMFFITypeIndex::kTVMFFIDevice, "cpu:0"); + TestAnyStringify(DLDevice{kDLCUDA, 1}, TVMFFITypeIndex::kTVMFFIDevice, "cuda:1"); +} + +TEST(Any_Stringify, DataType) { + TestAnyStringify(DLDataType{kDLInt, 32, 1}, TVMFFITypeIndex::kTVMFFIDataType, + "int32"); + TestAnyStringify(DLDataType{kDLUInt, 1, 1}, TVMFFITypeIndex::kTVMFFIDataType, "bool"); + TestAnyStringify(DLDataType{kDLOpaqueHandle, 0, 0}, TVMFFITypeIndex::kTVMFFIDataType, + "void"); + TestAnyStringify(DLDataType{kDLFloat, 8, 4}, TVMFFITypeIndex::kTVMFFIDataType, + "float8x4"); +} + +TEST(Any_Stringify, RawStr) { + TestAnyStringify("Hello", TVMFFITypeIndex::kTVMFFIStr, "\"Hello\""); + TestAnyStringify("Hello", TVMFFITypeIndex::kTVMFFIStr, "\"Hello\""); + TestAnyStringify("Hello", TVMFFITypeIndex::kTVMFFIStr, "\"Hello\""); +} + +TEST(Any_Stringify, Object) { + auto check = [](const Any &v) -> void { + std::string expected_prefix = "object.Object@0"; + int n = static_cast(expected_prefix.size()); + std::string str = v.str()->c_str(); + EXPECT_GT(str.size(), n); + EXPECT_EQ(str.substr(0, n), expected_prefix); + }; + TestAnyStringifyChecker>(Ref::New(), TVMFFITypeIndex::kTVMFFIObject, check); +} + +TEST(Any_Stringify, Func) { + auto check = [](const Any &v) -> void { + std::string expected_prefix = "object.Func@0"; + int n = static_cast(expected_prefix.size()); + std::string str = v.str()->c_str(); + EXPECT_GT(str.size(), n); + EXPECT_EQ(str.substr(0, n), expected_prefix); + }; + TestAnyStringifyChecker>(Ref::New(FuncCall), TVMFFITypeIndex::kTVMFFIFunc, check); +} + +TEST(Any_Stringify, Str) { + auto check = [](const Any &v) -> void { + std::string str = v.str()->c_str(); + EXPECT_EQ(str, "\"Hello World\""); + }; + TestAnyStringifyChecker>(Ref::New("Hello World"), TVMFFITypeIndex::kTVMFFIStr, + check); +} + +} // namespace diff --git a/ffi/tests/cpp/test_any_view.cc b/ffi/tests/cpp/test_any_view.cc new file mode 100644 index 000000000000..67aaa9de3aed --- /dev/null +++ b/ffi/tests/cpp/test_any_view.cc @@ -0,0 +1,630 @@ +#include +#include + +namespace { +using namespace tvm::ffi; + +template +void TestAnyViewConstructor(Checker check, TVMFFITypeIndex expected_type_index, + const SrcType &source) { + AnyView v(source); + EXPECT_EQ(v.type_index, static_cast(expected_type_index)); + EXPECT_EQ(v.ref_cnt, 0); + check(&v, source); +}; + +int64_t FuncCall(int64_t x) { return x + 1; } + +std::vector AnyViewArrayFactory() { + static const char *raw_str = "Hello (raw str)"; + static std::string std_str = "World (std::string)"; + static std::string ref_str = "Hello World (Ref)"; + static Ref obj = Ref::New(); + static Ref func = Ref::New(FuncCall); + static Ref str = Ref::New(ref_str); + return std::vector{ + AnyView(nullptr), + AnyView(1), + AnyView(2.5), + AnyView(reinterpret_cast(FuncCall)), + AnyView(DLDevice{kDLCPU, 0}), + AnyView(DLDataType{kDLInt, 32, 1}), + AnyView(raw_str), + AnyView(obj), + AnyView(func), + AnyView(std_str), // TODO: disable AnyView(std::string&&) + AnyView(str), + }; +} + +template +void TestAnyViewStringify(const SrcType &source, TVMFFITypeIndex expected_type_index, + const std::string &expected) { + AnyView v(source); + EXPECT_EQ(v.type_index, static_cast(expected_type_index)); + EXPECT_EQ(v.str()->c_str(), expected); +} + +template +void TestAnyViewStringifyChecker(const SrcType &source, TVMFFITypeIndex expected_type_index, + Checker check) { + AnyView v(source); + EXPECT_EQ(v.type_index, static_cast(expected_type_index)); + check(v); +} + +void CheckAnyViewRefCnt(const TVMFFIAny *v) { + if (v->type_index >= static_cast(TVMFFITypeIndex::kTVMFFIStaticObjectBegin)) { + EXPECT_EQ(v->v_obj->ref_cnt, 1); + } +} + +TEST(AnyView_Constructor_0_Default, Default) { + AnyView v; + EXPECT_EQ(v.type_index, 0); + EXPECT_EQ(v.ref_cnt, 0); + EXPECT_EQ(v.v_int64, 0); +} + +TEST(AnyView_Constructor_1_AnyView, Copy) { + AnyView v1(1); + AnyView v2(v1); + EXPECT_EQ(v1.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v1.v_int64, 1); + EXPECT_EQ(v2.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v2.v_int64, 1); +} + +TEST(AnyView_Constructor_1_AnyView, Move) { + AnyView v1(1); + AnyView v2(std::move(v1)); + EXPECT_EQ(v1.type_index, static_cast(TVMFFITypeIndex::kTVMFFINone)); + EXPECT_EQ(v1.v_int64, 0); + EXPECT_EQ(v2.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v2.v_int64, 1); +} + +TEST(AnyView_Constructor_2_Any, Copy) { + Any v1(1); + AnyView v2(v1); + EXPECT_EQ(v1.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v1.v_int64, 1); + EXPECT_EQ(v2.type_index, static_cast(TVMFFITypeIndex::kTVMFFIInt)); + EXPECT_EQ(v2.v_int64, 1); +} + +TEST(AnyView_Constructor_2_Any, Move) { + // ---- The following behavior is disallowed ---- + // Any v1(1); + // AnyView v2(std::move(v1)); +} + +TEST(AnyView_Constructor_3_Ref, Copy) { + Ref obj = Ref::New(); + AnyView v(obj); + const TVMFFIAny *v_obj = v.v_obj; + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIObject)); + EXPECT_EQ(v.ref_cnt, 0); + EXPECT_EQ(v_obj, static_cast(obj.get())); + EXPECT_EQ(v_obj->ref_cnt, 1); +} + +TEST(AnyView_Constructor_3_Ref, Move) { + // The following behavior is disallowed + // Ref obj = Ref::New(); + // AnyView v(std::move(obj)); +} + +TEST(AnyView_Constructor_4_TypeTraits, Integer) { + auto check = [](TVMFFIAny *v, int64_t source) -> void { EXPECT_EQ(v->v_int64, source); }; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(1)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(2)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(3)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(4)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(1)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(2)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(3)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIInt, static_cast(4)); +} + +TEST(AnyView_Constructor_4_TypeTraits, Float) { + auto check = [](TVMFFIAny *v, double source) -> void { EXPECT_EQ(v->v_float64, source); }; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIFloat, static_cast(3)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIFloat, static_cast(4)); +} + +TEST(AnyView_Constructor_4_TypeTraits, Ptr) { + int p = 4; + auto check = [](TVMFFIAny *v, void *source) -> void { EXPECT_EQ(v->v_ptr, source); }; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFINone, static_cast(nullptr)); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIPtr, static_cast(&p)); +} + +TEST(AnyView_Constructor_4_TypeTraits, Device) { + auto check = [](TVMFFIAny *v, const DLDevice &source) -> void { + EXPECT_EQ(v->v_device.device_type, source.device_type); + EXPECT_EQ(v->v_device.device_id, source.device_id); + }; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIDevice, DLDevice{kDLCPU, 0}); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIDevice, DLDevice{kDLCUDA, 1}); +} + +TEST(AnyView_Constructor_4_TypeTraits, DataType) { + auto check = [](TVMFFIAny *v, const DLDataType &source) -> void { + EXPECT_EQ(v->v_dtype.code, source.code); + EXPECT_EQ(v->v_dtype.bits, source.bits); + EXPECT_EQ(v->v_dtype.lanes, source.lanes); + }; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIDataType, DLDataType{kDLInt, 32, 1}); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIDataType, DLDataType{kDLUInt, 0, 0}); +} + +TEST(AnyView_Constructor_4_TypeTraits, RawStr) { + auto check = [](TVMFFIAny *v, const char *source) -> void { EXPECT_EQ(v->v_str, source); }; + const char *empty = ""; + const char *hello = "hello"; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIRawStr, empty); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIRawStr, hello); +} + +TEST(AnyView_Constructor_4_TypeTraits, CharArray) { + auto check = [](TVMFFIAny *v, const char *source) -> void { EXPECT_EQ(v->v_str, source); }; + const char empty[] = ""; + const char hello[] = "hello"; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIRawStr, empty); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIRawStr, hello); +} + +TEST(AnyView_Constructor_4_TypeTraits, StdString) { + auto check = [](TVMFFIAny *v, const std::string &source) -> void { + EXPECT_EQ(v->v_str, source.data()); + }; + std::string empty = ""; + std::string hello = "hello"; + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIRawStr, hello); + TestAnyViewConstructor(check, TVMFFITypeIndex::kTVMFFIRawStr, empty); +} + +TEST(AnyView_Constructor_5_Object_Ptr, Object) { + Ref obj = Ref::New(); + AnyView v(obj.get()); + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIObject)); + EXPECT_EQ(v.v_obj->ref_cnt, 1); + EXPECT_EQ(v.v_obj, static_cast(obj.get())); +} + +TEST(AnyView_Constructor_5_Object_Ptr, Func) { + Ref func = Ref::New(FuncCall); + AnyView v(func.get()); + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIFunc)); + EXPECT_EQ(v.v_obj->ref_cnt, 1); + EXPECT_EQ(v.v_ptr, static_cast(func.get())); +} + +TEST(AnyView_Constructor_5_Object_Ptr, Str) { + Ref str = Ref::New("hello"); + AnyView v(str.get()); + EXPECT_EQ(v.type_index, static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_EQ(v.v_obj->ref_cnt, 1); + EXPECT_EQ(v.v_obj, static_cast(str.get())); +} + +TEST(AnyView_Converter_0_TypeTraits, Integer) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> int64_t { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIInt)) { + EXPECT_EQ(convert(), 1); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `int`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_0_TypeTraits, Float) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> double { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIInt)) { + EXPECT_EQ(convert(), 1.0); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIFloat)) { + EXPECT_EQ(convert(), 2.5); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `float`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_0_TypeTraits, Ptr) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> void * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + EXPECT_EQ(convert(), nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIPtr)) { + EXPECT_EQ(convert(), reinterpret_cast(&FuncCall)); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIRawStr)) { + EXPECT_EQ(convert(), v.v_str); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `Ptr`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_0_TypeTraits, Device) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> DLDevice { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIDevice)) { + EXPECT_EQ(convert().device_type, kDLCPU); + EXPECT_EQ(convert().device_id, 0); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `Device`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_0_TypeTraits, DataType) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> DLDataType { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIDataType)) { + EXPECT_EQ(convert().code, kDLInt); + EXPECT_EQ(convert().bits, 32); + EXPECT_EQ(convert().lanes, 1); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) << "` to `dtype`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_0_TypeTraits, RawStr) { + std::vector views = AnyViewArrayFactory(); + int counter = 0; + for (const AnyView &v : views) { + auto convert = [&]() -> const char * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIRawStr)) { + counter += 1; + EXPECT_LT(counter, 3); + if (counter == 1) { + EXPECT_STREQ(convert(), "Hello (raw str)"); + } else if (counter == 2) { + EXPECT_STREQ(convert(), "World (std::string)"); + } + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIStr)) { + EXPECT_STREQ(convert(), "Hello World (Ref)"); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `const char *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_0_TypeTraits, RawStrToStrStar_Fail) { + AnyView v = "Hello"; + try { + Str *v_str = v; + (void)v_str; + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), "Cannot convert from type `const char *` to `object.Str *`"); + } +} + +TEST(AnyView_Converter_0_TypeTraits, RawStrToStrStar_WrithStorage) { + Any storage; + AnyView v = "Hello"; + Str *v_str = v.CastWithStorage(&storage); + EXPECT_STREQ(v_str->c_str(), "Hello"); +} + +TEST(AnyView_Converter_1_Any, Any) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &view : views) { + auto convert = [&]() -> Any { return view; }; + { + Any ret = convert(); + EXPECT_EQ(view.ref_cnt, 0); + if (view.type_index == static_cast(TVMFFITypeIndex::kTVMFFIRawStr)) { + Str *str = ret; + EXPECT_EQ(ret.type_index, static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_STREQ(str->c_str(), view.v_str); + EXPECT_EQ(ret.ref_cnt, 0); + } else { + EXPECT_EQ(ret.type_index, view.type_index); + EXPECT_EQ(ret.ref_cnt, 0); + EXPECT_EQ(ret.v_obj, view.v_obj); + } + } + CheckAnyViewRefCnt(&view); + } +} + +TEST(AnyView_Converter_2_Ref, Object) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> Ref { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), nullptr); + } else if (v.type_index >= static_cast(TVMFFITypeIndex::kTVMFFIStaticObjectBegin)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), static_cast(v.v_obj)); + EXPECT_EQ(v.v_obj->ref_cnt, 2); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Object`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_2_Ref, Func) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> Ref { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIFunc)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), static_cast(v.v_obj)); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Func`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_2_Ref, Str) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> Ref { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Ref ret = convert(); + EXPECT_EQ(ret.get(), nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIStr)) { + Ref ret = convert(); + EXPECT_STREQ(ret->c_str(), "Hello World (Ref)"); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIRawStr)) { + Ref ret = convert(); + EXPECT_EQ(reinterpret_cast(ret.get())->ref_cnt, 1); + EXPECT_STREQ(ret->c_str(), v.v_str); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Str`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_3_Object_Ptr, Object) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> Object * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Object *ret = convert(); + EXPECT_EQ(ret, nullptr); + } else if (v.type_index >= static_cast(TVMFFITypeIndex::kTVMFFIStaticObjectBegin)) { + Object *ret = convert(); + EXPECT_EQ(ret, static_cast(v.v_obj)); + EXPECT_EQ(v.v_obj->ref_cnt, 1); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Object *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_3_Object_Ptr, Func) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> Func * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Func *ret = convert(); + EXPECT_EQ(ret, nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIFunc)) { + Func *ret = convert(); + EXPECT_EQ(ret, reinterpret_cast(v.v_ptr)); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Func *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Converter_3_Object_Ptr, Str) { + std::vector views = AnyViewArrayFactory(); + for (const AnyView &v : views) { + auto convert = [&]() -> Str * { return v; }; + if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFINone)) { + Str *ret = convert(); + EXPECT_EQ(ret, nullptr); + } else if (v.type_index == static_cast(TVMFFITypeIndex::kTVMFFIStr)) { + Str *ret = convert(); + EXPECT_STREQ(ret->c_str(), "Hello World (Ref)"); + } else { + try { + convert(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + std::ostringstream os; + os << "Cannot convert from type `" << TypeIndex2TypeKey(v.type_index) + << "` to `object.Str *`"; + EXPECT_EQ(ex.what(), os.str()); + } + } + CheckAnyViewRefCnt(&v); + } +} + +TEST(AnyView_Stringify, Integer) { + TestAnyViewStringify(-13, TVMFFITypeIndex::kTVMFFIInt, "-13"); + TestAnyViewStringify(-5, TVMFFITypeIndex::kTVMFFIInt, "-5"); + TestAnyViewStringify(0, TVMFFITypeIndex::kTVMFFIInt, "0"); + TestAnyViewStringify(1, TVMFFITypeIndex::kTVMFFIInt, "1"); +} + +TEST(AnyView_Stringify, Float) { + auto check = [](const AnyView &v) -> void { + std::string str = v.str()->c_str(); + double f_str = std::stod(str); + double f_src = v.v_float64; + EXPECT_NEAR(f_src, f_str, 1e-5); + }; + TestAnyViewStringifyChecker(float(-3.14), TVMFFITypeIndex::kTVMFFIFloat, check); + TestAnyViewStringifyChecker(0.0, TVMFFITypeIndex::kTVMFFIFloat, check); +} + +TEST(AnyView_Stringify, Ptr) { + auto check = [](const AnyView &v) -> void { + std::string str = v.str()->c_str(); + EXPECT_GT(str.size(), 2); + }; + TestAnyViewStringify(nullptr, TVMFFITypeIndex::kTVMFFINone, "None"); + TestAnyViewStringifyChecker(reinterpret_cast(FuncCall), + TVMFFITypeIndex::kTVMFFIPtr, check); +} + +TEST(AnyView_Stringify, Device) { + TestAnyViewStringify(DLDevice{kDLCPU, 0}, TVMFFITypeIndex::kTVMFFIDevice, "cpu:0"); + TestAnyViewStringify(DLDevice{kDLCUDA, 1}, TVMFFITypeIndex::kTVMFFIDevice, "cuda:1"); +} + +TEST(AnyView_Stringify, DataType) { + TestAnyViewStringify(DLDataType{kDLInt, 32, 1}, TVMFFITypeIndex::kTVMFFIDataType, + "int32"); + TestAnyViewStringify(DLDataType{kDLUInt, 1, 1}, TVMFFITypeIndex::kTVMFFIDataType, + "bool"); + TestAnyViewStringify(DLDataType{kDLOpaqueHandle, 0, 0}, + TVMFFITypeIndex::kTVMFFIDataType, "void"); + TestAnyViewStringify(DLDataType{kDLFloat, 8, 4}, TVMFFITypeIndex::kTVMFFIDataType, + "float8x4"); +} + +TEST(AnyView_Stringify, RawStr) { + TestAnyViewStringify("Hello", TVMFFITypeIndex::kTVMFFIRawStr, "\"Hello\""); + TestAnyViewStringify("Hello", TVMFFITypeIndex::kTVMFFIRawStr, "\"Hello\""); + TestAnyViewStringify("Hello", TVMFFITypeIndex::kTVMFFIRawStr, "\"Hello\""); +} + +TEST(AnyView_Stringify, Object) { + auto check = [](const AnyView &v) -> void { + std::string expected_prefix = "object.Object@0"; + int n = static_cast(expected_prefix.size()); + std::string str = v.str()->c_str(); + EXPECT_GT(str.size(), n); + EXPECT_EQ(str.substr(0, n), expected_prefix); + }; + TestAnyViewStringifyChecker>(Ref::New(), TVMFFITypeIndex::kTVMFFIObject, + check); +} + +TEST(AnyView_Stringify, Func) { + auto check = [](const AnyView &v) -> void { + std::string expected_prefix = "object.Func@0"; + int n = static_cast(expected_prefix.size()); + std::string str = v.str()->c_str(); + EXPECT_GT(str.size(), n); + EXPECT_EQ(str.substr(0, n), expected_prefix); + }; + TestAnyViewStringifyChecker>(Ref::New(FuncCall), TVMFFITypeIndex::kTVMFFIFunc, + check); +} + +TEST(AnyView_Stringify, Str) { + auto check = [](const AnyView &v) -> void { + std::string str = v.str()->c_str(); + EXPECT_EQ(str, "\"Hello World\""); + }; + TestAnyViewStringifyChecker>(Ref::New("Hello World"), TVMFFITypeIndex::kTVMFFIStr, + check); +} + +} // namespace diff --git a/ffi/tests/cpp/test_dict.cc b/ffi/tests/cpp/test_dict.cc new file mode 100644 index 000000000000..04d5a2854ff9 --- /dev/null +++ b/ffi/tests/cpp/test_dict.cc @@ -0,0 +1,183 @@ +#include +#include +#include + +namespace { + +using namespace tvm::ffi; + +bool DTypeEqual(DLDataType a, DLDataType b) { + return a.code == b.code && a.bits == b.bits && a.lanes == b.lanes; +} +// bool DeviceEqual(DLDevice a, DLDevice b) { +// return a.device_type == b.device_type && a.device_id == b.device_id; +// } + +TEST(Dict_Construtor, Default) { + Ref dict; + ASSERT_EQ(dict.size(), 0); + TVMFFIDict *dict_ptr = reinterpret_cast(dict.get()); + EXPECT_EQ(dict_ptr->type_index, static_cast(TVMFFITypeIndex::kTVMFFIDict)); + EXPECT_EQ(dict_ptr->ref_cnt, 1); + EXPECT_NE(dict_ptr->deleter, nullptr); + EXPECT_EQ(dict_ptr->size, 0); + EXPECT_EQ(dict_ptr->capacity, 0); +} + +TEST(Dict_Construtor, InitializerList) { + Ref dict{{"key1", 1}, {"key2", "value2"}, {3, 4}}; + EXPECT_EQ(dict.size(), 3); + EXPECT_EQ(int(dict["key1"]), 1); + EXPECT_EQ(std::string(dict["key2"]), "value2"); + EXPECT_EQ(int(dict[3]), 4); + + bool found[3] = {false, false, false}; + for (const auto &kv : dict) { + if (AnyEqual()(kv.first, Any("key1"))) { + found[0] = true; + EXPECT_EQ(int(kv.second), 1); + } else if (AnyEqual()(kv.first, Any("key2"))) { + found[1] = true; + EXPECT_EQ(std::string(kv.second), "value2"); + } else if (AnyEqual()(kv.first, Any(3))) { + found[2] = true; + EXPECT_EQ(int(kv.second), 4); + } else { + FAIL() << "Unexpected key: " << kv.first; + } + } + EXPECT_TRUE(found[0]); + EXPECT_TRUE(found[1]); + EXPECT_TRUE(found[2]); +} + +TEST(Dict_Insert, New) { + int64_t integer = 100; + double fp = 1.0; + std::string str = "Hi"; + DLDataType dtype{kDLInt, 32, 1}; + DLDevice device{kDLCPU, 0}; + Ref obj = Ref::New(); + Ref null_obj{nullptr}; + Ref dict{{integer, fp}, {str, dtype}, {null_obj, 0}}; + dict[device] = null_obj; + EXPECT_EQ(dict.size(), 4); + EXPECT_DOUBLE_EQ(double(dict[integer]), fp); + EXPECT_PRED2(DTypeEqual, DLDataType(dict[str]), dtype); + EXPECT_EQ(int(dict[null_obj]), 0); + EXPECT_EQ((Object *)(dict[device]), nullptr); +} + +TEST(Dict_Insert, Override) { + Ref dict{{"key1", 1}, {"key2", "value2"}, {3, 4}}; + EXPECT_EQ(dict.size(), 3); + dict["key1"] = 2; + dict["key2"] = "new_value"; + dict[3] = 5; + EXPECT_EQ(dict.size(), 3); + EXPECT_EQ(int(dict["key1"]), 2); + EXPECT_EQ(std::string(dict["key2"]), "new_value"); + EXPECT_EQ(int(dict[3]), 5); +} + +TEST(Dict_At, Found) { + int64_t integer = 100; + double fp = 1.0; + std::string str = "Hi"; + DLDataType dtype{kDLInt, 32, 1}; + Ref obj = Ref::New(); + Ref null_obj{nullptr}; + Ref dict{{integer, fp}, {str, dtype}, {null_obj, 0}}; + EXPECT_DOUBLE_EQ(double(dict.at(integer)), fp); + EXPECT_PRED2(DTypeEqual, DLDataType(dict.at(str)), dtype); + EXPECT_EQ(int(dict.at(null_obj)), 0); +} + +TEST(Dict_At, NotFound) { + Ref dict{{"key1", 1}, {"key2", "value2"}, {3, 4}}; + try { + dict.at("key3"); + FAIL() << "Expected TVMError"; + } catch (const TVMError &e) { + } +} + +TEST(Dict_ReHash, POD) { + Ref dict; + for (int j = 0; j < 1000; ++j) { + dict[j] = j; + } + EXPECT_EQ(dict.size(), 1000); + std::unordered_set keys; + for (auto &kv : dict) { + int64_t key = kv.first; + int64_t value = kv.second; + EXPECT_EQ(key, value); + EXPECT_FALSE(keys.count(key)); + EXPECT_EQ(key, value); + EXPECT_TRUE(0 <= key && key < 1000); + } + EXPECT_EQ(dict.size(), 1000); +} + +TEST(Dict_ReHash, Object) { + std::vector> objs; + std::unordered_map obj_map; + for (int j = 0; j < 1000; ++j) { + objs.push_back(Ref::New()); + obj_map[objs[j].get()] = j; + } + Ref dict; + for (int j = 0; j < 1000; ++j) { + dict[objs[j]] = j; + } + EXPECT_EQ(dict.size(), 1000); + std::unordered_set keys; + for (auto &kv : dict) { + Ref key = kv.first; + int64_t value = kv.second; + keys.insert(key.get()); + EXPECT_EQ(value, obj_map[key.get()]); + } + EXPECT_EQ(dict.size(), 1000); +} + +TEST(Dict_Erase, POD) { + Ref dict; + for (int j = 0; j < 1000; ++j) { + dict[j] = j; + } + EXPECT_EQ(dict.size(), 1000); + for (int j = 0; j < 1000; ++j) { + dict.erase(j); + EXPECT_EQ(dict.size(), 1000 - j - 1); + } + for (int j = 0; j < 1000; ++j) { + dict[j] = j; + EXPECT_EQ(dict.size(), j + 1); + } +} + +TEST(Dict_Erase, Object) { + std::vector> objs; + std::unordered_map obj_map; + for (int j = 0; j < 1000; ++j) { + objs.push_back(Ref::New()); + obj_map[objs[j].get()] = j; + } + Ref dict; + for (int j = 0; j < 1000; ++j) { + dict[objs[j]] = j; + } + EXPECT_EQ(dict.size(), 1000); + for (int j = 0; j < 1000; ++j) { + dict.erase(objs[j]); + EXPECT_EQ(dict.size(), 1000 - j - 1); + } + for (int j = 0; j < 1000; ++j) { + dict[objs[j]] = j; + EXPECT_EQ(dict.size(), j + 1); + } +} + +} // namespace diff --git a/ffi/tests/cpp/test_func.cc b/ffi/tests/cpp/test_func.cc new file mode 100644 index 000000000000..32512d3dcf12 --- /dev/null +++ b/ffi/tests/cpp/test_func.cc @@ -0,0 +1,215 @@ +#include +#include + +namespace { +using namespace tvm::ffi; + +const char *c_str_raw = "Hello"; + +double func_unpacked_0(int64_t a, double b, const char *c, const double &d) { + EXPECT_STREQ(c, c_str_raw); + return a + b + d; +} + +void func_unpacked_1(DLDataType dtype, DLDevice device, std::string str) { + (void)dtype; + (void)device; + (void)str; +} + +void func_packed_0(int num_args, const AnyView *, Any *ret) { *ret = num_args; } + +template +void func_unpacked_anyview_arg(AnyView a) { + EXPECT_EQ(a.type_index, static_cast(type_index)); +} +template +void func_unpacked_any_arg(Any a) { + EXPECT_EQ(a.type_index, static_cast(type_index)); +} +AnyView func_unpacked_anyview_ret() { return AnyView(1); } +Any func_unpacked_any_ret() { return Any(1); } + +std::string func_unpacked_str_obj(Str *str, const char *str_2) { + EXPECT_EQ(reinterpret_cast(str)->ref_cnt, 1); + EXPECT_STREQ(str->c_str(), str_2); + return str->c_str(); +} + +TEST(Func_Signature, 0) { + EXPECT_EQ(details::FuncFunctor::Sig(), + "(0: int, 1: float, 2: const char *, 3: float) -> float"); +} + +TEST(Func_Signature, 1) { + EXPECT_EQ(details::FuncFunctor::Sig(), + "(0: dtype, 1: Device, 2: str) -> void"); +} + +TEST(Func_Signature, AnyView_Arg) { + EXPECT_EQ(details::FuncFunctor)>::Sig(), + "(0: AnyView) -> void"); +} + +TEST(Func_Signature, AnyView_Ret) { + EXPECT_EQ(details::FuncFunctor::Sig(), + "() -> AnyView"); +} + +TEST(Func_Signature, Any_Arg) { + EXPECT_EQ( + details::FuncFunctor< + decltype(func_unpacked_any_arg)>::Sig(), + "(0: Any) -> void"); +} + +TEST(Func_Signature, Any_Ret) { + EXPECT_EQ(details::FuncFunctor::Sig(), + "() -> Any"); +} + +TEST(Func_Unpacked_Invoke, Func0_RawStr) { + Ref func = Ref::New(func_unpacked_0); + double ret = func(1, 2, c_str_raw, 4); + EXPECT_DOUBLE_EQ(ret, 7); + double ret2 = func(1, 2, std::string("Hello"), 4); + EXPECT_DOUBLE_EQ(ret2, 7); +} + +TEST(Func_Unpacked_Invoke, Func0_StdString_Move) { + std::string str = c_str_raw; + Ref func = Ref::New(func_unpacked_0); + double ret = func(1, 2, std::move(str), 4); + EXPECT_DOUBLE_EQ(ret, 7); +} + +TEST(Func_Unpacked_Invoke, Func0_StdString_Copy) { + std::string str = c_str_raw; + Ref func = Ref::New(func_unpacked_0); + double ret = func(1, 2, str, 4); + EXPECT_DOUBLE_EQ(ret, 7); + EXPECT_EQ(str, c_str_raw); +} + +TEST(Func_Unpacked_Invoke, Func1) { + Ref func = Ref::New(func_unpacked_1); + func(DLDataType{kDLInt, 32, 1}, DLDevice{kDLCPU, 0}, "Hello"); +} + +TEST(Func_Unpacked_Invoke, AnyView_Arg) { + Ref func = + Ref::New(func_unpacked_anyview_arg); + func(1); +} + +TEST(Func_Unpacked_Invoke, AnyView_Ret) { + Ref func = Ref::New(func_unpacked_anyview_ret); + int ret = func(); + EXPECT_EQ(ret, 1); +} + +TEST(Func_Unpacked_Invoke, Any_Arg) { + Ref func = + Ref::New(func_unpacked_any_arg); + func(1); +} + +TEST(Func_Unpacked_Invoke, Any_Ret) { + Ref func = Ref::New(func_unpacked_any_ret); + int ret = func(); + EXPECT_EQ(ret, 1); +} + +TEST(Func_Unpacked_Invoke, StrObj) { + Ref func = Ref::New(func_unpacked_str_obj); + std::string ret = func("Hello", "Hello"); + EXPECT_EQ(ret, "Hello"); +} + +TEST(Func_Packed_Invoke, 0) { + Ref func = Ref::New(func_packed_0); + int ret = func(); + EXPECT_EQ(ret, 0); +} + +TEST(Func_Packed_Invoke, 1) { + Ref func = Ref::New(func_packed_0); + int ret = func(1.0); + EXPECT_EQ(ret, 1); +} + +TEST(Func_Packed_Invoke, 2) { + Ref func = Ref::New(func_packed_0); + int ret = func(1.0, "test"); + EXPECT_EQ(ret, 2); +} + +TEST(Func_Packed_Invoke, 4) { + Ref func = Ref::New(func_packed_0); + int ret = func(1.0, "test", DLDataType{kDLInt, 32, 1}, DLDevice{kDLCPU, 0}); + EXPECT_EQ(ret, 4); +} + +TEST(Func_Unpacked_Invoke_TypeError, TypeMismatch_0) { + Ref func = Ref::New(func_unpacked_0); + try { + func(1.0, 2, c_str_raw, 4); + FAIL() << "No execption thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), + "Mismatched type on argument #0 when calling: " + "`(0: int, 1: float, 2: const char *, 3: float) -> float`. " + "Expected `int` but got `float`"); + } +} + +TEST(Func_Unpacked_Invoke_TypeError, TypeMismatch_1) { + Ref func = Ref::New(func_unpacked_1); + try { + func(DLDataType{kDLInt, 32, 1}, DLDevice{kDLCPU, 0}, 1); + FAIL() << "No execption thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), "Mismatched type on argument #2 when calling: " + "`(0: dtype, 1: Device, 2: str) -> void`. " + "Expected `str` but got `int`"); + } +} + +TEST(Func_Unpacked_Invoke_TypeError, ArgCountMismatch_0) { + Ref func = Ref::New(func_unpacked_0); + try { + func(1, 2, c_str_raw); + FAIL() << "No execption thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), + "Mismatched number of arguments when calling: " + "`(0: int, 1: float, 2: const char *, 3: float) -> float`. " + "Expected 4 but got 3 arguments"); + } +} + +TEST(Func_Unpacked_Invoke_TypeError, ArgCountMismatch_1) { + Ref func = Ref::New(func_unpacked_1); + try { + func(DLDataType{kDLInt, 32, 1}, DLDevice{kDLCPU, 0}); + FAIL() << "No execption thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), "Mismatched number of arguments when calling: " + "`(0: dtype, 1: Device, 2: str) -> void`. " + "Expected 3 but got 2 arguments"); + } +} + +TEST(Func_Unpacked_Invoke_TypeError, ReturnTypeMismatch_0) { + Ref func = Ref::New(func_unpacked_1); + try { + int ret = func(DLDataType{kDLInt, 32, 1}, DLDevice{kDLCPU, 0}, "Hello"); + (void)ret; + FAIL() << "No execption thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), "Cannot convert from type `None` to `int`"); + } +} + +} // namespace diff --git a/ffi/tests/cpp/test_list.cc b/ffi/tests/cpp/test_list.cc new file mode 100644 index 000000000000..53fd727af6fb --- /dev/null +++ b/ffi/tests/cpp/test_list.cc @@ -0,0 +1,461 @@ +#include +#include + +namespace { + +using namespace tvm::ffi; + +bool DTypeEqual(DLDataType a, DLDataType b) { + return a.code == b.code && a.bits == b.bits && a.lanes == b.lanes; +} +bool DeviceEqual(DLDevice a, DLDevice b) { + return a.device_type == b.device_type && a.device_id == b.device_id; +} + +void TestSizeCapacityClear(Ref *list, int64_t size, int64_t capacity) { + EXPECT_EQ(list->size(), size); + EXPECT_EQ(list->capacity(), capacity); + EXPECT_EQ(list->empty(), size == 0); + list->clear(); + EXPECT_EQ(list->size(), 0); + EXPECT_EQ(list->capacity(), capacity); + EXPECT_EQ(list->empty(), true); +} + +TEST(List_Constructor, Default) { + Ref list = Ref::New(); + TVMFFIList *list_ptr = reinterpret_cast(list.get()); + ASSERT_NE(list_ptr, nullptr); + EXPECT_EQ(list_ptr->type_index, static_cast(TVMFFITypeIndex::kTVMFFIList)); + EXPECT_EQ(list_ptr->ref_cnt, 1); + EXPECT_NE(list_ptr->deleter, nullptr); + EXPECT_EQ(list_ptr->list_capacity, 0); + EXPECT_EQ(list_ptr->list_length, 0); + EXPECT_EQ(list_ptr->pool_capacity, 0); + EXPECT_EQ(list_ptr->pool_length, 0); + TestSizeCapacityClear(&list, 0, 0); +} + +TEST(List_Constructor, InitializerList) { + Ref list1{ + 100, 1.0f, "Hi", DLDataType{kDLInt, 32, 1}, DLDevice{kDLCPU, 0}, Ref::New(), + Ref()}; + Ref list2 = { + 100, 1.0f, "Hi", DLDataType{kDLInt, 32, 1}, DLDevice{kDLCPU, 0}, Ref::New(), + Ref()}; + + auto test = [](Ref *src) { + auto *list_ptr = reinterpret_cast(src->get()); + ASSERT_NE(list_ptr, nullptr); + EXPECT_EQ(list_ptr->type_index, static_cast(TVMFFITypeIndex::kTVMFFIList)); + EXPECT_EQ(list_ptr->ref_cnt, 1); + EXPECT_NE(list_ptr->deleter, nullptr); + EXPECT_EQ(list_ptr->list_capacity, 7); + EXPECT_EQ(list_ptr->list_length, 7); + EXPECT_EQ(list_ptr->pool_capacity, 7); + EXPECT_EQ(list_ptr->pool_length, 4); // string is not in the POD pool + EXPECT_EQ(src->size(), 7); + EXPECT_EQ(src->capacity(), 7); + EXPECT_EQ(src->empty(), false); + TestSizeCapacityClear(src, 7, 7); + }; + test(&list1); + test(&list2); +} + +TEST(List_PushBack, POD) { + Ref list; + ASSERT_NE(list.get(), nullptr); + list.push_back(100); + list.push_back(1.0f); + TVMFFIList *list_ptr = reinterpret_cast(list.get()); + ASSERT_NE(list_ptr, nullptr); + EXPECT_EQ(list_ptr->type_index, static_cast(TVMFFITypeIndex::kTVMFFIList)); + EXPECT_EQ(list_ptr->ref_cnt, 1); + EXPECT_NE(list_ptr->deleter, nullptr); + EXPECT_EQ(list_ptr->list_capacity, List::kMinCapacity); + EXPECT_EQ(list_ptr->list_length, 2); + EXPECT_EQ(list_ptr->pool_capacity, List::kMinCapacity); + EXPECT_EQ(list_ptr->pool_length, 2); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + TestSizeCapacityClear(&list, 2, List::kMinCapacity); +} + +TEST(List_PushBack, Obj) { + Ref list; + Ref obj1 = Ref::New(); + Ref obj2 = Ref::New(); + list.push_back(obj1); + list.push_back(obj2); + TVMFFIList *list_ptr = reinterpret_cast(list.get()); + ASSERT_NE(list_ptr, nullptr); + EXPECT_EQ(list_ptr->type_index, static_cast(TVMFFITypeIndex::kTVMFFIList)); + EXPECT_EQ(list_ptr->ref_cnt, 1); + EXPECT_NE(list_ptr->deleter, nullptr); + EXPECT_EQ(list_ptr->list_capacity, List::kMinCapacity); + EXPECT_EQ(list_ptr->list_length, 2); + EXPECT_EQ(list_ptr->pool_capacity, 0); + EXPECT_EQ(list_ptr->pool_length, 0); + EXPECT_EQ((Object *)(list[0]), obj1.get()); + EXPECT_EQ((Object *)(list[1]), obj2.get()); + TestSizeCapacityClear(&list, 2, List::kMinCapacity); +} + +TEST(List_PushBack, Heterogeneous) { + constexpr int n = 128; + constexpr int k = 8; + constexpr int expected_size = n * k; + constexpr int expected_capacity = 1024; + constexpr int expected_pool_capacity = 1024; + constexpr int expected_pool_length = 512; + int64_t integer = 100; + double fp = 1.0; + std::string str = "Hi"; + DLDataType dtype{kDLInt, 32, 1}; + DLDevice device{kDLCPU, 0}; + Ref obj = Ref::New(); + Ref null_obj{nullptr}; + std::string long_str(1024, 'a'); + + Ref list = Ref::New(); + { + std::string long_str_copy(1024, 'a'); + for (int i = 0; i < n; ++i) { + list.push_back(integer); + list.push_back(fp); + list.push_back(str); + list.push_back(dtype); + list.push_back(device); + list.push_back(obj); + list.push_back(null_obj); + list.push_back(long_str_copy); + } + } + for (int i = 0; i < n; ++i) { + int64_t i_0 = list[i * k]; + double i_1 = list[i * k + 1]; + std::string i_2 = list[i * k + 2]; + DLDataType i_3 = list[i * k + 3]; + DLDevice i_4 = list[i * k + 4]; + Object *i_5 = list[i * k + 5]; + Object *i_6 = list[i * k + 6]; + const char *i_7 = list[i * k + 7]; + EXPECT_EQ(i_0, integer); + EXPECT_DOUBLE_EQ(i_1, fp); + EXPECT_EQ(i_2, str); + EXPECT_PRED2(DTypeEqual, i_3, dtype); + EXPECT_PRED2(DeviceEqual, i_4, device); + EXPECT_EQ(i_5, obj.get()); + EXPECT_EQ(i_6, nullptr); + EXPECT_STREQ(i_7, long_str.c_str()); + } + auto *list_ptr = reinterpret_cast(list.get()); + EXPECT_EQ(list_ptr->list_capacity, expected_capacity); + EXPECT_EQ(list_ptr->list_length, expected_size); + EXPECT_EQ(list_ptr->pool_capacity, expected_pool_capacity); + EXPECT_EQ(list_ptr->pool_length, expected_pool_length); +} + +TEST(List_Insert, Once) { + Ref values = {100, + 1.0, + "Hi", // + DLDataType{kDLInt, 32, 1}, + DLDevice{kDLCPU, 0}, + Ref::New(), + Ref(), + std::string(1024, 'a')}; + int n = values.size(); + for (int pos = 0; pos <= n; ++pos) { + for (AnyView data : values) { + // Test: insert at `pos` with value `data` + Ref list(values.begin(), values.end()); + list.insert(pos, data); + auto test = [](AnyView expected, AnyView actual) { + EXPECT_EQ(expected.type_index, actual.type_index); + EXPECT_EQ(expected.v_int64, actual.v_int64); + }; + for (int i = 0; i < pos; ++i) { + test(values[i], list[i]); + } + for (int i = pos; i < n; ++i) { + test(values[i], list[i + 1]); + } + test(data, list[pos]); + } + } +} + +TEST(List_Insert, Error_0) { + Ref list = {100, 1.0, "Hi"}; + try { + list.insert(-1, 1.0); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), "Indexing `-1` of a list of size 3"); + } +} + +TEST(List_Insert, Error_1) { + Ref list = {100, 1.0, "Hi"}; + try { + list.insert(4, 1.0); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), "Indexing `4` of a list of size 3"); + } +} + +TEST(List_Resize, Shrink) { + Ref list = {100, 1.0, "Hi"}; + list.resize(2); + EXPECT_EQ(list.size(), 2); + EXPECT_EQ(list.capacity(), 3); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); +} + +TEST(List_Resize, Expand) { + Ref list = {100, 1.0, "Hi"}; + list.resize(4); + EXPECT_EQ(list.size(), 4); + EXPECT_EQ(list.capacity(), 6); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + EXPECT_STREQ(list[2], "Hi"); + EXPECT_EQ(list[3].operator void *(), nullptr); +} + +TEST(List_Reserve, Shrink) { + Ref list = {100, 1.0, "Hi"}; + list.reserve(2); + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + EXPECT_STREQ(list[2], "Hi"); +} + +TEST(List_Reserve, Expand) { + Ref list = {100, 1.0, "Hi"}; + list.reserve(4); + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 6); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + EXPECT_STREQ(list[2], "Hi"); +} + +TEST(List_SetItem, PodToPod) { + Ref list = {100, 1.0, "Hi"}; + for (int i = 0; i < 16; ++i) { + list[1] = i; + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_EQ(int32_t(list[1]), i); + EXPECT_STREQ(list[2], "Hi"); + } + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + TVMFFIList *list_ptr = reinterpret_cast(list.get()); + EXPECT_EQ(list_ptr->list_capacity, 3); + EXPECT_EQ(list_ptr->list_length, 3); + EXPECT_EQ(list_ptr->pool_capacity, 24); + EXPECT_EQ(list_ptr->pool_length, 3); +} + +TEST(List_SetItem, ObjToPod) { + Ref list = {100, 1.0, "Hi"}; // + for (int i = 0; i < 16; ++i) { + list[2] = i; + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + EXPECT_EQ(int32_t(list[2]), i); + } + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + TVMFFIList *list_ptr = reinterpret_cast(list.get()); + EXPECT_EQ(list_ptr->list_capacity, 3); + EXPECT_EQ(list_ptr->list_length, 3); + EXPECT_EQ(list_ptr->pool_capacity, 24); + EXPECT_EQ(list_ptr->pool_length, 6); +} + +TEST(List_SetItem, PodToObj) { + Ref list = {100, 1.0, "Hi"}; + for (int i = 0; i < 1; ++i) { + Ref obj = Ref::New(); + list[0] = obj; + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + EXPECT_EQ((Object *)(list[0]), obj.get()); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + EXPECT_STREQ(list[2], "Hi"); + } + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + TVMFFIList *list_ptr = reinterpret_cast(list.get()); + EXPECT_EQ(list_ptr->list_capacity, 3); + EXPECT_EQ(list_ptr->list_length, 3); + EXPECT_EQ(list_ptr->pool_capacity, 3); + EXPECT_EQ(list_ptr->pool_length, 2); +} + +TEST(List_SetItem, ObjToObj) { + Ref list = {100, 1.0, "Hi"}; + for (int i = 0; i < 1; ++i) { + Ref obj = Ref::New(); + list[2] = obj; + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + EXPECT_EQ((Object *)(list[2]), obj.get()); + } + EXPECT_EQ(list.size(), 3); + EXPECT_EQ(list.capacity(), 3); + TVMFFIList *list_ptr = reinterpret_cast(list.get()); + EXPECT_EQ(list_ptr->list_capacity, 3); + EXPECT_EQ(list_ptr->list_length, 3); + EXPECT_EQ(list_ptr->pool_capacity, 3); + EXPECT_EQ(list_ptr->pool_length, 2); +} + +TEST(List_PopBack, Heterogeneous) { + int64_t integer = 100; + double fp = 1.0; + std::string str = "Hi"; + DLDataType dtype{kDLInt, 32, 1}; + DLDevice device{kDLCPU, 0}; + Ref obj = Ref::New(); + Ref null_obj{nullptr}; + Ref list{integer, fp, str, dtype, device, obj, null_obj}; + int n = static_cast(list.size()); + for (int i = 0; i < n; ++i) { + list.pop_back(); + EXPECT_EQ(list.size(), n - 1 - i); + EXPECT_EQ(list.capacity(), n); + int m = static_cast(list.size()); + if (m > 0) { + EXPECT_EQ(int32_t(list[0]), integer); + } + if (m > 1) { + EXPECT_DOUBLE_EQ(double(list[1]), fp); + } + if (m > 2) { + EXPECT_STREQ(list[2], str.c_str()); + } + if (m > 3) { + EXPECT_PRED2(DTypeEqual, DLDataType(list[3]), dtype); + } + if (m > 4) { + EXPECT_PRED2(DeviceEqual, DLDevice(list[4]), device); + } + if (m > 5) { + EXPECT_EQ((Object *)(list[5]), obj.get()); + } + if (m > 6) { + EXPECT_EQ((Object *)(list[6]), nullptr); + } + } + EXPECT_EQ(list.size(), 0); + EXPECT_EQ(list.capacity(), n); + EXPECT_EQ(list.empty(), true); + EXPECT_EQ(list.begin(), list.end()); + try { + list.pop_back(); + FAIL() << "No exception thrown"; + } catch (TVMError &ex) { + EXPECT_STREQ(ex.what(), "Indexing `-1` of a list of size 0"); + } +} + +TEST(List_Erase, Front) { + int64_t integer = 100; + double fp = 1.0; + std::string str = "Hi"; + DLDataType dtype{kDLInt, 32, 1}; + DLDevice device{kDLCPU, 0}; + Ref obj = Ref::New(); + Ref null_obj{nullptr}; + Ref list{integer, fp, str, dtype, device, obj, null_obj}; + list.erase(0); + EXPECT_EQ(list.size(), 6); + EXPECT_EQ(list.capacity(), 7); + EXPECT_EQ(double(list[0]), 1.0); + EXPECT_STREQ(list[1], "Hi"); + EXPECT_PRED2(DTypeEqual, DLDataType(list[2]), dtype); + EXPECT_PRED2(DeviceEqual, DLDevice(list[3]), device); + EXPECT_EQ((Object *)(list[4]), obj.get()); + EXPECT_EQ((Object *)(list[5]), nullptr); +} + +TEST(List_Erase, Back) { + int64_t integer = 100; + double fp = 1.0; + std::string str = "Hi"; + DLDataType dtype{kDLInt, 32, 1}; + DLDevice device{kDLCPU, 0}; + Ref obj = Ref::New(); + Ref null_obj{nullptr}; + Ref list{integer, fp, str, dtype, device, obj, null_obj}; + list.erase(0); + EXPECT_EQ(list.size(), 6); + EXPECT_EQ(list.capacity(), 7); + EXPECT_EQ(double(list[0]), 1.0); + EXPECT_STREQ(list[1], "Hi"); + EXPECT_PRED2(DTypeEqual, DLDataType(list[2]), dtype); + EXPECT_PRED2(DeviceEqual, DLDevice(list[3]), device); + EXPECT_EQ((Object *)(list[4]), obj.get()); + EXPECT_EQ((Object *)(list[5]), nullptr); +} + +TEST(List_Erase, Mid) { + int64_t integer = 100; + double fp = 1.0; + std::string str = "Hi"; + DLDataType dtype{kDLInt, 32, 1}; + DLDevice device{kDLCPU, 0}; + Ref obj = Ref::New(); + Ref null_obj{nullptr}; + Ref list{integer, fp, str, dtype, device, obj, null_obj}; + list.erase(3); + EXPECT_EQ(list.size(), 6); + EXPECT_EQ(list.capacity(), 7); + EXPECT_EQ(int32_t(list[0]), 100); + EXPECT_DOUBLE_EQ(double(list[1]), 1.0); + EXPECT_STREQ(list[2], "Hi"); + EXPECT_PRED2(DeviceEqual, DLDevice(list[3]), device); + EXPECT_EQ((Object *)(list[4]), obj.get()); + EXPECT_EQ((Object *)(list[5]), nullptr); +} + +TEST(List_Iter, Test) { + Ref list; + for (int i = 0; i < 16; ++i) { + list.push_back(i * i); + } + int i = 0; + for (int item : list) { + EXPECT_EQ(i * i, item); + ++i; + } +} + +TEST(List_RevIter, Test) { + Ref list; + for (int i = 0; i < 16; ++i) { + list.push_back(i * i); + } + int i = list.size() - 1; + for (auto it = list.rbegin(); it != list.rend(); ++it) { + EXPECT_EQ(i * i, int32_t(*it)); + --i; + } +} + +} // namespace diff --git a/ffi/tests/cpp/test_ref.cc b/ffi/tests/cpp/test_ref.cc new file mode 100644 index 000000000000..3c9dc0561d9b --- /dev/null +++ b/ffi/tests/cpp/test_ref.cc @@ -0,0 +1,293 @@ +#include +#include +#include + +namespace { +using namespace tvm::ffi; +using tvm::ffi::details::StrPad; +using tvm::ffi::details::StrStd; + +using ObjDeleter = void (*)(void *); + +struct AllocRecorder { + std::unordered_map deleters; + + void Alloc(void *ptr) { + deleters[ptr] = reinterpret_cast(ptr)->deleter; + reinterpret_cast(ptr)->deleter = AllocRecorder::Deleter; + } + + void Delete(void *ptr) { + ASSERT_EQ(deleters.count(ptr), 1); + ObjDeleter d = this->deleters[ptr]; + d(ptr); + deleters.erase(ptr); + } + + bool IsDeletedImpl(void *ptr) { return deleters.count(ptr) == 0; } + + static void Deleter(void *ptr) { AllocRecorder::Global()->Delete(ptr); } + + static bool IsDeleted(void *ptr) { return AllocRecorder::Global()->IsDeletedImpl(ptr); } + + static AllocRecorder *Global() { + static AllocRecorder inst; + return &inst; + } +}; + +template +struct TestAllocator { + using Allocator = typename ::tvm::ffi::GetAllocator::Type; + + template + TVM_FFI_INLINE static ObjectType *New(Args &&...args) { + ObjectType *ret = Allocator::New(std::forward(args)...); + AllocRecorder::Global()->Alloc(ret); + return ret; + } + + template + TVM_FFI_INLINE static ObjectType *NewWithPad(size_t pad_size, Args &&...args) { + ObjectType *ret = + Allocator::template NewWithPad(pad_size, std::forward(args)...); + AllocRecorder::Global()->Alloc(ret); + return ret; + } +}; + +int64_t FuncCall(int64_t x) { return x + 1; } + +int32_t GetRefCount(void *obj) { return reinterpret_cast(obj)->ref_cnt; } + +int32_t GetTypeIndex(void *obj) { return reinterpret_cast(obj)->type_index; } + +ObjDeleter GetDeleter(void *obj) { return reinterpret_cast(obj)->deleter; } + +TEST(Ref_Constructor_0_Default, Default) { + Ref ref; + EXPECT_EQ(ref.get(), nullptr); +} + +TEST(Ref_Constructor_1_Ptr, SameType) { + Object *obj = TestAllocator::New(); + { + Ref ref(obj); + EXPECT_EQ(ref.get(), obj); + EXPECT_EQ(GetRefCount(obj), 1); + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_1_Ptr, SubType) { + Str *obj = TestAllocator::New("Hello world"); + { + Ref ref(obj); + EXPECT_EQ(ref.get(), static_cast(obj)); + EXPECT_EQ(GetRefCount(obj), 1); + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_2_Ref, SameType_Copy) { + Object *obj = TestAllocator::New(); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + Ref ref2(ref1); + EXPECT_EQ(GetRefCount(obj), 2); + } + EXPECT_EQ(GetRefCount(ref1.get()), 1); + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_2_Ref, SameType_Move) { + Object *obj = TestAllocator::New(); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + Ref ref2(std::move(ref1)); + EXPECT_EQ(GetRefCount(obj), 1); + EXPECT_EQ(ref1.get(), nullptr); + } + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_2_Ref, SubType_Copy) { + Str *obj = TestAllocator::New("Hello world"); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + Ref ref2(ref1); + EXPECT_EQ(GetRefCount(obj), 2); + } + EXPECT_EQ(GetRefCount(obj), 1); + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_2_Ref, SubType_Move) { + Str *obj = TestAllocator::New("Hello world"); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + Ref ref2(std::move(ref1)); + EXPECT_EQ(GetRefCount(obj), 1); + EXPECT_EQ(ref1.get(), nullptr); + } + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_3_AnyView, Copy) { + Object *obj = TestAllocator::New(); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + AnyView view(ref1); + Ref ref2(view); + EXPECT_EQ(GetRefCount(obj), 2); + } + EXPECT_EQ(GetRefCount(obj), 1); + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_3_AnyView, Move) { + Object *obj = TestAllocator::New(); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + AnyView view(ref1); + EXPECT_EQ(GetRefCount(obj), 1); + { + Ref ref2(std::move(view)); + EXPECT_EQ(GetRefCount(obj), 2); + EXPECT_EQ(view.type_index, static_cast(TVMFFITypeIndex::kTVMFFINone)); + EXPECT_EQ(view.ref_cnt, 0); + EXPECT_EQ(view.v_int64, 0); + } + EXPECT_EQ(GetRefCount(obj), 1); + } + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_4_Any, Copy) { + Object *obj = TestAllocator::New(); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + Any any(ref1); + EXPECT_EQ(GetRefCount(obj), 2); + Ref ref2(any); + EXPECT_EQ(GetRefCount(obj), 3); + } + EXPECT_EQ(GetRefCount(obj), 1); + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_Constructor_4_Any, Move) { + Object *obj = TestAllocator::New(); + { + Ref ref1(obj); + EXPECT_EQ(GetRefCount(obj), 1); + { + Any any(ref1); + EXPECT_EQ(GetRefCount(obj), 2); + { + Ref ref2(std::move(any)); + EXPECT_EQ(GetRefCount(obj), 2); + EXPECT_EQ(any.type_index, static_cast(TVMFFITypeIndex::kTVMFFINone)); + EXPECT_EQ(any.ref_cnt, 0); + EXPECT_EQ(any.v_int64, 0); + } + EXPECT_EQ(GetRefCount(obj), 1); + } + } + EXPECT_TRUE(AllocRecorder::IsDeleted(obj)); +} + +TEST(Ref_New, Object) { + Ref ref = Ref::New(); + EXPECT_EQ(GetTypeIndex(ref.get()), static_cast(TVMFFITypeIndex::kTVMFFIObject)); + EXPECT_EQ(GetRefCount(ref.get()), 1); +} + +TEST(Ref_New, Func) { + Ref ref = Ref::New(FuncCall); + EXPECT_EQ(GetTypeIndex(ref.get()), static_cast(TVMFFITypeIndex::kTVMFFIFunc)); + EXPECT_EQ(GetRefCount(ref.get()), 1); +} + +TEST(Ref_New, RawStr) { + const char *str = "Hello world"; + Ref ref = Ref::New(str); + EXPECT_EQ(GetTypeIndex(ref.get()), static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_EQ(GetRefCount(ref.get()), 1); + EXPECT_EQ(GetDeleter(ref.get()), DefaultObjectAllocator::Deleter); +} + +TEST(Ref_New, CharArray) { + const char str[18] = "Hello world"; + Ref ref = Ref::New(str); + EXPECT_EQ(GetTypeIndex(ref.get()), static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_EQ(GetRefCount(ref.get()), 1); + EXPECT_EQ(GetDeleter(ref.get()), DefaultObjectAllocator::Deleter); + EXPECT_EQ(ref->size(), 17); +} + +TEST(Ref_New, StdString_Copy) { + std::string str = "Hello world"; + Ref ref = Ref::New(str); + EXPECT_EQ(GetTypeIndex(ref.get()), static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_EQ(GetRefCount(ref.get()), 1); + EXPECT_EQ(GetDeleter(ref.get()), DefaultObjectAllocator::Deleter); +} + +TEST(Ref_New, StdString_Move) { + std::string str = "Hello world"; + Ref ref = Ref::New(std::move(str)); + EXPECT_EQ(GetTypeIndex(ref.get()), static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_EQ(GetRefCount(ref.get()), 1); + EXPECT_EQ(GetDeleter(ref.get()), DefaultObjectAllocator::Deleter); +} + +TEST(Ref_Stringify, Object) { + std::string str = Ref::New().str()->c_str(); + std::string expected_prefix = "object.Object@0"; + EXPECT_GT(str.size(), expected_prefix.size()); + EXPECT_EQ(str.substr(0, expected_prefix.size()), expected_prefix); +} + +TEST(Ref_Stringify, Func) { + std::string str = Ref::New(FuncCall).str()->c_str(); + std::string expected_prefix = "object.Func@0"; + EXPECT_GT(str.size(), expected_prefix.size()); + EXPECT_EQ(str.substr(0, expected_prefix.size()), expected_prefix); +} + +TEST(Ref_Stringify, Str) { + std::string str = Ref::New("Hello world").str()->c_str(); + EXPECT_EQ(str, "\"Hello world\""); +} + +TEST(Ref_Misc, MoveToRaw) { + TVMFFIAny *str = reinterpret_cast(Ref::New("Hello world").MoveToRawObjPtr()); + EXPECT_EQ(GetTypeIndex(str), static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_EQ(GetRefCount(str), 1); + EXPECT_EQ(GetDeleter(str), DefaultObjectAllocator::Deleter); + str->deleter(str); +} + +} // namespace diff --git a/ffi/tests/cpp/test_str.cc b/ffi/tests/cpp/test_str.cc new file mode 100644 index 000000000000..bd91f241f610 --- /dev/null +++ b/ffi/tests/cpp/test_str.cc @@ -0,0 +1,47 @@ +#include +#include + +namespace { +using namespace tvm::ffi; + +const char c_str_long[] = "Hello, World! This is an extremely long string to " + "avoid any on-stack optimization."; + +TEST(Str, CopyFromStdString) { + std::string std_str = "Hello, World!"; + Ref str = Ref::New(std_str); + EXPECT_EQ(str->size(), std_str.size()); + EXPECT_STREQ(str->c_str(), std_str.c_str()); + EXPECT_STREQ(str->data(), std_str.data()); +} + +TEST(Str, MoveFromStdString_0) { + std::string std_str = c_str_long; + const void *data = std_str.data(); + Ref str = Ref::New(std::move(std_str)); + EXPECT_EQ(static_cast(str->data()), data); + EXPECT_STREQ(str->c_str(), c_str_long); +} + +TEST(Str, MoveFromStdString_1) { + Ref str = Ref::New(std::string(c_str_long)); + EXPECT_STREQ(str->c_str(), c_str_long); + EXPECT_EQ(str->size(), sizeof(c_str_long) - 1); +} + +TEST(Str, CopyFromCStr) { + Ref str = Ref::New(c_str_long); + EXPECT_STREQ(str->c_str(), c_str_long); + EXPECT_EQ(str->size(), sizeof(c_str_long) - 1); +} + +TEST(Str, CopyFromCharArray) { + const char c_str[18] = "Hello, World!"; + Ref str = Ref::New(c_str); + EXPECT_EQ(sizeof(c_str), 18); + EXPECT_EQ(strlen(c_str), 13); + EXPECT_STREQ(str->c_str(), c_str); + EXPECT_EQ(str->size(), 17); +} + +} // namespace diff --git a/ffi/tests/cpp/test_type_dyn.cc b/ffi/tests/cpp/test_type_dyn.cc new file mode 100644 index 000000000000..45910528b3f9 --- /dev/null +++ b/ffi/tests/cpp/test_type_dyn.cc @@ -0,0 +1,71 @@ +#include +#include + +#if TVM_FFI_ALLOW_DYN_TYPE == 1 + +namespace { +using namespace tvm::ffi; + +struct TestObj : public Object { + int x; + explicit TestObj(int x) : x(x) {} + TVM_FFI_DEF_DYN_TYPE(TestObj, Object, "test.TestObj"); +}; + +struct SubTestObj : public TestObj { + int y; + explicit SubTestObj(int x, int y) : TestObj(x), y(y) {} + TVM_FFI_DEF_DYN_TYPE(SubTestObj, TestObj, "test.SubTestObj"); +}; + +void CheckAncestor(int32_t num, const int32_t *ancestors, + std::vector expected) { + EXPECT_EQ(num, expected.size()); + for (int i = 0; i < num; ++i) { + EXPECT_EQ(ancestors[i], expected[i]); + } +} + +TEST(DynTypeInfo, TestObj) { + EXPECT_GE(TestObj::_type_index, + static_cast(TVMFFITypeIndex::kTVMFFIDynObjectBegin)); + EXPECT_STRCASEEQ(TestObj::_type_key, "test.TestObj"); + EXPECT_EQ(TestObj::_type_depth, 1); + CheckAncestor(TestObj::_type_depth, TestObj::_type_ancestors.data(), + {static_cast(TVMFFITypeIndex::kTVMFFIObject)}); +} + +TEST(DynTypeInfo, SubTestObj) { + EXPECT_GE(SubTestObj::_type_index, + static_cast(TVMFFITypeIndex::kTVMFFIDynObjectBegin)); + EXPECT_NE(SubTestObj::_type_index, TestObj::_type_index); + EXPECT_STRCASEEQ(SubTestObj::_type_key, "test.SubTestObj"); + EXPECT_EQ(SubTestObj::_type_depth, 2); + CheckAncestor(SubTestObj::_type_depth, SubTestObj::_type_ancestors.data(), + {static_cast(TVMFFITypeIndex::kTVMFFIObject), + TestObj::_type_index}); +} + +TEST(DynTypeInheritance, TestObj) { + Ref obj = Ref::New(10); + EXPECT_EQ(obj->x, 10); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); +} + +TEST(DynTypeInheritance, SubTestObj) { + Ref obj = Ref::New(10, 20); + EXPECT_EQ(obj->x, 10); + EXPECT_EQ(obj->y, 20); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); +} + +} // namespace + +#endif diff --git a/ffi/tests/cpp/test_type_static.cc b/ffi/tests/cpp/test_type_static.cc new file mode 100644 index 000000000000..aebeba7c3e7f --- /dev/null +++ b/ffi/tests/cpp/test_type_static.cc @@ -0,0 +1,119 @@ +#include +#include + +namespace { +using namespace tvm::ffi; + +struct SubType : public Object { + int data; + explicit SubType(int data) : data(data) { + if (data == 1) { + throw std::runtime_error("New Error"); + } + } +}; + +int64_t FuncCall(int64_t x) { return x + 1; } + +void CheckAncestor(int32_t num, const int32_t *ancestors, + std::vector expected) { + EXPECT_EQ(num, expected.size()); + for (int i = 0; i < num; ++i) { + EXPECT_EQ(ancestors[i], expected[i]); + } +} + +static_assert(IsObject, "IsObject == true"); +static_assert(IsObject, "IsObject == true"); +static_assert(IsObject, "IsObject == true"); + +TEST(StaticTypeInfo, Object) { + EXPECT_EQ(Object::_type_index, + static_cast(TVMFFITypeIndex::kTVMFFIObject)); + EXPECT_STRCASEEQ(Object::_type_key, "object.Object"); + EXPECT_EQ(Object::_type_depth, 0); + CheckAncestor(Object::_type_depth, Object::_type_ancestors.data(), {}); +} + +TEST(StaticTypeInfo, Func) { + EXPECT_EQ(Func::_type_index, + static_cast(TVMFFITypeIndex::kTVMFFIFunc)); + EXPECT_STRCASEEQ(Func::_type_key, "object.Func"); + EXPECT_EQ(Func::_type_depth, 1); + CheckAncestor(Func::_type_depth, Func::_type_ancestors.data(), + {static_cast(TVMFFITypeIndex::kTVMFFIObject)}); +} + +TEST(StaticTypeInfo, Str) { + EXPECT_EQ(Str::_type_index, + static_cast(TVMFFITypeIndex::kTVMFFIStr)); + EXPECT_STRCASEEQ(Str::_type_key, "object.Str"); + EXPECT_EQ(Str::_type_depth, 1); + CheckAncestor(Str::_type_depth, Str::_type_ancestors.data(), + {static_cast(TVMFFITypeIndex::kTVMFFIObject)}); +} + +TEST(StaticTypeInheritance, None) { + Ref obj; + // FIXME: The lines below are going to segfault + // EXPECT_STREQ(obj->GetTypeKey(), "None"); + // EXPECT_FALSE(obj->IsInstance()); + // EXPECT_FALSE(obj->IsInstance()); + // EXPECT_FALSE(obj->IsInstance()); +} + +TEST(StaticTypeInheritance, Object) { + Ref obj = Ref::New(); + EXPECT_STREQ(obj->GetTypeKey(), "object.Object"); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); +} + +TEST(StaticTypeInheritance, Func_0) { + Ref obj = Ref::New(FuncCall); + EXPECT_STREQ(obj->GetTypeKey(), "object.Func"); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); +} + +TEST(StaticTypeInheritance, Func_1) { + Ref obj = Ref::New(FuncCall); + EXPECT_STREQ(obj->GetTypeKey(), "object.Func"); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); +} + +TEST(StaticTypeInheritance, Str_0) { + Ref obj = Ref::New("Hello, World!"); + EXPECT_STREQ(obj->GetTypeKey(), "object.Str"); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); + EXPECT_TRUE(obj->IsInstance()); +} + +TEST(StaticTypeInheritance, Str_1) { + Ref obj = Ref::New("Hello, World!"); + EXPECT_STREQ(obj->GetTypeKey(), "object.Str"); + EXPECT_TRUE(obj->IsInstance()); + EXPECT_FALSE(obj->IsInstance()); + EXPECT_TRUE(obj->IsInstance()); +} + +TEST(StaticTypeSubclass, NoException) { + Ref obj = Ref::New(0); + EXPECT_EQ(obj->data, 0); +} + +TEST(StaticTypeSubclass, Exception) { + try { + Ref::New(1); + FAIL() << "No exception thrown"; + } catch (std::runtime_error &ex) { + EXPECT_STREQ(ex.what(), "New Error"); + } +} + +} // namespace diff --git a/include/tvm/ffi/core/c_ffi_abi.h b/include/tvm/ffi/core/c_ffi_abi.h new file mode 100644 index 000000000000..fadbda220eb7 --- /dev/null +++ b/include/tvm/ffi/core/c_ffi_abi.h @@ -0,0 +1,166 @@ +#ifndef TVM_FFI_C_FFI_ABI_H_ +#define TVM_FFI_C_FFI_ABI_H_ + +#include +#include + +#if !defined(TVM_FFI_API) && defined(__EMSCRIPTEN__) +#include +#define TVM_FFI_API EMSCRIPTEN_KEEPALIVE +#endif +#if !defined(TVM_FFI_API) && defined(_MSC_VER) +#ifdef TVM_FFI_EXPORTS +#define TVM_FFI_API __declspec(dllexport) +#else +#define TVM_FFI_API __declspec(dllimport) +#endif +#endif +#ifndef TVM_FFI_API +#define TVM_FFI_API __attribute__((visibility("default"))) +#endif + +#ifndef TVM_FFI_ALLOW_DYN_TYPE +#define TVM_FFI_ALLOW_DYN_TYPE 0 +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +enum class TVMFFITypeIndex : int32_t { +#else +typedef enum { +#endif + // [Section] On-stack POD Types: [0, kTVMFFIStaticObjectBegin) + // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, + // which is not owned by TVMFFIAny. It is required that the following + // invariant holds: + // - `Any::type_index` is never `kTVMFFIRawStr` + // - `AnyView::type_index` can be `kTVMFFIRawStr` + kTVMFFINone = 0, + kTVMFFIInt = 1, + kTVMFFIFloat = 2, + kTVMFFIPtr = 3, + kTVMFFIDataType = 4, + kTVMFFIDevice = 5, + kTVMFFIRawStr = 6, + // [Section] Static Boxed: [kTVMFFIStaticObjectBegin, kTVMFFIDynObjectBegin) + kTVMFFIStaticObjectBegin = 64, + kTVMFFIObject = 64, + kTVMFFIList = 65, + kTVMFFIDict = 66, + kTVMFFIError = 67, + kTVMFFIFunc = 68, + kTVMFFIStr = 69, + // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) + kTVMFFIDynObjectBegin = 128, +#ifdef __cplusplus +}; +#else +} TypeIndex; +#endif + +struct TVMFFIAny; +typedef TVMFFIAny TVMFFIObject; +typedef TVMFFIObject *TVMFFIObjectHandle; +typedef TVMFFIAny *TVMFFIAnyHandle; + +typedef struct TVMFFIAny { + int32_t type_index; + union { // 4 bytes + int32_t ref_cnt; // reference counter for heap object + int32_t small_len; // length for on-stack object + }; + union { // 8 bytes + int64_t v_int64; // integers + double v_float64; // floating-point numbers + DLDataType v_dtype; // data type + DLDevice v_device; // device + void *v_ptr; // typeless pointers + const char *v_str; // raw string + TVMFFIObjectHandle v_obj; // ref counted objects + void (*deleter)(void *); // Deleter of the object + char v_bytes[8]; // small string + char32_t v_char32[2]; // UCS4 string and Unicode + }; +} TVMFFIAny; + +typedef struct { + int32_t type_index; + const char *type_key; + int32_t type_depth; + int32_t *type_ancestors; +} TVMFFITypeInfo; + +typedef TVMFFITypeInfo *TVMFFITypeInfoHandle; + +typedef struct { + const char *filename; + const char *func; + int32_t lineno; + void (*deleter)(void *); +} TVMFFIStackFrame; + +typedef struct { + int32_t type_index; + int32_t ref_cnt; + void (*deleter)(void *); + const char *kind; + int32_t num_frames; + const char **linenos; + const char *message; +} TVMFFIError; + +typedef struct { + int32_t type_index; + int32_t ref_cnt; + void (*deleter)(void *); + int64_t length; + char *data; +} TVMFFIStr; + +typedef struct { + int32_t type_index; + int32_t ref_cnt; + void (*deleter)(void *); + void (*call)(const void *self, int32_t num_args, const TVMFFIAny *args, TVMFFIAny *ret); + int32_t (*safe_call)(const void *self, int32_t num_args, const TVMFFIAny *args, TVMFFIAny *ret); +} TVMFFIFunc; + +typedef struct { + int32_t type_index; + int32_t ref_cnt; + void (*deleter)(void *); + int64_t list_capacity; + int64_t list_length; + int64_t pool_capacity; + int64_t pool_length; +} TVMFFIList; + +typedef struct { + int32_t type_index; + int32_t ref_cnt; + void (*deleter)(void *); + int64_t capacity; + int64_t size; +} TVMFFIDict; + +#if TVM_FFI_ALLOW_DYN_TYPE +typedef void *TVMFFITypeTableHandle; +TVM_FFI_API void TVMFFIDynTypeIndex2Info(TVMFFITypeTableHandle self, int32_t type_index, + TVMFFITypeInfoHandle *out_type_info); +TVM_FFI_API void TVMFFIDynTypeDef(TVMFFITypeTableHandle self, const char *type_key, + int32_t type_depth, const int32_t *type_ancestors, + int32_t *out_type_index); +TVM_FFI_API void TVMFFIDynTypeSetAttr(TVMFFITypeTableHandle self, int32_t type_index, + const char *key, TVMFFIAnyHandle value); +TVM_FFI_API void TVMFFIDynTypeGetAttr(TVMFFITypeTableHandle self, int32_t type_index, + const char *key, TVMFFIAnyHandle *out_type_attr); +#endif // TVM_FFI_ALLOW_DYN_TYPE + +#ifdef __cplusplus +} // TVM_FFI_EXTERN_C +#endif + +#endif // TVM_FFI_C_FFI_ABI_H_ diff --git a/include/tvm/ffi/core/core.h b/include/tvm/ffi/core/core.h new file mode 100644 index 000000000000..608e3e4c1b89 --- /dev/null +++ b/include/tvm/ffi/core/core.h @@ -0,0 +1,421 @@ +#ifndef TVM_FFI_CORE_H_ +#define TVM_FFI_CORE_H_ + +#include "./c_ffi_abi.h" +#include "./traits.h" +#include "./utils.h" +#include +#include + +namespace tvm { +namespace ffi { + +#define TVM_FFI_DEF_TYPE_FRIENDS() \ + template \ + friend struct ::tvm::ffi::IsAnyWithExtra; \ + template \ + friend struct ::tvm::ffi::Ref; \ + template \ + friend struct ::tvm::ffi::RefBase; \ + template \ + friend struct ::tvm::ffi::TypeTraits; \ + template \ + friend struct ::tvm::ffi::TypeTraitsDefaultForObject; \ + template \ + friend struct ::tvm::ffi::DefaultObjectAllocator + +/********** Section 1. Object *********/ + +#define TVM_FFI_DEF_STATIC_TYPE(SelfType, ParentType, TypeIndex) \ +public: \ + TVM_FFI_DEF_TYPE_FRIENDS(); \ + friend struct ::tvm::ffi::Any; \ + friend struct ::tvm::ffi::AnyView; \ + [[maybe_unused]] static constexpr int32_t _type_index = static_cast(TypeIndex); \ + using _type_parent [[maybe_unused]] = ParentType; \ + [[maybe_unused]] static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ + [[maybe_unused]] static inline constexpr std::array _type_ancestors = \ + ::tvm::ffi::details::ObjectAncestorsConstExpr(); \ + template \ + TVM_FFI_INLINE bool IsInstance() const { \ + return ::tvm::ffi::details::IsInstanceOf(this->type_index); \ + } \ + TVM_FFI_INLINE const char *GetTypeKey() const { return TypeIndex2TypeKey(this->type_index); } \ + [[maybe_unused]] static constexpr const char *_type_key = TypeIndexTraits::type_key + +#define TVM_FFI_DEF_DYN_TYPE(SelfType, ParentType, TypeKey) \ +public: \ + TVM_FFI_DEF_TYPE_FRIENDS(); \ + friend struct ::tvm::ffi::Any; \ + friend struct ::tvm::ffi::AnyView; \ + using _type_parent [[maybe_unused]] = ParentType; \ + [[maybe_unused]] static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ + [[maybe_unused]] static inline const std::array _type_ancestors = \ + ::tvm::ffi::details::ObjectAncestors(); \ + [[maybe_unused]] static inline int32_t _type_index = []() -> int32_t { \ + int32_t ret; \ + TVMFFIDynTypeDef(nullptr, TypeKey, _type_depth, _type_ancestors.data(), &ret); \ + return ret; \ + }(); \ + template \ + TVM_FFI_INLINE bool IsInstance() const { \ + return ::tvm::ffi::details::IsInstanceOf(this->type_index); \ + } \ + TVM_FFI_INLINE const char *GetTypeKey() const { return TypeIndex2TypeKey(this->type_index); } \ + [[maybe_unused]] static constexpr const char *_type_key = TypeKey + +struct Object : protected TVMFFIAny { + TVM_FFI_DEF_STATIC_TYPE(Object, details::DummyRoot, TVMFFITypeIndex::kTVMFFIObject); + + TVM_FFI_INLINE Object() : TVMFFIAny() {} + TVM_FFI_INLINE Object(const Object &) : TVMFFIAny() {} + TVM_FFI_INLINE Object(Object &&) {} + TVM_FFI_INLINE Object &operator=(const Object &) { return *this; } + TVM_FFI_INLINE Object &operator=(Object &&) { return *this; } + Ref str() const; + TVM_FFI_INLINE friend std::ostream &operator<<(std::ostream &os, const Object &src) { + TVMFFIAny v{}; + v.type_index = src.type_index; + v.v_obj = const_cast(&src); + details::AnyView2Str(os, &v); + return os; + } +}; + +/********** Section 2. Ref *********/ + +#define TVM_FFI_REF_DEF_DELEGATE_CONSTRUCTORS(SelfType, BaseType) \ + template > \ + TVM_FFI_INLINE Ref(Other &&src) : BaseType(std::forward(src)) {} \ + template > \ + TVM_FFI_DEF_ASSIGN(SelfType, const Other &) + +template +struct RefBase { +private: + using TSelf = RefBase; + using TSub = Ref; + TVM_FFI_DEF_TYPE_FRIENDS(); + friend struct Any; + friend struct AnyView; + template + using EnableAnyOrViewOrRef = std::enable_if_t || details::IsRef>; + template + using EnableDerivedObj = + typename std::enable_if_t || + (std::is_same_v && IsObject)>; + +public: + /***** Factory: the `new` operator *****/ + template + TVM_FFI_INLINE static TSub New(Args &&...args) { + return TSub(Allocator::New(std::forward(args)...)); + } + /***** Accessors *****/ + TVM_FFI_INLINE const Type *get() const { return reinterpret_cast(this->data_); } + TVM_FFI_INLINE Type *get() { return reinterpret_cast(data_); } + TVM_FFI_INLINE const Type *operator->() const { return get(); } + TVM_FFI_INLINE const Type &operator*() const { return *get(); } + TVM_FFI_INLINE Type *operator->() { return get(); } + TVM_FFI_INLINE Type &operator*() { return *get(); } + /***** Misc *****/ + Ref str() const; + TVM_FFI_INLINE friend std::ostream &operator<<(std::ostream &os, const TSelf &src) { + TVMFFIAny v = src.AsTVMFFIAny(); + details::AnyView2Str(os, &v); + return os; + } + template + BaseType *GetRawObjPtr() const { + static_assert(std::is_same_v || std::is_base_of_v, + "Only downcasting is allowed"); + return reinterpret_cast(this->data_); + } + template + BaseType *MoveToRawObjPtr() { + static_assert(std::is_same_v || std::is_base_of_v, + "Only downcasting is allowed"); + BaseType *ret = reinterpret_cast(this->data_); + this->data_ = nullptr; + return ret; + } + +protected: + using Allocator = typename GetAllocator::Type; + /***** Destructor *****/ + TVM_FFI_INLINE ~RefBase() { this->DecRef(); } + /***** Constructor 0: default *****/ + TVM_FFI_INLINE RefBase() : data_(nullptr) {} + /***** Constructor 1: From raw pointers *****/ + TVM_FFI_INLINE RefBase(TVMFFIAny *data) : data_(data) { this->IncRef(); } + template > + TVM_FFI_INLINE RefBase(Derived *data) : TSelf(reinterpret_cast(data)) {} + TVM_FFI_DEF_ASSIGN(TSelf, TVMFFIAny *) + template > + TVM_FFI_DEF_ASSIGN(TSelf, Derived *) + /***** Constructor 2: from RefBase *****/ + TVM_FFI_INLINE RefBase(const TSelf &other) : data_(other.data_) { + this->IncRef(); + } + TVM_FFI_INLINE RefBase(TSelf &&other) : data_(other.data_) { other.data_ = nullptr; } + TVM_FFI_DEF_ASSIGN(TSelf, const TSelf &) + TVM_FFI_DEF_ASSIGN(TSelf, TSelf &&) + /***** Constructor 3: from Ref, Any and AnyView *****/ + template > + TVM_FFI_DEF_ASSIGN(TSub, Other &&) + template > + TVM_FFI_DEF_ASSIGN(TSub, const Other &) + template > + TVM_FFI_INLINE RefBase(const Other &src) : RefBase(src.template GetRawObjPtr()) {} + template > + TVM_FFI_INLINE RefBase(Other &&src) + : data_(reinterpret_cast(src.template MoveToRawObjPtr())) { + if constexpr (std::is_same_v) { + this->IncRef(); + } + } + TVM_FFI_INLINE void IncRef() { details::IncRef(this->data_); } + TVM_FFI_INLINE void DecRef() { details::DecRef(this->data_); } + TVM_FFI_INLINE void Swap(TSelf &other) { std::swap(this->data_, other.data_); } + TVM_FFI_INLINE TVMFFIAny AsTVMFFIAny() const { + if (data_ == nullptr) { + return TVMFFIAny(); + } + TVMFFIAny ret{}; + ret.type_index = data_->type_index; + ret.v_obj = data_; + return ret; + } + + TVMFFIAny *data_; +}; + +template +struct Ref : public RefBase { + TVM_FFI_DEF_TYPE_FRIENDS(); + TVM_FFI_INLINE Ref() : RefBase() {} + TVM_FFI_REF_DEF_DELEGATE_CONSTRUCTORS(Ref, RefBase) +}; + +/********** Section 3. AnyView *********/ + +struct AnyView : public TVMFFIAny { + TVM_FFI_DEF_TYPE_FRIENDS(); + friend struct Any; + /***** Destructor *****/ + TVM_FFI_INLINE ~AnyView() = default; + TVM_FFI_INLINE void Reset() { *(static_cast(this)) = TVMFFIAny(); } + /***** Constructor 0: default *****/ + TVM_FFI_INLINE AnyView() : TVMFFIAny() {} + /***** Constructor 1: from AnyView *****/ + TVM_FFI_INLINE AnyView(const AnyView &src) = default; + TVM_FFI_INLINE AnyView &operator=(const AnyView &src) = default; + TVM_FFI_INLINE AnyView(AnyView &&src) : TVMFFIAny(*&src) { src.Reset(); } + TVM_FFI_DEF_ASSIGN(AnyView, AnyView &&) + /***** Constructor 2: from Any *****/ + TVM_FFI_INLINE AnyView(const Any &src); + TVM_FFI_DEF_ASSIGN(AnyView, const Any &) + AnyView(Any &&src) = delete; + AnyView &operator=(Any &&src) = delete; + /***** Constructor 3: from Ref *****/ + template + TVM_FFI_INLINE AnyView(const Ref &src) : TVMFFIAny(src.AsTVMFFIAny()) {} + template + TVM_FFI_DEF_ASSIGN(AnyView, const Ref &) + template + AnyView(Ref &&src) = delete; + template + AnyView &operator=(Ref &&src) = delete; + /***** Constructors 4: use TypeTraits *****/ + template > + TVM_FFI_INLINE AnyView(const Type &src) : TVMFFIAny() { + TypeTraitsNoCR::CopyFromTypeToTVMFFIAny(src, this); + } + template > + TVM_FFI_DEF_ASSIGN(AnyView, const Type &) + /*** Converter 0: use TypeTraits ***/ + template > + operator Type() const; + template > + Type CastWithStorage(Any *storage) const; + /***** Misc *****/ + Ref str() const; + TVM_FFI_INLINE friend std::ostream &operator<<(std::ostream &os, const AnyView &src) { + details::AnyView2Str(os, &src); + return os; + } + template + TVM_FFI_INLINE Type *GetRawObjPtr() const; + template + TVM_FFI_INLINE Type *MoveToRawObjPtr(); + +protected: + TVM_FFI_INLINE void Swap(TVMFFIAny &src) { + TVMFFIAny tmp = *this; + *static_cast(this) = src; + src = tmp; + } +}; + +/********** Section 4. Any *********/ + +struct Any : public TVMFFIAny { + TVM_FFI_DEF_TYPE_FRIENDS(); + friend struct AnyView; + /***** Destructor *****/ + TVM_FFI_INLINE ~Any() { this->Reset(); } + TVM_FFI_INLINE void Reset() { + this->DecRef(); + *(static_cast(this)) = TVMFFIAny(); + } + /***** Constructor 0: default *****/ + TVM_FFI_INLINE Any() : TVMFFIAny() {} + /***** Constructor 1: from AnyView *****/ + TVM_FFI_INLINE Any(const AnyView &src) : TVMFFIAny(*static_cast(&src)) { + if (this->type_index == static_cast(TVMFFITypeIndex::kTVMFFIRawStr)) { + // Special case: handle the case where `Any` needs to own a raw string. + this->type_index = static_cast(TVMFFITypeIndex::kTVMFFIStr); + this->v_obj = details::StrCopyFromCharArray(this->v_str, std::strlen(this->v_str)); + } + this->IncRef(); + } + TVM_FFI_INLINE Any(AnyView &&src) : Any(static_cast(src)) { // TODO: add test + src.Reset(); + } + TVM_FFI_DEF_ASSIGN(Any, const AnyView &) + TVM_FFI_DEF_ASSIGN(Any, AnyView &&) + /***** Constructor 2: from Any *****/ + TVM_FFI_INLINE Any(const Any &src) : TVMFFIAny(*static_cast(&src)) { + this->IncRef(); + } + TVM_FFI_INLINE Any(Any &&src) : TVMFFIAny(*static_cast(&src)) { + *static_cast(&src) = TVMFFIAny(); + } + TVM_FFI_DEF_ASSIGN(Any, const Any &) + TVM_FFI_DEF_ASSIGN(Any, Any &&) + /***** Constructor 3: from Ref *****/ + template + TVM_FFI_INLINE Any(const Ref &src) : TVMFFIAny(src.AsTVMFFIAny()) { + this->IncRef(); + } + template + TVM_FFI_INLINE Any(Ref &&src) : TVMFFIAny(src.AsTVMFFIAny()) { + src.data_ = nullptr; + } + template + TVM_FFI_DEF_ASSIGN(Any, const Ref &) + template + TVM_FFI_DEF_ASSIGN(Any, Ref &&) + /***** Constructors 4: use TypeTraits *****/ + template > + TVM_FFI_INLINE Any(const Type &src) : Any(AnyView(src)) {} + template > + TVM_FFI_DEF_ASSIGN(Any, const Type &) + /***** Constructors 5: Special handling for strings *****/ + TVM_FFI_INLINE Any(const std::string &s) : Any(AnyView(s)) {} + TVM_FFI_INLINE Any(const char *s) : Any(AnyView(s)) {} + TVM_FFI_INLINE Any(std::string &&s) : TVMFFIAny() { + this->type_index = static_cast(TVMFFITypeIndex::kTVMFFIStr); + this->v_obj = details::StrMoveFromStdString(std::move(s)); + this->IncRef(); + } + template + TVM_FFI_INLINE Any(const CharArray &s) : TVMFFIAny() { + this->type_index = static_cast(TVMFFITypeIndex::kTVMFFIStr); + this->v_obj = details::StrCopyFromCharArray(s, N - 1); + this->IncRef(); + } + TVM_FFI_DEF_ASSIGN(Any, const std::string &) + TVM_FFI_DEF_ASSIGN(Any, const char *&) + TVM_FFI_DEF_ASSIGN(Any, std::string &&) + template + TVM_FFI_DEF_ASSIGN(Any, const CharArray &) + /*** Converter 0: use TypeTraits ***/ + template > + operator Type() const; + /***** Misc *****/ + Ref str() const; + TVM_FFI_INLINE friend std::ostream &operator<<(std::ostream &os, const Any &src) { + details::AnyView2Str(os, &src); + return os; + } + template + TVM_FFI_INLINE Type *GetRawObjPtr() const { + return reinterpret_cast(this)->GetRawObjPtr(); + } + template + TVM_FFI_INLINE Type *MoveToRawObjPtr() { + return reinterpret_cast(this)->MoveToRawObjPtr(); + } + +protected: + TVM_FFI_INLINE void Swap(TVMFFIAny &src) { + TVMFFIAny tmp = *this; + *static_cast(this) = src; + src = tmp; + } + TVM_FFI_INLINE void IncRef() { + if (!details::IsTypeIndexPOD(this->type_index)) { + details::IncRef(this->v_obj); + } + } + TVM_FFI_INLINE void DecRef() { + if (!details::IsTypeIndexPOD(this->type_index)) { + details::DecRef(this->v_obj); + } + } +}; + +/********** Section 5. Type Conversion and Type Table *********/ + +TVM_FFI_INLINE AnyView::AnyView(const Any &src) : TVMFFIAny(*&src) {} + +#define TVM_FFI_TRY_CONVERT(Expr, TypeStr) \ + try { \ + return Expr; \ + } catch (const TemporaryTypeError &) { \ + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" \ + << TypeIndex2TypeKey(this->type_index) << "` to `" << TypeStr << "`"; \ + } \ + TVM_FFI_UNREACHABLE(); +template +inline AnyView::operator Type() const { + TVM_FFI_TRY_CONVERT(TypeTraitsNoCR::CopyFromTVMFFIAnyToType(this), + TypeTraitsNoCR::Type2Str()); +} +template +inline Type AnyView::CastWithStorage(Any *storage) const { + TVM_FFI_TRY_CONVERT(TypeTraitsNoCR::CopyFromTVMFFIAnyToTypeWithStorage(this, storage), + TypeTraitsNoCR::Type2Str()); +} +template +inline Any::operator Type() const { + TVM_FFI_TRY_CONVERT(TypeTraitsNoCR::CopyFromTVMFFIAnyToType(this), + TypeTraitsNoCR::Type2Str()); +} +template +inline Type *AnyView::GetRawObjPtr() const { + TVM_FFI_TRY_CONVERT(TypeTraitsNoCR::CopyFromTVMFFIAnyToRef(this), Type::_type_key); +} +template +inline Type *AnyView::MoveToRawObjPtr() { + TVM_FFI_TRY_CONVERT(TypeTraitsNoCR::MoveFromTVMFFIAnyToRef(this), Type::_type_key); +} +#undef TVM_FFI_TRY_CONVERT + +#if TVM_FFI_ALLOW_DYN_TYPE +TVM_FFI_INLINE void TypeSetAttr(int32_t type_index, const char *attr_key, AnyView attr_value) { + TVMFFIDynTypeSetAttr(nullptr, type_index, attr_key, &attr_value); +} + +TVM_FFI_INLINE AnyView TypeGetAttr(int32_t type_index, const char *attr_key) { + TVMFFIAnyHandle attr; + TVMFFIDynTypeGetAttr(nullptr, type_index, attr_key, &attr); + return AnyView(*(static_cast(attr))); +} +#endif + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_CORE_H_ diff --git a/include/tvm/ffi/core/traits.h b/include/tvm/ffi/core/traits.h new file mode 100644 index 000000000000..8f80b77f49be --- /dev/null +++ b/include/tvm/ffi/core/traits.h @@ -0,0 +1,613 @@ +#ifndef TVM_FFI_TRAITS_H_ +#define TVM_FFI_TRAITS_H_ + +#include "./utils.h" +#include +#include + +namespace tvm { +namespace ffi { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4702) +#endif + +/*! + * \brief TypeTraits is a template class that provides a set of static + * methods that are associated with a specific type `T` for compile-time + * dispatching. + * + * [Trait 0] Type2Str: () -> std::string + * Returns the string representation of the type `T` + * + * [Trait 1] CopyFromTVMFFIAnyToType: (const TVMFFIAny *) -> T + * Converts an `AnyView` or `Any` to `T`. It could incur copy when + * inevitable, for example, copy the content of a Str obj if it's converted + * to `std::string`. + * + * It is used in the following cases in the codebase: + * 1) `AnyView::operator Type()` + * 2) `Any::operator Type()` + * + * [Trait 2] CopyFromTypeToTVMFFIAny: (T, TVMFFIAny *) -> void + * Converts a value of type T to `AnyView`. + * + * It is used in the following case in the codebase: + * 1) `AnyView::AnyView(const Type& src)` + * + * [Trait 3] CopyFromTVMFFIAnyToRef: (const TVMFFIAny *) -> T* + * Converts an `AnyView` or `Any` to `TVMFFIAny*`, which is subsequently used + * to initialize `Ref::data_`. Note that this method does not increment + * the reference counter, which is supposed to be later handled automatically by + * `Ref`. + * + * It is used in the following cases in the codebase: + * 1) `Ref::Ref(const AnyView& src)` + * 2) `Ref::Ref(const Any& src)` + * + * [Trait 4] MoveFromTVMFFIAnyToRef: (TVMFFIAny *) -> T* + * Moves an `AnyView` or `Any` to `TVMFFIAny*`, which is subsequently used + * to initialize `Ref::data_`. Note that the reference counter will not + * increment throughout the process because it corresponds to move semantics. + * + * It is used in the following cases in the codebase: + * 1) `Ref::Ref(AnyView&& src)` + * 2) `Ref::Ref(Any&& src)` + * + * [Trait 5] (Optional) CopyFromTVMFFIAnyToTypeWithStorage: + * (const TVMFFIAny *, Any *) -> T + * It does similar thing as [Trait 1] that converts an `AnyView` or `Any` to + * `T`, but is onyl used by TVM FFI's calling convention, where an additional + * storage `Any*` is provided to retain ownership when unpacking `AnyView[]` + * before calling into a C++ function. + * + * Example. When converting an `AnyView (kTVMFFIRawStr)` to `Str*`, + * which consists of two intermediate steps `AnyView (kTVMFFIRawStr)` to + * `Ref`, and `Ref` to `Str*`, this method will store the by-product + * `Ref` into the given storage to lifespan expiration. + * + * It is used in the following case(s) in the codebase: + * 1) `AnyView::CastWithStorage(Any *)` + */ +template +struct TypeTraits { + constexpr static bool enabled = false; +}; +template +using TypeTraitsNoCR = TypeTraits>; +template +using HasTypeTraits = std::enable_if_t::enabled>; + +template +struct TypeIndexTraits; + +#define TVM_FFI_DEF_TYPE_INDEX_TRAITS(TypeIndex_, TypeKey) \ + template <> \ + struct TypeIndexTraits { \ + static constexpr const char *type_key = TypeKey; \ + } +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFINone, "None"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIInt, "int"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIFloat, "float"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIPtr, "Ptr"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIDataType, "dtype"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIDevice, "Device"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIRawStr, "const char *"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIObject, "object.Object"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIList, "object.List"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIDict, "object.Dict"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIError, "object.Error"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIFunc, "object.Func"); +TVM_FFI_DEF_TYPE_INDEX_TRAITS(TVMFFITypeIndex::kTVMFFIStr, "object.Str"); +#undef TVM_FFI_DEF_TYPE_INDEX_TRAITS + +TVM_FFI_INLINE const char *TypeIndex2TypeKey(int32_t type_index) { +#define TVM_FFI_TYPE_INDEX_SWITCH_CASE(TypeIndex_) \ + case TypeIndex_: \ + return TypeIndexTraits::type_key; + switch (static_cast(type_index)) { + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFINone); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIInt); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIFloat); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIPtr); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIDataType); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIDevice); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIRawStr); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIObject); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIList); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIDict); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIError); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIFunc); + TVM_FFI_TYPE_INDEX_SWITCH_CASE(TVMFFITypeIndex::kTVMFFIStr); + default: +#if TVM_FFI_ALLOW_DYN_TYPE + { + TVMFFITypeInfoHandle type_info; + TVMFFIDynTypeIndex2Info(nullptr, type_index, &type_info); + return type_info ? type_info->type_key : "(undefined)"; + } +#else + return "Unknown"; +#endif + } +#undef TVM_FFI_TYPE_INDEX_SWITCH_CASE + TVM_FFI_UNREACHABLE(); +} + +namespace details { +struct DummyRoot { + static constexpr bool _type_is_static = true; + static constexpr int32_t _type_depth = -1; + static constexpr int32_t _type_index = -1; + static constexpr std::array _type_ancestors = {}; +}; +template +TVM_FFI_INLINE void InitTypeTable(F f, TypeTableType *self) { + { + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + f.template RegisterType(self); + } + { + f.template RegisterStr(self); + f.template RegisterStr(self); + f.template RegisterStr(self); + f.template RegisterStr(self); + f.template RegisterStr(self); + f.template RegisterStr(self); + f.template RegisterStr(self); + f.template RegisterStr(self); + } +} + +template +using Identity = T; + +template