diff --git a/docs/how_to/tutorials/optimize_llm.py b/docs/how_to/tutorials/optimize_llm.py index 49855910fc45..8cc674920da1 100644 --- a/docs/how_to/tutorials/optimize_llm.py +++ b/docs/how_to/tutorials/optimize_llm.py @@ -426,7 +426,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I with target: - ex = relax.build(mod, target, pipeline=relax.get_pipeline("opt_llm")) + ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline("opt_llm")) vm = relax.VirtualMachine(ex, dev) diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index 03e58392c269..8940408f8048 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -125,12 +125,12 @@ class ExecBuilderNode : public Object { /*! * \brief Raw access to underlying executable build in progress. */ - vm::Executable* exec() const; + vm::VMExecutable* exec() const; /*! * \brief Finalize the build, run formalize and get the final result. * \note This function should not be called during construction. */ - ObjectPtr Get(); + ObjectPtr Get(); /*! * \brief Create an ExecBuilder. * \return The ExecBuilder. @@ -165,7 +165,7 @@ class ExecBuilderNode : public Object { void Formalize(); /*! \brief The mutable internal executable. */ - ObjectPtr exec_; // mutable + ObjectPtr exec_; // mutable /*! \brief internal dedup map when creating index for a new constant */ std::unordered_map const_dedup_map_; }; diff --git a/include/tvm/runtime/relax_vm/executable.h b/include/tvm/runtime/relax_vm/executable.h index 845842f22ac4..43d69c7c5c28 100644 --- a/include/tvm/runtime/relax_vm/executable.h +++ b/include/tvm/runtime/relax_vm/executable.h @@ -81,12 +81,12 @@ struct VMFuncInfo { }; /*! - * \brief The executable emitted by the VM compiler. + * \brief The virtual machine executable emitted by the VM compiler. * * The executable contains information (e.g. data in different memory regions) * to run in a virtual machine. */ -class Executable : public runtime::ModuleNode { +class VMExecutable : public runtime::ModuleNode { public: /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; @@ -121,18 +121,18 @@ class Executable : public runtime::ModuleNode { */ String AsPython() const; /*! - * \brief Write the Executable to the binary stream in serialized form. + * \brief Write the VMExecutable to the binary stream in serialized form. * \param stream The binary stream to save the executable to. */ void SaveToBinary(dmlc::Stream* stream) final; /*! - * \brief Load Executable from the binary stream in serialized form. + * \brief Load VMExecutable from the binary stream in serialized form. * \param stream The binary stream that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ static Module LoadFromBinary(void* stream); /*! - * \brief Write the Executable to the provided path as a file containing its serialized content. + * \brief Write the VMExecutable to the provided path as a file containing its serialized content. * \param file_name The name of the file to write the serialized data to. * \param format The target format of the saved file. */ @@ -141,10 +141,10 @@ class Executable : public runtime::ModuleNode { Module VMLoadExecutable() const; /*! \brief Create a Relax virtual machine with profiler and load `this` as the executable. */ Module VMProfilerLoadExecutable() const; - /*! \brief Check if the Executable contains a specific function. */ + /*! \brief Check if the VMExecutable contains a specific function. */ bool HasFunction(const String& name) const; /*! - * \brief Load Executable from the file. + * \brief Load VMExecutable from the file. * \param file_name The path of the file that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ @@ -161,15 +161,15 @@ class Executable : public runtime::ModuleNode { /*! \brief The byte data of instruction. */ std::vector instr_data; - virtual ~Executable() {} + virtual ~VMExecutable() {} - TVM_MODULE_VTABLE_BEGIN("relax.Executable"); - TVM_MODULE_VTABLE_ENTRY("stats", &Executable::Stats); - TVM_MODULE_VTABLE_ENTRY("as_text", &Executable::AsText); - TVM_MODULE_VTABLE_ENTRY("as_python", &Executable::AsPython); - TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable); - TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable", &Executable::VMProfilerLoadExecutable); - TVM_MODULE_VTABLE_ENTRY("has_function", &Executable::HasFunction); + TVM_MODULE_VTABLE_BEGIN("relax.VMExecutable"); + TVM_MODULE_VTABLE_ENTRY("stats", &VMExecutable::Stats); + TVM_MODULE_VTABLE_ENTRY("as_text", &VMExecutable::AsText); + TVM_MODULE_VTABLE_ENTRY("as_python", &VMExecutable::AsPython); + TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &VMExecutable::VMLoadExecutable); + TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable", &VMExecutable::VMProfilerLoadExecutable); + TVM_MODULE_VTABLE_ENTRY("has_function", &VMExecutable::HasFunction); TVM_MODULE_VTABLE_END(); private: diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index da833d5d6c5f..6c10aaa88107 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -143,7 +143,7 @@ class VirtualMachine : public runtime::ModuleNode { * \brief Load the executable for the virtual machine. * \param exec The executable. */ - virtual void LoadExecutable(ObjectPtr exec) = 0; + virtual void LoadExecutable(ObjectPtr exec) = 0; /*! * \brief Get global function in the VM. * \param func_name The name of the function. diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index f4519f834d74..b853c4fa616e 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -55,7 +55,7 @@ from . import te # tvm.driver -from .driver import build +from .driver import build, compile # others from . import arith diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 2456aa244ee9..4f017abcf6ba 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -24,11 +24,12 @@ from typing import Union import tvm -from tvm import relax +import tvm.contrib.hexagon as hexagon from tvm import rpc as _rpc +from tvm import runtime from tvm.contrib import utils -import tvm.contrib.hexagon as hexagon -from .tools import export_module, HEXAGON_SIMULATOR_NAME + +from .tools import HEXAGON_SIMULATOR_NAME, export_module class Session: @@ -202,26 +203,26 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]): return self._rpc.get_function("tvm.hexagon.load_module")(str(remote_file_path)) def get_executor_from_factory( - self, module: Union[ExecutorFactoryModule, relax.Executable, str], hexagon_arch: str = "v68" + self, module: Union[runtime.executable, str], hexagon_arch: str = "v68" ): """Create a local GraphModule which consumes a remote libmod. Parameters ---------- - module : Union[relax.Executable] + module : Union[runtime.Executable, str] The module to upload to the remote session and load. hexagon_arch : str The hexagon arch to be used """ - if isinstance(module, (relax.Executable, str)): + if isinstance(module, (runtime.Executable, str)): return self._relax_vm_executable_executor(module, hexagon_arch=hexagon_arch) raise TypeError(f"Unsupported executor type: {type(module)}") - def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule]): + def _set_device_type(self, module: Union[str, pathlib.Path]): """Set session device type(hexagon, cpu) based on target in module. Parameters @@ -244,18 +245,19 @@ def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactory self._requires_cpu_device = False def _relax_vm_executable_executor( - self, vm_exec: Union[relax.Executable, str], hexagon_arch: str + self, executable: Union[runtime.Executable, str], hexagon_arch: str ): """Create a local TVM module which consumes a remote vm executable. - Paramters - --------- + Parameters + ---------- - vm_exec : relax.Executable - The Relax VM Executable to upload to the remote and load. This will typically be the - output of `relax.build` or the path to an already built and exported shared library + executable : runtime.Executable + The Executable to upload to the remote and load. This will typically be the + output of `tvm.compile` or the path to an already built and exported shared library hexagon_arch : str The hexagon arch to be used + Returns ------- TVMModule : @@ -263,21 +265,21 @@ def _relax_vm_executable_executor( """ assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" - if isinstance(vm_exec, relax.Executable): + if isinstance(executable, runtime.Executable): temp_dir = utils.tempdir() path_exec = temp_dir.relpath("exec.so") - vm_exec.mod.export_library( + executable.export_library( path_exec, fcompile=hexagon.create_aot_shared, hexagon_arch=hexagon_arch, ) path = self.upload(path_exec, "exec.so") - elif isinstance(vm_exec, str): - path_exec = vm_exec + elif isinstance(executable, str): + path_exec = executable else: - raise TypeError(f"Unsupported executor type: {type(vm_exec)}") + raise TypeError(f"Unsupported executor type: {type(executable)}") path = self.upload(path_exec, "exec.so") return self._rpc.get_function("tvm.hexagon.load_module")(str(path)) diff --git a/python/tvm/driver/__init__.py b/python/tvm/driver/__init__.py index b97375c3a364..6a4a2ba9f95f 100644 --- a/python/tvm/driver/__init__.py +++ b/python/tvm/driver/__init__.py @@ -14,5 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=redefined-builtin + """Namespace for driver APIs""" -from .build_module import build +from .build_module import build, compile diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 8d6a2a534389..ea923aae9afc 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -17,16 +17,95 @@ # pylint: disable=invalid-name """The build utils in python.""" -from typing import Union, Optional +import warnings +from typing import Callable, Optional, Union + import tvm -from tvm.tir import PrimFunc from tvm.ir.module import IRModule +from tvm.runtime import Executable from tvm.target import Target +from tvm.tir import PrimFunc def build( mod: Union[PrimFunc, IRModule], target: Optional[Union[str, Target]] = None, - pipeline: Optional[Union[str, tvm.transform.Pass]] = "default_tir", + pipeline: Optional[Union[str, tvm.transform.Pass]] = "default", ): + """ + Build a function with a signature, generating code for devices + coupled with target information. + + This function is deprecated. Use `tvm.compile` or `tvm.tir.build` instead. + + Parameters + ---------- + mod : Union[PrimFunc, IRModule] + The input to be built. + target : Optional[Union[str, Target]] + The target for compilation. + pipeline : Optional[Union[str, tvm.transform.Pass]] + The pipeline to use for compilation. + + Returns + ------- + tvm.runtime.Module + A module combining both host and device code. + """ + warnings.warn( + "build is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.", + DeprecationWarning, + ) return tvm.tir.build(mod, target, pipeline) + + +def _contains_relax(mod: Union[PrimFunc, IRModule]) -> bool: + if isinstance(mod, PrimFunc): + return False + if isinstance(mod, IRModule): + return any(isinstance(func, tvm.relax.Function) for _, func in mod.functions_items()) + + raise ValueError(f"Function input must be a PrimFunc or IRModule, but got {type(mod)}") + + +def compile( # pylint: disable=redefined-builtin + mod: Union[PrimFunc, IRModule], + target: Optional[Target] = None, + *, + relax_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] = "default", + tir_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] = "default", +) -> Executable: + """ + Compile an IRModule to a runtime executable. + + This function serves as a unified entry point for compiling both TIR and Relax modules. + It automatically detects the module type and routes to the appropriate build function. + + Parameters + ---------- + mod : Union[PrimFunc, IRModule] + The input module to be compiled. Can be a PrimFunc or an IRModule containing + TIR or Relax functions. + target : Optional[Target] + The target platform to compile for. + relax_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]] + The compilation pipeline to use for Relax functions. + Only used if the module contains Relax functions. + tir_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]] + The compilation pipeline to use for TIR functions. + + Returns + ------- + Executable + A runtime executable that can be loaded and executed. + """ + # TODO(tvm-team): combine two path into unified one + if _contains_relax(mod): + return tvm.relax.build( + mod, + target, + relax_pipeline=relax_pipeline, + tir_pipeline=tir_pipeline, + ) + lib = tvm.tir.build(mod, target, pipeline=tir_pipeline) + return Executable(lib) diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index c3c24aa631d6..9c293a1654ab 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -382,7 +382,7 @@ def compile_relax( target: Union[Target, str], params: Optional[Dict[str, NDArray]], enable_warning: bool = False, -) -> "relax.Executable": +) -> "relax.VMExecutable": """Compile a relax program with a MetaSchedule database. Parameters @@ -401,8 +401,8 @@ def compile_relax( Returns ------- - lib : relax.Executable - The built runtime module or vm Executable for the given relax workload. + lib : relax.VMExecutable + The built runtime module or vm VMExecutable for the given relax workload. """ # pylint: disable=import-outside-toplevel from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 7e7a3a1d9d9d..2da672b40561 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -17,21 +17,18 @@ """Customized builder and runner methods""" # pylint: disable=import-outside-toplevel -from typing import TYPE_CHECKING, Dict, Union, Callable +from typing import Dict, Union, Callable -if TYPE_CHECKING: - import numpy as np # type: ignore - from tvm.ir import IRModule - from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig - from tvm.runtime import Device, Module, NDArray - from tvm.target import Target +import numpy as np # type: ignore +from tvm.meta_schedule.runner import RPCConfig +from tvm.runtime import Module, Executable def run_module_via_rpc( - rpc_config: "RPCConfig", - lib: Union["Module", "Executable"], + rpc_config: RPCConfig, + lib: Union[Module, Executable], dev_type: str, - args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]], + args: Union[Dict[int, np.ndarray], Dict[str, np.ndarray]], continuation: Callable, ): """Execute a tvm.runtime.Module on RPC remote""" diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index cb97b221b281..08618a289d52 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -87,8 +87,8 @@ def f_calculator( Parameters ---------- - rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable] - The runtime module or vm executable. + rt_mod : tvm.runtime.Module + The runtime module. dev : tvm.device The device type to run workload. input_data : Dict[str, np.ndarray] diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 8494bd8e5838..da288942edf0 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -114,6 +114,6 @@ from . import utils # VM -from .vm_build import build, Executable +from .vm_build import build, VMExecutable from .binding_rewrite import DataflowBlockRewrite diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 140c497eb967..699860786072 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -21,7 +21,7 @@ import tvm from tvm.runtime import Object from tvm.runtime.container import ShapeTuple -from .vm_build import Executable +from .vm_build import VMExecutable from . import _ffi_api @@ -142,6 +142,6 @@ def emit_if(self, cond, false_offset): self._check_scope() _ffi_api.ExecBuilderEmitIf(self, cond, false_offset) # type: ignore - def get(self) -> Executable: + def get(self) -> VMExecutable: """return the executable""" - return Executable(_ffi_api.ExecBuilderGet(self)) # type: ignore + return VMExecutable(_ffi_api.ExecBuilderGet(self)) # type: ignore diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 21118b1cb8af..c25b6838447f 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -522,7 +522,7 @@ def _compile(spec, device, pipeline, debug): relax_build( mod, target=Target.from_device(device), - pipeline=pipeline, + relax_pipeline=pipeline, ), device, ) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index ffb38cdd9370..ddf88e8bd08f 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -194,6 +194,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # global map of pre-built pipelines PIPELINE_MAP = { "zero": zero_pipeline, + "default": default_build_pipeline, "default_build": default_build_pipeline, "static_shape_tuning": static_shape_tuning_pipeline, } diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py index 7cdb211bd32f..efdc9b13e3fe 100644 --- a/python/tvm/relax/transform/tuning_api/default_functions.py +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -185,7 +185,7 @@ def relax_build( if runner is None: def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): - relax_exec = tvm.relax.Executable(rt_mod) + relax_exec = tvm.relax.VMExecutable(rt_mod) relax_vm = tvm.relax.VirtualMachine(relax_exec, device=device) evaluator = relax_vm.module.time_evaluator( diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index ac4d9698a072..f44fcb9c226c 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -16,22 +16,22 @@ # under the License. # pylint: disable=invalid-name, no-member """VM build logics""" -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import tvm from tvm import relax -from tvm.contrib import utils as _utils from tvm.ir.module import IRModule from tvm.tir.function import PrimFunc +from tvm.runtime import Executable from . import _ffi_api -class Executable: - """The executable object emitted by the VM compiler or the ExecBuilder.""" +class VMExecutable(Executable): + """The virtual machine executable object emitted by the VM compiler or the ExecBuilder.""" def __init__(self, mod: tvm.runtime.Module): - self.mod = mod + super().__init__(mod) self._stats = self.mod["stats"] self._as_text = self.mod["as_text"] self._as_python = self.mod["as_python"] @@ -48,105 +48,6 @@ def as_python(self) -> str: """print the instructions as python program.""" return self._as_python() - def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module: - """Just-in-time compile and link the modules. - - The Executable returned by relax.build may not be directly - runnable as they may contain cuda source files and objects that - are yet to be compiled and linked. - This function helps to create a runtime.Module for these cases. - - Parameters - ---------- - fcompile : function(target, file_list, kwargs), optional - The compilation function to use create the final library object during - - kwargs : dict, optional - Additional arguments passed to fcompile - - Returns - ------- - rt_mod: tvm.runtime.Module - A runnable runtime module that can be passed to VirtualMachine. - - Examples - -------- - .. code:: python - - ex = relax.build(mod, target) - # build a runnable module using nvcc to link everything - rt_mod = ex.jit() - vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) - """ - - # TODO(tvm-team): Update runtime.Module interface - # to query these properties as bitmask. - def _not_runnable(x): - return x.type_key in ("c", "static_library") - - # pylint:disable = protected-access - not_runnable_list = self.mod._collect_from_import_tree(_not_runnable) - - # everything is runnable, directly return mod. - if len(not_runnable_list) == 0: - return self.mod - - # found source module, or other not runnable modules - # need to be export and load - # TODO(tvm-team): Support runnable but not exportable module. - # by collecting the link and allow export_library skip those modules. - workspace_dir = _utils.tempdir() - dso_path = workspace_dir.relpath("exported.so") - self.mod.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs) - return tvm.runtime.load_module(dso_path) - - def export_library( - self, - file_name: str, - fcompile: Optional[Union[str, callable]] = None, - workspace_dir: Optional[str] = None, - **kwargs, - ) -> Any: - """Export the executable to a library which can then be loaded back. - - Parameters - ---------- - file_name : str - The name of the shared library. - - fcompile : function(target, file_list, kwargs), optional - The compilation function to use create the final library object during - - workspace_dir : str, optional - The path of the directory used to create the intermediate - artifacts when exporting the module. - If this is not provided a temporary dir will be created. - - kwargs : dict, optional - Additional arguments passed to fcompile - - Returns - ------- - result of fcompile() : unknown, optional - If the compilation function returns an artifact it would be returned via - export_library, if any. - - Examples - -------- - .. code:: python - - ex = relax.build(mod, target) - # export the library - ex.export_library("exported.so") - - # load it back for future uses. - rt_mod = tvm.runtime.load_module("exported.so") - vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) - """ - return self.mod.export_library( - file_name=file_name, fcompile=fcompile, workspace_dir=workspace_dir, **kwargs - ) - def _vmcodegen( builder: "relax.ExecBuilder", @@ -202,6 +103,7 @@ def _vmlink( builder: "relax.ExecBuilder", target: Optional[Union[str, tvm.target.Target]], tir_mod: Optional[tvm.IRModule] = None, + tir_pipeline: Optional[Union[str, tvm.transform.Pass]] = "default", ext_libs: List[tvm.runtime.Module] = None, params: Optional[Dict[str, list]] = None, *, @@ -249,7 +151,7 @@ def _vmlink( tir_ext_libs = [] if tir_mod is not None and len(tir_mod.get_global_vars()) > 0: tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib) - lib = tvm.build(tir_mod, target=target) + lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline) for ext_mod in ext_libs: if ext_mod.is_device_module: tir_ext_libs.append(ext_mod) @@ -260,14 +162,16 @@ def _vmlink( lib.import_module(mod) elif len(tir_ext_libs) > 0: print("Warning: No TIR module is found, but external modules for TIR are provided.") - return Executable(_ffi_api.VMLink(builder, target, lib, relax_ext_libs, params)) # type: ignore + lib = _ffi_api.VMLink(builder, target, lib, relax_ext_libs, params) # type: ignore + return VMExecutable(lib) def build( mod: tvm.IRModule, target: Optional[Union[str, tvm.target.Target]] = None, params: Optional[Dict[str, list]] = None, - pipeline: Union[None, str, tvm.transform.Pass] = "default_build", + relax_pipeline: Union[None, str, tvm.transform.Pass] = "default", + tir_pipeline: Union[None, str, tvm.transform.Pass] = "default", exec_mode: str = "bytecode", *, system_lib: Optional[bool] = None, @@ -336,14 +240,14 @@ def _extract_attrs(mod: tvm.IRModule): if not params: params = {} - if pipeline is not None: - if isinstance(pipeline, str): - pipeline = relax.get_pipeline(pipeline) + if relax_pipeline is not None: + if isinstance(relax_pipeline, str): + relax_pipeline = relax.get_pipeline(relax_pipeline) if target is None: - mod = pipeline(mod) + mod = relax_pipeline(mod) else: with target: - mod = pipeline(mod) + mod = relax_pipeline(mod) ext_libs, constants = _extract_attrs(mod) params.update(dict(constants)) @@ -353,6 +257,7 @@ def _extract_attrs(mod: tvm.IRModule): builder=builder, target=target, tir_mod=_filter_tir(mod), + tir_pipeline=tir_pipeline, ext_libs=ext_libs, params=params, system_lib=system_lib, diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index b748f84beca4..c7e407f0285f 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -23,13 +23,14 @@ from .script_printer import Scriptable from .object_generic import ObjectGeneric, ObjectTypes from .ndarray import NDArray, DataType, DataTypeCode, Device -from .module import Module, num_threads +from .module import Module from .profiling import Report +from .executable import Executable # function exposures from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev -from .module import load_module, enabled, system_lib, load_static_library +from .module import load_module, enabled, system_lib, load_static_library, num_threads from .container import String, ShapeTuple # , BoxBool from .object_generic import convert_to_object, convert, const from .params import ( diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py new file mode 100644 index 000000000000..cf4e5b05874b --- /dev/null +++ b/python/tvm/runtime/executable.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member + +"""Executable object for TVM Runtime""" +from typing import Any, Callable, Dict, List, Optional, Union + +import tvm +from tvm.contrib import utils as _utils +from . import PackedFunc, Module + + +class Executable: + """The executable object generated by `tvm.compile`.""" + + def __init__(self, mod: Module): + """Initialize the Executable object.""" + self.mod: Module = mod + self._jitted_mod: Optional[Module] = None + self.entry_name = mod.entry_name + + def __getitem__(self, name: str) -> PackedFunc: + """Get the PackedFunc from the jitted module.""" + return self.jit().get_function(name, query_imports=True) + + def __call__(self, *args, **kwargs) -> Any: + """Call the executable.""" + return self.jit().get_function(self.entry_name, query_imports=True)(*args, **kwargs) + + def jit( + self, + *, + fcompile: Optional[Callable[[str, List[str], Dict[str, Any]], None]] = None, + addons: Optional[List[str]] = None, + force_recompile: bool = False, + **kwargs, + ) -> Module: + """Just-in-time compile and link the modules. + + The Executable returned by tvm.compile may not be directly + runnable as they may contain cuda source files and objects that + are yet to be compiled and linked. + This function helps to create a runtime.Module for these cases. + + Parameters + ---------- + fcompile : function(target, file_list, kwargs), optional + The compilation function to use create the final library object during + + addons : list of str, optional + Additional object files to link against. + + force_recompile : bool, optional + If True, force a recompile of the module. + + kwargs : dict, optional + Additional arguments passed to fcompile + + Returns + ------- + rt_mod: tvm.runtime.Module + A runnable runtime module that can be passed to VirtualMachine. + + Examples + -------- + .. code:: python + + ex = tvm.compile(mod, target) + rt_mod = ex.jit() + + """ + + # If the module is already jitted and we don't want to force a recompile, + # return the cached module + if self._jitted_mod is not None and not force_recompile: + return self._jitted_mod + + # TODO(tvm-team): Update runtime.Module interface + # to query these properties as bitmask. + def _not_runnable(x): + return x.type_key in ("c", "static_library") + + # pylint:disable = protected-access + not_runnable_list = self.mod._collect_from_import_tree(_not_runnable) + + # everything is runnable, directly return mod. + if len(not_runnable_list) == 0: + return self.mod + + # found source module, or other not runnable modules need to be export and load + # TODO(tvm-team): Support runnable but not exportable module. + # by collecting the link and allow export_library skip those modules. + workspace_dir = _utils.tempdir() + dso_path = workspace_dir.relpath("exported.so") + self.mod.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs) + self._jitted_mod = tvm.runtime.load_module(dso_path) + return self._jitted_mod + + def export_library( + self, + file_name: str, + *, + fcompile: Optional[Union[str, Callable[[str, List[str], Dict[str, Any]], None]]] = None, + addons: Optional[List[str]] = None, + workspace_dir: Optional[str] = None, + **kwargs, + ) -> Any: + """Export the executable to a library which can then be loaded back. + + Parameters + ---------- + file_name : str + The name of the shared library. + + fcompile : function(target, file_list, kwargs), optional + The compilation function to use create the final library object during + + addons : list of str, optional + Additional object files to link against. + + workspace_dir : str, optional + The path of the directory used to create the intermediate + artifacts when exporting the module. + If this is not provided a temporary dir will be created. + + kwargs : dict, optional + Additional arguments passed to fcompile + + Returns + ------- + result of fcompile() : unknown, optional + If the compilation function returns an artifact it would be returned via + export_library, if any. + + Examples + -------- + .. code:: python + + ex = tvm.compile(mod, target) + # export the library + ex.export_library("exported.so") + + # load it back for future uses. + rt_mod = tvm.runtime.load_module("exported.so") + """ + return self.mod.export_library( + file_name=file_name, + fcompile=fcompile, + addons=addons, + workspace_dir=workspace_dir, + **kwargs, + ) diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py index 5b8bbe6d330e..1dc4aef6f4b4 100644 --- a/python/tvm/runtime/relax_vm.py +++ b/python/tvm/runtime/relax_vm.py @@ -44,7 +44,7 @@ class VirtualMachine(object): def __init__( self, - rt_mod: Union[tvm.runtime.Module, "tvm.relax.Executable"], + rt_mod: Union[tvm.runtime.Module, tvm.runtime.Executable], device: Union[Device, List[Device]], memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, profile: bool = False, @@ -54,7 +54,7 @@ def __init__( Parameters ---------- - rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable] + rt_mod: Union[tvm.runtime.Module, tvm.runtime.Executable] Runtime module exported by the result of build. device : Union[Device, List[Device]] @@ -72,13 +72,7 @@ def __init__( Whether or not to enable profiling. """ if not isinstance(rt_mod, tvm.runtime.Module): - # important to keep this import local - # as the relax_vm needs to be isolated from compiler - # if we do not use the jit feature - # pylint:disable=import-outside-toplevel - from tvm import relax - - if isinstance(rt_mod, relax.Executable): + if isinstance(rt_mod, tvm.runtime.Executable): rt_mod = rt_mod.jit() else: raise ValueError("Expect the rt_mod to be an runtime.Module") @@ -101,10 +95,7 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) devs = dev if not isinstance(dev, (list, tuple)): if not isinstance(dev, tvm.runtime.Device): - raise TypeError( - "dev is expected to be Device or \ - List[Device]" - ) + raise TypeError("dev is expected to be Device or List[Device]") devs = [dev] # CPU is required for executing shape functions diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index ee6280b74091..14bc189b9f6d 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -105,7 +105,7 @@ def tir_to_runtime( def build( mod: Union[PrimFunc, IRModule], target: Optional[Union[str, Target]] = None, - pipeline: Union[None, str, tvm.transform.Pass] = "default_tir", + pipeline: Union[None, str, tvm.transform.Pass] = "default", ): """Build a function with a signature, generating code for devices coupled with target information. diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index c8019c922981..ae78b0573822 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -151,11 +151,11 @@ def finalize_device_passes(): # pylint: disable=unused-argument # global map of pre-built pipelines PIPELINE_MAP = { - "default_tir": default_tir_pipeline, + "default": default_tir_pipeline, } -def get_tir_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: +def get_tir_pipeline(name: str = "default", **kwargs) -> tvm.transform.Pass: """Get pre-build pipeline by name Parameters diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 18da88be805d..bd56b0fd7bd7 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -170,7 +170,7 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg); // obtain the temp exec in progress. - vm::Executable* exec = builder_->exec(); + vm::VMExecutable* exec = builder_->exec(); // Record the offset of If instruction size_t if_offset = exec->instr_offset.size(); @@ -436,7 +436,7 @@ TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); * module(s). * \return The created module. */ -void LinkModules(ObjectPtr exec, const Map& params, +void LinkModules(ObjectPtr exec, const Map& params, const tvm::runtime::Module& lib, const Array& ext_libs) { // query if we need const loader for ext_modules // Wrap all submodules in the initialization wrapper. @@ -482,7 +482,7 @@ void LinkModules(ObjectPtr exec, const Map */ Module VMLink(ExecBuilder builder, Target target, Optional lib, Array ext_libs, Map params) { - ObjectPtr executable = builder->Get(); + ObjectPtr executable = builder->Get(); if (!lib.defined()) { lib = codegen::CSourceModuleCreate(";", "", Array{}); } diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index 0e6f59b4604e..36bfa7e2421e 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -33,13 +33,13 @@ TVM_REGISTER_NODE_TYPE(ExecBuilderNode); ExecBuilder ExecBuilderNode::Create() { ExecBuilder ret(make_object()); - ret->exec_ = make_object(); + ret->exec_ = make_object(); return ret; } -Executable* ExecBuilderNode::exec() const { return exec_.get(); } +VMExecutable* ExecBuilderNode::exec() const { return exec_.get(); } -ObjectPtr ExecBuilderNode::Get() { +ObjectPtr ExecBuilderNode::Get() { this->Formalize(); this->CheckExecutable(); return exec_; @@ -270,7 +270,7 @@ void ExecBuilderNode::CheckExecutable() { void ExecBuilderNode::Formalize() { // a pass to formalize user-specified register indexes in the order of use - // and decide the number of registers to allocate for each VMFunction in the Executable + // and decide the number of registers to allocate for each VMFunction in the VMExecutable for (auto it = this->exec_->func_table.begin(); it != this->exec_->func_table.end(); ++it) { if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue; if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) continue; @@ -395,7 +395,7 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, }); TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { - ObjectPtr p_exec = builder->Get(); + ObjectPtr p_exec = builder->Get(); return runtime::Module(p_exec); }); diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index f45786c3da32..bf122cc04b6f 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -52,7 +52,7 @@ enum ConstantType : int { ICHECK(val) << "Invalid VM file format in the " << section << " section." \ << "\n"; -std::string Executable::Stats() const { +std::string VMExecutable::Stats() const { std::ostringstream oss; oss << "Relax VM executable statistics:" << std::endl; @@ -116,14 +116,14 @@ std::string Executable::Stats() const { return oss.str(); } -void Executable::SetInstructionData(Index i, Index j, ExecWord val) { +void VMExecutable::SetInstructionData(Index i, Index j, ExecWord val) { ICHECK_LT(i, instr_offset.size()); Index instr_idx = instr_offset[i]; ICHECK_LT(instr_idx + j, instr_data.size()); instr_data[instr_idx + j] = val; } -Instruction Executable::GetInstruction(Index i) const { +Instruction VMExecutable::GetInstruction(Index i) const { Index offset = instr_offset[i]; Opcode op = static_cast(instr_data[offset]); switch (op) { @@ -173,7 +173,7 @@ void LoadHeader(dmlc::Stream* strm) { STREAM_CHECK(version == RELAX_VM_VERSION, "version"); } -void Executable::SaveToBinary(dmlc::Stream* stream) { +void VMExecutable::SaveToBinary(dmlc::Stream* stream) { std::string code; // Initialize the stream object. dmlc::MemoryStringStream strm(&code); @@ -193,20 +193,20 @@ void Executable::SaveToBinary(dmlc::Stream* stream) { stream->Write(code); } -void Executable::SaveToFile(const String& file_name, const String& format) { +void VMExecutable::SaveToFile(const String& file_name, const String& format) { std::string data; dmlc::MemoryStringStream writer(&data); dmlc::SeekStream* strm = &writer; - Executable::SaveToBinary(strm); + VMExecutable::SaveToBinary(strm); runtime::SaveBinaryToFile(file_name, data); } -Module Executable::LoadFromBinary(void* stream) { +Module VMExecutable::LoadFromBinary(void* stream) { std::string code; static_cast(stream)->Read(&code); dmlc::MemoryStringStream strm(&code); - ObjectPtr exec = make_object(); + ObjectPtr exec = make_object(); // Load header. LoadHeader(&strm); @@ -223,19 +223,19 @@ Module Executable::LoadFromBinary(void* stream) { return Module(exec); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.Executable") - .set_body_typed(Executable::LoadFromBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.VMExecutable") + .set_body_typed(VMExecutable::LoadFromBinary); -Module Executable::LoadFromFile(const String& file_name) { +Module VMExecutable::LoadFromFile(const String& file_name) { std::string data; runtime::LoadBinaryFromFile(file_name, &data); dmlc::MemoryStringStream reader(&data); dmlc::Stream* strm = &reader; - return Executable::LoadFromBinary(reinterpret_cast(strm)); + return VMExecutable::LoadFromBinary(reinterpret_cast(strm)); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.Executable") - .set_body_typed(Executable::LoadFromFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.VMExecutable") + .set_body_typed(VMExecutable::LoadFromFile); void VMFuncInfo::Save(dmlc::Stream* strm) const { int32_t temp_kind = static_cast(kind); @@ -261,9 +261,9 @@ bool VMFuncInfo::Load(dmlc::Stream* strm) { return true; } -void Executable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(func_table); } +void VMExecutable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(func_table); } -void Executable::SaveConstantSection(dmlc::Stream* strm) { +void VMExecutable::SaveConstantSection(dmlc::Stream* strm) { strm->Write(static_cast(this->constants.size())); for (const auto& it : this->constants) { if (it.IsObjectRef()) { @@ -301,12 +301,12 @@ void Executable::SaveConstantSection(dmlc::Stream* strm) { } } -void Executable::SaveCodeSection(dmlc::Stream* strm) { +void VMExecutable::SaveCodeSection(dmlc::Stream* strm) { strm->Write(instr_offset); strm->Write(instr_data); } -void Executable::LoadGlobalSection(dmlc::Stream* strm) { +void VMExecutable::LoadGlobalSection(dmlc::Stream* strm) { STREAM_CHECK(strm->Read(&func_table), "Global Section"); // setup func map for (size_t i = 0; i < func_table.size(); ++i) { @@ -314,7 +314,7 @@ void Executable::LoadGlobalSection(dmlc::Stream* strm) { } } -void Executable::LoadConstantSection(dmlc::Stream* strm) { +void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { uint64_t sz; // Load the number of constants. STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); @@ -375,7 +375,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { } } -void Executable::LoadCodeSection(dmlc::Stream* strm) { +void VMExecutable::LoadCodeSection(dmlc::Stream* strm) { STREAM_CHECK(strm->Read(&(this->instr_offset)), "instr offset"); STREAM_CHECK(strm->Read(&(this->instr_data)), "instr data"); } @@ -404,21 +404,21 @@ std::string RegNameToStr(RegName reg) { return "%" + std::to_string(reg); } -Module Executable::VMLoadExecutable() const { +Module VMExecutable::VMLoadExecutable() const { ObjectPtr vm = VirtualMachine::Create(); - vm->LoadExecutable(GetObjectPtr(const_cast(this))); + vm->LoadExecutable(GetObjectPtr(const_cast(this))); return Module(vm); } -Module Executable::VMProfilerLoadExecutable() const { +Module VMExecutable::VMProfilerLoadExecutable() const { ObjectPtr vm = VirtualMachine::CreateProfiler(); - vm->LoadExecutable(GetObjectPtr(const_cast(this))); + vm->LoadExecutable(GetObjectPtr(const_cast(this))); return Module(vm); } -bool Executable::HasFunction(const String& name) const { return func_map.count(name); } +bool VMExecutable::HasFunction(const String& name) const { return func_map.count(name); } -String Executable::AsText() const { +String VMExecutable::AsText() const { auto get_func_name = [&](Index index) -> std::string { if (static_cast(index) < func_table.size()) { return func_table[index].name; @@ -495,7 +495,7 @@ String Executable::AsText() const { return String(os.str()); } -String Executable::AsPython() const { +String VMExecutable::AsPython() const { auto get_func_name = [&](Index index) -> std::string { if (static_cast(index) < func_table.size()) { return "\"" + func_table[index].name + "\""; @@ -573,7 +573,7 @@ String Executable::AsPython() const { return String(os.str()); } -TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(Executable::LoadFromFile); +TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(VMExecutable::LoadFromFile); } // namespace relax_vm } // namespace runtime diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index ebb5afb1f4ae..05d5570c0c6f 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -198,7 +198,7 @@ class VirtualMachineImpl : public VirtualMachine { //--------------------------------------------------- // Public facing functions overloading //--------------------------------------------------- - void LoadExecutable(ObjectPtr exec) final; + void LoadExecutable(ObjectPtr exec) final; void Init(const std::vector& devices, const std::vector& alloc_types) final; VMClosure GetClosure(const String& func_name) final { @@ -425,7 +425,7 @@ class VirtualMachineImpl : public VirtualMachine { // Internal states for execution. //-------------------------------------------------------- /*! \brief The loaded executable. */ - ObjectPtr exec_; + ObjectPtr exec_; /*! \brief The global constant pool */ std::vector const_pool_; /*! @@ -462,7 +462,7 @@ class VirtualMachineImpl : public VirtualMachine { PackedFunc instrument_ = nullptr; }; -void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { +void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { this->exec_ = exec; this->imports_ = exec_->imports(); } diff --git a/tests/python/driver/test_compile.py b/tests/python/driver/test_compile.py new file mode 100644 index 000000000000..e66bd7c2906d --- /dev/null +++ b/tests/python/driver/test_compile.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +import tvm.testing +from tvm import relax, te +from tvm.runtime import Executable +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_compile_tir(): + """Test tvm.compile with TIR input.""" + n = te.var("n") + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + func = te.create_prim_func([A, B, C]) + + # Test compile with PrimFunc + exec_prim = tvm.compile(func) + assert isinstance(exec_prim, Executable) + + # Test compile with IRModule containing PrimFunc + mod = tvm.IRModule.from_expr(func) + exec_mod = tvm.compile(mod) + assert isinstance(exec_mod, Executable) + + # Verify the compiled module works + dev = tvm.cpu(0) + a_np = np.random.uniform(size=10).astype(np.float32) + b_np = np.random.uniform(size=10).astype(np.float32) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(10, dtype=np.float32), dev) + + exec_prim(a, b, c) + np.testing.assert_allclose(c.numpy(), a_np + b_np) + exec_mod(a, b, c) + np.testing.assert_allclose(c.numpy(), a_np + b_np) + + +def test_compile_relax(): + """Test tvm.compile with Relax input.""" + # Define a simple Relax program + @I.ir_module + class MyModule: + @R.function + def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")) -> R.Tensor: + z = R.add(x, y) + return z + + # Test compile with Relax module + target = tvm.target.Target("llvm") + exec_relax = tvm.compile(MyModule, target) + assert isinstance(exec_relax, Executable) + + # Verify the compiled module works + dev = tvm.cpu(0) + x_np = np.random.uniform(size=(3, 4)).astype(np.float32) + y_np = np.random.uniform(size=(3, 4)).astype(np.float32) + x = tvm.nd.array(x_np, dev) + y = tvm.nd.array(y_np, dev) + + vm = relax.VirtualMachine(exec_relax, dev) + z = vm["main"](x, y) + np.testing.assert_allclose(z.numpy(), x_np + y_np) + + +@tvm.testing.skip_if_32bit(reason="skipping test for i386.") +def test_compile_mixed_module(): + @tvm.script.ir_module + class MyModule: + @T.prim_func + def add_one(X: T.Buffer((4,), "float32"), Y: T.Buffer((4,), "float32")): + for i in range(4): + Y[i] = X[i] + 1 + + @R.function + def main(x: R.Tensor((4,), "float32")): + cls = MyModule + with R.dataflow(): + y = R.call_tir(cls.add_one, [x], R.Tensor((4,), "float32")) + return y + + # Test with custom pipeline + target = tvm.target.Target("c") + ex = tvm.compile(MyModule, target) + assert isinstance(ex, Executable) + + dev = tvm.cpu(0) + x = tvm.nd.array(np.array([1, 2, 3, 4], dtype=np.float32), dev) + y = tvm.nd.array(np.zeros(4, dtype=np.float32), dev) + # For tir function, we can directly call the function + ex["add_one"](x, y) + np.testing.assert_allclose(y.numpy(), x.numpy() + 1) + # For relax function, we need to use the vm to call the function + vm = relax.VirtualMachine(ex, dev) + z = vm["main"](x) + np.testing.assert_allclose(z.numpy(), x.numpy() + 1) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/runtime/test_executable.py b/tests/python/runtime/test_executable.py new file mode 100644 index 000000000000..571ce7adb2bf --- /dev/null +++ b/tests/python/runtime/test_executable.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the Executable class.""" + +import os +import tempfile + +import numpy as np + +import tvm +import tvm.testing +from tvm.runtime import Executable +from tvm.script import tir as T + + +@tvm.script.ir_module +class MyModule: + @T.prim_func + def add( + A: T.Buffer((10,), "float32"), + B: T.Buffer((10,), "float32"), + C: T.Buffer((10,), "float32"), + ): + for i in range(10): + C[i] = A[i] + B[i] + + +def test_executable_init(): + """Test initialization of Executable class.""" + lib = tvm.tir.build(MyModule, target="llvm") + executable = Executable(lib) + + assert executable.mod is lib + assert executable._jitted_mod is None + + +def test_executable_getitem(): + """Test __getitem__ method of Executable class.""" + lib = tvm.tir.build(MyModule, target="llvm") + executable = Executable(lib) + + # Jit the module first + executable.jit() + + # Test __getitem__ + add_func = executable["add"] + + # Verify the function works + a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) + b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) + c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + + add_func(a, b, c) + + # Check results + tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10, dtype="float32")) + + +def test_executable_jit_already_jitted(): + """Test jit method when module is already jitted.""" + lib = tvm.tir.build(MyModule, target="llvm") + executable = Executable(lib) + + # First jit call + jitted_mod1 = executable.jit() + + # Second jit call should return the cached jitted module + jitted_mod2 = executable.jit() + assert jitted_mod2 is jitted_mod1 + + # Test with force_recompile + jitted_mod3 = executable.jit(force_recompile=True) + # The module might be different after force recompilation + + # Verify both modules work correctly + a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) + b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) + c1 = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + c2 = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + + jitted_mod1["add"](a, b, c1) + jitted_mod3["add"](a, b, c2) + + tvm.testing.assert_allclose(c1.numpy(), np.array([3.0] * 10, dtype="float32")) + tvm.testing.assert_allclose(c2.numpy(), np.array([3.0] * 10, dtype="float32")) + + +def test_executable_export_library(): + """Test export_library method.""" + lib = tvm.tir.build(MyModule, target="llvm") + executable = Executable(lib) + + # Create a temporary directory for the library + temp_dir = tempfile.mkdtemp() + try: + lib_path = os.path.join(temp_dir, "test_lib.so") + executable.export_library(lib_path) + + # Verify the library was created + assert os.path.exists(lib_path) + + # Load the library back + loaded_mod = tvm.runtime.load_module(lib_path) + assert loaded_mod is not None + + # Test the loaded module + a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) + b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) + c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + + loaded_mod["add"](a, b, c) + + # Check results + tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10, dtype="float32")) + finally: + # Clean up + if os.path.exists(temp_dir): + import shutil + + shutil.rmtree(temp_dir) + + +def test_executable_export_library_with_workspace(): + """Test export_library method with workspace_dir.""" + lib = tvm.tir.build(MyModule, target="llvm") + executable = Executable(lib) + + # Create temporary directories + temp_dir = tempfile.mkdtemp() + workspace_dir = tempfile.mkdtemp() + + try: + lib_path = os.path.join(temp_dir, "test_lib.so") + executable.export_library(lib_path, workspace_dir=workspace_dir) + + # Verify the library was created + assert os.path.exists(lib_path) + + # Load the library back + loaded_mod = tvm.runtime.load_module(lib_path) + assert loaded_mod is not None + + # Test the loaded module + a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) + b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) + c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + + loaded_mod["add"](a, b, c) + + # Check results + tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10, dtype="float32")) + finally: + # Clean up + for directory in [temp_dir, workspace_dir]: + if os.path.exists(directory): + import shutil + + shutil.rmtree(directory) + + +def test_executable_integration(): + """Integration test for Executable with a simple TVM module.""" + # Create target and build + target = tvm.target.Target("llvm") + lib = tvm.tir.build(MyModule, target=target) + + # Create an executable + executable = Executable(lib) + + # Test jit + jitted_mod = executable.jit() + assert jitted_mod is not None + + # Test __getitem__ + add_func = executable["add"] + assert add_func is not None + + # Test the function works + a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) + b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) + c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + + add_func(a, b, c) + + # Check results + tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10, dtype="float32")) + + # Test export_library + temp_dir = tempfile.mkdtemp() + try: + lib_path = os.path.join(temp_dir, "test_lib.so") + executable.export_library(lib_path) + + # Verify the library was created + assert os.path.exists(lib_path) + + # Load the library back + loaded_mod = tvm.runtime.load_module(lib_path) + assert loaded_mod is not None + + # Test the loaded module + loaded_add = loaded_mod["add"] + c_loaded = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + loaded_add(a, b, c_loaded) + + # Check results + tvm.testing.assert_allclose(c_loaded.numpy(), np.array([3.0] * 10, dtype="float32")) + + finally: + # Clean up + if os.path.exists(temp_dir): + import shutil + + shutil.rmtree(temp_dir) + + +def test_executable_jit_force_recompile(): + """Test jit method with force_recompile=True.""" + # Create target and build + target = tvm.target.Target("c") + lib = tvm.tir.build(MyModule, target=target) + + # Create an executable + executable = Executable(lib) + + # First jit call + jitted_mod1 = executable.jit() + + # Second jit call without force_recompile should return the same module + jitted_mod2 = executable.jit() + assert jitted_mod1 is jitted_mod2 + + # Third jit call with force_recompile should return a new module + jitted_mod3 = executable.jit(force_recompile=True) + assert jitted_mod3 is not jitted_mod1 + + # Test the function works + a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) + b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) + c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + + jitted_mod3["add"](a, b, c) + + # Check results + tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10, dtype="float32")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index b074af314786..ddc775933cdc 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -37,21 +37,20 @@ run_pytest ${TVM_UNITTEST_TESTSUITE_NAME}-platform-minimal-test tests/python/all # Then run all unittests on both ctypes and cython. TEST_FILES=( "arith" + "ci" "codegen" + "driver" "ir" "meta_schedule" "runtime" + "target" "te" "testing" "tir-analysis" "tir-base" "tir-schedule" "tir-transform" - "tir-usmp" "tvmscript" - "usmp" - "ci" - "target" ) for TEST_FILE in ${TEST_FILES[@]}; do