From 02919c75a03cd6f112b33864aa5d88b776ed17e3 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 9 Aug 2025 15:15:20 -0400 Subject: [PATCH] [FFI] Formalize ffi.Module This PR formalizes original runtime::Module into ffi as ffi.Module and cleans the APIs around it. The goal is to stablize the Module API as extra API that can benefit the overall ffi interactions. We also refactors the c++ code that depends on the Module. --- .../app/src/main/jni/tvm_runtime.h | 8 +- apps/cpp_rpc/rpc_env.cc | 2 +- apps/hexagon_launcher/launcher_core.cc | 4 +- apps/ios_rpc/tvmrpc/TVMRuntime.mm | 8 +- ffi/CMakeLists.txt | 4 + ffi/include/tvm/ffi/extra/c_env_api.h | 70 +++++ ffi/include/tvm/ffi/extra/module.h | 224 ++++++++++++++ ffi/include/tvm/ffi/object.h | 13 +- ffi/src/ffi/extra/buffer_stream.h | 127 ++++++++ ffi/src/ffi/extra/library_module.cc | 199 +++++++++++++ .../ffi/extra/library_module_dynamic_lib.cc | 84 ++---- .../ffi/extra/library_module_system_lib.cc | 57 ++-- ffi/src/ffi/extra/module.cc | 139 +++++++++ ffi/src/ffi/extra/module_internal.h | 104 +++++++ include/tvm/runtime/c_backend_api.h | 2 +- include/tvm/runtime/disco/builtin.h | 2 +- include/tvm/runtime/module.h | 279 +----------------- include/tvm/runtime/profiling.h | 4 +- include/tvm/runtime/vm/executable.h | 28 +- include/tvm/runtime/vm/vm.h | 4 +- include/tvm/target/codegen.h | 14 +- .../src/main/java/org/apache/tvm/Module.java | 12 +- python/tvm/contrib/hexagon/tools.py | 2 +- python/tvm/relax/op/nn/nn.py | 12 +- python/tvm/relax/vm_build.py | 2 +- python/tvm/rpc/client.py | 16 +- python/tvm/runtime/disco/session.py | 24 +- python/tvm/runtime/executable.py | 2 +- python/tvm/runtime/module.py | 139 ++++----- python/tvm/testing/usmp.py | 39 --- src/contrib/msc/framework/tensorrt/codegen.cc | 12 +- src/node/structural_hash.cc | 9 +- src/relax/backend/contrib/clml/codegen.cc | 10 +- src/relax/backend/contrib/cublas/codegen.cc | 8 +- src/relax/backend/contrib/cudnn/codegen.cc | 8 +- src/relax/backend/contrib/cutlass/codegen.cc | 13 +- src/relax/backend/contrib/dnnl/codegen.cc | 8 +- src/relax/backend/contrib/hipblas/codegen.cc | 8 +- src/relax/backend/contrib/nnapi/codegen.cc | 8 +- src/relax/backend/contrib/tensorrt/codegen.cc | 10 +- src/relax/backend/vm/codegen_vm.cc | 38 +-- src/relax/backend/vm/exec_builder.cc | 2 +- src/relax/transform/fold_constant.cc | 4 +- src/relax/transform/run_codegen.cc | 10 +- src/runtime/const_loader_module.cc | 61 ++-- src/runtime/const_loader_module.h | 2 +- .../contrib/arm_compute_lib/acl_runtime.cc | 12 +- src/runtime/contrib/bnns/bnns_json_runtime.cc | 10 +- src/runtime/contrib/clml/clml_runtime.cc | 10 +- src/runtime/contrib/coreml/coreml_runtime.h | 11 +- src/runtime/contrib/coreml/coreml_runtime.mm | 24 +- .../contrib/cublas/cublas_json_runtime.cc | 17 +- .../contrib/cudnn/cudnn_json_runtime.cc | 12 +- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 15 +- .../contrib/edgetpu/edgetpu_runtime.cc | 4 +- .../contrib/hipblas/hipblas_json_runtime.cc | 17 +- src/runtime/contrib/json/json_runtime.h | 30 +- src/runtime/contrib/mrvl/mrvl_hw_runtime.cc | 38 ++- src/runtime/contrib/mrvl/mrvl_runtime.cc | 37 ++- src/runtime/contrib/msc/tensorrt_runtime.cc | 16 +- src/runtime/contrib/nnapi/nnapi_runtime.cc | 8 +- .../contrib/tensorrt/tensorrt_runtime.cc | 12 +- src/runtime/contrib/tflite/tflite_runtime.cc | 11 +- src/runtime/contrib/tflite/tflite_runtime.h | 10 +- src/runtime/cuda/cuda_module.cc | 46 +-- src/runtime/cuda/cuda_module.h | 6 +- src/runtime/device_api.cc | 6 +- src/runtime/disco/builtin.cc | 39 +-- src/runtime/disco/loader.cc | 17 +- src/runtime/hexagon/hexagon_common.cc | 7 +- src/runtime/hexagon/hexagon_module.cc | 24 +- src/runtime/hexagon/hexagon_module.h | 27 +- src/runtime/hexagon/rpc/hexagon/rpc_server.cc | 7 +- .../hexagon/rpc/simulator/rpc_server.cc | 7 +- src/runtime/library_module.cc | 201 ------------- src/runtime/library_module.h | 125 -------- src/runtime/metal/metal_module.h | 8 +- src/runtime/metal/metal_module.mm | 42 +-- src/runtime/module.cc | 143 ++------- src/runtime/opencl/opencl_common.h | 16 +- src/runtime/opencl/opencl_module.cc | 44 +-- src/runtime/opencl/opencl_module.h | 11 +- src/runtime/opencl/opencl_module_spirv.cc | 22 +- src/runtime/profiling.cc | 87 +++--- src/runtime/rocm/rocm_module.cc | 48 +-- src/runtime/rocm/rocm_module.h | 8 +- src/runtime/rpc/rpc_endpoint.cc | 6 +- src/runtime/rpc/rpc_module.cc | 83 +++--- src/runtime/rpc/rpc_pipe_impl.cc | 2 +- src/runtime/rpc/rpc_session.h | 6 +- src/runtime/rpc/rpc_socket_impl.cc | 4 +- src/runtime/static_library.cc | 39 +-- src/runtime/static_library.h | 2 +- src/runtime/vm/executable.cc | 49 ++- src/runtime/vm/ndarray_cache_support.cc | 14 +- src/runtime/vm/vm.cc | 42 +-- src/runtime/vulkan/vulkan_module.cc | 18 +- src/runtime/vulkan/vulkan_module.h | 7 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 16 +- src/runtime/vulkan/vulkan_wrapped_func.h | 14 +- src/support/ffi_testing.cc | 20 +- src/target/codegen.cc | 151 ++++------ src/target/llvm/codegen_amdgpu.cc | 2 +- src/target/llvm/codegen_blob.cc | 6 +- src/target/llvm/codegen_cpu.cc | 11 +- src/target/llvm/codegen_hexagon.cc | 2 +- src/target/llvm/codegen_nvptx.cc | 2 +- src/target/llvm/llvm_module.cc | 87 +++--- src/target/llvm/llvm_module.h | 2 +- src/target/opt/build_cuda_off.cc | 8 +- src/target/opt/build_cuda_on.cc | 2 +- src/target/opt/build_hexagon_off.cc | 7 +- src/target/opt/build_metal_off.cc | 6 +- src/target/opt/build_opencl_off.cc | 13 +- src/target/opt/build_rocm_off.cc | 6 +- src/target/source/codegen_c_host.cc | 10 +- src/target/source/codegen_metal.cc | 2 +- src/target/source/codegen_opencl.cc | 2 +- src/target/source/codegen_source_base.h | 16 +- src/target/source/codegen_webgpu.cc | 22 +- src/target/source/source_module.cc | 90 +++--- src/target/spirv/build_vulkan.cc | 2 +- .../opencl/opencl_compile_to_bin.cc | 8 +- .../codegen/test_target_codegen_aarch64.py | 30 +- .../python/codegen/test_target_codegen_arm.py | 6 +- .../codegen/test_target_codegen_c_host.py | 2 +- .../codegen/test_target_codegen_cross_llvm.py | 6 +- .../codegen/test_target_codegen_cuda.py | 10 +- .../codegen/test_target_codegen_cuda_fp8.py | 6 +- .../codegen/test_target_codegen_hexagon.py | 4 +- .../codegen/test_target_codegen_llvm.py | 16 +- .../codegen/test_target_codegen_llvm_vla.py | 10 +- .../codegen/test_target_codegen_metal.py | 2 +- .../codegen/test_target_codegen_opencl.py | 6 +- .../codegen/test_target_codegen_riscv.py | 2 +- .../codegen/test_target_codegen_vulkan.py | 2 +- .../python/codegen/test_target_codegen_x86.py | 2 +- .../test_benchmark_elemwise_add.py | 2 +- .../test_hexagon/test_benchmark_maxpool2d.py | 2 +- .../contrib/test_hexagon/test_sigmoid.py | 6 +- .../python/contrib/test_hexagon/test_vtcm.py | 2 +- .../ir/test_roundtrip_runtime_module.py | 8 +- .../relax/backend/clml/test_clml_codegen.py | 2 +- tests/python/relax/test_vm_instrument.py | 2 +- .../runtime/test_runtime_module_export.py | 12 +- .../runtime/test_runtime_module_load.py | 10 +- .../runtime/test_runtime_module_property.py | 18 +- tests/python/runtime/test_runtime_rpc.py | 2 +- ...est_tir_transform_inject_ptx_async_copy.py | 2 +- web/emcc/tvmjs_support.cc | 16 +- web/emcc/wasm_runtime.cc | 6 +- web/emcc/webgpu_runtime.cc | 24 +- web/src/runtime.ts | 8 +- 153 files changed, 2131 insertions(+), 1944 deletions(-) create mode 100644 ffi/include/tvm/ffi/extra/c_env_api.h create mode 100644 ffi/include/tvm/ffi/extra/module.h create mode 100644 ffi/src/ffi/extra/buffer_stream.h create mode 100644 ffi/src/ffi/extra/library_module.cc rename src/runtime/dso_library.cc => ffi/src/ffi/extra/library_module_dynamic_lib.cc (51%) rename src/runtime/system_library.cc => ffi/src/ffi/extra/library_module_system_lib.cc (63%) create mode 100644 ffi/src/ffi/extra/module.cc create mode 100644 ffi/src/ffi/extra/module_internal.h delete mode 100644 python/tvm/testing/usmp.py delete mode 100644 src/runtime/library_module.cc delete mode 100644 src/runtime/library_module.h diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 5255d3f4b10a..feed44498917 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -37,6 +37,10 @@ #include "../ffi/src/ffi/container.cc" #include "../ffi/src/ffi/dtype.cc" #include "../ffi/src/ffi/error.cc" +#include "../ffi/src/ffi/extra/library_module.cc" +#include "../ffi/src/ffi/extra/library_module_dynamic_lib.cc" +#include "../ffi/src/ffi/extra/library_module_system_lib.cc" +#include "../ffi/src/ffi/extra/module.cc" #include "../ffi/src/ffi/function.cc" #include "../ffi/src/ffi/ndarray.cc" #include "../ffi/src/ffi/object.cc" @@ -44,13 +48,10 @@ #include "../ffi/src/ffi/traceback.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/device_api.cc" -#include "../src/runtime/dso_library.cc" #include "../src/runtime/file_utils.cc" -#include "../src/runtime/library_module.cc" #include "../src/runtime/logging.cc" #include "../src/runtime/memory/memory_manager.cc" #include "../src/runtime/minrpc/minrpc_logger.cc" -#include "../src/runtime/module.cc" #include "../src/runtime/ndarray.cc" #include "../src/runtime/profiling.cc" #include "../src/runtime/registry.cc" @@ -62,7 +63,6 @@ #include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" #include "../src/runtime/workspace_pool.cc" diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index e5a5154acbf2..c4a43dc9f39f 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -147,7 +147,7 @@ RPCEnv::RPCEnv(const std::string& wd) { std::string file_name = this->GetPath(path); file_name = BuildSharedLibrary(file_name); LOG(INFO) << "Load module from " << file_name << " ..."; - return Module::LoadFromFile(file_name, ""); + return ffi::Module::LoadFromFile(file_name); })); ffi::Function::SetGlobal("tvm.rpc.server.download_linked_module", diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index 56242082cca3..fa2c3d8e3300 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -144,7 +144,7 @@ const tvm::ffi::Function get_runtime_func(const std::string& name) { } const tvm::ffi::Function get_module_func(tvm::runtime::Module module, const std::string& name) { - return module.GetFunction(name, false); + return module->GetFunction(name, false).value_or(tvm::ffi::Function()); } void reset_device_api() { @@ -153,7 +153,7 @@ void reset_device_api() { } tvm::runtime::Module load_module(const std::string& file_name) { - static const tvm::ffi::Function loader = get_runtime_func("runtime.module.loadfile_hexagon"); + static const tvm::ffi::Function loader = get_runtime_func("ffi.Module.load_from_file.hexagon"); tvm::ffi::Any rv = loader(file_name); if (rv.type_code() == kTVMModuleHandle) { ICHECK_EQ(rv.type_code(), kTVMModuleHandle) diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index c6f62515736c..09ee55390959 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -33,7 +33,7 @@ #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 // internal TVM header to achieve Library class -#include <../../../src/runtime/library_module.h> +#include <../../../ffi/src/ffi/extra/library_module.h> #include #endif @@ -70,7 +70,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s NSBundle* bundle = [NSBundle mainBundle]; base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; - if (tvm::ffi::Function::GetGlobal("runtime.module.loadfile_dylib_custom")) { + if (tvm::ffi::Function::GetGlobal("ffi.Module.load_from_file.dylib_custom")) { // Custom dso loader is present. Will use it. base = NSTemporaryDirectory(); fmt = "dylib_custom"; @@ -114,11 +114,11 @@ void Init(const std::string& name) { // Add UnsignedDSOLoader plugin in global registry TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("runtime.module.loadfile_dylib_custom", + refl::GlobalDef().def_packed("ffi.Module.load_from_file.dylib_custom", [](ffi::PackedArgs args, ffi::Any* rv) { auto n = make_object(); n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); + *rv = tvm::ffi::CreateLibraryModule(n); }); }); diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index af9943476e3d..ce4f4d4e208a 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -69,6 +69,10 @@ if (TVM_FFI_USE_EXTRA_CXX_API) "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/module.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" ) endif() diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h new file mode 100644 index 000000000000..5d5d908f78ba --- /dev/null +++ b/ffi/include/tvm/ffi/extra/c_env_api.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/c_env_api.h + * \brief Extra environment API. + */ +#ifndef TVM_FFI_EXTRA_C_ENV_API_H_ +#define TVM_FFI_EXTRA_C_ENV_API_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief FFI function to lookup a function from a module's imports. + * + * This is a helper function that is used by generated code. + * + * \param library_ctx The library context module handle. + * \param func_name The name of the function. + * \param out The result function. + * \note The returned function is a weak reference that is cached/owned by the module. + * \return 0 when no error is thrown, -1 when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, + TVMFFIObjectHandle* out); + +/* + * \brief Register a symbol value that will be initialized when a library with the symbol is loaded. + * + * This function can be used to make context functions to be available in the library + * module that wants to avoid an explicit link dependency + * + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvRegisterContextSymbol(const char* name, void* symbol); + +/*! + * \brief Register a symbol that will be initialized when a system library is loaded. + * + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* symbol); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // TVM_FFI_EXTRA_C_ENV_API_H_ diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h new file mode 100644 index 000000000000..f220c582a91f --- /dev/null +++ b/ffi/include/tvm/ffi/extra/module.h @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/module.h + * \brief A managed dynamic module in the TVM FFI. + */ +#ifndef TVM_FFI_EXTRA_MODULE_H_ +#define TVM_FFI_EXTRA_MODULE_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +// forward declare Module +class Module; + +/*! + * \brief A module that can dynamically load ffi::Functions or exportable source code. + */ +class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { + public: + /*! + * \return The per module type key. + * \note This key is used to for serializing custom modules. + */ + virtual const char* kind() const = 0; + /*! + * \brief Get the property mask of the module. + * \return The property mask of the module. + * + * \sa Module::ModulePropertyMask + */ + virtual int GetPropertyMask() const { return 0b000; } + /*! + * \brief Get a ffi::Function from the module. + * \param name The name of the function. + * \return The function. + */ + virtual Optional GetFunction(const String& name) = 0; + /*! + * \brief Returns true if this module has a definition for a function of \p name. + * + * Note that even if this function returns true the corresponding \p GetFunction result + * may be nullptr if the function is not yet callable without further compilation. + * + * The default implementation just checks if \p GetFunction is non-null. + * \param name The name of the function. + * \return True if the module implements the function, false otherwise. + */ + virtual bool ImplementsFunction(const String& name) { return GetFunction(name).defined(); } + /*! + * \brief Write the current module to file with given format (for further compilation). + * + * \param file_name The file to be saved to. + * \param format The format of the file. + * + * \note This function is mainly used by modules that + */ + virtual void WriteToFile(const String& file_name, const String& format) const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; + } + /*! + * \brief Get the possible write formats of the module, when available. + * \return Possible write formats when available. + */ + virtual Array GetWriteFormats() const { return Array(); } + /*! + * \brief Serialize the the module to bytes. + * \return The serialized module. + */ + virtual Bytes SaveToBytes() const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; + TVM_FFI_UNREACHABLE(); + } + /*! + * \brief Get the source code of module, when available. + * \param format Format of the source code, can be empty by default. + * \return Possible source code when available, or empty string if not available. + */ + virtual String InspectSource(const String& format = "") const { return String(); } + /*! + * \brief Import another module. + * \param other The module to import. + */ + virtual void ImportModule(const Module& other); + /*! + * \brief Clear all imported modules. + */ + virtual void ClearImports(); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return The function. + */ + Optional GetFunction(const String& name, bool query_imports); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return True if the module implements the function, false otherwise. + */ + bool ImplementsFunction(const String& name, bool query_imports); + /*! + * \brief Get the imports of the module. + * \return The imports of the module. + * \note Note the signature is not part of the public API. + */ + const Array& imports() const { return this->imports_; } + + struct InternalUnsafe; + + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIModule; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ModuleObj, Object); + + protected: + friend struct InternalUnsafe; + + /*! + * \brief The modules that this module depends on. + * \note Use ObjectRef to avoid circular dep on Module. + */ + Array imports_; + + private: + /*! + * \brief cache used by TVMFFIModuleLookupFromImports + */ + Map import_lookup_cache_; +}; + +/*! + * \brief Reference to module object. + */ +class Module : public ObjectRef { + public: + /*! + * \brief Property of ffi::Module + */ + enum ModulePropertyMask : int { + /*! + * \brief The module can be serialized to bytes. + * + * This prooperty indicates that module implements SaveToBytes. + * The system also registers a GlobalDef function + * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. + */ + kBinarySerializable = 0b001, + /*! + * \brief The module can directly get runnable functions. + * + * This property indicates that module implements GetFunction that returns + * runnable ffi::Functions. + */ + kRunnable = 0b010, + /*! + * \brief The module can be exported to a object file or source file that then be compiled. + * + * This property indicates that module implements WriteToFile with a given format + * that can be queried by GetLibExportFormat. + * + * Examples include modules that can be exported to .o, .cc, .cu files. + * + * Such modules can be exported, compiled and loaded back as a dynamic library module. + */ + kCompilationExportable = 0b100 + }; + + /*! + * \brief Load a module from file. + * \param file_name The name of the host function module. + * \param format The format of the file. + * \note This function won't load the import relationship. + * Re-create import relationship by calling Import. + */ + TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String& file_name); + /* + * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. + * \param callback The callback to be called with the symbol name and address. + * \note This helper can be used to implement custom Module that needs to access context symbols. + */ + TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( + const ffi::TypedFunction& callback); + + TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Module, ObjectRef, ModuleObj); +}; + +/* + * \brief Symbols for library module. + */ +namespace symbol { +/*! \brief Global variable to store context pointer for a library module. */ +constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; +/*! \brief Global variable to store binary data alongside a library module. */ +constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; +/*! \brief Default entry function of a library module. */ +constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; +} // namespace symbol +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_MODULE_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 4b7b56209af5..abf7f489038b 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -52,6 +52,8 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIRawStr = "const char*"; static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef"; + static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; + static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes"; static constexpr const char* kTVMFFIBytes = "ffi.Bytes"; static constexpr const char* kTVMFFIStr = "ffi.String"; static constexpr const char* kTVMFFIShape = "ffi.Shape"; @@ -60,8 +62,7 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIFunction = "ffi.Function"; static constexpr const char* kTVMFFIArray = "ffi.Array"; static constexpr const char* kTVMFFIMap = "ffi.Map"; - static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; - static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes"; + static constexpr const char* kTVMFFIModule = "ffi.Module"; }; /*! @@ -671,10 +672,10 @@ struct ObjectPtrEqual { */ #define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() = default; \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ ObjectName* operator->() const { return static_cast(data_.get()); } \ - using ContainerType = ObjectName; + using ContainerType = ObjectName /* * \brief Define object reference methods that is both not nullable and mutable. @@ -685,11 +686,11 @@ struct ObjectPtrEqual { */ #define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ ObjectName* operator->() const { return static_cast(data_.get()); } \ ObjectName* get() const { return operator->(); } \ static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName; + using ContainerType = ObjectName namespace details { template diff --git a/ffi/src/ffi/extra/buffer_stream.h b/ffi/src/ffi/extra/buffer_stream.h new file mode 100644 index 000000000000..f6f162676607 --- /dev/null +++ b/ffi/src/ffi/extra/buffer_stream.h @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file buffer_stream.h + * \brief Internal minimal stream helper to read from a buffer. + */ +#ifndef TVM_FFI_EXTRA_BUFFER_STREAM_H_ +#define TVM_FFI_EXTRA_BUFFER_STREAM_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Lightweight stream helper to read from a buffer. + */ +class BufferInStream { + public: + /*! + * \brief constructor + * \param p_buffer the head pointer of the memory region. + * \param buffer_size the size of the memorybuffer + */ + BufferInStream(const void* data, size_t size) + : data_(reinterpret_cast(data)), size_(size) {} + /*! + * \brief Reads raw from stream. + * \param ptr pointer to the data to be read + * \param size the size of the data to be read + * \return the number of bytes read + */ + size_t Read(void* ptr, size_t size) { + size_t nread = std::min(size_ - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, data_ + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + /*! + * \brief Reads arithmetic data from stream in endian-aware manner. + * \param data data to be read + * \tparam T the data type to be read + * \return whether the read was successful + */ + template >> + bool Read(T* data) { + bool ret = Read(static_cast(data), sizeof(T)) == sizeof(T); // NOLINT(*) + if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { + ByteSwap(&data, sizeof(T), 1); + } + return ret; + } + /*! + * \brief Reads an array of data from stream in endian-aware manner. + * \param data data to be read + * \param size the size of the data to be read + * \return whether the read was successful + */ + template >> + bool ReadArray(T* data, size_t size) { + bool ret = + this->Read(static_cast(data), sizeof(T) * size) == sizeof(T) * size; // NOLINT(*) + if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { + ByteSwap(data, sizeof(T), size); + } + return ret; + } + /*! + * \brief Reads a string from stream. + * \param data data to be read + * \return whether the read was successful + */ + bool Read(std::string* data) { + // use uint64_t to ensure platform independent size + uint64_t size = 0; + if (!this->Read(&size)) return false; + data->resize(size); + if (!this->Read(data->data(), size)) return false; + return true; + } + /*! + * \brief Reads a vector of data from stream in endian-aware manner. + * \param data data to be read + * \return whether the read was successful + */ + template >> + bool Read(std::vector* data) { + uint64_t size = 0; + if (!this->Read(&size)) return false; + data->resize(size); + return this->ReadArray(data->data(), size); + } + + private: + /*! \brief in memory buffer */ + const char* data_; + /*! \brief size of the buffer */ + size_t size_; + /*! \brief current pointer */ + size_t curr_ptr_{0}; +}; // class BytesInStream + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_BUFFER_STREAM_H_ diff --git a/ffi/src/ffi/extra/library_module.cc b/ffi/src/ffi/extra/library_module.cc new file mode 100644 index 000000000000..34286d6d0eb2 --- /dev/null +++ b/ffi/src/ffi/extra/library_module.cc @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/library_module.cc + * + * \brief Library module implementation. + */ +#include +#include +#include + +#include "buffer_stream.h" +#include "module_internal.h" + +namespace tvm { +namespace ffi { + +class LibraryModuleObj final : public ModuleObj { + public: + explicit LibraryModuleObj(ObjectPtr lib) : lib_(lib) {} + + const char* kind() const final { return "library"; } + + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return Module::kBinarySerializable | Module::kRunnable; }; + + Optional GetFunction(const String& name) final { + TVMFFISafeCallType faddr; + faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); + // ensure the function keeps the Library Module alive + Module self_strong_ref = GetRef(this); + if (faddr != nullptr) { + return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, + ffi::Any* rv) { + TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); + TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), + args.size(), reinterpret_cast(rv))); + }); + } + return std::nullopt; + } + + private: + ObjectPtr lib_; +}; + +Module LoadModuleFromBytes(const std::string& kind, const Bytes& bytes) { + std::string loader_key = "ffi.Module.load_from_bytes." + kind; + const auto floader = tvm::ffi::Function::GetGlobal(loader_key); + if (!floader.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Library binary was created using {" << kind + << "} but a loader of that name is not registered. " + << "Make sure to have runtime that registers " << loader_key; + } + return (*floader)(bytes).cast(); +} + +/*! + * \brief Process libary binary to recover binary-serialized modules + * \param library_bin The binary embedded in the library. + * \param opt_lib The library, can be nullptr in which case we expect to deserialize + * all binary-serialized modules + * \param library_ctx_addr the pointer to library module as ctx addr + * \return the root module + * + */ +Module ProcessLibraryBin(const char* library_bin, ObjectPtr opt_lib, + void** library_ctx_addr = nullptr) { + // Layout of the library binary: + // ... + // key can be: "_lib", or a module kind + // - "_lib" indicate this location places the library module + // - other keys are module kinds + // Import tree structure (CSR structure of child indices): + // = > > + TVM_FFI_ICHECK(library_bin != nullptr); + uint64_t nbytes = 0; + for (size_t i = 0; i < sizeof(nbytes); ++i) { + uint64_t c = library_bin[i]; + nbytes |= (c & 0xffUL) << (i * 8); + } + + BufferInStream stream(library_bin + sizeof(nbytes), static_cast(nbytes)); + std::vector import_tree_indptr; + std::vector import_tree_child_indices; + TVM_FFI_ICHECK(stream.Read(&import_tree_indptr)); + TVM_FFI_ICHECK(stream.Read(&import_tree_child_indices)); + size_t num_modules = import_tree_indptr.size() - 1; + std::vector modules; + modules.reserve(num_modules); + + for (uint64_t i = 0; i < num_modules; ++i) { + std::string kind; + TVM_FFI_ICHECK(stream.Read(&kind)); + // "_lib" serves as a placeholder in the module import tree to indicate where + // to place the DSOModule + if (kind == "_lib") { + TVM_FFI_ICHECK(opt_lib != nullptr) << "_lib is not allowed during module serialization"; + auto lib_mod_ptr = make_object(opt_lib); + if (library_ctx_addr) { + *library_ctx_addr = lib_mod_ptr.get(); + } + modules.emplace_back(Module(lib_mod_ptr)); + } else { + std::string module_bytes; + TVM_FFI_ICHECK(stream.Read(&module_bytes)); + Module m = LoadModuleFromBytes(kind, Bytes(module_bytes)); + modules.emplace_back(m); + } + } + for (size_t i = 0; i < modules.size(); ++i) { + for (size_t j = import_tree_indptr[i]; j < import_tree_indptr[i + 1]; ++j) { + Array* module_imports = ModuleObj::InternalUnsafe::GetImports(modules[i].operator->()); + auto child_index = import_tree_child_indices[j]; + TVM_FFI_ICHECK(child_index < modules.size()); + module_imports->emplace_back(modules[child_index]); + } + } + return modules[0]; +} + +// registry to store context symbols +class ContextSymbolRegistry { + public: + void InitContextSymbols(ObjectPtr lib) { + for (const auto& [name, symbol] : context_symbols_) { + if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name.c_str()))) { + *symbol_addr = symbol; + } + } + } + + void VisitContextSymbols(const ffi::TypedFunction& callback) { + for (const auto& [name, symbol] : context_symbols_) { + callback(name, symbol); + } + } + + void Register(String name, void* symbol) { context_symbols_.emplace_back(name, symbol); } + + static ContextSymbolRegistry* Global() { + static ContextSymbolRegistry* inst = new ContextSymbolRegistry(); + return inst; + } + + private: + std::vector> context_symbols_; +}; + +void Module::VisitContextSymbols(const ffi::TypedFunction& callback) { + ContextSymbolRegistry::Global()->VisitContextSymbols(callback); +} + +Module CreateLibraryModule(ObjectPtr lib) { + const char* library_bin = + reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_bin)); + void** library_ctx_addr = + reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_ctx)); + + ContextSymbolRegistry::Global()->InitContextSymbols(lib); + if (library_bin != nullptr) { + // we have embedded binaries that needs to be deserialized + return ProcessLibraryBin(library_bin, lib, library_ctx_addr); + } else { + // Only have one single DSO Module + auto lib_mod_ptr = make_object(lib); + Module root_mod = Module(lib_mod_ptr); + if (library_ctx_addr) { + *library_ctx_addr = root_mod.operator->(); + } + return root_mod; + } +} + +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvRegisterContextSymbol(const char* name, void* symbol) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::String s_name(name); + tvm::ffi::ContextSymbolRegistry::Global()->Register(s_name, symbol); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/src/runtime/dso_library.cc b/ffi/src/ffi/extra/library_module_dynamic_lib.cc similarity index 51% rename from src/runtime/dso_library.cc rename to ffi/src/ffi/extra/library_module_dynamic_lib.cc index 0d74fc87a0fe..25463a7e5f92 100644 --- a/src/runtime/dso_library.cc +++ b/ffi/src/ffi/extra/library_module_dynamic_lib.cc @@ -18,15 +18,14 @@ */ /*! - * \file dso_libary.cc + * \file library_module_dynamic_lib.cc * \brief Create library module to load from dynamic shared library. */ -#include #include #include -#include +#include -#include "library_module.h" +#include "module_internal.h" #if defined(_WIN32) #include @@ -41,46 +40,21 @@ extern "C" { #endif namespace tvm { -namespace runtime { +namespace ffi { -/*! - * \brief Dynamic shared library object used to load - * and retrieve symbols by name. This is the default - * module TVM uses for host-side AOT compilation. - */ class DSOLibrary final : public Library { public: - ~DSOLibrary(); - /*! - * \brief Initialize by loading and storing - * a handle to the underlying shared library. - * \param name The string name/path to the - * shared library over which to initialize. - */ - void Init(const std::string& name); - /*! - * \brief Returns the symbol address within - * the shared library for a given symbol name. - * \param name The name of the symbol. - * \return The symbol. - */ - void* GetSymbol(const char* name) final; + explicit DSOLibrary(const String& name) { Load(name); } + ~DSOLibrary() { + if (lib_handle_) Unload(); + } + + void* GetSymbol(const char* name) final { return GetSymbol_(name); } private: - /*! \brief Private implementation of symbol lookup. - * Implementation is operating system dependent. - * \param The name of the symbol. - * \return The symbol. - */ + // private system dependent implementation void* GetSymbol_(const char* name); - /*! \brief Implementation of shared library load. - * Implementation is operating system dependent. - * \param The name/path of the shared library. - */ - void Load(const std::string& name); - /*! \brief Implementation of shared library unload. - * Implementation is operating system dependent. - */ + void Load(const String& name); void Unload(); #if defined(_WIN32) @@ -92,25 +66,17 @@ class DSOLibrary final : public Library { #endif }; -DSOLibrary::~DSOLibrary() { - if (lib_handle_) Unload(); -} - -void DSOLibrary::Init(const std::string& name) { Load(name); } - -void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); } - #if defined(_WIN32) void* DSOLibrary::GetSymbol_(const char* name) { return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } -void DSOLibrary::Load(const std::string& name) { +void DSOLibrary::Load(const String& name) { // use wstring version that is needed by LLVM. - std::wstring wname(name.begin(), name.end()); + std::wstring wname(name.data(), name.data() + name.size()); lib_handle_ = LoadLibraryW(wname.c_str()); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; + TVM_FFI_ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; } void DSOLibrary::Unload() { @@ -120,10 +86,10 @@ void DSOLibrary::Unload() { #else -void DSOLibrary::Load(const std::string& name) { +void DSOLibrary::Load(const String& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " - << dlerror(); + TVM_FFI_ICHECK(lib_handle_ != nullptr) + << "Failed to load dynamic shared library " << name << " " << dlerror(); #if defined(__hexagon__) int p; int rc = dlinfo(lib_handle_, RTLD_DI_LOAD_ADDR, &p); @@ -140,21 +106,13 @@ void DSOLibrary::Unload() { dlclose(lib_handle_); lib_handle_ = nullptr; } - #endif -ObjectPtr CreateDSOLibraryObject(std::string library_path) { - auto n = make_object(); - n->Init(library_path); - return n; -} - TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadfile_so", [](std::string library_path, std::string) { - ObjectPtr n = CreateDSOLibraryObject(library_path); - return CreateModuleFromLibrary(n); + refl::GlobalDef().def("ffi.Module.load_from_file.so", [](String library_path, String) { + return CreateLibraryModule(make_object(library_path)); }); }); -} // namespace runtime +} // namespace ffi } // namespace tvm diff --git a/src/runtime/system_library.cc b/ffi/src/ffi/extra/library_module_system_lib.cc similarity index 63% rename from src/runtime/system_library.cc rename to ffi/src/ffi/extra/library_module_system_lib.cc index 65df96f96375..64b95a122d56 100644 --- a/src/runtime/system_library.cc +++ b/ffi/src/ffi/extra/library_module_system_lib.cc @@ -21,35 +21,34 @@ * \file system_library.cc * \brief Create library module that directly get symbol from the system lib. */ -#include +#include +#include #include #include -#include +#include #include -#include "library_module.h" +#include "module_internal.h" namespace tvm { -namespace runtime { +namespace ffi { class SystemLibSymbolRegistry { public: void RegisterSymbol(const std::string& name, void* ptr) { - std::lock_guard lock(mutex_); auto it = symbol_table_.find(name); - if (it != symbol_table_.end() && ptr != it->second) { - LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a different address " << ptr - << "->" << it->second; + if (it != symbol_table_.end() && ptr != (*it).second) { + std::cerr << "Warning:SystemLib symbol " << name << " get overriden to a different address " + << ptr << "->" << (*it).second << std::endl; } - symbol_table_[name] = ptr; + symbol_table_.Set(name, ptr); } void* GetSymbol(const char* name) { - std::lock_guard lock(mutex_); auto it = symbol_table_.find(name); if (it != symbol_table_.end()) { - return it->second; + return (*it).second; } else { return nullptr; } @@ -61,19 +60,17 @@ class SystemLibSymbolRegistry { } private: - // Internal mutex - std::mutex mutex_; // Internal symbol table - std::unordered_map symbol_table_; + Map symbol_table_; }; -class SystemLibrary : public Library { +class SystemLibrary final : public Library { public: - explicit SystemLibrary(const std::string& symbol_prefix) : symbol_prefix_(symbol_prefix) {} + explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {} void* GetSymbol(const char* name) { if (symbol_prefix_.length() != 0) { - std::string name_with_prefix = symbol_prefix_ + name; + String name_with_prefix = symbol_prefix_ + name; void* symbol = reg_->GetSymbol(name_with_prefix.c_str()); if (symbol != nullptr) return symbol; } @@ -82,19 +79,19 @@ class SystemLibrary : public Library { private: SystemLibSymbolRegistry* reg_ = SystemLibSymbolRegistry::Global(); - std::string symbol_prefix_; + String symbol_prefix_; }; class SystemLibModuleRegistry { public: - runtime::Module GetOrCreateModule(std::string symbol_prefix) { + Module GetOrCreateModule(String symbol_prefix) { std::lock_guard lock(mutex_); auto it = lib_map_.find(symbol_prefix); if (it != lib_map_.end()) { - return it->second; + return (*it).second; } else { - auto mod = CreateModuleFromLibrary(make_object(symbol_prefix)); - lib_map_[symbol_prefix] = mod; + Module mod = CreateLibraryModule(make_object(symbol_prefix)); + lib_map_.Set(symbol_prefix, mod); return mod; } } @@ -107,26 +104,26 @@ class SystemLibModuleRegistry { private: // Internal mutex std::mutex mutex_; + // maps prefix to the library module // we need to make sure each lib map have an unique // copy through out the entire lifetime of the process - // so the cached ffi::Function in the system do not get out dated. - std::unordered_map lib_map_; + Map lib_map_; }; TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("runtime.SystemLib", [](ffi::PackedArgs args, ffi::Any* rv) { - std::string symbol_prefix = ""; + refl::GlobalDef().def_packed("ffi.SystemLib", [](ffi::PackedArgs args, ffi::Any* rv) { + String symbol_prefix = ""; if (args.size() != 0) { - symbol_prefix = args[0].cast(); + symbol_prefix = args[0].cast(); } *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); }); }); -} // namespace runtime +} // namespace ffi } // namespace tvm -int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) { - tvm::runtime::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr); +int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr) { + tvm::ffi::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr); return 0; } diff --git a/ffi/src/ffi/extra/module.cc b/ffi/src/ffi/extra/module.cc new file mode 100644 index 000000000000..a7f6d4460079 --- /dev/null +++ b/ffi/src/ffi/extra/module.cc @@ -0,0 +1,139 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include +#include + +#include "module_internal.h" + +namespace tvm { +namespace ffi { + +Optional ModuleObj::GetFunction(const String& name, bool query_imports) { + if (auto opt_func = this->GetFunction(name)) { + return opt_func; + } + if (query_imports) { + for (const Any& import : imports_) { + if (auto opt_func = import.cast()->GetFunction(name, query_imports)) { + return *opt_func; + } + } + } + return std::nullopt; +} + +void ModuleObj::ImportModule(const Module& other) { + std::unordered_set visited{other.operator->()}; + std::vector stack{other.operator->()}; + while (!stack.empty()) { + const ModuleObj* n = stack.back(); + stack.pop_back(); + for (const Any& m : n->imports_) { + const ModuleObj* next = m.cast(); + if (visited.count(next)) continue; + visited.insert(next); + stack.push_back(next); + } + } + if (visited.count(this)) { + TVM_FFI_THROW(RuntimeError) << "Cyclic dependency detected during import"; + } + imports_.push_back(other); +} + +void ModuleObj::ClearImports() { imports_.clear(); } + +bool ModuleObj::ImplementsFunction(const String& name, bool query_imports) { + if (this->ImplementsFunction(name)) { + return true; + } + if (query_imports) { + for (const Any& import : imports_) { + if (import.cast()->ImplementsFunction(name, query_imports)) { + return true; + } + } + } + return false; +} + +Module Module::LoadFromFile(const String& file_name) { + String format = [&file_name]() -> String { + const char* data = file_name.data(); + for (size_t i = file_name.size(); i > 0; i--) { + if (data[i - 1] == '.') { + return String(data + i, file_name.size() - i); + } + } + TVM_FFI_THROW(RuntimeError) << "Failed to get file format from " << file_name; + TVM_FFI_UNREACHABLE(); + }(); + + if (format == "dll" || format == "dylib" || format == "dso") { + format = "so"; + } + String loader_name = "ffi.Module.load_from_file." + format; + const auto floader = tvm::ffi::Function::GetGlobal(loader_name); + if (!floader.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Loader for `." << format << "` files is not registered," + << " resolved to (" << loader_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + } + return (*floader)(file_name, format).cast(); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + ModuleObj::InternalUnsafe::RegisterReflection(); + + refl::GlobalDef() + .def("ffi.ModuleLoadFromFile", &Module::LoadFromFile) + .def_method("ffi.ModuleImplementsFunction", + [](Module mod, String name, bool query_imports) { + return mod->ImplementsFunction(name, query_imports); + }) + .def_method("ffi.ModuleGetFunction", + [](Module mod, String name, bool query_imports) { + return mod->GetFunction(name, query_imports); + }) + .def_method("ffi.ModuleGetPropertyMask", &ModuleObj::GetPropertyMask) + .def_method("ffi.ModuleInspectSource", &ModuleObj::InspectSource) + .def_method("ffi.ModuleGetKind", [](const Module& mod) -> String { return mod->kind(); }) + .def_method("ffi.ModuleGetWriteFormats", &ModuleObj::GetWriteFormats) + .def_method("ffi.ModuleWriteToFile", &ModuleObj::WriteToFile) + .def_method("ffi.ModuleImportModule", &ModuleObj::ImportModule) + .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports); +}); +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, + TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + *out = tvm::ffi::ModuleObj::InternalUnsafe::GetFunctionFromImports( + reinterpret_cast(library_ctx), func_name); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/ffi/src/ffi/extra/module_internal.h b/ffi/src/ffi/extra/module_internal.h new file mode 100644 index 000000000000..f43d3a3d2c42 --- /dev/null +++ b/ffi/src/ffi/extra/module_internal.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file library_module.h + * \brief Module that builds from a libary of symbols. + */ +#ifndef TVM_FFI_EXTRA_MODULE_INTERNAL_H_ +#define TVM_FFI_EXTRA_MODULE_INTERNAL_H_ + +#include +#include + +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Library is the common interface + * for storing data in the form of shared libaries. + * + * \sa src/ffi/extra/dso_library.cc + * \sa src/ffi/extra/system_library.cc + */ +class Library : public Object { + public: + // destructor. + virtual ~Library() {} + /*! + * \brief Get the symbol address for a given name. + * \param name The name of the symbol. + * \return The symbol. + */ + virtual void* GetSymbol(const char* name) = 0; + // NOTE: we do not explicitly create an type index and type_key here for libary. + // This is because we do not need dynamic type downcasting and only need to use the refcounting +}; + +struct ModuleObj::InternalUnsafe { + static Array* GetImports(ModuleObj* module) { return &(module->imports_); } + + static void* GetFunctionFromImports(ModuleObj* module, const char* name) { + // backend implementation for TVMFFIEnvLookupFromImports + static std::mutex mutex_; + std::lock_guard lock(mutex_); + String s_name(name); + auto it = module->import_lookup_cache_.find(s_name); + if (it != module->import_lookup_cache_.end()) { + return const_cast((*it).second.operator->()); + } + + auto opt_func = [&]() -> std::optional { + for (const Any& import : module->imports_) { + if (auto opt_func = import.cast()->GetFunction(s_name, true)) { + return *opt_func; + } + } + // try global at last + return tvm::ffi::Function::GetGlobal(s_name); + }(); + if (!opt_func.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Cannot find function " << name + << " in the imported modules or global registry."; + } + module->import_lookup_cache_.Set(s_name, *opt_func); + return const_cast((*opt_func).operator->()); + } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("imports_", &ModuleObj::imports_); + } +}; + +/*! + * \brief Create a library module from a given library. + * + * \param lib The library. + * + * \return The corresponding loaded module. + */ +Module CreateLibraryModule(ObjectPtr lib); + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_MODULE_INTERNAL_H_ diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index 0d84b55fe318..e44fe465bc96 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -54,7 +54,7 @@ TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, * \param ptr The symbol address. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); +TVM_DLL int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr); /*! * \brief Backend function to allocate temporal workspace. diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index d594de7247c8..bc0faf2413e5 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -62,7 +62,7 @@ inline std::string ReduceKind2String(ReduceKind kind) { * \param device The default device used to initialize the RelaxVM * \return The RelaxVM as a runtime Module */ -TVM_DLL Module LoadVMModule(std::string path, Optional device); +TVM_DLL ffi::Module LoadVMModule(std::string path, Optional device); /*! * \brief Create an uninitialized empty NDArray * \param shape The shape of the NDArray diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 80c03ea75132..f805ec988d37 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -27,251 +27,19 @@ #define TVM_RUNTIME_MODULE_H_ #include +#include +#include #include #include #include #include #include -#include -#include -#include -#include #include -#include namespace tvm { namespace runtime { -/*! - * \brief Property of runtime module - * We classify the property of runtime module into the following categories. - */ -enum ModulePropertyMask : int { - /*! \brief kBinarySerializable - * we can serialize the module to the stream of bytes. CUDA/OpenCL/JSON - * runtime are representative examples. A binary exportable module can be integrated into final - * runtime artifact by being serialized as data into the artifact, then deserialized at runtime. - * This class of modules must implement SaveToBinary, and have a matching deserializer registered - * as 'runtime.module.loadbinary_'. - */ - kBinarySerializable = 0b001, - /*! \brief kRunnable - * we can run the module directly. LLVM/CUDA/JSON runtime, executors (e.g, - * virtual machine) runtimes are runnable. Non-runnable modules, such as CSourceModule, requires a - * few extra steps (e.g,. compilation, link) to make it runnable. - */ - kRunnable = 0b010, - /*! \brief kDSOExportable - * we can export the module as DSO. A DSO exportable module (e.g., a - * CSourceModuleNode of type_key 'c') can be incorporated into the final runtime artifact (ie - * shared library) by compilation and/or linking using the external compiler (llvm, nvcc, etc). - * DSO exportable modules must implement SaveToFile. In general, DSO exportable modules are not - * runnable unless there is a special support like JIT for `LLVMModule`. - */ - kDSOExportable = 0b100 -}; - -class ModuleNode; - -/*! - * \brief Module container of TVM. - */ -class Module : public ObjectRef { - public: - Module() {} - // constructor from container. - explicit Module(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief Get packed function from current module by name. - * - * \param name The name of the function. - * \param query_imports Whether also query dependency modules. - * \return The result function. - * This function will return ffi::Function(nullptr) if function do not exist. - * \note Implemented in packed_func.cc - */ - inline ffi::Function GetFunction(const String& name, bool query_imports = false); - // The following functions requires link with runtime. - /*! - * \brief Import another module into this module. - * \param other The module to be imported. - * - * \note Cyclic dependency is not allowed among modules, - * An error will be thrown when cyclic dependency is detected. - */ - inline void Import(Module other); - /*! \return internal container */ - inline ModuleNode* operator->(); - /*! \return internal container */ - inline const ModuleNode* operator->() const; - /*! - * \brief Load a module from file. - * \param file_name The name of the host function module. - * \param format The format of the file. - * \note This function won't load the import relationship. - * Re-create import relationship by calling Import. - */ - TVM_DLL static Module LoadFromFile(const String& file_name, const String& format = ""); - // refer to the corresponding container. - using ContainerType = ModuleNode; - friend class ModuleNode; -}; - -/*! - * \brief Base container of module. - * - * Please subclass ModuleNode to create a specific runtime module. - * - * \code - * - * class MyModuleNode : public ModuleNode { - * public: - * // implement the interface - * }; - * - * // use make_object to create a specific - * // instace of MyModuleNode. - * Module CreateMyModule() { - * ObjectPtr n = - * tvm::ffi::make_object(); - * return Module(n); - * } - * - * \endcode - */ -class TVM_DLL ModuleNode : public Object { - public: - /*! \brief virtual destructor */ - virtual ~ModuleNode() = default; - /*! - * \return The per module type key. - * \note This key is used to for serializing custom modules. - */ - virtual const char* type_key() const = 0; - /*! - * \brief Get a ffi::Function from module. - * - * The ffi::Function may not be fully initialized, - * there might still be first time running overhead when - * executing the function on certain devices. - * For benchmarking, use prepare to eliminate - * - * \param name the name of the function. - * \param sptr_to_self The ObjectPtr that points to this module node. - * - * \return ffi::Function(nullptr) when it is not available. - * - * \note The function will always remain valid. - * If the function need resource from the module(e.g. late linking), - * it should capture sptr_to_self. - */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) = 0; - /*! - * \brief Save the module to file. - * \param file_name The file to be saved to. - * \param format The format of the file. - */ - virtual void SaveToFile(const String& file_name, const String& format); - /*! - * \brief Save the module to binary stream. - * \param stream The binary stream to save to. - * \note It is recommended to implement this for device modules, - * but not necessarily host modules. - * We can use this to do AOT loading of bundled device functions. - */ - virtual void SaveToBinary(dmlc::Stream* stream); - /*! - * \brief Get the source code of module, when available. - * \param format Format of the source code, can be empty by default. - * \return Possible source code when available. - */ - virtual String GetSource(const String& format = ""); - /*! - * \brief Get the format of the module, when available. - * \return Possible format when available. - */ - virtual String GetFormat(); - /*! - * \brief Get packed function from current module by name. - * - * \param name The name of the function. - * \param query_imports Whether also query dependency modules. - * \return The result function. - * This function will return ffi::Function(nullptr) if function do not exist. - * \note Implemented in packed_func.cc - */ - ffi::Function GetFunction(const String& name, bool query_imports = false); - /*! - * \brief Import another module into this module. - * \param other The module to be imported. - * - * \note Cyclic dependency is not allowed among modules, - * An error will be thrown when cyclic dependency is detected. - */ - void Import(Module other); - /*! - * \brief Get a function from current environment - * The environment includes all the imports as well as Global functions. - * - * \param name name of the function. - * \return The corresponding function. - */ - const ffi::Function* GetFuncFromEnv(const String& name); - - /*! \brief Clear all imports of the module. */ - void ClearImports() { imports_.clear(); } - - /*! \return The module it imports from */ - const std::vector& imports() const { return imports_; } - - /*! - * \brief Returns bitmap of property. - * By default, none of the property is set. Derived class can override this function and set its - * own property. - */ - virtual int GetPropertyMask() const { return 0b000; } - - /*! \brief Returns true if this module is 'DSO exportable'. */ - bool IsDSOExportable() const { - return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0; - } - - /*! \brief Returns true if this module is 'Binary Serializable'. */ - bool IsBinarySerializable() const { - return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0; - } - - /*! - * \brief Returns true if this module has a definition for a function of \p name. If - * \p query_imports is true, also search in any imported modules. - * - * Note that even if this function returns true the corresponding \p GetFunction result may be - * nullptr if the function is not yet callable without further compilation. - * - * The default implementation just checkis if \p GetFunction is non-null. - */ - virtual bool ImplementsFunction(const String& name, bool query_imports = false); - - // integration with the existing components. - static constexpr const uint32_t _type_index = ffi::TypeIndex::kTVMFFIModule; - static constexpr const char* _type_key = "runtime.Module"; - // NOTE: ModuleNode can still be sub-classed - // - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ModuleNode, Object); - - protected: - friend class Module; - friend class ModuleInternal; - /*! \brief The modules this module depend on */ - std::vector imports_; - - private: - /*! \brief Cache used by GetImport */ - std::unordered_map> import_cache_; - std::mutex mutex_; -}; - /*! * \brief Check if runtime module is enabled for target. * \param target The target module name. @@ -279,19 +47,8 @@ class TVM_DLL ModuleNode : public Object { */ TVM_DLL bool RuntimeEnabled(const String& target); -// implementation of Module::GetFunction -inline ffi::Function Module::GetFunction(const String& name, bool query_imports) { - return (*this)->GetFunction(name, query_imports); -} - /*! \brief namespace for constant symbols */ namespace symbol { -/*! \brief Global variable to store context pointer for a library module. */ -constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; -/*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; -/*! \brief Placeholder for the module's entry function. */ -constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; /*! \brief global function to set device */ constexpr const char* tvm_set_device = "__tvm_set_device"; /*! \brief Auxiliary counter to global barrier. */ @@ -300,24 +57,6 @@ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier"; } // namespace symbol -// implementations of inline functions. - -inline void Module::Import(Module other) { return (*this)->Import(other); } - -inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); } - -inline const ModuleNode* Module::operator->() const { - return static_cast(get()); -} - -inline std::ostream& operator<<(std::ostream& out, const Module& module) { - out << "Module(type_key= "; - out << module->type_key(); - out << ")"; - - return out; -} - namespace details { template @@ -366,12 +105,14 @@ struct ModuleVTableEntryHelper { } // namespace runtime } // namespace tvm -#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ - const char* type_key() const final { return TypeKey; } \ - ffi::Function GetFunction(const String& _name, const ObjectPtr& _self) override { \ - using SelfPtr = std::remove_cv_t; -#define TVM_MODULE_VTABLE_END() \ - return ffi::Function(nullptr); \ +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* kind() const final { return TypeKey; } \ + ::tvm::ffi::Optional<::tvm::ffi::Function> GetFunction(const String& _name) override { \ + using SelfPtr = std::remove_cv_t; \ + ::tvm::ffi::ObjectPtr<::tvm::ffi::Object> _self = \ + ::tvm::ffi::GetObjectPtr<::tvm::ffi::Object>(this); +#define TVM_MODULE_VTABLE_END() \ + return std::nullopt; \ } #define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \ { \ diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index e512710ea396..9f25b6775c13 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -539,8 +539,8 @@ String ShapeString(const std::vector& shape, DLDataType dtype); * and returns performance metrics as a `Map` where * values can be `CountNode`, `DurationNode`, `PercentNode`. */ -ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, - int warmup_iters, Array collectors); +ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, + int device_id, int warmup_iters, Array collectors); /*! * \brief Wrap a timer function to measure the time cost of a given packed function. diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index a84c902b6711..6dfc2b0c50be 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -85,10 +85,10 @@ struct VMFuncInfo { * The executable contains information (e.g. data in different memory regions) * to run in a virtual machine. */ -class VMExecutable : public runtime::ModuleNode { +class VMExecutable : public ffi::ModuleObj { public: /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; }; /*! * \brief Print the detailed statistics of the given code, i.e. number of @@ -121,25 +121,25 @@ class VMExecutable : public runtime::ModuleNode { String AsPython() const; /*! * \brief Write the VMExecutable to the binary stream in serialized form. - * \param stream The binary stream to save the executable to. + * \return The binary bytes that save the executable to. */ - void SaveToBinary(dmlc::Stream* stream) final; + ffi::Bytes SaveToBytes() const final; /*! * \brief Load VMExecutable from the binary stream in serialized form. - * \param stream The binary stream that load the executable from. + * \param bytes The binary bytes that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ - static Module LoadFromBinary(void* stream); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes); /*! * \brief Write the VMExecutable to the provided path as a file containing its serialized content. * \param file_name The name of the file to write the serialized data to. * \param format The target format of the saved file. */ - void SaveToFile(const String& file_name, const String& format) final; + void WriteToFile(const String& file_name, const String& format) const final; /*! \brief Create a Relax virtual machine and load `this` as the executable. */ - Module VMLoadExecutable() const; + ffi::Module VMLoadExecutable() const; /*! \brief Create a Relax virtual machine with profiler and load `this` as the executable. */ - Module VMProfilerLoadExecutable() const; + ffi::Module VMProfilerLoadExecutable() const; /*! \brief Check if the VMExecutable contains a specific function. */ bool HasFunction(const String& name) const; /*! @@ -147,7 +147,7 @@ class VMExecutable : public runtime::ModuleNode { * \param file_name The path of the file that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ - static Module LoadFromFile(const String& file_name); + static ffi::Module LoadFromFile(const String& file_name); /*! \brief The virtual machine's function table. */ std::vector func_table; @@ -176,22 +176,22 @@ class VMExecutable : public runtime::ModuleNode { * \brief Save the globals. * \param strm The input stream. */ - void SaveGlobalSection(dmlc::Stream* strm); + void SaveGlobalSection(dmlc::Stream* strm) const; /*! * \brief Save the constant pool. * \param strm The input stream. */ - void SaveConstantSection(dmlc::Stream* strm); + void SaveConstantSection(dmlc::Stream* strm) const; /*! * \brief Save the instructions. * \param strm The input stream. */ - void SaveCodeSection(dmlc::Stream* strm); + void SaveCodeSection(dmlc::Stream* strm) const; /*! * \brief Save the packed functions. * \param strm The input stream. */ - void SavePackedFuncNames(dmlc::Stream* strm); + void SavePackedFuncNames(dmlc::Stream* strm) const; /*! * \brief Load the globals. * \param strm The input stream. diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index ed74ba7b7b2a..3a0b7418b946 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -27,6 +27,8 @@ #define TVM_VM_ENABLE_PROFILER 1 #endif +#include + #include #include #include @@ -128,7 +130,7 @@ class VMExtension : public ObjectRef { * multiple threads, or serialize them to disk or over the * wire. */ -class VirtualMachine : public runtime::ModuleNode { +class VirtualMachine : public ffi::ModuleObj { public: /*! * \brief Initialize the virtual machine for a set of devices. diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 54f09a081b93..d92ef674f12e 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -45,7 +45,7 @@ using ffi::PackedArgs; * \param target The target to be built. * \return The result runtime::Module. */ -runtime::Module Build(IRModule mod, Target target); +ffi::Module Build(IRModule mod, Target target); /*! * \brief Serialize runtime module including its submodules @@ -53,14 +53,14 @@ runtime::Module Build(IRModule mod, Target target); * \param export_dso By default, include the info of DSOExportable modules. If disabled, an error * will be raised when encountering DSO modules. */ -std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso = true); +std::string SerializeModuleToBytes(const ffi::Module& mod, bool export_dso = true); /*! * \brief Deserialize runtime module including its submodules * \param blob byte stream, which are generated by `SerializeModuleToBytes`. * \return runtime::Module runtime module constructed from the given stream */ -runtime::Module DeserializeModuleFromBytes(std::string blob); +ffi::Module DeserializeModuleFromBytes(std::string blob); /*! * \brief Pack imported device library to a C file. @@ -73,7 +73,7 @@ runtime::Module DeserializeModuleFromBytes(std::string blob); * \param c_symbol_prefix Optional symbol prefix of the blob symbol. * \return cstr The C string representation of the file. */ -std::string PackImportsToC(const runtime::Module& m, bool system_lib, +std::string PackImportsToC(const ffi::Module& m, bool system_lib, const std::string& c_symbol_prefix = ""); /*! @@ -89,9 +89,9 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib, * * \return runtime::Module The generated LLVM module. */ -runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib, - const std::string& target_triple, - const std::string& c_symbol_prefix = ""); +ffi::Module PackImportsToLLVM(const ffi::Module& m, bool system_lib, + const std::string& target_triple, + const std::string& c_symbol_prefix = ""); } // namespace codegen } // namespace tvm diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 9fa65054f91f..46a74346760e 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -35,7 +35,7 @@ protected Map initialValue() { private static Function getApi(String name) { Function func = apiFuncs.get().get(name); if (func == null) { - func = Function.getFunction("runtime." + name); + func = Function.getFunction(name); apiFuncs.get().put(name, func); } return func; @@ -75,7 +75,7 @@ public Function entryFunc() { * @return The result function. */ public Function getFunction(String name, boolean queryImports) { - TVMValue ret = getApi("ModuleGetFunction") + TVMValue ret = getApi("ffi.ModuleGetFunction") .pushArg(this).pushArg(name).pushArg(queryImports ? 1 : 0).invoke(); return ret.asFunction(); } @@ -89,7 +89,7 @@ public Function getFunction(String name) { * @param module The other module. */ public void importModule(Module module) { - getApi("ModuleImport") + getApi("ffi.ModuleImportModule") .pushArg(this).pushArg(module).invoke(); } @@ -98,7 +98,7 @@ public void importModule(Module module) { * @return type key of the module. */ public String typeKey() { - return getApi("ModuleGetTypeKey").pushArg(this).invoke().asString(); + return getApi("ffi.ModuleGetTypeKind").pushArg(this).invoke().asString(); } /** @@ -109,7 +109,7 @@ public String typeKey() { * @return The loaded module */ public static Module load(String path, String fmt) { - TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); + TVMValue ret = getApi("ffi.ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); return ret.asModule(); } @@ -125,7 +125,7 @@ public static Module load(String path) { * @return Whether runtime is enabled. */ public static boolean enabled(String target) { - TVMValue ret = getApi("RuntimeEnabled").pushArg(target).invoke(); + TVMValue ret = getApi("runtime.RuntimeEnabled").pushArg(target).invoke(); return ret.asLong() != 0; } } diff --git a/python/tvm/contrib/hexagon/tools.py b/python/tvm/contrib/hexagon/tools.py index 5ee89713d9a5..f7f22db721ce 100644 --- a/python/tvm/contrib/hexagon/tools.py +++ b/python/tvm/contrib/hexagon/tools.py @@ -404,7 +404,7 @@ def pack_imports( def export_module(module, out_dir, binary_name="test_binary.so"): """Export Hexagon shared object to a file.""" binary_path = pathlib.Path(out_dir) / binary_name - module.save(str(binary_path)) + module.write_to_file(str(binary_path)) return binary_path diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index a38b31c9bb00..c7be2a7ba6f6 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1621,7 +1621,17 @@ def batch_norm( The computed result. """ return _ffi_api.batch_norm( # type: ignore - data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum, training + data, + gamma, + beta, + moving_mean, + moving_var, + axis, + epsilon, + center, + scale, + momentum, + training, ) diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index f31927e2f1f9..f6db61af61d2 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -153,7 +153,7 @@ def _vmlink( tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib) lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline) for ext_mod in ext_libs: - if ext_mod.is_device_module: + if ext_mod.is_device_module(): tir_ext_libs.append(ext_mod) else: relax_ext_libs.append(ext_mod) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index ea78b0d7d418..0bb4e8cb7d29 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -55,7 +55,7 @@ def system_lib(self): -------- tvm.runtime.system_lib """ - return self.get_function("runtime.SystemLib")() + return self.get_function("ffi.SystemLib")() def get_function(self, name): """Get function from the session. @@ -380,7 +380,12 @@ def text_summary(self): return res def request( - self, key, priority=1, session_timeout=0, max_retry=5, session_constructor_args=None + self, + key, + priority=1, + session_timeout=0, + max_retry=5, + session_constructor_args=None, ): """Request a new connection from the tracker. @@ -474,7 +479,12 @@ def request_and_run(self, key, func, priority=1, session_timeout=0, max_retry=2) def connect( - url, port, key="", session_timeout=0, session_constructor_args=None, enable_logging=False + url, + port, + key="", + session_timeout=0, + session_constructor_args=None, + enable_logging=False, ): """Connect to RPC Server diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index bd0d3d8ed869..49449a451a12 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -94,7 +94,7 @@ def __init__(self, dref: DRef, session: "Session") -> None: self.session = session def __getitem__(self, name: str) -> DPackedFunc: - func = self.session._get_cached_method("runtime.ModuleGetFunction") + func = self.session._get_cached_method("ffi.ModuleGetFunction") return DPackedFunc(func(self, name, False), self.session) @@ -328,7 +328,10 @@ def init_ccl(self, ccl: str, *device_ids): self._clear_ipc_memory_pool() def broadcast( - self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + self, + src: Union[np.ndarray, NDArray], + dst: Optional[DRef] = None, + in_group: bool = True, ) -> DRef: """Broadcast an array to all workers @@ -383,7 +386,10 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool = True) -> func(src, in_group, dst) def scatter( - self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None, in_group: bool = True + self, + src: Union[np.ndarray, NDArray], + dst: Optional[DRef] = None, + in_group: bool = True, ) -> DRef: """Scatter an array across all workers @@ -540,7 +546,10 @@ class ProcessSession(Session): """A Disco session backed by pipe-based multi-processing.""" def __init__( - self, num_workers: int, num_groups: int = 1, entrypoint: str = "tvm.exec.disco_worker" + self, + num_workers: int, + num_groups: int = 1, + entrypoint: str = "tvm.exec.disco_worker", ) -> None: self.__init_handle_by_constructor__( _ffi_api.SessionProcess, # type: ignore # pylint: disable=no-member @@ -585,7 +594,12 @@ class SocketSession(Session): """A Disco session backed by socket-based multi-node communication.""" def __init__( - self, num_nodes: int, num_workers_per_node: int, num_groups: int, host: str, port: int + self, + num_nodes: int, + num_workers_per_node: int, + num_groups: int, + host: str, + port: int, ) -> None: self.__init_handle_by_constructor__( _ffi_api.SocketSession, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py index b6e13a65a9f2..51f0a772e403 100644 --- a/python/tvm/runtime/executable.py +++ b/python/tvm/runtime/executable.py @@ -91,7 +91,7 @@ def jit( # TODO(tvm-team): Update runtime.Module interface # to query these properties as bitmask. def _not_runnable(x): - return x.type_key in ("c", "static_library") + return x.kind in ("c", "static_library") # pylint:disable = protected-access not_runnable_list = self.mod._collect_from_import_tree(_not_runnable) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index e645d3a2b6ce..3925c24365d5 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -28,6 +28,7 @@ from tvm.libinfo import find_include_path from . import _ffi_api +from ..ffi import _ffi_api as _mod_ffi_api class BenchmarkResult: @@ -94,10 +95,10 @@ class ModulePropertyMask(object): BINARY_SERIALIZABLE = 0b001 RUNNABLE = 0b010 - DSO_EXPORTABLE = 0b100 + COMPILATION_EXPORTABLE = 0b100 -@tvm.ffi.register_object("runtime.Module") +@tvm.ffi.register_object("ffi.Module") class Module(tvm.ffi.Object): """Runtime Module.""" @@ -121,6 +122,22 @@ def entry_func(self): self._entry = self.get_function("__tvm_ffi_main__") return self._entry + @property + def kind(self): + """Get type key of the module.""" + return _mod_ffi_api.ModuleGetKind(self) + + @property + def imports(self): + """Get imported modules + + Returns + ---------- + modules : list of Module + The module + """ + return self.imports_ + def implements_function(self, name, query_imports=False): """Returns True if the module has a definition for the global function with name. Note that has_function(name) does not imply get_function(name) is non-null since the module @@ -141,7 +158,7 @@ def implements_function(self, name, query_imports=False): b : Bool True if module (or one of its imports) has a definition for name. """ - return _ffi_api.ModuleImplementsFunction(self, name, query_imports) + return _mod_ffi_api.ModuleImplementsFunction(self, name, query_imports) def get_function(self, name, query_imports=False): """Get function from the module. @@ -159,7 +176,7 @@ def get_function(self, name, query_imports=False): f : tvm.runtime.PackedFunc The result function. """ - func = _ffi_api.ModuleGetFunction(self, name, query_imports) + func = _mod_ffi_api.ModuleGetFunction(self, name, query_imports) if func is None: raise AttributeError(f"Module has no function '{name}'") return func @@ -172,7 +189,7 @@ def import_module(self, module): module : tvm.runtime.Module The other module. """ - _ffi_api.ModuleImport(self, module) + _mod_ffi_api.ModuleImportModule(self, module) def __getitem__(self, name): if not isinstance(name, str): @@ -185,17 +202,7 @@ def __call__(self, *args): # pylint: disable=not-callable return self.entry_func(*args) - @property - def type_key(self): - """Get type key of the module.""" - return _ffi_api.ModuleGetTypeKey(self) - - @property - def format(self): - """Get the format of the module.""" - return _ffi_api.ModuleGetFormat(self) - - def get_source(self, fmt=""): + def inspect_source(self, fmt=""): """Get source code from module, if available. Parameters @@ -208,19 +215,11 @@ def get_source(self, fmt=""): source : str The result source code. """ - return _ffi_api.ModuleGetSource(self, fmt) + return _mod_ffi_api.ModuleInspectSource(self, fmt) - @property - def imported_modules(self): - """Get imported modules - - Returns - ---------- - modules : list of Module - The module - """ - nmod = _ffi_api.ModuleImportsSize(self) - return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)] + def get_write_formats(self): + """Get the format of the module.""" + return _mod_ffi_api.ModuleGetWriteFormats(self) def get_property_mask(self): """Get the runtime module property mask. The mapping is stated in ModulePropertyMask. @@ -230,9 +229,8 @@ def get_property_mask(self): mask : int Bitmask of runtime module property """ - return _ffi_api.ModuleGetPropertyMask(self) + return _mod_ffi_api.ModuleGetPropertyMask(self) - @property def is_binary_serializable(self): """Returns true if module is 'binary serializable', ie can be serialzed into binary stream and loaded back to the runtime module. @@ -244,7 +242,6 @@ def is_binary_serializable(self): """ return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 - @property def is_runnable(self): """Returns true if module is 'runnable'. ie can be executed without any extra compilation/linking steps. @@ -256,31 +253,26 @@ def is_runnable(self): """ return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 - @property def is_device_module(self): - return self.type_key in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"] + return self.kind in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"] - @property - def is_dso_exportable(self): - """Returns true if module is 'DSO exportable', ie can be included in result of + def is_compilation_exportable(self): + """Returns true if module is 'compilation exportable', ie can be included in result of export_library by the external compiler directly. Returns ------- b : Bool - True if the module is DSO exportable. + True if the module is compilation exportable. """ - return (self.get_property_mask() & ModulePropertyMask.DSO_EXPORTABLE) != 0 + return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0 def clear_imports(self): """Remove all imports of the module.""" - _ffi_api.ModuleClearImports(self) + _mod_ffi_api.ModuleClearImports(self) - def save(self, file_name, fmt=""): - """Save the module to file. - - This do not save the dependent device modules. - See also export_shared + def write_to_file(self, file_name, fmt=""): + """Write the current module to file. Parameters ---------- @@ -293,7 +285,7 @@ def save(self, file_name, fmt=""): -------- runtime.Module.export_library : export the module to shared library. """ - _ffi_api.ModuleSaveToFile(self, file_name, fmt) + _mod_ffi_api.ModuleWriteToFile(self, file_name, fmt) def time_evaluator( self, @@ -414,19 +406,19 @@ def _collect_from_import_tree(self, filter_func): while stack: module = stack.pop() assert ( - module.is_dso_exportable or module.is_binary_serializable - ), f"Module {module.type_key} should be either dso exportable or binary serializable." + module.is_compilation_exportable() or module.is_binary_serializable() + ), f"Module {module.kind} should be either dso exportable or binary serializable." if filter_func(module): dso_modules.append(module) - for m in module.imported_modules: + for m in module.imports: if m not in visited: visited.add(m) stack.append(m) return dso_modules def _collect_dso_modules(self): - return self._collect_from_import_tree(lambda m: m.is_dso_exportable) + return self._collect_from_import_tree(lambda m: m.is_compilation_exportable()) def export_library( self, @@ -509,29 +501,24 @@ def export_library( system_lib_prefix = None llvm_target_string = None global_object_format = "o" + + def get_source_format_from_module(module): + for fmt in module.get_write_formats(): + if fmt in ["c", "cc", "cpp", "cu"]: + return fmt + raise ValueError(f"Module {module.kind} does not exporting to c, cc, cpp or cu.") + for index, module in enumerate(modules): if fcompile is not None and hasattr(fcompile, "object_format"): - if module.type_key == "c": - assert module.format in [ - "c", - "cc", - "cpp", - "cu", - ], "The module.format needs to be either c, cc, cpp or cu." - object_format = module.format + if module.kind == "c": + object_format = get_source_format_from_module(module) has_c_module = True else: global_object_format = object_format = fcompile.object_format else: - if module.type_key == "c": - if len(module.format) > 0: - assert module.format in [ - "c", - "cc", - "cpp", - "cu", - ], "The module.format needs to be either c, cc, cpp, or cu." - object_format = module.format + if module.kind == "c": + if len(module.get_write_formats()) > 0: + object_format = get_source_format_from_module(module) else: object_format = "c" if "cc" in kwargs: @@ -539,13 +526,13 @@ def export_library( object_format = "cu" has_c_module = True else: - assert module.is_dso_exportable + assert module.is_compilation_exportable() global_object_format = object_format = "o" path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}") - module.save(path_obj) + module.write_to_file(path_obj) files.append(path_obj) - if module.type_key == "llvm": + if module.kind == "llvm": is_system_lib = module.get_function("__tvm_is_system_module")() llvm_target_string = module.get_function("_get_target_string")() system_lib_prefix = module.get_function("__tvm_get_system_lib_prefix")() @@ -566,7 +553,7 @@ def export_library( if getattr(fcompile, "need_system_lib", False) and not is_system_lib: raise ValueError(f"{str(fcompile)} need --system-lib option") - if self.imported_modules: + if self.imports: pack_lib_prefix = system_lib_prefix if system_lib_prefix else "" if fpack_imports is not None: @@ -579,7 +566,7 @@ def export_library( m = _ffi_api.ModulePackImportsToLLVM( self, is_system_lib, llvm_target_string, pack_lib_prefix ) - m.save(path_obj) + m.write_to_file(path_obj) files.append(path_obj) else: path_cc = os.path.join(workspace_dir, f"{pack_lib_prefix}devc.c") @@ -625,10 +612,10 @@ def system_lib(symbol_prefix=""): module : runtime.Module The system-wide library module. """ - return _ffi_api.SystemLib(symbol_prefix) + return _mod_ffi_api.SystemLib(symbol_prefix) -def load_module(path, fmt=""): +def load_module(path): """Load module from file. Parameters @@ -636,10 +623,6 @@ def load_module(path, fmt=""): path : str The path to the module file. - fmt : str, optional - The format of the file, if not specified - it will be inferred from suffix of the file. - Returns ------- module : runtime.Module @@ -673,7 +656,7 @@ def load_module(path, fmt=""): _cc.create_shared(path + ".so", files) path += ".so" # Redirect to the load API - return _ffi_api.ModuleLoadFromFile(path, fmt) + return _mod_ffi_api.ModuleLoadFromFile(path) def load_static_library(path, func_names): diff --git a/python/tvm/testing/usmp.py b/python/tvm/testing/usmp.py deleted file mode 100644 index c35ac255c3b1..000000000000 --- a/python/tvm/testing/usmp.py +++ /dev/null @@ -1,39 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" This file contains USMP tests harnesses.""" - -import tvm - - -def is_tvm_backendallocworkspace_calls(mod: tvm.runtime.module) -> bool: - """TVMBackendAllocWorkspace call check. - - This checker checks whether any c-source produced has TVMBackendAllocWorkspace calls. - If USMP is invoked, none of them should have TVMBAW calls - """ - dso_modules = mod._collect_dso_modules() - for dso_mod in dso_modules: - if dso_mod.type_key not in ["c", "llvm"]: - assert ( - False - ), 'Current AoT codegen flow should only produce type "c" or "llvm" runtime modules' - - source = dso_mod.get_source() - if source.count("TVMBackendAllocWorkspace") != 0: - return True - - return False diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 684abbe38c17..7acd0f215502 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -599,10 +599,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array MSCTensorRTCompiler(Array functions, - Map target_option, - Map constant_names) { - Array compiled_functions; +Array MSCTensorRTCompiler(Array functions, + Map target_option, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; const auto& name_opt = func->GetAttr(msc_attr::kUnique); @@ -615,9 +615,9 @@ Array MSCTensorRTCompiler(Array functions, serializer.serialize(func); std::string graph_json = serializer.GetJSON(); const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.msc_tensorrt_runtime_create"); - VLOG(1) << "Creating msc_tensorrt runtime::Module for '" << func_name << "'"; + VLOG(1) << "Creating msc_tensorrt ffi::Module for '" << func_name << "'"; compiled_functions.push_back( - pf(func_name, graph_json, serializer.GetConstantNames()).cast()); + pf(func_name, graph_json, serializer.GetConstantNames()).cast()); } return compiled_functions; } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 7b0846051609..41a22e4d39d8 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -21,6 +21,7 @@ */ #include #include +#include #include #include #include @@ -46,16 +47,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](const Any& object, bool map_free_vars) -> int64_t { return ffi::StructuralHash::Hash(object, map_free_vars); }); - refl::TypeAttrDef() + refl::TypeAttrDef() .def("__data_to_json__", - [](const runtime::ModuleNode* node) { - std::string bytes = codegen::SerializeModuleToBytes(GetRef(node), + [](const ffi::ModuleObj* node) { + std::string bytes = codegen::SerializeModuleToBytes(GetRef(node), /*export_dso*/ false); return ffi::Base64Encode(ffi::Bytes(bytes)); }) .def("__data_from_json__", [](const String& base64_bytes) { Bytes bytes = ffi::Base64Decode(base64_bytes); - runtime::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string()); + ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string()); return rtmod; }); diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 146f7b932f9c..b25bfbdb22a7 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -311,9 +311,9 @@ void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) * \param functions The extern functions to be compiled via OpenCLML * \return Runtime modules. */ -Array OpenCLMLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array OpenCLMLCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "OpenCLML partition:" << std::endl << func; OpenCLMLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -322,8 +322,8 @@ Array OpenCLMLCompiler(Array functions, Map()); + VLOG(1) << "Creating clml ffi::Module for '" << func_name << "'"; + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; } diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 3f132b024a1b..0cd0150970e6 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -109,9 +109,9 @@ class CublasJSONSerializer : public JSONSerializer { Map bindings_; }; -Array CublasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array CublasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -120,7 +120,7 @@ Array CublasCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index b529c6f79692..a0201ccfda77 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -133,9 +133,9 @@ class cuDNNJSONSerializer : public JSONSerializer { Map bindings_; }; -Array cuDNNCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array cuDNNCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { cuDNNJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -144,7 +144,7 @@ Array cuDNNCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index b6307af0237b..29ad2de412d8 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -55,7 +55,7 @@ std::string EmitSignature(const std::vector& out, const std::string& fun return code_stream_.str(); } -runtime::Module Finalize(const std::string& code, const Array& func_names) { +ffi::Module Finalize(const std::string& code, const Array& func_names) { ICHECK(!func_names.empty()) << "Should only create CUTLASS CSourceModule if there is at least one CUTLASS partition"; @@ -72,7 +72,7 @@ runtime::Module Finalize(const std::string& code, const Array& func_name VLOG(1) << "Generated CUTLASS code:" << std::endl << code; return pf(default_headers.str() + code, "cu", func_names, /*const_vars=*/Array()) - .cast(); + .cast(); } class CodegenResultNode : public Object { @@ -337,8 +337,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, class CutlassModuleCodegen { public: - runtime::Module CreateCSourceModule(Array functions, - const Map& options) { + ffi::Module CreateCSourceModule(Array functions, const Map& options) { std::string headers = ""; std::string code = ""; for (const auto& f : functions) { @@ -373,8 +372,8 @@ class CutlassModuleCodegen { Array func_names_; }; -Array CUTLASSCompiler(Array functions, Map options, - Map /*unused*/) { +Array CUTLASSCompiler(Array functions, Map options, + Map /*unused*/) { const auto tune_func = tvm::ffi::Function::GetGlobal("contrib.cutlass.tune_relax_function"); ICHECK(tune_func.has_value()) << "The packed function contrib.cutlass.tune_relax_function not found, " @@ -386,7 +385,7 @@ Array CUTLASSCompiler(Array functions, Map(); + ffi::Module cutlass_mod = (*pf)(source_mod, options).cast(); return {cutlass_mod}; } diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc index 83cbdd8e2bbc..efa4e1b685c7 100644 --- a/src/relax/backend/contrib/dnnl/codegen.cc +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -81,9 +81,9 @@ class DNNLJSONSerializer : public JSONSerializer { Map bindings_; }; -Array DNNLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array DNNLCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { DNNLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -92,7 +92,7 @@ Array DNNLCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc index 761221c88bac..e1104ac3d6c7 100644 --- a/src/relax/backend/contrib/hipblas/codegen.cc +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -86,9 +86,9 @@ class HipblasJSONSerializer : public JSONSerializer { Map bindings_; }; -Array HipblasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array HipblasCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { HipblasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -97,7 +97,7 @@ Array HipblasCompiler(Array functions, Map()); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index c62523f5392d..f045e5b9c2c0 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -247,11 +247,11 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { ExprVisitor::VisitExpr_(call_node); } -Array NNAPICompiler(Array functions, Map /*unused*/, - Map constant_names) { +Array NNAPICompiler(Array functions, Map /*unused*/, + Map constant_names) { VLOG(1) << "NNAPI Compiler"; - Array compiled_functions; + Array compiled_functions; for (const auto& func : functions) { NNAPIJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); serializer.serialize(func); @@ -260,7 +260,7 @@ Array NNAPICompiler(Array functions, Map(); + tvm::ffi::Module mod = result.cast(); compiled_functions.push_back(mod); } diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 02483abdc3dc..6dd8216469c2 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -225,9 +225,9 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array TensorRTCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +Array TensorRTCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "TensorRT partition:" << std::endl << func; TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -237,8 +237,8 @@ Array TensorRTCompiler(Array functions, Map()); + VLOG(1) << "Creating tensorrt ffi::Module for '" << func_name << "'"; + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast()); } return compiled_functions; } diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 27165db34350..1f9e8c0378a7 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -310,10 +310,10 @@ class CodeGenVM : public ExprFunctor { String sym = op->global_symbol; String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); String code = opt_code.value(); - Module c_source_module = + ffi::Module c_source_module = codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, /*func_names=*/{sym}, /*const_vars=*/{}); - builder_->exec()->Import(c_source_module); + builder_->exec()->ImportModule(c_source_module); } builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); return builder_->GetFunction(op->global_symbol); @@ -441,17 +441,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \return The created module. */ void LinkModules(ObjectPtr exec, const Map& params, - const tvm::runtime::Module& lib, const Array& ext_libs) { + const tvm::ffi::Module& lib, const Array& ext_libs) { // query if we need const loader for ext_modules // Wrap all submodules in the initialization wrapper. std::unordered_map> const_vars_by_symbol; - for (tvm::runtime::Module mod : ext_libs) { - auto pf_sym = mod.GetFunction("get_symbol"); - auto pf_var = mod.GetFunction("get_const_vars"); + for (tvm::ffi::Module mod : ext_libs) { + auto pf_sym = mod->GetFunction("get_symbol"); + auto pf_var = mod->GetFunction("get_const_vars"); std::vector symbol_const_vars; - if (pf_sym != nullptr && pf_var != nullptr) { - String symbol = pf_sym().cast(); - Array variables = pf_var().cast>(); + if (pf_sym.has_value() && pf_var.has_value()) { + String symbol = (*pf_sym)().cast(); + Array variables = (*pf_var)().cast>(); for (size_t i = 0; i < variables.size(); i++) { symbol_const_vars.push_back(variables[i].operator std::string()); } @@ -465,18 +465,18 @@ void LinkModules(ObjectPtr exec, const MapImportModule(lib); for (const auto& it : ext_libs) { - const_loader_mod.Import(it); + const_loader_mod->ImportModule(it); } - exec->Import(const_loader_mod); + exec->ImportModule(const_loader_mod); } else { // directly import the ext_modules as we don't need const loader - exec->Import(lib); + exec->ImportModule(lib); for (const auto& it : ext_libs) { - exec->Import(it); + exec->ImportModule(it); } } } @@ -484,14 +484,14 @@ void LinkModules(ObjectPtr exec, const Map lib, Array ext_libs, - Map params) { +ffi::Module VMLink(ExecBuilder builder, Target target, Optional lib, + Array ext_libs, Map params) { ObjectPtr executable = builder->Get(); if (!lib.defined()) { - lib = codegen::CSourceModuleCreate(";", "", Array{}); + lib = codegen::CSourceModuleCreate(";", "c", Array{}); } LinkModules(executable, params, lib.value(), ext_libs); - return Module(executable); + return ffi::Module(executable); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index bb43b4ef033d..8e229c4fe641 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -374,7 +374,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ExecBuilder builder, String value) { return builder->GetFunction(value).data(); }) .def("relax.ExecBuilderGet", [](ExecBuilder builder) { ObjectPtr p_exec = builder->Get(); - return runtime::Module(p_exec); + return ffi::Module(p_exec); }); }); diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index fc7c6b26df10..c1aee73cc258 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -119,8 +119,8 @@ class ConstantFolder : public ExprMutator { // TODO(Hongyi): further check and narrow the scope of foldable function const auto pf = tvm::ffi::Function::GetGlobalRequired("tir.build"); func = WithAttr(func, tvm::attr::kGlobalSymbol, String("tir_function")); - runtime::Module rt_module = pf(func, eval_cpu_target).cast(); - build_func = rt_module.GetFunction("tir_function"); + ffi::Module rt_module = pf(func, eval_cpu_target).cast(); + build_func = rt_module->GetFunction("tir_function"); } catch (const tvm::Error& err) { // build failure may happen in which case we skip DLOG(WARNING) << "Build failure for function " << func << ", Error message: " << err.what(); diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index a4a109cb0e22..0cc0a070aac5 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -80,7 +80,7 @@ class CodeGenRunner : ExprMutator { auto out_mod = builder_->GetContextIRModule(); if (ext_mods.size()) { - if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { + if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { auto old_ext_mods = opt_old_ext_mods.value(); ext_mods.insert(ext_mods.begin(), old_ext_mods.begin(), old_ext_mods.end()); } @@ -168,7 +168,7 @@ class CodeGenRunner : ExprMutator { } private: - Array InvokeCodegen(IRModule mod, Map target_options) { + Array InvokeCodegen(IRModule mod, Map target_options) { std::unordered_map> target_functions; for (const auto& entry : mod->functions) { @@ -186,7 +186,7 @@ class CodeGenRunner : ExprMutator { }); } - Array ext_mods; + Array ext_mods; for (const auto& [target, functions] : target_functions) { OptionMap options = target_options.Get(target).value_or(OptionMap()); @@ -196,8 +196,8 @@ class CodeGenRunner : ExprMutator { const auto codegen = tvm::ffi::Function::GetGlobal(codegen_name); ICHECK(codegen.has_value()) << "Codegen is not found: " << codegen_name << "\n"; - Array compiled_functions = - (*codegen)(functions, options, constant_names).cast>(); + Array compiled_functions = + (*codegen)(functions, options, constant_names).cast>(); ext_mods.insert(ext_mods.end(), compiled_functions.begin(), compiled_functions.end()); } diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 27daf7cc3e01..2c02fb556c73 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -27,12 +27,13 @@ * code and constants significantly reduces the efforts for handling external * codegen and runtimes. */ +#include #include #include +#include #include #include #include -#include #include #include @@ -44,9 +45,9 @@ namespace runtime { * \brief The const-loader module is designed to manage initialization of the * imported submodules for the C++ runtime. */ -class ConstLoaderModuleNode : public ModuleNode { +class ConstLoaderModuleObj : public ffi::ModuleObj { public: - ConstLoaderModuleNode( + ConstLoaderModuleObj( const std::unordered_map& const_var_ndarray, const std::unordered_map>& const_vars_by_symbol) : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { @@ -66,7 +67,7 @@ class ConstLoaderModuleNode : public ModuleNode { } } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + ffi::Optional GetFunction(const String& name) final { VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")"; // Initialize and memoize the module. // Usually, we have some warmup runs. The module initialization should be @@ -75,9 +76,10 @@ class ConstLoaderModuleNode : public ModuleNode { this->InitSubModule(name); initialized_[name] = true; } + ObjectRef _self = ffi::GetRef(this); if (name == "get_const_var_ndarray") { - return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + return ffi::Function([_self, this](ffi::PackedArgs args, ffi::Any* rv) { Map ret_map; for (const auto& kv : const_var_ndarray_) { ret_map.Set(kv.first, kv.second); @@ -89,18 +91,18 @@ class ConstLoaderModuleNode : public ModuleNode { // Run the module. // Normally we would only have a limited number of submodules. The runtime // symobl lookup overhead should be minimal. - ICHECK(!this->imports().empty()); - for (Module it : this->imports()) { - ffi::Function pf = it.GetFunction(name); - if (pf != nullptr) return pf; + ICHECK(!this->imports_.empty()); + for (const Any& it : this->imports_) { + ffi::Optional pf = it.cast()->GetFunction(name); + if (pf.has_value()) return pf.value(); } - return ffi::Function(nullptr); + return std::nullopt; } - const char* type_key() const final { return "const_loader"; } + const char* kind() const final { return "const_loader"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; }; /*! * \brief Get the list of constants that is required by the given module. @@ -134,15 +136,14 @@ class ConstLoaderModuleNode : public ModuleNode { * found module accordingly by passing the needed constants into it. */ void InitSubModule(const std::string& symbol) { - ffi::Function init(nullptr); - for (Module it : this->imports()) { + for (const Any& it : this->imports_) { // Get the initialization function from the imported modules. std::string init_name = "__init_" + symbol; - init = it.GetFunction(init_name, false); - if (init != nullptr) { + Optional init = it.cast()->GetFunction(init_name, false); + if (init.has_value()) { auto md = GetRequiredConstants(symbol); // Initialize the module with constants. - int ret = init(md).cast(); + int ret = (*init)(md).cast(); // Report the error if initialization is failed. ICHECK_EQ(ret, 0); break; @@ -150,7 +151,11 @@ class ConstLoaderModuleNode : public ModuleNode { } } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string bytes_buffer; + dmlc::MemoryStringStream ms(&bytes_buffer); + dmlc::Stream* stream = &ms; + std::vector variables; std::vector const_var_ndarray; for (const auto& it : const_var_ndarray_) { @@ -182,10 +187,12 @@ class ConstLoaderModuleNode : public ModuleNode { for (uint64_t i = 0; i < sz; i++) { stream->Write(const_vars[i]); } + return ffi::Bytes(bytes_buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; // Load the variables. std::vector variables; @@ -225,8 +232,8 @@ class ConstLoaderModuleNode : public ModuleNode { const_vars_by_symbol[symbols[i]] = const_vars[i]; } - auto n = make_object(const_var_ndarray, const_vars_by_symbol); - return Module(n); + auto n = make_object(const_var_ndarray, const_vars_by_symbol); + return ffi::Module(n); } private: @@ -241,17 +248,17 @@ class ConstLoaderModuleNode : public ModuleNode { std::unordered_map> const_vars_by_symbol_; }; -Module ConstLoaderModuleCreate( +ffi::Module ConstLoaderModuleCreate( const std::unordered_map& const_var_ndarray, const std::unordered_map>& const_vars_by_symbol) { - auto n = make_object(const_var_ndarray, const_vars_by_symbol); - return Module(n); + auto n = make_object(const_var_ndarray, const_vars_by_symbol); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_const_loader", - ConstLoaderModuleNode::LoadFromBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.const_loader", + ConstLoaderModuleObj::LoadFromBytes); }); } // namespace runtime diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h index eb548dfcf370..c093818763d8 100644 --- a/src/runtime/const_loader_module.h +++ b/src/runtime/const_loader_module.h @@ -43,7 +43,7 @@ namespace runtime { * * \return The created ConstLoaderModule. */ -Module ConstLoaderModuleCreate( +ffi::Module ConstLoaderModuleCreate( const std::unordered_map& const_var_ndarray, const std::unordered_map>& const_vars_by_symbol); diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 183d4c0a5b27..3de9e85a57c5 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -69,7 +69,7 @@ class ACLRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const override { return "arm_compute_lib"; } + const char* kind() const override { return "arm_compute_lib"; } /*! * \brief Initialize runtime. Create ACL layer from JSON @@ -588,18 +588,18 @@ class ACLRuntime : public JSONRuntimeBase { } #endif }; -runtime::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.arm_compute_lib_runtime_create", ACLRuntimeCreate) - .def("runtime.module.loadbinary_arm_compute_lib", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.arm_compute_lib", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 0bd961524e0c..9080eeb9bb34 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -91,7 +91,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { const Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - const char* type_key() const override { return "bnns_json"; } + const char* kind() const override { return "bnns_json"; } void Init(const Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) @@ -557,17 +557,17 @@ class BNNSJSONRuntime : public JSONRuntimeBase { std::vector tensors_eid_; }; -runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.BNNSJSONRuntimeCreate", BNNSJSONRuntimeCreate) - .def("runtime.module.loadbinary_bnns_json", BNNSJSONRuntime::LoadFromBinary); + .def("ffi.Module.load_from_bytes.bnns_json", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 6b96cbb41bec..9d13e427b24a 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -193,7 +193,7 @@ class CLMLRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const override { return "clml"; } + const char* kind() const override { return "clml"; } /*! * \brief Initialize runtime. Create CLML layer from JSON @@ -1826,17 +1826,17 @@ class CLMLRuntime : public JSONRuntimeBase { std::string clml_symbol; }; -runtime::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.clml_runtime_create", CLMLRuntimeCreate) - .def("runtime.module.loadbinary_clml", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.clml", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index 5f5eec1d03ca..257b624bbf2b 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -29,6 +29,7 @@ #import #include +#include #include #include @@ -95,7 +96,7 @@ class CoreMLModel { * This runtime can be accessed in various language via * TVM runtime ffi::Function API. */ -class CoreMLRuntime : public ModuleNode { +class CoreMLRuntime : public ffi::ModuleObj { public: /*! * \brief Get member function to front-end. @@ -103,11 +104,11 @@ class CoreMLRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual Optional GetFunction(const String& name); /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -115,12 +116,12 @@ class CoreMLRuntime : public ModuleNode { * binary stream. * \param stream The binary stream to save to. */ - void SaveToBinary(dmlc::Stream* stream) final; + ffi::Bytes SaveToBytes() const final; /*! * \return The type key of the executor. */ - const char* type_key() const { return "coreml"; } + const char* kind() const { return "coreml"; } /*! * \brief Initialize the coreml runtime with coreml model and context. diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 6dfd7a67e5b4..8e0b2542b443 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -129,8 +129,7 @@ model_ = std::unique_ptr(new CoreMLModel(url)); } -ffi::Function CoreMLRuntime::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional CoreMLRuntime::GetFunction(const String& name) { // Return member functions during query. if (name == "invoke" || name == "run") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { model_->Invoke(); }); @@ -183,14 +182,14 @@ *rv = out; }); } else { - return ffi::Function(); + return std::nullopt; } } -Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_path) { +ffi::Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_path) { auto exec = make_object(); exec->Init(symbol, model_path); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ @@ -200,7 +199,10 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p }); }); -void CoreMLRuntime::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes CoreMLRuntime::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; NSURL* url = model_->url_; NSFileWrapper* dirWrapper = [[[NSFileWrapper alloc] initWithURL:url options:0 error:nil] autorelease]; @@ -209,6 +211,7 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p stream->Write((uint64_t)[dirData length]); stream->Write([dirData bytes], [dirData length]); DLOG(INFO) << "Save " << symbol_ << " (" << [dirData length] << " bytes)"; + return ffi::Bytes(buffer); } /*! @@ -218,8 +221,9 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p * * \return The created CoreML module. */ -Module CoreMLRuntimeLoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module CoreMLRuntimeLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; NSString* tempBaseDir = NSTemporaryDirectory(); if (tempBaseDir == nil) tempBaseDir = @"/tmp"; @@ -249,12 +253,12 @@ Module CoreMLRuntimeLoadFromBinary(void* strm) { auto exec = make_object(); exec->Init(symbol, [model_path UTF8String]); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_coreml", CoreMLRuntimeLoadFromBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.coreml", CoreMLRuntimeLoadFromBytes); }); } // namespace runtime diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 9f2cfaa50698..11fa3b0c4d49 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -49,21 +49,22 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Optional GetFunction(const String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since CublasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call cuBLAS on the inputs from ffi::PackedArgs. + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); } else { - return JSONRuntimeBase::GetFunction(name, sptr_to_self); + return JSONRuntimeBase::GetFunction(name); } } - const char* type_key() const override { return "cublas_json"; } // May be overridden + const char* kind() const override { return "cublas_json"; } // May be overridden void Run(ffi::PackedArgs args) { auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); @@ -148,18 +149,18 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.CublasJSONRuntimeCreate", CublasJSONRuntimeCreate) - .def("runtime.module.loadbinary_cublas_json", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.cublas_json", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index d54fad9d99ab..fd4fa68c783c 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -69,7 +69,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { } } - const char* type_key() const override { return "cudnn_json"; } // May be overridden + const char* kind() const override { return "cudnn_json"; } // May be overridden void Run() override { for (const auto& f : op_execs_) { @@ -232,18 +232,18 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { std::vector> op_execs_; }; -runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.cuDNNJSONRuntimeCreate", cuDNNJSONRuntimeCreate) - .def("runtime.module.loadbinary_cudnn_json", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.cudnn_json", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 138f41cb7751..686a8048c7b5 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -58,7 +58,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { for (const auto e : outputs_) run_arg_eid_.push_back(EntryID(e)); } - const char* type_key() const override { return "dnnl_json"; } + const char* kind() const override { return "dnnl_json"; } void Init(const Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) @@ -100,7 +100,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /* Override GetFunction to reimplement Run method */ - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Optional GetFunction(const String& name) override { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; @@ -111,7 +112,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Run(args); }); } else { - return JSONRuntimeBase::GetFunction(name, sptr_to_self); + return JSONRuntimeBase::GetFunction(name); } } @@ -922,17 +923,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::vector run_arg_eid_; }; -runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.DNNLJSONRuntimeCreate", DNNLJSONRuntimeCreate) - .def("runtime.module.loadbinary_dnnl_json", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.dnnl_json", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 515eac9489b6..a52da2318b71 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -63,10 +63,10 @@ void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, Device dev) { device_ = dev; } -Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { +ffi::Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { auto exec = make_object(); exec->Init(tflite_model_bytes, dev); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index fbac6f12fea9..5750b91ab4ca 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -47,21 +47,22 @@ class HipblasJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + ffi::Optional GetFunction(const String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since HipblasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call hipBLAS on the inputs from ffi::PackedArgs. + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); } else { - return JSONRuntimeBase::GetFunction(name, sptr_to_self); + return JSONRuntimeBase::GetFunction(name); } } - const char* type_key() const override { return "hipblas_json"; } // May be overridden + const char* kind() const override { return "hipblas_json"; } // May be overridden void Run(ffi::PackedArgs args) { auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(); @@ -134,18 +135,18 @@ class HipblasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -runtime::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { +ffi::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.HipblasJSONRuntimeCreate", HipblasJSONRuntimeCreate) - .def("runtime.module.loadbinary_hipblas_json", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.hipblas_json", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 025e85263ebc..d9e5af60f299 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -25,7 +25,7 @@ #ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ -#include +#include #include #include @@ -47,7 +47,7 @@ namespace json { * \brief A json runtime that executes the serialized JSON format. This runtime * can be extended by user defined runtime for execution. */ -class JSONRuntimeBase : public ModuleNode { +class JSONRuntimeBase : public ffi::ModuleObj { public: JSONRuntimeBase(const std::string& symbol_name, const std::string& graph_json, const Array const_names) @@ -55,13 +55,11 @@ class JSONRuntimeBase : public ModuleNode { LoadGraph(graph_json_); } - ~JSONRuntimeBase() override = default; - - const char* type_key() const override { return "json"; } // May be overridden + const char* kind() const override { return "json"; } // May be overridden /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const override { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! \brief Initialize a specific json runtime. */ @@ -95,7 +93,8 @@ class JSONRuntimeBase : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + Optional GetFunction(const String& name) override { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); @@ -148,11 +147,14 @@ class JSONRuntimeBase : public ModuleNode { *rv = 0; }); } else { - return ffi::Function(nullptr); + return std::nullopt; } } - void SaveToBinary(dmlc::Stream* stream) override { + ffi::Bytes SaveToBytes() const override { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; // Save the symbol stream->Write(symbol_name_); // Save the graph @@ -163,12 +165,14 @@ class JSONRuntimeBase : public ModuleNode { consts.push_back(it); } stream->Write(consts); + return ffi::Bytes(buffer); } template ::value>::type> - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string symbol; std::string graph_json; std::vector consts; @@ -181,7 +185,7 @@ class JSONRuntimeBase : public ModuleNode { const_names.push_back(it); } auto n = make_object(symbol, graph_json, const_names); - return Module(n); + return ffi::Module(n); } /*! @@ -190,7 +194,7 @@ class JSONRuntimeBase : public ModuleNode { * \param format the format to return. * \return A string of JSON. */ - String GetSource(const String& format = "json") override { return graph_json_; } + String InspectSource(const String& format) const override { return graph_json_; } protected: /*! diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index d531963ec822..bc1eb77ea18c 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -23,9 +23,9 @@ */ #include +#include #include #include -#include #include #include @@ -155,7 +155,7 @@ hardware and then runs the generated binary on the target hardware. * */ -class MarvellHardwareModuleNode : public ModuleNode { +class MarvellHardwareModuleNode : public ffi::ModuleObj { public: MarvellHardwareModuleNode(const std::string& symbol_name, const std::string& nodes_json, const std::string& bin_code, const int input_count, @@ -200,10 +200,10 @@ class MarvellHardwareModuleNode : public ModuleNode { } } - const char* type_key() const { return "mrvl_hw"; } + const char* kind() const { return "mrvl_hw"; } int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -212,7 +212,8 @@ class MarvellHardwareModuleNode : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) { + virtual Optional GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); @@ -240,10 +241,13 @@ class MarvellHardwareModuleNode : public ModuleNode { *rv = 0; }); } - return ffi::Function(nullptr); + return std::nullopt; } - virtual void SaveToBinary(dmlc::Stream* stream) { + virtual ffi::Bytes SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; // Save the symbol name and other data and serialize them to // binary format. stream->Write(symbol_name_); @@ -252,10 +256,12 @@ class MarvellHardwareModuleNode : public ModuleNode { stream->Write(num_inputs_); stream->Write(num_outputs_); stream->Write(run_arg.num_batches); + return ffi::Bytes(buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string symbol_name; std::string nodes_json; std::string bin_code; @@ -270,7 +276,7 @@ class MarvellHardwareModuleNode : public ModuleNode { ICHECK(stream->Read(&batch_size)) << "Loading batch_size failed"; auto n = make_object(symbol_name, nodes_json, bin_code, num_inputs, num_outputs, batch_size); - return Module(n); + return ffi::Module(n); } /*! @@ -279,7 +285,7 @@ class MarvellHardwareModuleNode : public ModuleNode { * \param format the format to return. * \return A string of JSON. */ - String GetSource(const String& format = "json") override { return nodes_json_; } + String InspectSource(const String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -463,12 +469,12 @@ class MarvellHardwareModuleNode : public ModuleNode { } }; -runtime::Module MarvellHardwareModuleRuntimeCreate(const String& symbol_name, - const String& nodes_json, const String& bin_code, - int num_input, int num_output, int batch_size) { +ffi::Module MarvellHardwareModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, + const String& bin_code, int num_input, + int num_output, int batch_size) { auto n = make_object(symbol_name, nodes_json, bin_code, num_input, num_output, batch_size); - return runtime::Module(n); + return ffi::Module(n); } bool MarvellHardwareModuleNode::initialized_model = false; @@ -481,7 +487,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.mrvl_hw_runtime_create", MarvellHardwareModuleRuntimeCreate) - .def("runtime.module.loadbinary_mrvl_hw", MarvellHardwareModuleNode::LoadFromBinary); + .def("ffi.Module.load_from_bytes.mrvl_hw", MarvellHardwareModuleNode::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index b9f9bc960c04..974ca4a69a1f 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -24,9 +24,9 @@ #include #include +#include #include #include -#include #include #include @@ -49,7 +49,7 @@ hardware and then runs the generated binary using the Marvell software simulator * \param bin_code The binary code generated by the Marvell compiler for the subgraph */ -class MarvellSimulatorModuleNode : public ModuleNode { +class MarvellSimulatorModuleNode : public ffi::ModuleObj { public: MarvellSimulatorModuleNode(const std::string& symbol_name, const std::string& nodes_json, const std::string& bin_code) @@ -57,11 +57,11 @@ class MarvellSimulatorModuleNode : public ModuleNode { set_num_inputs_outputs(); } - const char* type_key() const { return "mrvl_sim"; } + const char* kind() const { return "mrvl_sim"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -70,7 +70,8 @@ class MarvellSimulatorModuleNode : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) { + virtual Optional GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); @@ -83,19 +84,24 @@ class MarvellSimulatorModuleNode : public ModuleNode { *rv = 0; }); } - return ffi::Function(nullptr); + return std::nullopt; } - virtual void SaveToBinary(dmlc::Stream* stream) { + virtual ffi::Bytes SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; // Save the symbol name and other data and serialize them to // binary format. stream->Write(symbol_name_); stream->Write(nodes_json_); stream->Write(bin_code_); + return ffi::Bytes(buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string symbol_name; std::string nodes_json; std::string bin_code; @@ -106,7 +112,7 @@ class MarvellSimulatorModuleNode : public ModuleNode { << "Marvell-Compiler-ERROR-Internal::Loading nodes json failed"; ICHECK(stream->Read(&bin_code)) << "Marvell-Compiler-ERROR-Internal::Loading bin code failed"; auto n = make_object(symbol_name, nodes_json, bin_code); - return Module(n); + return ffi::Module(n); } /*! @@ -115,7 +121,7 @@ class MarvellSimulatorModuleNode : public ModuleNode { * \param format the format to return. * \return A string of JSON. */ - String GetSource(const String& format = "json") override { return nodes_json_; } + String InspectSource(const String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -143,18 +149,17 @@ class MarvellSimulatorModuleNode : public ModuleNode { } }; -runtime::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, - const String& nodes_json, - const String& bin_code) { +ffi::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, + const String& bin_code) { auto n = make_object(symbol_name, nodes_json, bin_code); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.mrvl_runtime_create", MarvellSimulatorModuleRuntimeCreate) - .def("runtime.module.loadbinary_mrvl_sim", MarvellSimulatorModuleNode::LoadFromBinary); + .def("ffi.Module.load_from_bytes.mrvl_sim", MarvellSimulatorModuleNode::LoadFromBytes); }); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index fd65175d6f8e..e19c03d4fda5 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -65,7 +65,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - ~MSCTensorRTRuntime() override { + ~MSCTensorRTRuntime() { VLOG(1) << "Destroying MSC TensorRT runtime"; DestroyEngine(); } @@ -75,11 +75,11 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const final { return "msc_tensorrt"; } + const char* kind() const final { return "msc_tensorrt"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -343,18 +343,18 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { #endif }; -runtime::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.msc_tensorrt_runtime_create", MSCTensorRTRuntimeCreate) - .def("runtime.module.loadbinary_msc_tensorrt", - JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.msc_tensorrt", + JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index 52b4a4711837..71335f3ee287 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -235,17 +235,17 @@ class NNAPIRuntime : public JSONRuntimeBase { #endif // ifdef TVM_GRAPH_EXECUTOR_NNAPI }; -runtime::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.nnapi_runtime_create", NNAPIRuntimeCreate) - .def("runtime.module.loadbinary_nnapi", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.nnapi", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index ba9725d9bb10..ff565444e2b5 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -96,11 +96,11 @@ class TensorRTRuntime : public JSONRuntimeBase { * * \return module type key. */ - const char* type_key() const final { return "tensorrt"; } + const char* kind() const final { return "tensorrt"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } /*! @@ -519,17 +519,17 @@ class TensorRTRuntime : public JSONRuntimeBase { bool use_fp16_; }; -runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { +ffi::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { auto n = make_object(symbol_name, graph_json, const_names); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.tensorrt_runtime_create", TensorRTRuntimeCreate) - .def("runtime.module.loadbinary_tensorrt", JSONRuntimeBase::LoadFromBinary); + .def("ffi.Module.load_from_bytes.tensorrt", JSONRuntimeBase::LoadFromBytes); }); } // namespace contrib diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 0fa3bc2fe64c..c35af35eae13 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -87,6 +87,7 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { return DataType::Float(16); default: LOG(FATAL) << "tflite data type not support yet: " << dtype; + TVM_FFI_UNREACHABLE(); } } @@ -151,8 +152,8 @@ NDArray TFLiteRuntime::GetOutput(int index) const { return ret; } -ffi::Function TFLiteRuntime::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Optional TFLiteRuntime::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Return member functions during query. if (name == "set_input") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -174,14 +175,14 @@ ffi::Function TFLiteRuntime::GetFunction(const String& name, this->SetNumThreads(num_threads); }); } else { - return ffi::Function(); + return std::nullopt; } } -Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { +ffi::Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { auto exec = make_object(); exec->Init(tflite_model_bytes, dev); - return Module(exec); + return ffi::Module(exec); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 5e8751a01281..396bd01104d5 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -27,8 +27,8 @@ #include #include +#include #include -#include #include #include @@ -46,7 +46,7 @@ namespace runtime { * This runtime can be accessed in various language via * TVM runtime ffi::Function API. */ -class TFLiteRuntime : public ModuleNode { +class TFLiteRuntime : public ffi::ModuleObj { public: /*! * \brief Get member function to front-end. @@ -54,15 +54,15 @@ class TFLiteRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual Optional GetFunction(const String& name); /*! * \return The type key of the executor. */ - const char* type_key() const { return "TFLiteRuntime"; } + const char* kind() const { return "TFLiteRuntime"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; }; + int GetPropertyMask() const final { return ffi::Module::kRunnable; }; /*! * \brief Invoke the internal tflite interpreter and run the whole model in diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 6b71df928d23..5a4e682da8da 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -46,7 +47,7 @@ namespace runtime { // cuModule is a per-GPU module // The runtime will contain a per-device module table // The modules will be lazily loaded -class CUDAModuleNode : public runtime::ModuleNode { +class CUDAModuleNode : public ffi::ModuleObj { public: explicit CUDAModuleNode(std::string data, std::string fmt, std::unordered_map fmap, @@ -64,16 +65,16 @@ class CUDAModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { return "cuda"; } + const char* kind() const final { return "cuda"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { @@ -87,13 +88,17 @@ class CUDAModuleNode : public runtime::ModuleNode { } } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (format == fmt_) return data_; if (cuda_source_.length() != 0) { return cuda_source_; @@ -205,7 +210,7 @@ class CUDAWrappedFunc { << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) << ")\n"; - std::string cuda = m_->GetSource(""); + std::string cuda = m_->InspectSource(""); if (cuda.length() != 0) { os << "// func_name=" << func_name_ << "\n" << "// CUDA Source\n" @@ -255,8 +260,8 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -ffi::Function CUDAModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional CUDAModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == symbol::tvm_prepare_global_barrier) { return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self)); @@ -269,15 +274,15 @@ ffi::Function CUDAModuleNode::GetFunction(const String& name, return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags); } -Module CUDAModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +ffi::Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { auto n = make_object(data, fmt, fmap, cuda_source); - return Module(n); + return ffi::Module(n); } // Load module from module. -Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -287,8 +292,9 @@ Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -Module CUDAModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string data; std::unordered_map fmap; std::string fmt; @@ -301,9 +307,9 @@ Module CUDAModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadfile_cubin", CUDAModuleLoadFile) - .def("runtime.module.loadfile_ptx", CUDAModuleLoadFile) - .def("runtime.module.loadbinary_cuda", CUDAModuleLoadBinary); + .def("ffi.Module.load_from_file.cuda", CUDAModuleLoadFile) + .def("ffi.Module.load_from_file.ptx", CUDAModuleLoadFile) + .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h index e65c5fe60811..b92dbe1cc683 100644 --- a/src/runtime/cuda/cuda_module.h +++ b/src/runtime/cuda/cuda_module.h @@ -47,9 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param cuda_source Optional, cuda source file */ -Module CUDAModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string cuda_source); +ffi::Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_ diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 947d8884a59c..ae85f9ce5384 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -22,6 +22,7 @@ * \brief Device specific implementations */ #include +#include #include #include #include @@ -235,10 +236,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ using namespace tvm::runtime; int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* func) { - TVM_FFI_SAFE_CALL_BEGIN(); - *func = const_cast( - static_cast(mod_node)->GetFuncFromEnv(func_name)->get()); - TVM_FFI_SAFE_CALL_END(); + return TVMFFIEnvLookupFromImports(mod_node, func_name, func); } void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 126a593e5173..b650b143e401 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -34,36 +34,39 @@ namespace runtime { class DSOLibraryCache { public: - Module Open(const std::string& library_path) { + ffi::Module Open(const std::string& library_path) { std::lock_guard lock(mutex_); - Module& lib = cache_[library_path]; - if (!lib.defined()) { - lib = Module::LoadFromFile(library_path, ""); + auto it = cache_.find(library_path); + if (it == cache_.end()) { + ffi::Module lib = ffi::Module::LoadFromFile(library_path); + cache_.emplace(library_path, lib); + return lib; } - return lib; + return it->second; } - std::unordered_map cache_; + std::unordered_map cache_; std::mutex mutex_; }; -Module LoadVMModule(std::string path, Optional device) { +ffi::Module LoadVMModule(std::string path, Optional device) { static DSOLibraryCache cache; - Module dso_mod = cache.Open(path); + ffi::Module dso_mod = cache.Open(path); Device dev = UseDefaultDeviceIfNone(device); - ffi::Function vm_load_executable = dso_mod.GetFunction("vm_load_executable"); - if (vm_load_executable == nullptr) { + Optional vm_load_executable = dso_mod->GetFunction("vm_load_executable"); + if (!vm_load_executable.has_value()) { // not built by RelaxVM, return the dso_mod directly return dso_mod; } - auto mod = vm_load_executable().cast(); - ffi::Function vm_initialization = mod.GetFunction("vm_initialization"); - CHECK(vm_initialization != nullptr) - << "ValueError: File `" << path - << "` is not built by RelaxVM, because `vm_initialization` does not exist"; - vm_initialization(static_cast(dev.device_type), static_cast(dev.device_id), - static_cast(AllocatorType::kPooled), static_cast(kDLCPU), 0, - static_cast(AllocatorType::kPooled)); + auto mod = (*vm_load_executable)().cast(); + Optional vm_initialization = mod->GetFunction("vm_initialization"); + if (!vm_initialization.has_value()) { + LOG(FATAL) << "ValueError: File `" << path + << "` is not built by RelaxVM, because `vm_initialization` does not exist"; + } + (*vm_initialization)(static_cast(dev.device_type), static_cast(dev.device_id), + static_cast(AllocatorType::kPooled), static_cast(kDLCPU), 0, + static_cast(AllocatorType::kPooled)); return mod; } diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index ec302661bd0e..97af8bc9d3de 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -117,7 +117,7 @@ class ShardLoaderObj : public Object { public: /*! \brief Create a shard loader. */ static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Module mod); + std::string shard_info, Optional mod); /*! \brief Load the i-th parameter */ NDArray Load(int weight_index) const; @@ -175,11 +175,10 @@ class ShardLoaderObj : public Object { }; ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Module mod) { - if (shard_info.empty() && mod.defined()) { - if (ffi::Function get_shard_info = mod->GetFunction("get_shard_info"); - get_shard_info != nullptr) { - shard_info = get_shard_info().cast(); + std::string shard_info, Optional mod) { + if (shard_info.empty() && mod.has_value()) { + if (auto get_shard_info = (*mod)->GetFunction("get_shard_info")) { + shard_info = (*get_shard_info)().cast(); } } ObjectPtr n = make_object(); @@ -195,9 +194,9 @@ ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std: ShardInfo& shard_info = shards[name]; for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) { const std::string& name = shard_func.name; - if (ffi::Function f = mod.defined() ? mod->GetFunction(name, true) : nullptr; - f != nullptr) { - n->shard_funcs_[name] = f; + if (Optional f = + mod.has_value() ? (*mod)->GetFunction(name, true) : std::nullopt) { + n->shard_funcs_[name] = *f; } else if (const auto f = tvm::ffi::Function::GetGlobal(name)) { n->shard_funcs_[name] = *f; } else { diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index c5e62f39ac5e..491ded5730e6 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -32,7 +32,6 @@ #include #include -#include "../library_module.h" #include "HAP_debug.h" #include "HAP_perf.h" #include "hexagon_buffer.h" @@ -93,9 +92,9 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( - "runtime.module.loadfile_hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { - ObjectPtr n = CreateDSOLibraryObject(args[0].cast()); - *rv = CreateModuleFromLibrary(n); + "ffi.Module.load_from_file.hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { + auto floader = tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); + *rv = floader(args[0].cast(), "so"); }); }); diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index a5a8de45357a..9db6a6680b06 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -24,8 +24,8 @@ #include "hexagon_module.h" #include +#include #include -#include #include #include @@ -42,12 +42,11 @@ HexagonModuleNode::HexagonModuleNode(std::string data, std::string fmt, std::string bc_str) : data_(data), fmt_(fmt), fmap_(fmap), asm_(asm_str), obj_(obj_str), ir_(ir_str), bc_(bc_str) {} -ffi::Function HexagonModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional HexagonModuleNode::GetFunction(const String& name) { LOG(FATAL) << "HexagonModuleNode::GetFunction is not implemented."; } -String HexagonModuleNode::GetSource(const String& format) { +String HexagonModuleNode::InspectSource(const String& format) const { if (format == "s" || format == "asm") { return asm_; } @@ -57,7 +56,7 @@ String HexagonModuleNode::GetSource(const String& format) { return ""; } -void HexagonModuleNode::SaveToFile(const String& file_name, const String& format) { +void HexagonModuleNode::WriteToFile(const String& file_name, const String& format) const { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -80,17 +79,22 @@ void HexagonModuleNode::SaveToFile(const String& file_name, const String& format } } -void HexagonModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes HexagonModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str) { +ffi::Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string asm_str, std::string obj_str, std::string ir_str, + std::string bc_str) { auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); - return Module(n); + return ffi::Module(n); } } // namespace runtime diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index b8a830bc7c29..ae7174236622 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -20,8 +20,8 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ +#include #include -#include #include #include @@ -44,9 +44,10 @@ namespace runtime { * \param ir_str String with the disassembled LLVM IR source. * \param bc_str String with the bitcode LLVM IR. */ -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str); +ffi::Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string asm_str, std::string obj_str, std::string ir_str, + std::string bc_str); /*! \brief Module implementation for compiled Hexagon binaries. It is suitable @@ -54,21 +55,21 @@ Module HexagonModuleCreate(std::string data, std::string fmt, See docstring for HexagonModuleCreate for construction parameter details. */ -class HexagonModuleNode : public runtime::ModuleNode { +class HexagonModuleNode : public ffi::ModuleObj { public: HexagonModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str); - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; - String GetSource(const String& format) override; - const char* type_key() const final { return "hexagon"; } + Optional GetFunction(const String& name) final; + String InspectSource(const String& format) const final; + const char* kind() const final { return "hexagon"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const override { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable | - ModulePropertyMask::kRunnable; + int GetPropertyMask() const final { + return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable | + ffi::Module::kRunnable; } - void SaveToFile(const String& file_name, const String& format) override; - void SaveToBinary(dmlc::Stream* stream) override; + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; protected: std::string data_; diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index fd5fc9ee2bc1..96c45bfdf0d1 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -36,7 +36,6 @@ extern "C" { #include #include -#include "../../../library_module.h" #include "../../../minrpc/minrpc_server.h" #include "../../hexagon/hexagon_common.h" #include "../../hexagon/hexagon_device_api.h" @@ -335,9 +334,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("tvm.hexagon.load_module", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); - tvm::ObjectPtr n = - tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n); + auto floader = + tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); + *rv = floader(soname, "so"); }) .def_packed( "tvm.hexagon.get_profile_output", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index 448cd0db9442..d511b0038f21 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -28,7 +28,6 @@ #include #include -#include "../../../library_module.h" #include "../../../minrpc/minrpc_server.h" #include "../../hexagon_common.h" #include "../../profiler/prof_utils.h" @@ -339,9 +338,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("tvm.hexagon.load_module", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); - tvm::ObjectPtr n = - tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n); + auto floader = + tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); + *rv = floader(soname, "so"); }) .def_packed( "tvm.hexagon.get_profile_output", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc deleted file mode 100644 index 24fc7518d6ad..000000000000 --- a/src/runtime/library_module.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file module_util.cc - * \brief Utilities for module. - */ -#include "library_module.h" - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -// Library module that exposes symbols from a library. -class LibraryModuleNode final : public ModuleNode { - public: - explicit LibraryModuleNode(ObjectPtr lib, FFIFunctionWrapper wrapper) - : lib_(lib), packed_func_wrapper_(wrapper) {} - - const char* type_key() const final { return "library"; } - - /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; - }; - - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { - TVMFFISafeCallType faddr; - faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); - if (faddr == nullptr) return ffi::Function(); - return packed_func_wrapper_(faddr, sptr_to_self); - } - - private: - ObjectPtr lib_; - FFIFunctionWrapper packed_func_wrapper_; -}; - -ffi::Function WrapFFIFunction(TVMFFISafeCallType faddr, const ObjectPtr& sptr_to_self) { - return ffi::Function::FromPacked([faddr, sptr_to_self](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); - TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), - args.size(), reinterpret_cast(rv))); - }); -} - -void InitContextFunctions(std::function fgetsymbol) { -#define TVM_INIT_CONTEXT_FUNC(FuncName) \ - if (auto* fp = reinterpret_cast(fgetsymbol("__" #FuncName))) { \ - *fp = FuncName; \ - } - // Initialize the functions - TVM_INIT_CONTEXT_FUNC(TVMFFIFunctionCall); - TVM_INIT_CONTEXT_FUNC(TVMFFIErrorSetRaisedFromCStr); - TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); - TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); - TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); - TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); - TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); - -#undef TVM_INIT_CONTEXT_FUNC -} - -Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { - std::string loadkey = "runtime.module.loadbinary_"; - std::string fkey = loadkey + type_key; - const auto f = tvm::ffi::Function::GetGlobal(fkey); - if (!f.has_value()) { - LOG(FATAL) << "Binary was created using {" << type_key - << "} but a loader of that name is not registered." - << "Perhaps you need to recompile with this runtime enabled."; - } - - return (*f)(static_cast(stream)).cast(); -} - -/*! - * \brief Load and append module blob to module list - * \param mblob The module blob. - * \param lib The library. - * \param root_module the output root module - * \param dso_ctx_addr the output dso module - */ -void ProcessLibraryBin(const char* mblob, ObjectPtr lib, - FFIFunctionWrapper packed_func_wrapper, runtime::Module* root_module, - runtime::ModuleNode** dso_ctx_addr = nullptr) { - ICHECK(mblob != nullptr); - uint64_t nbytes = 0; - for (size_t i = 0; i < sizeof(nbytes); ++i) { - uint64_t c = mblob[i]; - nbytes |= (c & 0xffUL) << (i * 8); - } - dmlc::MemoryFixedSizeStream fs(const_cast(mblob + sizeof(nbytes)), - static_cast(nbytes)); - dmlc::Stream* stream = &fs; - uint64_t size; - ICHECK(stream->Read(&size)); - std::vector modules; - std::vector import_tree_row_ptr; - std::vector import_tree_child_indices; - int num_dso_module = 0; - - for (uint64_t i = 0; i < size; ++i) { - std::string tkey; - ICHECK(stream->Read(&tkey)); - // "_lib" serves as a placeholder in the module import tree to indicate where - // to place the DSOModule - if (tkey == "_lib") { - auto dso_module = Module(make_object(lib, packed_func_wrapper)); - *dso_ctx_addr = dso_module.operator->(); - ++num_dso_module; - modules.emplace_back(dso_module); - ICHECK_EQ(num_dso_module, 1U) << "Multiple dso module detected, please upgrade tvm " - << " to the latest before exporting the module"; - } else if (tkey == "_import_tree") { - ICHECK(stream->Read(&import_tree_row_ptr)); - ICHECK(stream->Read(&import_tree_child_indices)); - } else { - auto m = LoadModuleFromBinary(tkey, stream); - modules.emplace_back(m); - } - } - - // if we are using old dll, we don't have import tree - // so that we can't reconstruct module relationship using import tree - if (import_tree_row_ptr.empty()) { - auto n = make_object(lib, packed_func_wrapper); - auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); - for (const auto& m : modules) { - module_import_addr->emplace_back(m); - } - *dso_ctx_addr = n.get(); - *root_module = Module(n); - } else { - for (size_t i = 0; i < modules.size(); ++i) { - for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { - auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->()); - auto child_index = import_tree_child_indices[j]; - ICHECK(child_index < modules.size()); - module_import_addr->emplace_back(modules[child_index]); - } - } - - ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; - // invariance: root module is always at location 0. - // The module order is collected via DFS - *root_module = modules[0]; - } -} - -Module CreateModuleFromLibrary(ObjectPtr lib, FFIFunctionWrapper packed_func_wrapper) { - InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); - auto n = make_object(lib, packed_func_wrapper); - // Load the imported modules - const char* library_bin = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_ffi_library_bin)); - - Module root_mod; - runtime::ModuleNode* dso_ctx_addr = nullptr; - if (library_bin != nullptr) { - ProcessLibraryBin(library_bin, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); - } else { - // Only have one single DSO Module - root_mod = Module(n); - dso_ctx_addr = root_mod.operator->(); - } - - // allow lookup of symbol from root (so all symbols are visible). - if (auto* ctx_addr = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_ffi_library_ctx))) { - *ctx_addr = dso_ctx_addr; - } - - return root_mod; -} -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h deleted file mode 100644 index 60ce95e2369b..000000000000 --- a/src/runtime/library_module.h +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file library_module.h - * \brief Module that builds from a libary of symbols. - */ -#ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ -#define TVM_RUNTIME_LIBRARY_MODULE_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! \brief Load a module with the given type key directly from the stream. - * This function wraps the registry mechanism used to store type based deserializers - * for each runtime::Module sub-class. - * - * \param type_key The type key of the serialized module. - * \param stream A pointer to the stream containing the serialized module. - * \return module The deserialized module. - */ -Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream); - -/*! - * \brief Library is the common interface - * for storing data in the form of shared libaries. - * - * \sa dso_library.cc - * \sa system_library.cc - */ -class Library : public Object { - public: - // destructor. - virtual ~Library() {} - /*! - * \brief Get the symbol address for a given name. - * \param name The name of the symbol. - * \return The symbol. - */ - virtual void* GetSymbol(const char* name) = 0; - // NOTE: we do not explicitly create an type index and type_key here for libary. - // This is because we do not need dynamic type downcasting. -}; - -/*! - * \brief Wrap a TVMFFISafeCallType to packed function. - * \param faddr The function address - * \param mptr The module pointer node. - */ -ffi::Function WrapFFIFunction(TVMFFISafeCallType faddr, const ObjectPtr& mptr); - -/*! - * \brief Utility to initialize conext function symbols during startup - * \param fgetsymbol A symbol lookup function. - */ -void InitContextFunctions(std::function fgetsymbol); - -/*! - * \brief Helper classes to get into internal of a module. - */ -class ModuleInternal { - public: - // Get mutable reference of imports. - static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } -}; - -/*! - * \brief Type alias for function to wrap a TVMFFISafeCallType. - * \param The function address imported from a module. - * \param mptr The module pointer node. - * \return Packed function that wraps the invocation of the function at faddr. - */ -using FFIFunctionWrapper = - std::function& mptr)>; - -/*! \brief Return a library object interface over dynamic shared - * libraries in Windows and Linux providing support for - * loading/unloading and symbol lookup. - * \param Full path to shared library. - * \return Returns pointer to the Library providing symbol lookup. - */ -ObjectPtr CreateDSOLibraryObject(std::string library_path); - -/*! - * \brief Create a module from a library. - * - * \param lib The library. - * \param wrapper Optional function used to wrap a TVMBackendPackedCFunc, - * by default WrapFFIFunction is used. - * \param symbol_prefix Optional symbol prefix that can be used to search alternative symbols. - * - * \return The corresponding loaded module. - * - * \note This function can create multiple linked modules - * by parsing the binary blob section of the library. - */ -Module CreateModuleFromLibrary(ObjectPtr lib, - FFIFunctionWrapper wrapper = WrapFFIFunction); -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index e2705a7a806b..213b6580b4e4 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_ #define TVM_RUNTIME_METAL_METAL_MODULE_H_ -#include +#include #include #include @@ -46,9 +46,9 @@ static constexpr const int kMetalMaxNumDevice = 32; * \param fmt The format of the source, can be "metal" or "metallib" * \param source Optional, source file, concatenaed for debug dump */ -Module MetalModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string fmt, - std::string source); +ffi::Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 33bb1705c8e4..71c46504c4d4 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -22,9 +22,9 @@ */ #include "metal_module.h" #include +#include #include #include -#include #include #include #include @@ -45,33 +45,37 @@ // Module to support thread-safe multi-GPU execution. // The runtime will contain a per-device module table // The modules will be lazily loaded -class MetalModuleNode final : public runtime::ModuleNode { +class MetalModuleNode final : public ffi::ModuleObj { public: explicit MetalModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string fmt, std::string source) : smap_(smap), fmap_(fmap), fmt_(fmt), source_(source) {} - const char* type_key() const final { return "metal"; } + const char* kind() const final { return "metal"; } /*! \brief Get the property of the runtime module. */ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { LOG(FATAL) << "Do not support save to file, use save to binary and export instead"; } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; std::string version = kMetalModuleVersion; stream->Write(version); stream->Write(smap_); stream->Write(fmap_); stream->Write(fmt_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { // return text source if available. return source_; } @@ -259,15 +263,14 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) LaunchParamConfig launch_param_config_; }; -ffi::Function MetalModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional MetalModuleNode::GetFunction(const String& name) { ffi::Function ret; AUTORELEASEPOOL { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); if (it == fmap_.end()) { - ret = ffi::Function(); - return; + return std::nullopt; } const FunctionInfo& info = it->second; MetalWrappedFunc f; @@ -279,12 +282,12 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) return ret; } -Module MetalModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string fmt, - std::string source) { +ffi::Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { ObjectPtr n; AUTORELEASEPOOL { n = make_object(smap, fmap, fmt, source); }; - return Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ @@ -303,8 +306,9 @@ Module MetalModuleCreate(std::unordered_map smap, }); }); -Module MetalModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; // version is reserved for future changes and // is discarded for now std::string ver; @@ -322,7 +326,7 @@ Module MetalModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_metal", MetalModuleLoadBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/module.cc b/src/runtime/module.cc index cf19ff147f0c..12b58da2df2a 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -21,8 +21,10 @@ * \file module.cc * \brief TVM module system */ +#include #include #include +#include #include #include @@ -33,106 +35,6 @@ namespace tvm { namespace runtime { -void ModuleNode::Import(Module other) { - // specially handle rpc - if (!std::strcmp(this->type_key(), "rpc")) { - static auto fimport_ = tvm::ffi::Function::GetGlobalRequired("rpc.ImportRemoteModule"); - fimport_(GetRef(this), other); - return; - } - // cyclic detection. - std::unordered_set visited{other.operator->()}; - std::vector stack{other.operator->()}; - while (!stack.empty()) { - const ModuleNode* n = stack.back(); - stack.pop_back(); - for (const Module& m : n->imports_) { - const ModuleNode* next = m.operator->(); - if (visited.count(next)) continue; - visited.insert(next); - stack.push_back(next); - } - } - ICHECK(!visited.count(this)) << "Cyclic dependency detected during import"; - this->imports_.emplace_back(std::move(other)); -} - -ffi::Function ModuleNode::GetFunction(const String& name, bool query_imports) { - ModuleNode* self = this; - ffi::Function pf = self->GetFunction(name, GetObjectPtr(this)); - if (pf != nullptr) return pf; - if (query_imports) { - for (Module& m : self->imports_) { - pf = m.operator->()->GetFunction(name, query_imports); - if (pf != nullptr) { - return pf; - } - } - } - return pf; -} - -Module Module::LoadFromFile(const String& file_name, const String& format) { - std::string fmt = GetFileFormat(file_name, format); - ICHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name; - if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { - fmt = "so"; - } - std::string load_f_name = "runtime.module.loadfile_" + fmt; - VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'"; - const auto f = tvm::ffi::Function::GetGlobal(load_f_name); - ICHECK(f.has_value()) << "Loader for `." << format << "` files is not registered," - << " resolved to (" << load_f_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - Module m = (*f)(file_name, format).cast(); - return m; -} - -void ModuleNode::SaveToFile(const String& file_name, const String& format) { - LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile"; -} - -void ModuleNode::SaveToBinary(dmlc::Stream* stream) { - LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary"; -} - -String ModuleNode::GetSource(const String& format) { - LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource"; -} - -const ffi::Function* ModuleNode::GetFuncFromEnv(const String& name) { - std::lock_guard lock(mutex_); - auto it = import_cache_.find(name); - if (it != import_cache_.end()) return it->second.get(); - ffi::Function pf; - for (Module& m : this->imports_) { - pf = m.GetFunction(name, true); - if (pf != nullptr) break; - } - if (pf == nullptr) { - const auto f = tvm::ffi::Function::GetGlobal(name); - ICHECK(f.has_value()) << "Cannot find function " << name - << " in the imported modules or global registry." - << " If this involves ops from a contrib library like" - << " cuDNN, ensure TVM was built with the relevant" - << " library."; - import_cache_.insert(std::make_pair(name, std::make_shared(*f))); - return import_cache_.at(name).get(); - } else { - import_cache_.insert(std::make_pair(name, std::make_shared(pf))); - return import_cache_.at(name).get(); - } -} - -String ModuleNode::GetFormat() { - LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat"; -} - -bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) { - return GetFunction(name, query_imports) != nullptr; -} - bool RuntimeEnabled(const String& target_str) { std::string target = target_str; std::string f_name; @@ -166,33 +68,26 @@ bool RuntimeEnabled(const String& target_str) { return tvm::ffi::Function::GetGlobal(f_name).has_value(); } +#define TVM_INIT_CONTEXT_FUNC(FuncName) \ + TVM_FFI_CHECK_SAFE_CALL( \ + TVMFFIEnvRegisterContextSymbol("__" #FuncName, reinterpret_cast(FuncName))) + TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("runtime.RuntimeEnabled", RuntimeEnabled) - .def("runtime.ModuleGetSource", - [](Module mod, std::string fmt) { return mod->GetSource(fmt); }) - .def("runtime.ModuleImportsSize", - [](Module mod) { return static_cast(mod->imports().size()); }) - .def("runtime.ModuleGetImport", - [](Module mod, int index) { return mod->imports().at(index); }) - .def("runtime.ModuleClearImports", [](Module mod) { mod->ClearImports(); }) - .def("runtime.ModuleGetTypeKey", [](Module mod) { return std::string(mod->type_key()); }) - .def("runtime.ModuleGetFormat", [](Module mod) { return mod->GetFormat(); }) - .def("runtime.ModuleLoadFromFile", Module::LoadFromFile) - .def("runtime.ModuleSaveToFile", - [](Module mod, String name, String fmt) { mod->SaveToFile(name, fmt); }) - .def("runtime.ModuleGetPropertyMask", [](Module mod) { return mod->GetPropertyMask(); }) - .def("runtime.ModuleImplementsFunction", - [](Module mod, String name, bool query_imports) { - return mod->ImplementsFunction(std::move(name), query_imports); - }) - .def("runtime.ModuleGetFunction", - [](Module mod, String name, bool query_imports) { - return mod->GetFunction(name, query_imports); - }) - .def("runtime.ModuleImport", [](Module mod, Module other) { mod->Import(other); }); + + // Initialize the functions + TVM_INIT_CONTEXT_FUNC(TVMFFIFunctionCall); + TVM_INIT_CONTEXT_FUNC(TVMFFIErrorSetRaisedFromCStr); + TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); + TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); + TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); + TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); + TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); + + refl::GlobalDef().def("runtime.RuntimeEnabled", RuntimeEnabled); }); +#undef TVM_INIT_CONTEXT_FUNC + } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 3fefae597f21..3e0981146afc 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -456,7 +456,7 @@ struct BufferDescriptor { // To make the call thread-safe, we create a thread-local kernel table // and lazily install new kernels into the kernel table when the kernel is called. // The kernels are recycled when the module get destructed. -class OpenCLModuleNodeBase : public ModuleNode { +class OpenCLModuleNodeBase : public ffi::ModuleObj { public: // Kernel table reference entry. struct KTRefEntry { @@ -472,14 +472,14 @@ class OpenCLModuleNodeBase : public ModuleNode { */ virtual cl::OpenCLWorkspace* GetGlobalWorkspace(); - const char* type_key() const final { return workspace_->type_key.c_str(); } + const char* kind() const final { return workspace_->type_key.c_str(); } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override; + Optional GetFunction(const String& name) override; // Initialize the programs virtual void Init() = 0; @@ -509,14 +509,14 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap, std::string source) : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; // Return true if OpenCL program for the requested function and device was created bool IsProgramCreated(const std::string& func_name, int device_id); - void SaveToFile(const String& file_name, const String& format) final; - void SaveToBinary(dmlc::Stream* stream) final; + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; void SetPreCompiledPrograms(const std::string& bytes); std::string GetPreCompiledPrograms(); - String GetSource(const String& format) final; + String InspectSource(const String& format) const final; // Initialize the programs void Init() override; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 1c61eeb59635..a8e3b6fc20b6 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -135,11 +135,11 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } -ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional OpenCLModuleNodeBase::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); - if (it == fmap_.end()) return ffi::Function(); + if (it == fmap_.end()) return std::nullopt; const FunctionInfo& info = it->second; OpenCLWrappedFunc f; std::vector arg_size(info.arg_types.size()); @@ -160,7 +160,7 @@ ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, return PackFuncVoidAddr(f, info.arg_types); } -void OpenCLModuleNode::SaveToFile(const String& file_name, const String& format) { +void OpenCLModuleNode::WriteToFile(const String& file_name, const String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -168,13 +168,17 @@ void OpenCLModuleNode::SaveToFile(const String& file_name, const String& format) SaveBinaryToFile(file_name, data_); } -void OpenCLModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes OpenCLModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } -String OpenCLModuleNode::GetSource(const String& format) { +String OpenCLModuleNode::InspectSource(const String& format) const { if (format == fmt_) return data_; if (fmt_ == "cl") { return data_; @@ -201,7 +205,7 @@ void OpenCLModuleNode::Init() { } // split into source artifacts for each kernel - parsed_kernels_ = SplitKernels(GetSource("cl")); + parsed_kernels_ = SplitKernels(InspectSource("cl")); ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited " << "source from code generation, but no kernel " << "delimiter was found."; @@ -345,8 +349,8 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { return data; } -ffi::Function OpenCLModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional OpenCLModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == "opencl.GetPreCompiledPrograms") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -357,18 +361,19 @@ ffi::Function OpenCLModuleNode::GetFunction(const String& name, this->SetPreCompiledPrograms(args[0].cast()); }); } - return OpenCLModuleNodeBase::GetFunction(name, sptr_to_self); + return OpenCLModuleNodeBase::GetFunction(name); } -Module OpenCLModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); - return Module(n); + return ffi::Module(n); } // Load module from module. -Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -378,8 +383,9 @@ Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -Module OpenCLModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string data; std::unordered_map fmap; std::string fmt; @@ -392,9 +398,9 @@ Module OpenCLModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadfile_cl", OpenCLModuleLoadFile) - .def("runtime.module.loadfile_clbin", OpenCLModuleLoadFile) - .def("runtime.module.loadbinary_opencl", OpenCLModuleLoadBinary); + .def("ffi.Module.load_from_file.cl", OpenCLModuleLoadFile) + .def("ffi.Module.load_from_file.clbin", OpenCLModuleLoadFile) + .def("ffi.Module.load_from_bytes.opencl", OpenCLModuleLoadFromBytes); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 198adc6cb216..18afad56a0c8 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -44,8 +44,9 @@ namespace runtime { * \param fmap The map function information map of each function. * \param source Generated OpenCL kernels. */ -Module OpenCLModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source); +ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string source); /*! * \brief Create a opencl module from SPIRV. @@ -54,9 +55,9 @@ Module OpenCLModuleCreate(std::string data, std::string fmt, * \param spirv_text The concatenated text representation of SPIRV modules. * \param fmap The map function information map of each function. */ -Module OpenCLModuleCreate(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap); +ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 7d281694decb..5b90e0b566c7 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -39,9 +39,9 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap) : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} - void SaveToFile(const String& file_name, const String& format) final; - void SaveToBinary(dmlc::Stream* stream) final; - String GetSource(const String&) final { return spirv_text_; } + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; + String InspectSource(const String& format) const final { return spirv_text_; } void Init() override; cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, @@ -52,14 +52,18 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::string spirv_text_; }; -void OpenCLSPIRVModuleNode::SaveToFile(const String& file_name, const String& format) { +void OpenCLSPIRVModuleNode::WriteToFile(const String& file_name, const String& format) const { // TODO(masahi): How SPIRV binaries should be save to a file? LOG(FATAL) << "Not implemented."; } -void OpenCLSPIRVModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes OpenCLSPIRVModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmap_); stream->Write(shaders_); + return ffi::Bytes(buffer); } void OpenCLSPIRVModuleNode::Init() { @@ -125,12 +129,12 @@ cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC return kernel; } -Module OpenCLModuleCreate(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap) { +ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { auto n = make_object(shaders, spirv_text, fmap); n->Init(); - return Module(n); + return ffi::Module(n); } } // namespace runtime diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 4cce0d40d168..9d4c01d62366 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -792,59 +792,60 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("runtime.profiling.DeviceWrapper", [](Device dev) { return DeviceWrapper(dev); }); }); -ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, - int warmup_iters, Array collectors) { +ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, + int device_id, int warmup_iters, Array collectors) { // Module::GetFunction is not const, so this lambda has to be mutable - return ffi::Function::FromPacked( - [=](const ffi::AnyView* args, int32_t num_args, ffi::Any* ret) mutable { - ffi::Function f = mod.GetFunction(func_name); - CHECK(f.defined()) << "There is no function called \"" << func_name << "\" in the module"; - Device dev{static_cast(device_type), device_id}; - - // warmup - for (int i = 0; i < warmup_iters; i++) { - f.CallPacked(args, num_args, ret); - } - - for (auto& collector : collectors) { - collector->Init({DeviceWrapper(dev)}); - } - std::vector> results; - results.reserve(collectors.size()); - std::vector> collector_data; - collector_data.reserve(collectors.size()); - for (auto& collector : collectors) { - ObjectRef o = collector->Start(dev); - // If not defined, then the collector cannot time this device. - if (o.defined()) { - collector_data.push_back({collector, o}); - } - } + return ffi::Function::FromPacked([=](const ffi::AnyView* args, int32_t num_args, + ffi::Any* ret) mutable { + auto optf = mod->GetFunction(func_name); + CHECK(optf.has_value()) << "There is no function called \"" << func_name << "\" in the module"; + auto f = *optf; + Device dev{static_cast(device_type), device_id}; + + // warmup + for (int i = 0; i < warmup_iters; i++) { + f.CallPacked(args, num_args, ret); + } + + for (auto& collector : collectors) { + collector->Init({DeviceWrapper(dev)}); + } + std::vector> results; + results.reserve(collectors.size()); + std::vector> collector_data; + collector_data.reserve(collectors.size()); + for (auto& collector : collectors) { + ObjectRef o = collector->Start(dev); + // If not defined, then the collector cannot time this device. + if (o.defined()) { + collector_data.push_back({collector, o}); + } + } - // TODO(tkonolige): repeated calls if the runtime is small? - f.CallPacked(args, num_args, ret); + // TODO(tkonolige): repeated calls if the runtime is small? + f.CallPacked(args, num_args, ret); - for (auto& kv : collector_data) { - results.push_back(kv.first->Stop(kv.second)); - } - Map combined_results; - for (auto m : results) { - for (auto p : m) { - // assume that there is no shared metric name between collectors - combined_results.Set(p.first, p.second); - } - } - *ret = combined_results; - }); + for (auto& kv : collector_data) { + results.push_back(kv.first->Stop(kv.second)); + } + Map combined_results; + for (auto m : results) { + for (auto p : m) { + // assume that there is no shared metric name between collectors + combined_results.Set(p.first, p.second); + } + } + *ret = combined_results; + }); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "runtime.profiling.ProfileFunction", - [](Module mod, String func_name, int device_type, int device_id, int warmup_iters, + [](ffi::Module mod, String func_name, int device_type, int device_id, int warmup_iters, Array collectors) { - if (mod->type_key() == std::string("rpc")) { + if (mod->kind() == std::string("rpc")) { LOG(FATAL) << "Profiling a module over RPC is not yet supported"; // because we can't send // MetricCollectors over rpc. diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index a871a41f0f86..13b14e13e0e7 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -22,6 +22,7 @@ */ #include "rocm_module.h" +#include #include #include #include @@ -45,7 +46,7 @@ namespace runtime { // hipModule_t is a per-GPU module // The runtime will contain a per-device module table // The modules will be lazily loaded -class ROCMModuleNode : public runtime::ModuleNode { +class ROCMModuleNode : public ffi::ModuleObj { public: explicit ROCMModuleNode(std::string data, std::string fmt, std::unordered_map fmap, @@ -63,13 +64,13 @@ class ROCMModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { return "hip"; } + const char* kind() const final { return "hip"; } int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them @@ -78,13 +79,17 @@ class ROCMModuleNode : public runtime::ModuleNode { SaveBinaryToFile(file_name, data_); } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (format == fmt_) { return data_; } @@ -192,25 +197,25 @@ class ROCMWrappedFunc { LaunchParamConfig launch_param_config_; }; -ffi::Function ROCMModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional ROCMModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); - if (it == fmap_.end()) return ffi::Function(); + if (it == fmap_.end()) return std::nullopt; const FunctionInfo& info = it->second; ROCMWrappedFunc f; f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncPackedArgAligned(f, info.arg_types); } -Module ROCMModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string hip_source, - std::string assembly) { +ffi::Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string hip_source, std::string assembly) { auto n = make_object(data, fmt, fmap, hip_source, assembly); - return Module(n); + return ffi::Module(n); } -Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) { +ffi::Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -220,8 +225,9 @@ Module ROCMModuleLoadFile(const std::string& file_name, const std::string& forma return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } -Module ROCMModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string data; std::unordered_map fmap; std::string fmt; @@ -234,10 +240,10 @@ Module ROCMModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadbinary_hsaco", ROCMModuleLoadBinary) - .def("runtime.module.loadbinary_hip", ROCMModuleLoadBinary) - .def("runtime.module.loadfile_hsaco", ROCMModuleLoadFile) - .def("runtime.module.loadfile_hip", ROCMModuleLoadFile); + .def("ffi.Module.load_from_bytes.hsaco", ROCMModuleLoadFromBytes) + .def("ffi.Module.load_from_bytes.hip", ROCMModuleLoadFromBytes) + .def("ffi.Module.load_from_file.hsaco", ROCMModuleLoadFile) + .def("ffi.Module.load_from_file.hip", ROCMModuleLoadFile); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.h b/src/runtime/rocm/rocm_module.h index c17e123c1a12..ee6f29f43edb 100644 --- a/src/runtime/rocm/rocm_module.h +++ b/src/runtime/rocm/rocm_module.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_ROCM_ROCM_MODULE_H_ #define TVM_RUNTIME_ROCM_ROCM_MODULE_H_ -#include +#include #include #include @@ -47,9 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param rocm_source Optional, rocm source file */ -Module ROCMModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string rocm_source, - std::string assembly); +ffi::Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string rocm_source, std::string assembly); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_ diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 3dea9dc82239..e1282c17878a 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -593,13 +593,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { << " Error caught from session constructor " << constructor_name << ":\n" << e.what(); } - auto opt_con_ret = con_ret.as(); + auto opt_con_ret = con_ret.as(); // Legacy ABI translation ICHECK(opt_con_ret.has_value()) << "Server[" << name_ << "]:" << " Constructor " << constructor_name << " need to return an RPCModule"; - Module mod = opt_con_ret.value(); - std::string tkey = mod->type_key(); + ffi::Module mod = opt_con_ret.value(); + std::string tkey = mod->kind(); ICHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule"; serving_session_ = RPCModuleGetSession(mod); this->ReturnVoid(); diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 3094c9ca13a2..bcf661960f06 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -170,7 +170,7 @@ class RPCWrappedFunc : public Object { }; // RPC that represents a remote module session. -class RPCModuleNode final : public ModuleNode { +class RPCModuleNode final : public ffi::ModuleObj { public: RPCModuleNode(void* module_handle, std::shared_ptr sess) : module_handle_(module_handle), sess_(sess) {} @@ -186,11 +186,11 @@ class RPCModuleNode final : public ModuleNode { } } - const char* type_key() const final { return "rpc"; } + const char* kind() const final { return "rpc"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return ffi::Module::ModulePropertyMask::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { if (name == "CloseRPCConnection") { return ffi::Function([this](ffi::PackedArgs, ffi::Any*) { sess_->Shutdown(); }); } @@ -199,15 +199,10 @@ class RPCModuleNode final : public ModuleNode { return WrapRemoteFunc(sess_->GetFunction(name)); } else { InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); - return remote_mod_get_function_(GetRef(this), name, true); + return remote_mod_get_function_(GetRef(this), name, true); } } - String GetSource(const String& format) final { - LOG(FATAL) << "GetSource for rpc Module is not supported"; - throw; - } - ffi::Function GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, @@ -220,25 +215,25 @@ class RPCModuleNode final : public ModuleNode { if (module_handle_ != nullptr) { return remote_get_time_evaluator_( - GetRef(this), name, static_cast(dev.device_type), dev.device_id, number, + GetRef(this), name, static_cast(dev.device_type), dev.device_id, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } else { return remote_get_time_evaluator_( - Optional(std::nullopt), name, static_cast(dev.device_type), dev.device_id, - number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, - repeats_to_cooldown, cache_flush_bytes, f_preproc_name); + Optional(std::nullopt), name, static_cast(dev.device_type), + dev.device_id, number, repeat, min_repeat_ms, limit_zero_time_iterations, + cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } } - Module LoadModule(std::string name) { + ffi::Module LoadModule(std::string name) { InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module"); return remote_load_module_(name); } - void ImportModule(Module other) { + void ImportModule(ffi::Module other) { InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); - remote_import_module_(GetRef(this), other); + remote_import_module_(GetRef(this), other); } const std::shared_ptr& sess() { return sess_; } @@ -266,22 +261,22 @@ class RPCModuleNode final : public ModuleNode { // The local channel std::shared_ptr sess_; // remote function to get time evaluator - ffi::TypedFunction, std::string, int, int, int, int, int, int, int, - int, int, std::string)> + ffi::TypedFunction, std::string, int, int, int, int, int, int, + int, int, int, std::string)> remote_get_time_evaluator_; // remote function getter for modules. - ffi::TypedFunction remote_mod_get_function_; + ffi::TypedFunction remote_mod_get_function_; // remote function getter for load module - ffi::TypedFunction remote_load_module_; + ffi::TypedFunction remote_load_module_; // remote function getter for load module - ffi::TypedFunction remote_import_module_; + ffi::TypedFunction remote_import_module_; }; void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const AnyView& arg) const { // TODO(tqchen): only support Module unwrapping for now. if (arg.type_index() == ffi::TypeIndex::kTVMFFIModule) { - Module mod = arg.cast(); - std::string tkey = mod->type_key(); + ffi::Module mod = arg.cast(); + std::string tkey = mod->kind(); ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); ICHECK(rmod->sess() == sess_) @@ -309,7 +304,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto n = make_object(handle, sess_); - *rv = Module(n); + *rv = ffi::Module(n); } else if (type_index == ffi::TypeIndex::kTVMFFINDArray || type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr) { ICHECK_EQ(args.size(), 3); @@ -335,14 +330,14 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } } -Module CreateRPCSessionModule(std::shared_ptr sess) { +ffi::Module CreateRPCSessionModule(std::shared_ptr sess) { auto n = make_object(nullptr, sess); RPCSession::InsertToSessionTable(sess); - return Module(n); + return ffi::Module(n); } -std::shared_ptr RPCModuleGetSession(Module mod) { - std::string tkey = mod->type_key(); +std::shared_ptr RPCModuleGetSession(ffi::Module mod) { + std::string tkey = mod->kind(); ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); return rmod->sess(); @@ -402,7 +397,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.RPCTimeEvaluator", - [](Optional opt_mod, std::string name, int device_type, int device_id, + [](Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, std::string f_preproc_name) { @@ -410,8 +405,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ dev.device_type = static_cast(device_type); dev.device_id = device_id; if (opt_mod.defined()) { - Module m = opt_mod.value(); - std::string tkey = m->type_key(); + ffi::Module m = opt_mod.value(); + std::string tkey = m->kind(); if (tkey == "rpc") { return static_cast(m.operator->()) ->GetTimeEvaluator(name, dev, number, repeat, min_repeat_ms, @@ -425,10 +420,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - ffi::Function pf = m.GetFunction(name, true); - CHECK(pf != nullptr) << "Cannot find " << name << "` in the global registry"; + Optional pf = m->GetFunction(name); + CHECK(pf.has_value()) << "Cannot find " << name << "` in the global registry"; return profiling::WrapTimeEvaluator( - pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, + *pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc); } } else { @@ -455,9 +450,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tvm.rpc.server.ImportModule", - [](Module parent, Module child) { parent->Import(child); }) + [](ffi::Module parent, ffi::Module child) { parent->ImportModule(child); }) .def("tvm.rpc.server.ModuleGetFunction", - [](Module parent, std::string name, bool query_imports) { + [](ffi::Module parent, std::string name, bool query_imports) { return parent->GetFunction(name, query_imports); }); }); @@ -467,26 +462,26 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("rpc.LoadRemoteModule", - [](Module sess, std::string name) { - std::string tkey = sess->type_key(); + [](ffi::Module sess, std::string name) { + std::string tkey = sess->kind(); ICHECK_EQ(tkey, "rpc"); return static_cast(sess.operator->())->LoadModule(name); }) .def("rpc.ImportRemoteModule", - [](Module parent, Module child) { - std::string tkey = parent->type_key(); + [](ffi::Module parent, ffi::Module child) { + std::string tkey = parent->kind(); ICHECK_EQ(tkey, "rpc"); static_cast(parent.operator->())->ImportModule(child); }) .def_packed("rpc.SessTableIndex", [](ffi::PackedArgs args, ffi::Any* rv) { - Module m = args[0].cast(); - std::string tkey = m->type_key(); + ffi::Module m = args[0].cast(); + std::string tkey = m->kind(); ICHECK_EQ(tkey, "rpc"); *rv = static_cast(m.operator->())->sess()->table_index(); }) .def("tvm.rpc.NDArrayFromRemoteOpaqueHandle", - [](Module mod, void* remote_array, DLTensor* template_tensor, Device dev, + [](ffi::Module mod, void* remote_array, DLTensor* template_tensor, Device dev, void* ndarray_handle) -> NDArray { return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor, dev, ndarray_handle); diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index e50c6a456eaf..22619289d053 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -76,7 +76,7 @@ class PipeChannel final : public RPCChannel { pid_t child_pid_; }; -Module CreatePipeClient(std::vector cmd) { +ffi::Module CreatePipeClient(std::vector cmd) { int parent2child[2]; int child2parent[2]; ICHECK_EQ(pipe(parent2child), 0); diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index c0ec2067eb5f..c0e09ec004ba 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -267,7 +267,7 @@ class RPCSession { /*! \brief Insert the current session to the session table.*/ static void InsertToSessionTable(std::shared_ptr sess); // friend declaration - friend Module CreateRPCSessionModule(std::shared_ptr sess); + friend ffi::Module CreateRPCSessionModule(std::shared_ptr sess); }; /*! @@ -341,14 +341,14 @@ class RPCObjectRef : public ObjectRef { * \param sess The RPC session of the global module. * \return The created module. */ -Module CreateRPCSessionModule(std::shared_ptr sess); +ffi::Module CreateRPCSessionModule(std::shared_ptr sess); /*! * \brief Get the session module from a RPC session Module. * \param mod The input module(must be an RPCModule). * \return The internal RPCSession. */ -std::shared_ptr RPCModuleGetSession(Module mod); +std::shared_ptr RPCModuleGetSession(ffi::Module mod); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 5ed34051cf55..d2f141ee21e0 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -105,8 +105,8 @@ std::shared_ptr RPCConnect(std::string url, int port, std::string k return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging, - ffi::PackedArgs init_seq) { +ffi::Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging, + ffi::PackedArgs init_seq) { auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq); return CreateRPCSessionModule(CreateClientSession(endpt)); } diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index 5ad331d27d0a..b816fb600e1e 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -24,6 +24,7 @@ */ #include "./static_library.h" +#include #include #include #include @@ -42,30 +43,34 @@ namespace { * \brief A '.o' library which can be linked into the final output library by export_library. * Can be used by external codegen tools which can produce a ready-to-link artifact. */ -class StaticLibraryNode final : public runtime::ModuleNode { +class StaticLibraryNode final : public ffi::ModuleObj { public: - ~StaticLibraryNode() override = default; + const char* kind() const final { return "static_library"; } - const char* type_key() const final { return "static_library"; } - - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { + const ObjectPtr& sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_func_names") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = func_names_; }); } else { - return {}; + return std::nullopt; } } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(data_); std::vector func_names; for (const auto func_name : func_names_) func_names.push_back(func_name); stream->Write(func_names); + return Bytes(buffer); } - static Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(ffi::Bytes bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; auto n = make_object(); // load data std::string data; @@ -77,10 +82,10 @@ class StaticLibraryNode final : public runtime::ModuleNode { ICHECK(stream->Read(&func_names)) << "Loading func names failed"; for (auto func_name : func_names) n->func_names_.push_back(String(func_name)); - return Module(n); + return ffi::Module(n); } - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames() << " to '" << file_name << "'"; SaveBinaryToFile(file_name, data_); @@ -88,14 +93,14 @@ class StaticLibraryNode final : public runtime::ModuleNode { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const override { - return runtime::ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable; + return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name, bool query_imports) final { + bool ImplementsFunction(const String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } - std::string FuncNames() { + std::string FuncNames() const { std::ostringstream os; os << "["; bool first = true; @@ -119,19 +124,19 @@ class StaticLibraryNode final : public runtime::ModuleNode { } // namespace -Module LoadStaticLibrary(const std::string& filename, Array func_names) { +ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names) { auto node = make_object(); LoadBinaryFromFile(filename, &node->data_); node->func_names_ = std::move(func_names); VLOG(0) << "Loaded static library from '" << filename << "' implementing " << node->FuncNames(); - return Module(node); + return ffi::Module(node); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.ModuleLoadStaticLibrary", LoadStaticLibrary) - .def("runtime.module.loadbinary_static_library", StaticLibraryNode::LoadFromBinary); + .def("ffi.Module.load_from_bytes.static_library", StaticLibraryNode::LoadFromBytes); }); } // namespace runtime diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h index 196d2448b93f..8a5600fc0588 100644 --- a/src/runtime/static_library.h +++ b/src/runtime/static_library.h @@ -43,7 +43,7 @@ namespace runtime { * \brief Returns a static library with the contents loaded from filename which exports * func_names with the usual packed-func calling convention. */ -Module LoadStaticLibrary(const std::string& filename, Array func_names); +ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names); } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 1bec6f5016eb..ef6fbe6373af 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -161,7 +161,7 @@ void LoadHeader(dmlc::Stream* strm) { STREAM_CHECK(version == VM_VERSION, "version"); } -void VMExecutable::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes VMExecutable::SaveToBytes() const { std::string code; // Initialize the stream object. dmlc::MemoryStringStream strm(&code); @@ -178,21 +178,16 @@ void VMExecutable::SaveToBinary(dmlc::Stream* stream) { // Code section. SaveCodeSection(&strm); - stream->Write(code); + return ffi::Bytes(code); } -void VMExecutable::SaveToFile(const String& file_name, const String& format) { - std::string data; - dmlc::MemoryStringStream writer(&data); - dmlc::SeekStream* strm = &writer; - VMExecutable::SaveToBinary(strm); - runtime::SaveBinaryToFile(file_name, data); +void VMExecutable::WriteToFile(const String& file_name, const String& format) const { + runtime::SaveBinaryToFile(file_name, VMExecutable::SaveToBytes()); } -Module VMExecutable::LoadFromBinary(void* stream) { +ffi::Module VMExecutable::LoadFromBytes(const ffi::Bytes& bytes) { std::string code; - static_cast(stream)->Read(&code); - dmlc::MemoryStringStream strm(&code); + dmlc::MemoryFixedSizeStream strm(const_cast(bytes.data()), bytes.size()); ObjectPtr exec = make_object(); @@ -208,26 +203,20 @@ Module VMExecutable::LoadFromBinary(void* stream) { // Code section. exec->LoadCodeSection(&strm); - return Module(exec); + return ffi::Module(exec); } -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_relax.VMExecutable", - VMExecutable::LoadFromBinary); -}); - -Module VMExecutable::LoadFromFile(const String& file_name) { +ffi::Module VMExecutable::LoadFromFile(const String& file_name) { std::string data; runtime::LoadBinaryFromFile(file_name, &data); - dmlc::MemoryStringStream reader(&data); - dmlc::Stream* strm = &reader; - return VMExecutable::LoadFromBinary(reinterpret_cast(strm)); + return VMExecutable::LoadFromBytes(ffi::Bytes(data)); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadfile_relax.VMExecutable", VMExecutable::LoadFromFile); + refl::GlobalDef() + .def("ffi.Module.load_from_file.relax.VMExecutable", VMExecutable::LoadFromFile) + .def("ffi.Module.load_from_bytes.relax.VMExecutable", VMExecutable::LoadFromBytes); }); void VMFuncInfo::Save(dmlc::Stream* strm) const { @@ -254,9 +243,9 @@ bool VMFuncInfo::Load(dmlc::Stream* strm) { return true; } -void VMExecutable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(func_table); } +void VMExecutable::SaveGlobalSection(dmlc::Stream* strm) const { strm->Write(func_table); } -void VMExecutable::SaveConstantSection(dmlc::Stream* strm) { +void VMExecutable::SaveConstantSection(dmlc::Stream* strm) const { strm->Write(static_cast(this->constants.size())); for (const auto& it : this->constants) { if (auto opt_nd = it.as()) { @@ -291,7 +280,7 @@ void VMExecutable::SaveConstantSection(dmlc::Stream* strm) { } } -void VMExecutable::SaveCodeSection(dmlc::Stream* strm) { +void VMExecutable::SaveCodeSection(dmlc::Stream* strm) const { strm->Write(instr_offset); strm->Write(instr_data); } @@ -394,16 +383,16 @@ std::string RegNameToStr(RegName reg) { return "%" + std::to_string(reg); } -Module VMExecutable::VMLoadExecutable() const { +ffi::Module VMExecutable::VMLoadExecutable() const { ObjectPtr vm = VirtualMachine::Create(); vm->LoadExecutable(GetObjectPtr(const_cast(this))); - return Module(vm); + return ffi::Module(vm); } -Module VMExecutable::VMProfilerLoadExecutable() const { +ffi::Module VMExecutable::VMProfilerLoadExecutable() const { ObjectPtr vm = VirtualMachine::CreateProfiler(); vm->LoadExecutable(GetObjectPtr(const_cast(this))); - return Module(vm); + return ffi::Module(vm); } bool VMExecutable::HasFunction(const String& name) const { return func_map.count(name); } diff --git a/src/runtime/vm/ndarray_cache_support.cc b/src/runtime/vm/ndarray_cache_support.cc index d91669016d78..cfd979cc6f24 100644 --- a/src/runtime/vm/ndarray_cache_support.cc +++ b/src/runtime/vm/ndarray_cache_support.cc @@ -302,11 +302,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ // This param module node can be useful to get param dict in RPC mode // when the remote already have loaded parameters from file. -class ParamModuleNode : public runtime::ModuleNode { +class ParamModuleNode : public ffi::ModuleObj { public: - const char* type_key() const final { return "param_module"; } + const char* kind() const final { return "param_module"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { if (name == "get_params") { auto params = params_; return ffi::Function([params](ffi::PackedArgs args, ffi::Any* rv) { *rv = params; }); @@ -343,16 +343,16 @@ class ParamModuleNode : public runtime::ModuleNode { return result; } - static Module Create(const std::string& prefix, int num_params) { + static ffi::Module Create(const std::string& prefix, int num_params) { auto n = make_object(); n->params_ = GetParams(prefix, num_params); - return Module(n); + return ffi::Module(n); } - static Module CreateByName(const Array& names) { + static ffi::Module CreateByName(const Array& names) { auto n = make_object(); n->params_ = GetParamByName(names); - return Module(n); + return ffi::Module(n); } private: diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c28e30084fc1..c4fdedd815a9 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -300,12 +300,13 @@ class VirtualMachineImpl : public VirtualMachine { * \param name The name of the function. * \return The result function, can return ffi::Function(nullptr) if nothing is found. */ - ffi::Function GetFuncFromImports(const String& name) { + Optional GetFuncFromImports(const String& name) { for (auto& lib : this->imports_) { - ffi::Function func = lib->GetFunction(name, true); - if (func.defined()) return func; + if (auto opt_func = lib.cast()->GetFunction(name, true)) { + return *opt_func; + } } - return ffi::Function(nullptr); + return std::nullopt; } /*! * \brief Initialize function pool. @@ -452,7 +453,7 @@ class VirtualMachineImpl : public VirtualMachine { void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { this->exec_ = exec; - this->imports_ = exec_->imports(); + this->imports_ = exec->imports(); } void VirtualMachineImpl::Init(const std::vector& devices, @@ -508,7 +509,7 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, for (int i = 0; i < args.size(); ++i) { if (with_param_module && i == args.size() - 1) { // call param func to get the arguments(usually corresponds to param pack.) - func_args[i] = (args[i].cast()).GetFunction("get_params")(); + func_args[i] = (args[i].cast())->GetFunction("get_params").value()(); } else { func_args[i] = ConvertArgToDevice(args[i], devices[0], allocators[0]); } @@ -620,9 +621,9 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na } else { ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) << "Cannot support closure with function kind " << static_cast(finfo.kind); - ffi::Function tir_func = GetFuncFromImports("__vmtir__" + finfo.name); - ICHECK(tir_func != nullptr) << "Cannot find underlying compiled tir function of VMTIRFunc " - << finfo.name; + Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); + ICHECK(tir_func.has_value()) << "Cannot find underlying compiled tir function of VMTIRFunc " + << finfo.name; auto impl = ffi::Function([this, finfo, tir_func](ffi::PackedArgs args, ffi::Any* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].cast()); @@ -637,8 +638,8 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na void* reg_anylist_handle = reg_file.data(); void* const_anylist_handle = this->const_pool_.data(); void* func_anylist_handle = this->func_pool_.data(); - tir_func(static_cast(ctx_ptr), reg_anylist_handle, const_anylist_handle, - func_anylist_handle); + (*tir_func)(static_cast(ctx_ptr), reg_anylist_handle, const_anylist_handle, + func_anylist_handle); // Return value always stored after inputs. *rv = reg_file[finfo.num_args]; }); @@ -696,16 +697,16 @@ void VirtualMachineImpl::InitFuncPool() { const VMFuncInfo& info = exec_->func_table[func_index]; if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { // only look through imports first - ffi::Function func = GetFuncFromImports(info.name); - if (!func.defined()) { + Optional func = GetFuncFromImports(info.name); + if (!func.has_value()) { const auto p_func = tvm::ffi::Function::GetGlobal(info.name); - if (p_func.has_value()) func = *(p_func); + if (p_func.has_value()) func = *p_func; } - ICHECK(func.defined()) + ICHECK(func.has_value()) << "Error: Cannot find ffi::Function " << info.name << " in either Relax VM kernel library, or in TVM runtime ffi::Function registry, or in " "global Relax functions of the VM executable"; - func_pool_[func_index] = func; + func_pool_[func_index] = *func; } else { ICHECK(info.kind == VMFuncInfo::FuncKind::kVMFunc || @@ -951,8 +952,8 @@ std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, int ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { if (Optional opt = this->GetClosureInternal(name, true)) { - return ffi::Function([clo = opt.value(), _self = GetRef(this)](ffi::PackedArgs args, - ffi::Any* rv) -> void { + return ffi::Function([clo = opt.value(), _self = GetRef(this)]( + ffi::PackedArgs args, ffi::Any* rv) -> void { auto* self = const_cast(_self.as()); ICHECK(self); self->InvokeClosurePacked(clo, args, rv); @@ -972,7 +973,8 @@ ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { */ class VirtualMachineProfiler : public VirtualMachineImpl { public: - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) override { + Optional GetFunction(const String& name) override { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "profile") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { std::string f_name = args[0].cast(); @@ -1017,7 +1019,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } }); } else { - return VirtualMachineImpl::GetFunction(name, sptr_to_self); + return VirtualMachineImpl::GetFunction(name); } } diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 81136c28cd3c..a5fb6c2293fa 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -30,13 +30,14 @@ namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string source) { +ffi::Module VulkanModuleCreate(std::unordered_map smap, + std::unordered_map fmap, + std::string source) { auto n = make_object(smap, fmap, source); - return Module(n); + return ffi::Module(n); } -Module VulkanModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module VulkanModuleLoadFile(const std::string& file_name, const String& format) { std::string data; std::unordered_map smap; std::unordered_map fmap; @@ -53,8 +54,9 @@ Module VulkanModuleLoadFile(const std::string& file_name, const String& format) return VulkanModuleCreate(smap, fmap, ""); } -Module VulkanModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::unordered_map smap; std::unordered_map fmap; @@ -68,8 +70,8 @@ Module VulkanModuleLoadBinary(void* strm) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadfile_vulkan", VulkanModuleLoadFile) - .def("runtime.module.loadbinary_vulkan", VulkanModuleLoadBinary); + .def("ffi.Module.load_from_file.vulkan", VulkanModuleLoadFile) + .def("ffi.Module.load_from_bytes.vulkan", VulkanModuleLoadFromBytes); }); } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_module.h b/src/runtime/vulkan/vulkan_module.h index 878e096f5ac1..ea853721bfa2 100644 --- a/src/runtime/vulkan/vulkan_module.h +++ b/src/runtime/vulkan/vulkan_module.h @@ -20,6 +20,8 @@ #ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_ #define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_ +#include + #include #include @@ -29,8 +31,9 @@ namespace tvm { namespace runtime { namespace vulkan { -Module VulkanModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string source); +ffi::Module VulkanModuleCreate(std::unordered_map smap, + std::unordered_map fmap, + std::string source); } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index db81c959dccd..2f50a0154658 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -205,11 +205,11 @@ VulkanModuleNode::~VulkanModuleNode() { } } -ffi::Function VulkanModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional VulkanModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); - if (it == fmap_.end()) return ffi::Function(); + if (it == fmap_.end()) return std::nullopt; const FunctionInfo& info = it->second; VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); @@ -403,7 +403,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, return pe; } -void VulkanModuleNode::SaveToFile(const String& file_name, const String& format) { +void VulkanModuleNode::WriteToFile(const String& file_name, const String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; std::string meta_file = GetMetaFilePath(file_name); @@ -417,13 +417,17 @@ void VulkanModuleNode::SaveToFile(const String& file_name, const String& format) SaveBinaryToFile(file_name, data_bin); } -void VulkanModuleNode::SaveToBinary(dmlc::Stream* stream) { +ffi::Bytes VulkanModuleNode::SaveToBytes() const { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(smap_); + return ffi::Bytes(buffer); } -String VulkanModuleNode::GetSource(const String& format) { +String VulkanModuleNode::InspectSource(const String& format) const { // can only return disassembly code. return source_; } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index 9b6f3703f34f..2ff90568de9d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -80,29 +80,29 @@ class VulkanWrappedFunc { mutable std::array, kVulkanMaxNumDevice> scache_; }; -class VulkanModuleNode final : public runtime::ModuleNode { +class VulkanModuleNode final : public ffi::ModuleObj { public: explicit VulkanModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) {} ~VulkanModuleNode(); - const char* type_key() const final { return "vulkan"; } + const char* kind() const final { return "vulkan"; } /*! \brief Get the property of the runtime module. */ int GetPropertyMask() const final { - return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args); - void SaveToFile(const String& file_name, const String& format) final; + void WriteToFile(const String& file_name, const String& format) const final; - void SaveToBinary(dmlc::Stream* stream) final; - String GetSource(const String& format) final; + ffi::Bytes SaveToBytes() const final; + String InspectSource(const String& format) const final; private: // function information table. diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index e9d35c4496e7..70c23c546bbb 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -123,13 +123,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("testing.ErrorTest", ErrorTest); }); -class FrontendTestModuleNode : public runtime::ModuleNode { +class FrontendTestModuleNode : public ffi::ModuleObj { public: - const char* type_key() const final { return "frontend_test"; } + const char* kind() const final { return "frontend_test"; } static constexpr const char* kAddFunctionName = "__add_function"; - virtual ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self); + virtual ffi::Optional GetFunction(const String& name); private: std::unordered_map functions_; @@ -137,11 +137,11 @@ class FrontendTestModuleNode : public runtime::ModuleNode { constexpr const char* FrontendTestModuleNode::kAddFunctionName; -ffi::Function FrontendTestModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +ffi::Optional FrontendTestModuleNode::GetFunction(const String& name) { + ffi::Module self_strong_ref = GetRef(this); if (name == kAddFunctionName) { - return ffi::TypedFunction( - [this, sptr_to_self](std::string func_name, ffi::Function pf) { + return ffi::Function::FromTyped( + [this, self_strong_ref](std::string func_name, ffi::Function pf) { CHECK_NE(func_name, kAddFunctionName) << "func_name: cannot be special function " << kAddFunctionName; functions_[func_name] = pf; @@ -150,15 +150,15 @@ ffi::Function FrontendTestModuleNode::GetFunction(const String& name, auto it = functions_.find(name); if (it == functions_.end()) { - return ffi::Function(); + return std::nullopt; } return it->second; } -runtime::Module NewFrontendTestModule() { +ffi::Module NewFrontendTestModule() { auto n = make_object(); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 9b650c9aaa43..96075450183c 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -38,12 +38,10 @@ #include #include -#include "../runtime/library_module.h" - namespace tvm { namespace codegen { -runtime::Module Build(IRModule mod, Target target) { +ffi::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) .value()) { @@ -54,66 +52,42 @@ runtime::Module Build(IRModule mod, Target target) { std::string build_f_name = "target.build." + target->kind->name; const auto bf = tvm::ffi::Function::GetGlobal(build_f_name); ICHECK(bf.has_value()) << build_f_name << " is not enabled"; - return (*bf)(mod, target).cast(); + return (*bf)(mod, target).cast(); } /*! \brief Helper class to serialize module */ class ModuleSerializer { public: - explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); } + explicit ModuleSerializer(ffi::Module mod) : mod_(mod) { Init(); } void SerializeModuleToBytes(dmlc::Stream* stream, bool export_dso) { - // Only have one DSO module and it is in the root, then - // we will not produce import_tree_. - bool has_import_tree = true; - - if (export_dso) { - has_import_tree = !mod_->imports().empty(); - } - - uint64_t sz = 0; - if (has_import_tree) { - // we will append one key for _import_tree - // The layout is the same as before: binary_size, key, logic, key, logic... - sz = mod_group_vec_.size() + 1; - } else { - // Keep the old behaviour - sz = mod_->imports().size(); - } - stream->Write(sz); - + // Always _import_tree + stream->Write(import_tree_row_ptr_); + stream->Write(import_tree_child_indices_); for (const auto& group : mod_group_vec_) { ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module"; // we prioritize export dso when a module is both serializable and exportable if (export_dso) { - if (group[0]->IsDSOExportable()) { - if (has_import_tree) { - std::string mod_type_key = "_lib"; - stream->Write(mod_type_key); - } - } else if (group[0]->IsBinarySerializable()) { + if (group[0]->GetPropertyMask() & ffi::Module::kCompilationExportable) { + std::string mod_type_key = "_lib"; + stream->Write(mod_type_key); + } else if (group[0]->GetPropertyMask() & ffi::Module::kBinarySerializable) { ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; - std::string mod_type_key = group[0]->type_key(); + std::string mod_type_key = group[0]->kind(); stream->Write(mod_type_key); - group[0]->SaveToBinary(stream); + std::string bytes = group[0]->SaveToBytes(); + stream->Write(bytes); } } else { - ICHECK(group[0]->IsBinarySerializable()) - << group[0]->type_key() << " is not binary serializable."; + ICHECK(group[0]->GetPropertyMask() & ffi::Module::kBinarySerializable) + << group[0]->kind() << " is not binary serializable."; ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; - std::string mod_type_key = group[0]->type_key(); + std::string mod_type_key = group[0]->kind(); stream->Write(mod_type_key); - group[0]->SaveToBinary(stream); + std::string bytes = group[0]->SaveToBytes(); + stream->Write(bytes); } } - - // Write _import_tree key if we have - if (has_import_tree) { - std::string import_key = "_import_tree"; - stream->Write(import_key); - stream->Write(import_tree_row_ptr_); - stream->Write(import_tree_child_indices_); - } } private: @@ -127,13 +101,13 @@ class ModuleSerializer { // This function merges all the DSO exportable module into // a single one as this is also what happens in the final hierachy void CreateModuleIndex() { - std::unordered_set visited{mod_.operator->()}; - std::vector stack{mod_.operator->()}; + std::unordered_set visited{mod_.operator->()}; + std::vector stack{mod_.operator->()}; uint64_t module_index = 0; - auto fpush_imports_to_stack = [&](runtime::ModuleNode* node) { - for (runtime::Module m : node->imports()) { - runtime::ModuleNode* next = m.operator->(); + auto fpush_imports_to_stack = [&](ffi::ModuleObj* node) { + for (Any m : node->imports()) { + ffi::ModuleObj* next = m.cast().operator->(); if (visited.count(next) == 0) { visited.insert(next); stack.push_back(next); @@ -141,7 +115,7 @@ class ModuleSerializer { } }; - std::vector dso_exportable_boundary; + std::vector dso_exportable_boundary; // Create module index that merges all dso module into a single group. // @@ -154,16 +128,16 @@ class ModuleSerializer { // Phase 0: only expand non-dso-module and record the boundary. while (!stack.empty()) { - runtime::ModuleNode* n = stack.back(); + ffi::ModuleObj* n = stack.back(); stack.pop_back(); - if (n->IsDSOExportable()) { + if (n->GetPropertyMask() & ffi::Module::kCompilationExportable) { // do not recursively expand dso modules // we will expand in phase 1 dso_exportable_boundary.emplace_back(n); } else { // expand the non-dso modules mod2index_[n] = module_index++; - mod_group_vec_.emplace_back(std::vector({n})); + mod_group_vec_.emplace_back(std::vector({n})); fpush_imports_to_stack(n); } } @@ -173,22 +147,22 @@ class ModuleSerializer { // This index is chosen so that all the DSO's parents are // allocated before this index, and children will be allocated after uint64_t dso_module_index = module_index++; - mod_group_vec_.emplace_back(std::vector()); + mod_group_vec_.emplace_back(std::vector()); // restart visiting the stack using elements in dso exportable boundary stack = std::move(dso_exportable_boundary); // Phase 1: expand the children of dso modules. while (!stack.empty()) { - runtime::ModuleNode* n = stack.back(); + ffi::ModuleObj* n = stack.back(); stack.pop_back(); - if (n->IsDSOExportable()) { + if (n->GetPropertyMask() & ffi::Module::kCompilationExportable) { mod_group_vec_[dso_module_index].emplace_back(n); mod2index_[n] = dso_module_index; } else { mod2index_[n] = module_index++; - mod_group_vec_.emplace_back(std::vector({n})); + mod_group_vec_.emplace_back(std::vector({n})); } fpush_imports_to_stack(n); } @@ -200,8 +174,8 @@ class ModuleSerializer { for (size_t parent_index = 0; parent_index < mod_group_vec_.size(); ++parent_index) { child_indices.clear(); for (const auto* m : mod_group_vec_[parent_index]) { - for (runtime::Module im : m->imports()) { - uint64_t mod_index = mod2index_.at(im.operator->()); + for (Any im : m->imports()) { + uint64_t mod_index = mod2index_.at(im.cast().operator->()); // skip cycle when dso modules are merged together if (mod_index != parent_index) { child_indices.emplace_back(mod_index); @@ -218,8 +192,8 @@ class ModuleSerializer { CHECK_LT(parent_index, child_indices[0]) << "RuntimeError: Cannot export due to multiple dso-exportables " << "that cannot be merged without creating a cycle in the import tree. " - << "Related module keys: parent=" << mod_group_vec_[parent_index][0]->type_key() - << ", child=" << mod_group_vec_[child_indices[0]][0]->type_key(); + << "Related module keys: parent=" << mod_group_vec_[parent_index][0]->kind() + << ", child=" << mod_group_vec_[child_indices[0]][0]->kind(); } // insert the child indices import_tree_child_indices_.insert(import_tree_child_indices_.end(), child_indices.begin(), @@ -228,16 +202,16 @@ class ModuleSerializer { } } - runtime::Module mod_; + ffi::Module mod_; // construct module to index - std::unordered_map mod2index_; + std::unordered_map mod2index_; // index -> module group - std::vector> mod_group_vec_; + std::vector> mod_group_vec_; std::vector import_tree_row_ptr_{0}; std::vector import_tree_child_indices_; }; -std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso) { +std::string SerializeModuleToBytes(const ffi::Module& mod, bool export_dso) { std::string bin; dmlc::MemoryStringStream ms(&bin); dmlc::Stream* stream = &ms; @@ -247,16 +221,18 @@ std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso) return bin; } -runtime::Module DeserializeModuleFromBytes(std::string blob) { +ffi::Module DeserializeModuleFromBytes(std::string blob) { dmlc::MemoryStringStream ms(&blob); dmlc::Stream* stream = &ms; - uint64_t size; - ICHECK(stream->Read(&size)); - std::vector modules; + std::vector modules; std::vector import_tree_row_ptr; std::vector import_tree_child_indices; + stream->Read(&import_tree_row_ptr); + stream->Read(&import_tree_child_indices); + + uint64_t size = import_tree_row_ptr.size() - 1; for (uint64_t i = 0; i < size; ++i) { std::string tkey; ICHECK(stream->Read(&tkey)); @@ -267,29 +243,32 @@ runtime::Module DeserializeModuleFromBytes(std::string blob) { ICHECK(stream->Read(&import_tree_row_ptr)); ICHECK(stream->Read(&import_tree_child_indices)); } else { - auto m = runtime::LoadModuleFromBinary(tkey, stream); + std::string bytes; + ICHECK(stream->Read(&bytes)); + auto loader = ffi::Function::GetGlobal("ffi.Module.load_from_bytes." + tkey); + ICHECK(loader.has_value()) << "ffi.Module.load_from_bytes." << tkey << " is not enabled"; + auto m = (*loader)(ffi::Bytes(bytes)).cast(); modules.emplace_back(m); } } for (size_t i = 0; i < modules.size(); ++i) { for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { - auto module_import_addr = runtime::ModuleInternal::GetImportsAddr(modules[i].operator->()); auto child_index = import_tree_child_indices[j]; ICHECK(child_index < modules.size()); - module_import_addr->emplace_back(modules[child_index]); + modules[i]->ImportModule(modules[child_index]); } } ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; // invariance: root module is always at location 0. // The module order is collected via DFS - runtime::Module root_mod = modules[0]; + ffi::Module root_mod = modules[0]; return root_mod; } -std::string PackImportsToBytes(const runtime::Module& mod) { - std::string bin = SerializeModuleToBytes(mod); +std::string PackImportsToBytes(const ffi::Module& mod) { + std::string bin = SerializeModuleToBytes(mod, /*export_dso*/ true); uint64_t nbytes = bin.length(); std::string header; @@ -299,14 +278,14 @@ std::string PackImportsToBytes(const runtime::Module& mod) { return header + bin; } -std::string PackImportsToC(const runtime::Module& mod, bool system_lib, +std::string PackImportsToC(const ffi::Module& mod, bool system_lib, const std::string& c_symbol_prefix) { if (c_symbol_prefix.length() != 0) { CHECK(system_lib) << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; } - std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin; + std::string mdev_blob_name = c_symbol_prefix + ffi::symbol::tvm_ffi_library_bin; std::string blob = PackImportsToBytes(mod); // translate to C program @@ -332,10 +311,10 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib, } os << "\n};\n"; if (system_lib) { - os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n"; + os << "extern int TVMFFIEnvRegisterSystemLibSymbol(const char*, void*);\n"; os << "static int " << mdev_blob_name << "_reg_ = " - << "TVMBackendRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)" - << mdev_blob_name << ");\n"; + << "TVMFFIEnvRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)" << mdev_blob_name + << ");\n"; } os << "#ifdef __cplusplus\n" << "}\n" @@ -343,9 +322,9 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib, return os.str(); } -runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, - const std::string& llvm_target_string, - const std::string& c_symbol_prefix) { +ffi::Module PackImportsToLLVM(const ffi::Module& mod, bool system_lib, + const std::string& llvm_target_string, + const std::string& c_symbol_prefix) { if (c_symbol_prefix.length() != 0) { CHECK(system_lib) << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; @@ -359,7 +338,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, const auto codegen_f = tvm::ffi::Function::GetGlobal(codegen_f_name); ICHECK(codegen_f.has_value()) << "codegen.codegen_blob is not presented."; return (*codegen_f)(ffi::Bytes(blob), system_lib, llvm_target_string, c_symbol_prefix) - .cast(); + .cast(); } TVM_FFI_STATIC_INIT_BLOCK({ @@ -372,9 +351,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.ModuleImportsBlobName", - []() -> std::string { return runtime::symbol::tvm_ffi_library_bin; }) + []() -> std::string { return ffi::symbol::tvm_ffi_library_bin; }) .def("runtime.ModulePackImportsToNDArray", - [](const runtime::Module& mod) { + [](const ffi::Module& mod) { std::string buffer = PackImportsToBytes(mod); ffi::Shape::index_type size = buffer.size(); DLDataType uchar; diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 0cf218054320..9439af440b82 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -263,7 +263,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } }; -runtime::Module BuildAMDGPU(IRModule mod, Target target) { +ffi::Module BuildAMDGPU(IRModule mod, Target target) { LLVMInstance llvm_instance; With llvm_target(llvm_instance, target); diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index 3d48f57513d0..fc2acfddfb81 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -77,7 +77,7 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l llvm_target->SetTargetMetadata(module.get()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); - std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin; + std::string mdev_blob_name = c_symbol_prefix + ffi::symbol::tvm_ffi_library_bin; auto* tvm_ffi_library_bin = new llvm::GlobalVariable( *module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value, @@ -151,11 +151,11 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, llvm::Twine("__cxx_global_var_init"), module.get()); - // Create TVMBackendRegisterSystemLibSymbol function + // Create TVMFFIEnvRegisterSystemLibSymbol function llvm::Function* tvm_backend_fn = llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false), llvm::GlobalValue::ExternalLinkage, - llvm::Twine("TVMBackendRegisterSystemLibSymbol"), module.get()); + llvm::Twine("TVMFFIEnvRegisterSystemLibSymbol"), module.get()); // Set necessary fn sections auto get_static_init_section_specifier = [&triple]() -> std::string { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6271d4edbe30..eebbd5b64fd4 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -146,10 +146,10 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, if (system_lib_prefix_.has_value() && !target_c_runtime) { // We will need this in environment for backward registration. // Defined in include/tvm/runtime/c_backend_api.h: - // int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); + // int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr); f_tvm_register_system_symbol_ = llvm::Function::Create( llvm::FunctionType::get(t_int_, {llvmGetPointerTo(t_char_, 0), t_void_p_}, false), - llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get()); + llvm::Function::ExternalLinkage, "TVMFFIEnvRegisterSystemLibSymbol", module_.get()); } else { f_tvm_register_system_symbol_ = nullptr; } @@ -236,11 +236,11 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { // Create wrapper function llvm::Function* wrapper_func = llvm::Function::Create(target_func->getFunctionType(), llvm::Function::WeakAnyLinkage, - runtime::symbol::tvm_ffi_main, module_.get()); + ffi::symbol::tvm_ffi_main, module_.get()); // Set attributes (Windows comdat, DLL export, etc.) if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { - llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main); + llvm::Comdat* comdat = module_->getOrInsertComdat(ffi::symbol::tvm_ffi_main); comdat->setSelectionKind(llvm::Comdat::Any); wrapper_func->setComdat(comdat); } @@ -454,8 +454,7 @@ llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { } void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { - std::string ctx_symbol = - system_lib_prefix_.value_or("") + tvm::runtime::symbol::tvm_ffi_library_ctx; + std::string ctx_symbol = system_lib_prefix_.value_or("") + ffi::symbol::tvm_ffi_library_ctx; // Module context gv_mod_ctx_ = InitContextPtr(t_void_p_, ctx_symbol); // Register back the locations. diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 6f90da3d8aea..67fccd8b073a 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -440,7 +440,7 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { } } // namespace -runtime::Module BuildHexagon(IRModule mod, Target target) { +ffi::Module BuildHexagon(IRModule mod, Target target) { LLVMInstance llvm_instance; With llvm_target(llvm_instance, target); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a6b70ad39a32..a1c967e644cb 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -322,7 +322,7 @@ int GetCUDAComputeVersion(const Target& target) { return std::stoi(sm_version.substr(3)); } -runtime::Module BuildNVPTX(IRModule mod, Target target) { +ffi::Module BuildNVPTX(IRModule mod, Target target) { LLVMInstance llvm_instance; With llvm_target(llvm_instance, target); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index dd9622999bd2..f90729a45f06 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -57,11 +57,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include #include @@ -77,7 +77,6 @@ #include #include "../../runtime/file_utils.h" -#include "../../runtime/library_module.h" #include "codegen_blob.h" #include "codegen_cpu.h" #include "codegen_llvm.h" @@ -90,29 +89,29 @@ using ffi::Any; using ffi::Function; using ffi::PackedArgs; -class LLVMModuleNode final : public runtime::ModuleNode { +class LLVMModuleNode final : public ffi::ModuleObj { public: ~LLVMModuleNode(); - const char* type_key() const final { return "llvm"; } + const char* kind() const final { return "llvm"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + Optional GetFunction(const String& name) final; /*! \brief Get the property of the runtime module .*/ // TODO(tvm-team): Make it serializable int GetPropertyMask() const override { - return runtime::ModulePropertyMask::kRunnable | runtime::ModulePropertyMask::kDSOExportable; + return ffi::Module::kRunnable | ffi::Module::kCompilationExportable; } - void SaveToFile(const String& file_name, const String& format) final; - void SaveToBinary(dmlc::Stream* stream) final; - String GetSource(const String& format) final; + void WriteToFile(const String& file_name, const String& format) const final; + ffi::Bytes SaveToBytes() const final; + String InspectSource(const String& format) const final; void Init(const IRModule& mod, const Target& target); void Init(std::unique_ptr module, std::unique_ptr llvm_instance); void LoadIR(const std::string& file_name); - bool ImplementsFunction(const String& name, bool query_imports) final; + bool ImplementsFunction(const String& name) final; void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; } @@ -156,8 +155,8 @@ LLVMModuleNode::~LLVMModuleNode() { module_owning_ptr_.reset(); } -ffi::Function LLVMModuleNode::GetFunction(const String& name, - const ObjectPtr& sptr_to_self) { +Optional LLVMModuleNode::GetFunction(const String& name) { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "__tvm_is_system_module") { bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); return ffi::Function([flag](ffi::PackedArgs args, ffi::Any* rv) { *rv = flag; }); @@ -174,9 +173,9 @@ ffi::Function LLVMModuleNode::GetFunction(const String& name, return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->function_names_; }); } else if (name == "get_symbol") { - return ffi::Function(nullptr); + return std::nullopt; } else if (name == "get_const_vars") { - return ffi::Function(nullptr); + return std::nullopt; } else if (name == "_get_target_string") { std::string target_string = LLVMTarget::GetTargetMetadata(*module_); return ffi::Function( @@ -191,8 +190,13 @@ ffi::Function LLVMModuleNode::GetFunction(const String& name, TVMFFISafeCallType faddr; With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); - if (faddr == nullptr) return ffi::Function(); - return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self); + if (faddr == nullptr) return std::nullopt; + ffi::Module self_strong_ref = GetRef(this); + return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) { + TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); + TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), + args.size(), reinterpret_cast(rv))); + }); } namespace { @@ -231,7 +235,7 @@ bool LLVMAddPassesToEmitFile(llvm::TargetMachine* tm, llvm::legacy::PassManager* } // namespace -void LLVMModuleNode::SaveToFile(const String& file_name_str, const String& format) { +void LLVMModuleNode::WriteToFile(const String& file_name_str, const String& format) const { // CHECK(imports_.empty()) << "SaveToFile does not handle imported modules"; std::string file_name = file_name_str; std::string fmt = runtime::GetFileFormat(file_name, format); @@ -266,11 +270,11 @@ void LLVMModuleNode::SaveToFile(const String& file_name_str, const String& forma dest.close(); } -void LLVMModuleNode::SaveToBinary(dmlc::Stream* stream) { - LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; +ffi::Bytes LLVMModuleNode::SaveToBytes() const { + LOG(FATAL) << "LLVMModule: SaveToBytes not supported"; } -String LLVMModuleNode::GetSource(const String& format) { +String LLVMModuleNode::InspectSource(const String& format) const { std::string fmt = runtime::GetFileFormat("", format); std::string type_str; llvm::SmallString<256> str; @@ -381,7 +385,7 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) { Init(std::move(module), std::move(llvm_instance)); } -bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports) { +bool LLVMModuleNode::ImplementsFunction(const String& name) { return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); } @@ -434,12 +438,16 @@ void LLVMModuleNode::InitMCJIT() { // run ctors mcjit_ee_->runStaticConstructorsDestructors(false); - if (void** ctx_addr = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) { + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(ffi::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } - runtime::InitContextFunctions( - [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); + + ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { + *ctx_addr = symbol; + } + }); // There is a problem when a JITed function contains a call to a runtime function. // The runtime function (e.g. __truncsfhf2) may not be resolved, and calling it will // lead to a runtime crash. @@ -575,12 +583,15 @@ void LLVMModuleNode::InitORCJIT() { err = ctorRunner.run(); ICHECK(!err) << llvm::toString(std::move(err)); - if (void** ctx_addr = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) { + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(ffi::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } - runtime::InitContextFunctions( - [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); + ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { + *ctx_addr = symbol; + } + }); } bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { @@ -638,13 +649,13 @@ static void LLVMReflectionRegister() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.llvm", - [](IRModule mod, Target target) -> runtime::Module { + [](IRModule mod, Target target) -> ffi::Module { auto n = make_object(); n->Init(mod, target); - return runtime::Module(n); + return ffi::Module(n); }) .def("codegen.LLVMModuleCreate", - [](std::string target_str, std::string module_name) -> runtime::Module { + [](std::string target_str, std::string module_name) -> ffi::Module { auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, target_str); auto n = make_object(); @@ -659,7 +670,7 @@ static void LLVMReflectionRegister() { module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); n->Init(std::move(module), std::move(llvm_instance)); n->SetJITEngine(llvm_target->GetJITEngine()); - return runtime::Module(n); + return ffi::Module(n); }) .def("target.llvm_lookup_intrinsic_id", [](std::string name) -> int64_t { @@ -765,12 +776,12 @@ static void LLVMReflectionRegister() { return llvm_target.TargetHasCPUFeature(feature); }) .def("target.llvm_version_major", []() -> int { return TVM_LLVM_VERSION / 10; }) - .def("runtime.module.loadfile_ll", - [](std::string filename, std::string fmt) -> runtime::Module { + .def("ffi.Module.load_from_file.ll", + [](std::string filename, std::string fmt) -> ffi::Module { auto n = make_object(); n->SetJITEngine("orcjit"); n->LoadIR(filename); - return runtime::Module(n); + return ffi::Module(n); }) .def("codegen.llvm_target_enabled", [](std::string target_str) -> bool { @@ -781,7 +792,7 @@ static void LLVMReflectionRegister() { }) .def("codegen.codegen_blob", [](std::string data, bool system_lib, std::string llvm_target_string, - std::string c_symbol_prefix) -> runtime::Module { + std::string c_symbol_prefix) -> ffi::Module { auto n = make_object(); auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, llvm_target_string); @@ -789,7 +800,7 @@ static void LLVMReflectionRegister() { CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix); n->Init(std::move(blob), std::move(llvm_instance)); n->SetJITEngine(llvm_target->GetJITEngine()); - return runtime::Module(n); + return ffi::Module(n); }); } diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 2070d7da3e0c..75897e539f85 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -35,7 +35,7 @@ namespace tvm { namespace codegen { -runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target); +ffi::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target); } // namespace codegen } // namespace tvm diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc index 893eb67a268f..c0b494ff619c 100644 --- a/src/target/opt/build_cuda_off.cc +++ b/src/target/opt/build_cuda_off.cc @@ -24,11 +24,11 @@ namespace tvm { namespace runtime { -Module CUDAModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +ffi::Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { LOG(FATAL) << "CUDA is not enabled"; - return Module(); + TVM_FFI_UNREACHABLE(); } } // namespace runtime } // namespace tvm diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index da8896bf4826..6072a483877c 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -126,7 +126,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { return ptx; } -runtime::Module BuildCUDA(IRModule mod, Target target) { +ffi::Module BuildCUDA(IRModule mod, Target target) { bool output_ssa = false; CodeGenCUDA cg; cg.Init(output_ssa); diff --git a/src/target/opt/build_hexagon_off.cc b/src/target/opt/build_hexagon_off.cc index 2ce5cdb51f5d..696ca6399560 100644 --- a/src/target/opt/build_hexagon_off.cc +++ b/src/target/opt/build_hexagon_off.cc @@ -22,9 +22,10 @@ namespace tvm { namespace runtime { -Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str) { +ffi::Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string asm_str, std::string obj_str, std::string ir_str, + std::string bc_str) { LOG(WARNING) << "Hexagon runtime is not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex"); } diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index 555aa5002f98..4200a35cbb58 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -26,9 +26,9 @@ namespace tvm { namespace runtime { -Module MetalModuleCreate(std::unordered_map smap, - std::unordered_map fmap, std::string fmt, - std::string source) { +ffi::Module MetalModuleCreate(std::unordered_map smap, + std::unordered_map fmap, std::string fmt, + std::string source) { LOG(WARNING) << "Metal runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal"); } diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 9e368d5599cf..797aa3ef8d38 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -26,16 +26,17 @@ namespace tvm { namespace runtime { -Module OpenCLModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string source) { +ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string source) { return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); } -Module OpenCLModuleCreate(const std::unordered_map& shaders, - const std::string& spirv_text, - std::unordered_map fmap) { +ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, + const std::string& spirv_text, + std::unordered_map fmap) { LOG(FATAL) << "OpenCLModuleCreate is called but OpenCL is not enabled."; - return Module(); + TVM_FFI_UNREACHABLE(); } } // namespace runtime diff --git a/src/target/opt/build_rocm_off.cc b/src/target/opt/build_rocm_off.cc index 476e5a88fc6f..f161faa9f648 100644 --- a/src/target/opt/build_rocm_off.cc +++ b/src/target/opt/build_rocm_off.cc @@ -26,9 +26,9 @@ namespace tvm { namespace runtime { -Module ROCMModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, std::string rocm_source, - std::string assembly) { +ffi::Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string rocm_source, std::string assembly) { LOG(WARNING) << "ROCM runtime is not enabled, return a source module..."; auto fget_source = [rocm_source, assembly](const std::string& format) { if (format.length() == 0) return assembly; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 020054b3e1fc..e18ba0128d6b 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -22,8 +22,8 @@ */ #include "codegen_c_host.h" +#include #include -#include #include #include @@ -54,7 +54,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d } void CodeGenCHost::InitGlobalContext() { - decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx << " = NULL;\n"; + decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n"; } void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } @@ -77,11 +77,11 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; - function_names_.push_back(runtime::symbol::tvm_ffi_main); + function_names_.push_back(ffi::symbol::tvm_ffi_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); PrintType(func->ret_type, stream); - stream << " " << tvm::runtime::symbol::tvm_ffi_main + stream << " " << ffi::symbol::tvm_ffi_main << "(void* self, void* args,int num_args, void* result) {\n"; stream << " return " << global_symbol.value() << "(self, args, num_args, result);\n"; stream << "}\n"; @@ -355,7 +355,7 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, << "? (" << a_id << ") : (" << b_id << "))"; } -runtime::Module BuildCHost(IRModule mod, Target target) { +ffi::Module BuildCHost(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; bool emit_fwd_func_decl = true; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ffb1737a7063..dc019c28a7a0 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -431,7 +431,7 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO os << temp.str(); } -runtime::Module BuildMetal(IRModule mod, Target target) { +ffi::Module BuildMetal(IRModule mod, Target target) { bool output_ssa = false; mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 2645423affe3..1342464665f3 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -633,7 +633,7 @@ void CodeGenOpenCL::SetTextureScope( } } -runtime::Module BuildOpenCL(IRModule mod, Target target) { +ffi::Module BuildOpenCL(IRModule mod, Target target) { #if TVM_ENABLE_SPIRV Optional device = target->GetAttr("device"); if (device && device.value() == "spirv") { diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index a416e3fcae31..f077f8c3a83b 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -140,7 +140,7 @@ class CodeGenSourceBase { * \param code The code to be viewed. * \param fmt The code. format. */ -runtime::Module SourceModuleCreate(std::string code, std::string fmt); +ffi::Module SourceModuleCreate(std::string code, std::string fmt); /*! * \brief Create a C source module for viewing and compiling GCC code. @@ -150,9 +150,9 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt); * \param const_vars. The constant variables that the c source module needs. * \return The created module. */ -runtime::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, - const Array& const_vars = {}); +ffi::Module CSourceModuleCreate(const String& code, const String& fmt, + const Array& func_names, + const Array& const_vars = {}); /*! * \brief Wrap the submodules in a metadata module. @@ -163,9 +163,9 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, * \param target The target that all the modules are compiled for * \return The wrapped module. */ -runtime::Module CreateMetadataModule( - const std::unordered_map& params, runtime::Module target_module, - const Array& ext_modules, Target target); +ffi::Module CreateMetadataModule(const std::unordered_map& params, + ffi::Module target_module, const Array& ext_modules, + Target target); /*! * \brief Create a source module for viewing and limited saving for device. @@ -175,7 +175,7 @@ runtime::Module CreateMetadataModule( * \param type_key The type_key of the runtime module of this source code * \param fget_source a closure to replace default get source behavior. */ -runtime::Module DeviceSourceModuleCreate( +ffi::Module DeviceSourceModuleCreate( std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source = nullptr); diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index f5bfd80fee25..28d158c3c21e 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -22,6 +22,7 @@ */ #include "codegen_webgpu.h" +#include #include #include #include @@ -705,27 +706,30 @@ void CodeGenWebGPU::VisitStmt_(const WhileNode* op) { //------------------------------------------------- // WebGPUSourceModule to enable export //------------------------------------------------- -class WebGPUSourceModuleNode final : public runtime::ModuleNode { +class WebGPUSourceModuleNode final : public ffi::ModuleObj { public: explicit WebGPUSourceModuleNode(std::unordered_map smap, std::unordered_map fmap) : smap_(smap), fmap_(fmap) {} - const char* type_key() const final { return "webgpu"; } + const char* kind() const final { return "webgpu"; } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; - return ffi::Function(nullptr); } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmap_); stream->Write(smap_); + return ffi::Bytes(buffer); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (format == "func_info") { std::ostringstream stream; dmlc::JSONWriter(&stream).Write(fmap_); @@ -749,7 +753,7 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode { //------------------------------------------------- // Build logic. //------------------------------------------------- -runtime::Module BuildWebGPU(IRModule mod, Target target) { +ffi::Module BuildWebGPU(IRModule mod, Target target) { mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); bool output_ssa = false; bool skip_readonly_decl = false; @@ -777,7 +781,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { } auto n = make_object(smap, fmap); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index b3c5ff311a3c..1350357d866c 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,9 +23,9 @@ */ #include +#include #include #include -#include #include #include @@ -51,40 +51,43 @@ using runtime::GetMetaFilePath; using runtime::SaveBinaryToFile; // Simulator function -class SourceModuleNode : public runtime::ModuleNode { +class SourceModuleNode : public ffi::ModuleObj { public: SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} - const char* type_key() const final { return "source"; } + const char* kind() const final { return "source"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; - return ffi::Function(); } - String GetSource(const String& format) final { return code_; } + String InspectSource(const String& format) const final { return code_; } - String GetFormat() override { return fmt_; } + Array GetWriteFormats() const override { return {fmt_}; } protected: std::string code_; std::string fmt_; }; -runtime::Module SourceModuleCreate(std::string code, std::string fmt) { +ffi::Module SourceModuleCreate(std::string code, std::string fmt) { auto n = make_object(code, fmt); - return runtime::Module(n); + return ffi::Module(n); } // Simulator function -class CSourceModuleNode : public runtime::ModuleNode { +class CSourceModuleNode : public ffi::ModuleObj { public: CSourceModuleNode(const std::string& code, const std::string& fmt, const Array& func_names, const Array& const_vars) - : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) {} - const char* type_key() const final { return "c"; } + : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) { + if (fmt_.empty()) fmt_ = "c"; + } + + const char* kind() const final { return "c"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { + ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Currently c-source module is used as demonstration purposes with binary metadata module // that expects get_symbol interface. When c-source module is used as external module, it // will only contain one function. However, when its used as an internal module (e.g., target @@ -103,11 +106,14 @@ class CSourceModuleNode : public runtime::ModuleNode { } } - String GetSource(const String& format) final { return code_; } + String InspectSource(const String& format) const final { return code_; } - String GetFormat() override { return fmt_; } + Array GetWriteFormats() const override { return {fmt_}; } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(code_); stream->Write(fmt_); @@ -117,10 +123,12 @@ class CSourceModuleNode : public runtime::ModuleNode { for (auto const_var : const_vars_) const_vars.push_back(const_var); stream->Write(func_names); stream->Write(const_vars); + return ffi::Bytes(buffer); } - static runtime::Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); + static ffi::Module LoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::string code, fmt; ICHECK(stream->Read(&code)) << "Loading code failed"; @@ -137,10 +145,10 @@ class CSourceModuleNode : public runtime::ModuleNode { for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var)); auto n = make_object(code, fmt, func_names, const_vars); - return runtime::Module(n); + return ffi::Module(n); } - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") { @@ -152,11 +160,10 @@ class CSourceModuleNode : public runtime::ModuleNode { } int GetPropertyMask() const override { - return runtime::ModulePropertyMask::kBinarySerializable | - runtime::ModulePropertyMask::kDSOExportable; + return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name, bool query_imports) final { + bool ImplementsFunction(const String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } @@ -167,17 +174,16 @@ class CSourceModuleNode : public runtime::ModuleNode { Array func_names_; }; -runtime::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, - const Array& const_vars) { +ffi::Module CSourceModuleCreate(const String& code, const String& fmt, + const Array& func_names, const Array& const_vars) { auto n = make_object(code.operator std::string(), fmt.operator std::string(), func_names, const_vars); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("runtime.module.loadbinary_c", CSourceModuleNode::LoadFromBinary); + refl::GlobalDef().def("ffi.Module.load_from_bytes.c", CSourceModuleNode::LoadFromBytes); }); /*! @@ -197,20 +203,19 @@ class ConcreteCodegenSourceBase : public CodeGenSourceBase { }; // supports limited save without cross compile -class DeviceSourceModuleNode final : public runtime::ModuleNode { +class DeviceSourceModuleNode final : public ffi::ModuleObj { public: DeviceSourceModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source) : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; - return ffi::Function(); } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { if (fget_source_ != nullptr) { return fget_source_(format); } else { @@ -218,11 +223,11 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const final { return type_key_.c_str(); } + const char* kind() const final { return type_key_.c_str(); } /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - void SaveToFile(const String& file_name, const String& format) final { + void WriteToFile(const String& file_name, const String& format) const final { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -230,10 +235,14 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { SaveBinaryToFile(file_name, data_); } - void SaveToBinary(dmlc::Stream* stream) final { + ffi::Bytes SaveToBytes() const final { + std::string buffer; + dmlc::MemoryStringStream ms(&buffer); + dmlc::Stream* stream = &ms; stream->Write(fmt_); stream->Write(fmap_); stream->Write(data_); + return ffi::Bytes(buffer); } private: @@ -244,11 +253,12 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { std::function fget_source_; }; -runtime::Module DeviceSourceModuleCreate( - std::string data, std::string fmt, std::unordered_map fmap, - std::string type_key, std::function fget_source) { +ffi::Module DeviceSourceModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string type_key, + std::function fget_source) { auto n = make_object(data, fmt, fmap, type_key, fget_source); - return runtime::Module(n); + return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 26ecffcc6bd3..bd44607a98eb 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -32,7 +32,7 @@ namespace tvm { namespace codegen { -runtime::Module BuildSPIRV(IRModule mod, Target target) { +ffi::Module BuildSPIRV(IRModule mod, Target target) { auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } diff --git a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc index b8e5b90ece7c..1097a21128e1 100644 --- a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc +++ b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc @@ -184,9 +184,8 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { module.InstallKernel(m_workspace, m_workspace->GetThreadEntry(), m_kernelNames[i], e); } Timestamp comp_end = std::chrono::high_resolution_clock::now(); - auto get_pre_compiled_f = module.GetFunction("opencl.GetPreCompiledPrograms", - tvm::ffi::GetObjectPtr(&module)); - bytes = get_pre_compiled_f().cast(); + auto get_pre_compiled_f = module.GetFunction("opencl.GetPreCompiledPrograms").value(); + bytes = get_pre_compiled_f().cast(); std::chrono::duration duration = std::chrono::duration_cast(comp_end - comp_start); compileFromSourceTimeMS = duration.count() * 1e-6; @@ -195,8 +194,7 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { { OpenCLModuleNode module(m_dataSrc, "cl", m_fmap, std::string()); module.Init(); - module.GetFunction("opencl.SetPreCompiledPrograms", - GetObjectPtr(&module))(tvm::String(bytes)); + module.GetFunction("opencl.SetPreCompiledPrograms").value()(tvm::String(bytes)); Timestamp comp_start = std::chrono::high_resolution_clock::now(); for (size_t i = 0; i < m_kernelNames.size(); ++i) { OpenCLModuleNode::KTRefEntry e = {i, 1}; diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 2c8f185d8ecd..90ad1d65c7aa 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -48,7 +48,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mul instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"mul\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -80,7 +80,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and add instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"add\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -112,7 +112,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and sub instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"sub\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -145,7 +145,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C, D])) # Verify we see SVE load instructions and either mad or mla instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"mad|mla\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -177,7 +177,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a max instruction, all using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) compare = re.findall( r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -213,7 +213,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a min instruction, all using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) compare = re.findall( r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -249,7 +249,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and div instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"div\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -280,7 +280,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mls instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"mls\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -312,7 +312,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpeq or cmeq instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"cm(p)?eq\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -344,7 +344,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne instructions, all using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"cm(p)?(gt|ne)\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -375,7 +375,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and orr instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"orr\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -406,7 +406,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and and instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"and\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -436,7 +436,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, C])) # Verify we see SVE load instructions and eor instructions using z registers - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) matches = re.findall( r"eor\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly @@ -471,7 +471,7 @@ def check_correct_assembly(type): f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see gather instructions in the assembly - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) assert len(loads) > 0 @@ -503,7 +503,7 @@ def test_vscale_range_function_attribute(mattr, expect_attr): f = tvm.tir.build(te.create_prim_func([A, C])) # Check if the vscale_range() attribute exists - ll = f.get_source("ll") + ll = f.inspect_source("ll") attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll) if expect_attr: diff --git a/tests/python/codegen/test_target_codegen_arm.py b/tests/python/codegen/test_target_codegen_arm.py index d22e528770b3..e6d0c70f8734 100644 --- a/tests/python/codegen/test_target_codegen_arm.py +++ b/tests/python/codegen/test_target_codegen_arm.py @@ -30,7 +30,7 @@ def check_correct_assembly(type, elements, counts): sch.vectorize(sch.get_loops("B")[0]) f = tvm.tir.build(sch.mod, target=target) # Verify we see the correct number of vpaddl and vcnt instructions in the assembly - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") matches = re.findall("vpaddl", assembly) assert len(matches) == counts matches = re.findall("vcnt", assembly) @@ -61,7 +61,7 @@ def check_correct_assembly(N): f = tvm.tir.build(sch.mod, target=target) # Verify we see the correct number of vmlal.s16 instructions - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") matches = re.findall("vmlal.s16", assembly) assert len(matches) == N // 4 @@ -85,7 +85,7 @@ def check_broadcast_correct_assembly(N): f = tvm.tir.build(sch.mod, target=target) # Verify we see the correct number of vmlal.s16 instructions - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") matches = re.findall("vmlal.s16", assembly) assert len(matches) == N // 4 diff --git a/tests/python/codegen/test_target_codegen_c_host.py b/tests/python/codegen/test_target_codegen_c_host.py index af94cae71f1c..3c80cfbeb0b4 100644 --- a/tests/python/codegen/test_target_codegen_c_host.py +++ b/tests/python/codegen/test_target_codegen_c_host.py @@ -192,7 +192,7 @@ def subroutine(A_data: T.handle("float32")): "subroutine" not in func_names ), "Internal function should not be listed in available functions." - source = built.get_source() + source = built.inspect_source() assert ( source.count("main(void*") == 2 ), "Expected two occurrences, for forward-declaration and definition" diff --git a/tests/python/codegen/test_target_codegen_cross_llvm.py b/tests/python/codegen/test_target_codegen_cross_llvm.py index c126e531090e..9ae516c7de30 100644 --- a/tests/python/codegen/test_target_codegen_cross_llvm.py +++ b/tests/python/codegen/test_target_codegen_cross_llvm.py @@ -51,7 +51,7 @@ def build_i386(): target = "llvm -mtriple=i386-pc-linux-gnu" f = tvm.tir.build(sch.mod, target=target) path = temp.relpath("myadd.o") - f.save(path) + f.write_to_file(path) verify_elf(path, 0x03) def build_arm(): @@ -62,10 +62,10 @@ def build_arm(): temp = utils.tempdir() f = tvm.tir.build(sch.mod, target=target) path = temp.relpath("myadd.o") - f.save(path) + f.write_to_file(path) verify_elf(path, 0x28) asm_path = temp.relpath("myadd.asm") - f.save(asm_path) + f.write_to_file(asm_path) # Do a RPC verification, launch kernel on Arm Board if available. host = os.environ.get("TVM_RPC_ARM_HOST", None) remote = None diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index a304cb1e41c7..fb9c47410fea 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -663,7 +663,7 @@ def build(A, C, N, C_N): f = tvm.tir.build(sch.mod, target="cuda") - kernel_source = f.imported_modules[0].get_source() + kernel_source = f.imports[0].inspect_source() dev = tvm.cuda() a_data = np.arange(0, N).astype(A.dtype) a = tvm.nd.array(a_data, dev) @@ -774,7 +774,7 @@ def main(A_ptr: T.handle): A[0] = ((float)(*(double *)(&(A_map)))); } }""".strip() - in mod.mod.imported_modules[0].get_source() + in mod.mod.imports[0].inspect_source() ) @@ -797,7 +797,7 @@ def main( C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) lib = tvm.compile(Module, target="cuda") - cuda_code = lib.mod.imported_modules[0].get_source() + cuda_code = lib.mod.imports[0].inspect_source() assert 'extern "C" __device__ float add(float a, float b) {\n return (a + b);\n}' in cuda_code @@ -827,7 +827,7 @@ def main( # in order to avoid checking a function is host or device based on the "cpu" substring. target = tvm.target.Target({"kind": "cuda", "mcpu": "dummy_mcpu"}, host="c") lib = tvm.compile(Module, target=target) - cuda_code = lib.mod.imported_modules[0].get_source() + cuda_code = lib.mod.imports[0].inspect_source() assert 'extern "C" __device__ int add(int a, int b) {\n return (a + b);\n}' in cuda_code # Run a simple test @@ -854,7 +854,7 @@ def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): B[bx, tx] = A[bx, tx] lib = tvm.compile(Module, target="cuda") - cuda_code = lib.mod.imported_modules[0].get_source() + cuda_code = lib.mod.imports[0].inspect_source() assert "return;" in cuda_code diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index aa9080a48882..c0b6130bcb80 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -71,7 +71,7 @@ def add( target = "cuda" fadd = tvm.tir.build(sch.mod, target=target) - cuda_src = fadd.imported_modules[0].get_source() + cuda_src = fadd.imports[0].inspect_source() assert nv_dtype in cuda_src, f"{nv_dtype} datatype not found in generated CUDA" dev = tvm.device(target, 0) @@ -190,7 +190,7 @@ def add( target = "cuda" fadd = tvm.tir.build(sch.mod, target=target) - cuda_src = fadd.imported_modules[0].get_source() + cuda_src = fadd.imports[0].inspect_source() dev = tvm.device(target, 0) if "x" in native_dtype: @@ -710,7 +710,7 @@ def print_cuda(target, mod, name=None): if name: mod = mod[name] f = tvm.tir.build(mod, target=target) - cuda_src = f.imported_modules[0].get_source() + cuda_src = f.imports[0].inspect_source() print(cuda_src) print_cuda(target, dequant_mod, name="dequant") diff --git a/tests/python/codegen/test_target_codegen_hexagon.py b/tests/python/codegen/test_target_codegen_hexagon.py index c0665ce316ad..f14005ad9d0b 100644 --- a/tests/python/codegen/test_target_codegen_hexagon.py +++ b/tests/python/codegen/test_target_codegen_hexagon.py @@ -46,7 +46,7 @@ def check_add(): C = tvm.te.compute((128,), lambda i: A[i] + B[i], name="C") mod = tvm.IRModule.from_expr(te.create_prim_func([C, A, B])) hexm = tvm.compile(mod, target=tvm.target.Target(target, target)) - asm = hexm.get_source("s") + asm = hexm.inspect_source("s") vadds = re.findall(r"v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)", asm) assert vadds # Check that it's non-empty @@ -61,7 +61,7 @@ def test_llvm_target_features(): C = tvm.te.compute((128,), lambda i: A[i] + 1, name="C") mod = tvm.IRModule.from_expr(te.create_prim_func([C, A]).with_attr("global_symbol", "add_one")) m = tvm.compile(mod, target=tvm.target.Target(target, target)) - llvm_ir = m.get_source("ll") + llvm_ir = m.inspect_source("ll") # Make sure we find +hvx-length128b in "attributes". fs = re.findall(r"attributes.*\+hvx-length128b", llvm_ir) assert fs # Check that it's non-empty diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 15c030aeacf2..953adf78b342 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -454,7 +454,7 @@ def test_alignment(): # Build with name f = tvm.tir.build(sch.mod, target="llvm") - lines = f.get_source().split("\n") + lines = f.inspect_source().split("\n") # Check alignment on load/store. for l in lines: @@ -702,7 +702,7 @@ def check_llvm_object(): m = tvm.compile(mod, target="llvm") temp = utils.tempdir() o_path = temp.relpath("temp.o") - m.save(o_path) + m.write_to_file(o_path) import shutil import subprocess import sys @@ -738,7 +738,7 @@ def check_llvm_ir(): } ) m = tvm.tir.build(mod, target="llvm -mtriple=aarch64-linux-gnu") - ll = m.get_source("ll") + ll = m.inspect_source("ll") # On non-Darwin OS, don't explicitly specify DWARF version. import re @@ -748,7 +748,7 @@ def check_llvm_ir(): # Try Darwin, require DWARF-2 m = tvm.tir.build(mod, target="llvm -mtriple=x86_64-apple-darwin-macho") - ll = m.get_source("ll") + ll = m.inspect_source("ll") assert re.search(r"""i32 4, !"Dwarf Version", i32 2""", ll) assert re.search(r"""llvm.dbg.value""", ll) @@ -802,9 +802,9 @@ def test_llvm_crt_static_lib(): mod.with_attr("system_lib_prefix", ""), target=tvm.target.Target("llvm"), ) - module.get_source() + module.inspect_source() with utils.tempdir() as temp: - module.save(temp.relpath("test.o")) + module.write_to_file(temp.relpath("test.o")) @tvm.testing.requires_llvm @@ -829,7 +829,7 @@ def make_call_extern(caller, callee): "Kirby": make_call_extern("Kirby", "Fred"), } mod = tvm.IRModule(functions=functions) - ir_text = tvm.tir.build(mod, target="llvm").get_source("ll") + ir_text = tvm.tir.build(mod, target="llvm").inspect_source("ll") # Skip functions whose names start with _. matches = re.findall(r"^define[^@]*@([a-zA-Z][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) assert matches == sorted(matches) @@ -930,7 +930,7 @@ def test_llvm_target_attributes(): target = tvm.target.Target(target_llvm, host=target_llvm) module = tvm.tir.build(sch.mod, target=target) - llvm_ir = module.get_source() + llvm_ir = module.inspect_source() llvm_ir_lines = llvm_ir.split("\n") attribute_definitions = dict() diff --git a/tests/python/codegen/test_target_codegen_llvm_vla.py b/tests/python/codegen/test_target_codegen_llvm_vla.py index 7ca3083dd5e3..8930159481cb 100644 --- a/tests/python/codegen/test_target_codegen_llvm_vla.py +++ b/tests/python/codegen/test_target_codegen_llvm_vla.py @@ -46,7 +46,7 @@ def main(A: T.Buffer((5,), "int32")): with tvm.target.Target(target): build_mod = tvm.tir.build(main) - llvm = build_mod.get_source() + llvm = build_mod.inspect_source() assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." @@ -68,7 +68,7 @@ def my_func(a: T.handle, b: T.handle): with tvm.target.Target(target): mod = tvm.tir.build(my_func) - llvm = mod.get_source("ll") + llvm = mod.inspect_source("ll") assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." @@ -90,7 +90,7 @@ def my_func(a: T.handle): with tvm.target.Target(target): mod = tvm.tir.build(my_func) - llvm = mod.get_source("ll") + llvm = mod.inspect_source("ll") assert re.findall( r"shufflevector \( insertelement \(", llvm ), "No scalable broadcast in generated LLVM." @@ -114,7 +114,7 @@ def before(a: T.handle): with tvm.target.Target(target): out = tvm.tir.build(before) - ll = out.get_source("ll") + ll = out.inspect_source("ll") assert "get.active.lane.mask" in ll @@ -139,7 +139,7 @@ def before(a: T.handle, b: T.handle): with tvm.target.Target(target): out = tvm.tir.build(before) - ll = out.get_source("ll") + ll = out.inspect_source("ll") assert "get.active.lane.mask" in ll assert "llvm.masked.load" in ll assert "llvm.masked.store" in ll diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index 2d669081e347..6b413d532371 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -187,7 +187,7 @@ def compile_metal(src, target): mod = tvm.IRModule({"main": func}) f = tvm.compile(mod, target="metal") - src: str = f.imported_modules[0].get_source() + src: str = f.imports[0].inspect_source() occurrences = src.count("struct func_kernel_args_t") assert occurrences == 1, occurrences diff --git a/tests/python/codegen/test_target_codegen_opencl.py b/tests/python/codegen/test_target_codegen_opencl.py index cbdb60477b06..4eb96747bcee 100644 --- a/tests/python/codegen/test_target_codegen_opencl.py +++ b/tests/python/codegen/test_target_codegen_opencl.py @@ -140,7 +140,7 @@ def check_erf(dev, n, dtype): sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - source_str = fun.imported_modules[0].get_source() + source_str = fun.imports[0].inspect_source() matches = re.findall("erf", source_str) error_matches = re.findall("erff", source_str) assert len(matches) == 1 and len(error_matches) == 0 @@ -180,7 +180,7 @@ def check_type_casting(ctx, n, dtype): fun = tvm.tir.build(sch.mod, target=target) c = tvm.nd.empty((n,), dtype, ctx) - assembly = fun.imported_modules[0].get_source() + assembly = fun.imports[0].inspect_source() lcond = "convert_int4(((convert_uint4(((uint4)(((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3)))))" rcond = "(convert_uint4(((((int4)(((convert_int(get_local_id(0))))+(1*0), ((convert_int(get_local_id(0))))+(1*1), ((convert_int(get_local_id(0))))+(1*2), ((convert_int(get_local_id(0))))+(1*3))) % ((int4)(3, 3, 3, 3))) == ((int4)(1, 1, 1, 1))))))))" pattern_cond = "({} && {})".format(lcond, rcond) @@ -211,7 +211,7 @@ def _check(target, n, dtype): sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - assembly = fun.imported_modules[0].get_source() + assembly = fun.imports[0].inspect_source() if "adreno" in target: pattern = "convert_float" else: diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index b06aeb4ced06..1a30ab203f04 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -36,7 +36,7 @@ def load_vec(A: T.Buffer((N,), "int8")): f = tvm.tir.build(load_vec, target) # Check RVV `vsetvli` prensence - assembly = f.get_source("asm") + assembly = f.inspect_source("asm") if target_has_features("v"): assert "vsetvli" in assembly else: diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index 89acf598d6e3..a523ae037794 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -86,7 +86,7 @@ def test_vector_comparison(target, dev, dtype): # Verify we generate the boolx4 type declaration and the OpSelect # v4{float,half,int} instruction - assembly = f.imported_modules[0].get_source() + assembly = f.imports[0].inspect_source() matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly) assert len(matches) == 1 matches = re.findall("OpSelect %v4.*", assembly) diff --git a/tests/python/codegen/test_target_codegen_x86.py b/tests/python/codegen/test_target_codegen_x86.py index 51d648f2c4a9..8664d5ceb732 100644 --- a/tests/python/codegen/test_target_codegen_x86.py +++ b/tests/python/codegen/test_target_codegen_x86.py @@ -41,7 +41,7 @@ def fp16_to_fp32(target, width, match=None, not_match=None): sch.vectorize(sch.get_loops("B")[1]) f = tvm.tir.build(sch.mod, target=target) - assembly = f.get_source("asm").splitlines() + assembly = f.inspect_source("asm").splitlines() if match: matches = [l for l in assembly if re.search(match, l)] assert matches diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py index 1b75dd5bc915..d3adbc12c922 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py @@ -218,7 +218,7 @@ def _benchmark_hexagon_elementwise_add_kernel( # Create an actual Hexagon-native shared object file, initially stored on the # host's file system... host_dso_binary_path = os.path.join(host_files_dir_path, "test_binary.so") - built_module.save(host_dso_binary_path) + built_module.write_to_file(host_dso_binary_path) print(f"SAVED BINARY TO HOST PATH: {host_dso_binary_path}") # Upload the .so to the Android device's file system (or wherever is appropriate diff --git a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py index b6b1f8fa73d6..7d556e8bae73 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py @@ -261,7 +261,7 @@ def test_maxpool2d_nhwc( # Save a local copy of the Hexagon object code (in the form of a .so file) # to allow post-mortem inspection. host_dso_binary_path = os.path.join(host_files_dir_path, "test_binary.so") - built_module.save(host_dso_binary_path) + built_module.write_to_file(host_dso_binary_path) print(f"SAVED BINARY TO HOST PATH: {host_dso_binary_path}") hexagon_mod = hexagon_session.load_module(built_module) diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py index c6196ce42517..b873f606e619 100644 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -94,9 +94,9 @@ def test_sigmoid( with tvm.transform.PassContext(opt_level=3): runtime_module = tvm.compile(tir_s.mod, target=get_hexagon_target("v69")) - assert "hvx_sigmoid" in runtime_module.get_source("asm") - assert "vmin" in runtime_module.get_source("asm") - assert "vmax" in runtime_module.get_source("asm") + assert "hvx_sigmoid" in runtime_module.inspect_source("asm") + assert "vmin" in runtime_module.inspect_source("asm") + assert "vmax" in runtime_module.inspect_source("asm") mod = hexagon_session.load_module(runtime_module) mod(input_data, output_data) diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index 2795f5630163..eec48a972ea2 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -50,7 +50,7 @@ def test_vtcm_building(): sch = get_scale_by_two_schedule() target = get_hexagon_target("v68") built = tvm.compile(sch.mod, target=target) - assert "global.vtcm" in built.get_source("asm") + assert "global.vtcm" in built.inspect_source("asm") @tvm.testing.requires_hexagon diff --git a/tests/python/ir/test_roundtrip_runtime_module.py b/tests/python/ir/test_roundtrip_runtime_module.py index 3723cc6c112c..e6fca273a025 100644 --- a/tests/python/ir/test_roundtrip_runtime_module.py +++ b/tests/python/ir/test_roundtrip_runtime_module.py @@ -25,11 +25,11 @@ def test_csource_module(): mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], []) - assert mod.type_key == "c" - assert mod.is_binary_serializable + assert mod.kind == "c" + assert mod.is_binary_serializable() new_mod = tvm.ir.load_json(tvm.ir.save_json(mod)) - assert new_mod.type_key == "c" - assert new_mod.is_binary_serializable + assert new_mod.kind == "c" + assert new_mod.is_binary_serializable() if __name__ == "__main__": diff --git a/tests/python/relax/backend/clml/test_clml_codegen.py b/tests/python/relax/backend/clml/test_clml_codegen.py index b03d6afa1c9b..29448774d69b 100644 --- a/tests/python/relax/backend/clml/test_clml_codegen.py +++ b/tests/python/relax/backend/clml/test_clml_codegen.py @@ -52,7 +52,7 @@ def compare_codegen(clml_mod, clml_codegen): - source = clml_mod.attrs["external_mods"][0].get_source() + source = clml_mod.attrs["external_mods"][0].inspect_source() codegen = json.loads(source)["nodes"] for node in range(len(codegen)): if codegen[node]["op"] == "input" or codegen[node]["op"] == "const": diff --git a/tests/python/relax/test_vm_instrument.py b/tests/python/relax/test_vm_instrument.py index c5f293114f3c..8c4d728da18b 100644 --- a/tests/python/relax/test_vm_instrument.py +++ b/tests/python/relax/test_vm_instrument.py @@ -93,7 +93,7 @@ def test_lib_comparator(): ex = get_exec_int32(data_np.shape) vm = relax.VirtualMachine(ex, tvm.cpu()) # compare against library module - cmp = LibCompareVMInstrument(vm.module.imported_modules[0], tvm.cpu(), verbose=False) + cmp = LibCompareVMInstrument(vm.module.imports[0], tvm.cpu(), verbose=False) vm.set_instrument(cmp) vm["main"](tvm.nd.array(data_np)) diff --git a/tests/python/runtime/test_runtime_module_export.py b/tests/python/runtime/test_runtime_module_export.py index 8897837a26af..0db1fa93dc2a 100644 --- a/tests/python/runtime/test_runtime_module_export.py +++ b/tests/python/runtime/test_runtime_module_export.py @@ -40,16 +40,16 @@ def test_import_static_library(): assert mod0.implements_function("myadd0") assert mod1.implements_function("myadd1") - assert mod1.is_dso_exportable + assert mod1.is_compilation_exportable() # mod1 is currently an 'llvm' module. # Save and reload it as a vanilla 'static_library'. temp = utils.tempdir() mod1_o_path = temp.relpath("mod1.o") - mod1.save(mod1_o_path) + mod1.write_to_file(mod1_o_path) mod1_o = tvm.runtime.load_static_library(mod1_o_path, ["myadd1"]) assert mod1_o.implements_function("myadd1") - assert mod1_o.is_dso_exportable + assert mod1_o.is_compilation_exportable() # Import mod1 as a static library into mod0 and compile to its own DSO. mod0.import_module(mod1_o) @@ -58,13 +58,13 @@ def test_import_static_library(): # The imported mod1 is statically linked into mod0. loaded_lib = tvm.runtime.load_module(mod0_dso_path) - assert loaded_lib.type_key == "library" - assert len(loaded_lib.imported_modules) == 0 + assert loaded_lib.kind == "library" + assert len(loaded_lib.imports) == 0 assert loaded_lib.implements_function("myadd0") assert loaded_lib.get_function("myadd0") assert loaded_lib.implements_function("myadd1") assert loaded_lib.get_function("myadd1") - assert not loaded_lib.is_dso_exportable + assert not loaded_lib.is_compilation_exportable() if __name__ == "__main__": diff --git a/tests/python/runtime/test_runtime_module_load.py b/tests/python/runtime/test_runtime_module_load.py index 79b95256f9fa..d22d40f6f2b1 100644 --- a/tests/python/runtime/test_runtime_module_load.py +++ b/tests/python/runtime/test_runtime_module_load.py @@ -64,7 +64,7 @@ def save_object(names): ) m = tvm.tir.build(mod, target=target) for name in names: - m.save(name) + m.write_to_file(name) path_obj = temp.relpath("test.o") path_ll = temp.relpath("test.ll") @@ -169,8 +169,8 @@ def check_llvm(): path1 = temp.relpath("myadd1.o") path2 = temp.relpath("myadd2.o") path_dso = temp.relpath("mylib.so") - fadd1.save(path1) - fadd2.save(path2) + fadd1.write_to_file(path1) + fadd2.write_to_file(path2) # create shared library with multiple functions cc.create_shared(path_dso, [path1, path2]) m = tvm.runtime.load_module(path_dso) @@ -195,8 +195,8 @@ def check_system_lib(): path1 = temp.relpath("myadd1.o") path2 = temp.relpath("myadd2.o") path_dso = temp.relpath("mylib.so") - fadd1.save(path1) - fadd2.save(path2) + fadd1.write_to_file(path1) + fadd2.write_to_file(path2) cc.create_shared(path_dso, [path1, path2]) def popen_check(): diff --git a/tests/python/runtime/test_runtime_module_property.py b/tests/python/runtime/test_runtime_module_property.py index 83e535e1ac83..a071f9774323 100644 --- a/tests/python/runtime/test_runtime_module_property.py +++ b/tests/python/runtime/test_runtime_module_property.py @@ -21,9 +21,9 @@ def checker(mod, expected): - assert mod.is_binary_serializable == expected["is_binary_serializable"] - assert mod.is_runnable == expected["is_runnable"] - assert mod.is_dso_exportable == expected["is_dso_exportable"] + assert mod.is_binary_serializable() == expected["is_binary_serializable()"] + assert mod.is_runnable() == expected["is_runnable"] + assert mod.is_compilation_exportable() == expected["is_compilation_exportable()"] def create_csource_module(): @@ -39,12 +39,20 @@ def create_llvm_module(): def test_property(): checker( create_csource_module(), - expected={"is_binary_serializable": True, "is_runnable": False, "is_dso_exportable": True}, + expected={ + "is_binary_serializable()": True, + "is_runnable": False, + "is_compilation_exportable()": True, + }, ) checker( create_llvm_module(), - expected={"is_binary_serializable": False, "is_runnable": True, "is_dso_exportable": True}, + expected={ + "is_binary_serializable()": False, + "is_runnable": True, + "is_compilation_exportable()": True, + }, ) diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index e696cbcf086c..ac8653012ace 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -80,7 +80,7 @@ def verify_rpc(remote, target, shape, dtype): b = tvm.nd.array(np.zeros(shape).astype(A.dtype), device=dev) temp = utils.tempdir() path_dso = temp.relpath("dev_lib.o") - f.save(path_dso) + f.write_to_file(path_dso) remote.upload(path_dso) f = remote.load_module("dev_lib.o") f(a, b) diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 1858c00e8662..d4c93bb24ae9 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -393,7 +393,7 @@ def postproc_if_missing_async_support(): # way, even though the generated code doesn't compile on platforms # that do not support async, the comparison against an expected # output can still be performed. We cannot use - # `mod.get_source()`, as that contains the source after all + # `mod.inspect_source()`, as that contains the source after all # post-processing. original_code = None diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 4cf45075edcb..b33724c722d7 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -169,9 +169,9 @@ class AsyncLocalSession : public LocalSession { // special handle time evaluator. try { ffi::Function retfunc = this->GetTimeEvaluator( - args[0].cast>(), args[1].cast(), args[2].cast(), - args[3].cast(), args[4].cast(), args[5].cast(), args[6].cast(), - args[7].cast(), args[8].cast(), args[9].cast()); + args[0].cast>(), args[1].cast(), + args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast(), + args[6].cast(), args[7].cast(), args[8].cast(), args[9].cast()); ffi::Any rv; rv = retfunc; this->EncodeReturn(std::move(rv), [&](ffi::PackedArgs encoded_args) { @@ -252,7 +252,7 @@ class AsyncLocalSession : public LocalSession { std::optional async_wait_; // time evaluator - ffi::Function GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, + ffi::Function GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown) { @@ -261,10 +261,10 @@ class AsyncLocalSession : public LocalSession { dev.device_id = device_id; if (opt_mod.defined()) { - Module m = opt_mod.value(); - std::string tkey = m->type_key(); - return WrapWasmTimeEvaluator(m.GetFunction(name, false), dev, number, repeat, min_repeat_ms, - limit_zero_time_iterations, cooldown_interval_ms, + ffi::Module m = opt_mod.value(); + std::string tkey = m->kind(); + return WrapWasmTimeEvaluator(m->GetFunction(name, false).value(), dev, number, repeat, + min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown); } else { auto pf = tvm::ffi::Function::GetGlobal(name); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index ae8bea5524f6..6e2664a93bff 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -37,9 +37,7 @@ #include "src/runtime/cpu_device_api.cc" #include "src/runtime/device_api.cc" #include "src/runtime/file_utils.cc" -#include "src/runtime/library_module.cc" #include "src/runtime/logging.cc" -#include "src/runtime/module.cc" #include "src/runtime/ndarray.cc" #include "src/runtime/profiling.cc" #include "src/runtime/rpc/rpc_channel.cc" @@ -48,12 +46,14 @@ #include "src/runtime/rpc/rpc_local_session.cc" #include "src/runtime/rpc/rpc_module.cc" #include "src/runtime/rpc/rpc_session.cc" -#include "src/runtime/system_library.cc" #include "src/runtime/workspace_pool.cc" // relax setup #include "ffi/src/ffi/container.cc" #include "ffi/src/ffi/dtype.cc" #include "ffi/src/ffi/error.cc" +#include "ffi/src/ffi/extra/library_module.cc" +#include "ffi/src/ffi/extra/library_module_system_lib.cc" +#include "ffi/src/ffi/extra/module.cc" #include "ffi/src/ffi/function.cc" #include "ffi/src/ffi/ndarray.cc" #include "ffi/src/ffi/object.cc" diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 7e9f7c0f45ab..cd50bc067983 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -29,6 +29,7 @@ #define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY +#include #include #include #include @@ -156,7 +157,7 @@ WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return &inst; } -class WebGPUModuleNode final : public runtime::ModuleNode { +class WebGPUModuleNode final : public ffi::ModuleObj { public: explicit WebGPUModuleNode(std::unordered_map smap, std::unordered_map fmap) @@ -166,9 +167,9 @@ class WebGPUModuleNode final : public runtime::ModuleNode { create_shader_ = *fp; } - const char* type_key() const final { return "webgpu"; } + const char* kind() const final { return "webgpu"; } - ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + Optional GetFunction(const String& name) final { // special function if (name == "webgpu.get_fmap") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { @@ -206,15 +207,15 @@ class WebGPUModuleNode final : public runtime::ModuleNode { info.Save(&writer); return create_shader_(os.str(), it->second); } else { - return ffi::Function(nullptr); + return std::nullopt; } } - int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; }; - void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "Not implemented"; } + ffi::Bytes SaveToBytes() const final { LOG(FATAL) << "Not implemented"; } - String GetSource(const String& format) final { + String InspectSource(const String& format) const final { // can only return source code. return source_; } @@ -232,21 +233,22 @@ class WebGPUModuleNode final : public runtime::ModuleNode { ffi::TypedFunction create_shader_; }; -Module WebGPUModuleLoadBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); +ffi::Module WebGPUModuleLoadFromBytes(const ffi::Bytes& bytes) { + dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); + dmlc::Stream* stream = &ms; std::unordered_map smap; std::unordered_map fmap; stream->Read(&fmap); stream->Read(&smap); - return Module(make_object(smap, fmap)); + return ffi::Module(make_object(smap, fmap)); } // for now webgpu is hosted via a vulkan module. TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.module.loadbinary_webgpu", WebGPUModuleLoadBinary) + .def("ffi.Module.load_from_bytes.webgpu", WebGPUModuleLoadFromBytes) .def_packed("device_api.webgpu", [](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = WebGPUDeviceAPI::Global(); *rv = static_cast(ptr); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 75f4de855581..071b2eed68e4 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -189,8 +189,8 @@ class RuntimeContext implements Disposable { this.functionListGlobalNamesFunctor = getGlobalFunc( "ffi.FunctionListGlobalNamesFunctor" ); - this.moduleGetFunction = getGlobalFunc("runtime.ModuleGetFunction"); - this.moduleImport = getGlobalFunc("runtime.ModuleImport"); + this.moduleGetFunction = getGlobalFunc("ffi.ModuleGetFunction"); + this.moduleImport = getGlobalFunc("ffi.ModuleImportModule"); this.ndarrayEmpty = getGlobalFunc("runtime.TVMArrayAllocWithScope"); this.ndarrayCopyFromTo = getGlobalFunc("runtime.TVMArrayCopyFromTo"); this.ndarrayCopyFromJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyFromBytes"); @@ -199,7 +199,7 @@ class RuntimeContext implements Disposable { this.arrayGetSize = getGlobalFunc("ffi.ArraySize"); this.arrayMake = getGlobalFunc("ffi.Array"); this.arrayConcat = getGlobalFunc("tvmjs.runtime.ArrayConcat"); - this.getSysLib = getGlobalFunc("runtime.SystemLib"); + this.getSysLib = getGlobalFunc("ffi.SystemLib"); this.arrayCacheGet = getGlobalFunc("vm.builtin.ndarray_cache.get"); this.arrayCacheRemove = getGlobalFunc("vm.builtin.ndarray_cache.remove"); this.arrayCacheUpdate = getGlobalFunc("vm.builtin.ndarray_cache.update"); @@ -1900,7 +1900,7 @@ export class Instance implements Disposable { (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new TVMArray(handle, lib, ctx); }); - this.registerObjectConstructor("runtime.Module", + this.registerObjectConstructor("ffi.Module", (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new Module(handle, lib, ctx); });