Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/how_to/tutorials/optimize_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I


with target:
ex = relax.build(mod, target, pipeline=relax.get_pipeline("opt_llm"))
ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline("opt_llm"))
vm = relax.VirtualMachine(ex, dev)


Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relax/exec_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ class ExecBuilderNode : public Object {
/*!
* \brief Raw access to underlying executable build in progress.
*/
vm::Executable* exec() const;
vm::VMExecutable* exec() const;
/*!
* \brief Finalize the build, run formalize and get the final result.
* \note This function should not be called during construction.
*/
ObjectPtr<vm::Executable> Get();
ObjectPtr<vm::VMExecutable> Get();
/*!
* \brief Create an ExecBuilder.
* \return The ExecBuilder.
Expand Down Expand Up @@ -165,7 +165,7 @@ class ExecBuilderNode : public Object {
void Formalize();

/*! \brief The mutable internal executable. */
ObjectPtr<vm::Executable> exec_; // mutable
ObjectPtr<vm::VMExecutable> exec_; // mutable
/*! \brief internal dedup map when creating index for a new constant */
std::unordered_map<ObjectRef, vm::Index, StructuralHash, StructuralEqual> const_dedup_map_;
};
Expand Down
30 changes: 15 additions & 15 deletions include/tvm/runtime/relax_vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ struct VMFuncInfo {
};

/*!
* \brief The executable emitted by the VM compiler.
* \brief The virtual machine executable emitted by the VM compiler.
*
* The executable contains information (e.g. data in different memory regions)
* to run in a virtual machine.
*/
class Executable : public runtime::ModuleNode {
class VMExecutable : public runtime::ModuleNode {
public:
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; };
Expand Down Expand Up @@ -121,18 +121,18 @@ class Executable : public runtime::ModuleNode {
*/
String AsPython() const;
/*!
* \brief Write the Executable to the binary stream in serialized form.
* \brief Write the VMExecutable to the binary stream in serialized form.
* \param stream The binary stream to save the executable to.
*/
void SaveToBinary(dmlc::Stream* stream) final;
/*!
* \brief Load Executable from the binary stream in serialized form.
* \brief Load VMExecutable from the binary stream in serialized form.
* \param stream The binary stream that load the executable from.
* \return The loaded executable, in the form of a `runtime::Module`.
*/
static Module LoadFromBinary(void* stream);
/*!
* \brief Write the Executable to the provided path as a file containing its serialized content.
* \brief Write the VMExecutable to the provided path as a file containing its serialized content.
* \param file_name The name of the file to write the serialized data to.
* \param format The target format of the saved file.
*/
Expand All @@ -141,10 +141,10 @@ class Executable : public runtime::ModuleNode {
Module VMLoadExecutable() const;
/*! \brief Create a Relax virtual machine with profiler and load `this` as the executable. */
Module VMProfilerLoadExecutable() const;
/*! \brief Check if the Executable contains a specific function. */
/*! \brief Check if the VMExecutable contains a specific function. */
bool HasFunction(const String& name) const;
/*!
* \brief Load Executable from the file.
* \brief Load VMExecutable from the file.
* \param file_name The path of the file that load the executable from.
* \return The loaded executable, in the form of a `runtime::Module`.
*/
Expand All @@ -161,15 +161,15 @@ class Executable : public runtime::ModuleNode {
/*! \brief The byte data of instruction. */
std::vector<ExecWord> instr_data;

virtual ~Executable() {}
virtual ~VMExecutable() {}

TVM_MODULE_VTABLE_BEGIN("relax.Executable");
TVM_MODULE_VTABLE_ENTRY("stats", &Executable::Stats);
TVM_MODULE_VTABLE_ENTRY("as_text", &Executable::AsText);
TVM_MODULE_VTABLE_ENTRY("as_python", &Executable::AsPython);
TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable", &Executable::VMProfilerLoadExecutable);
TVM_MODULE_VTABLE_ENTRY("has_function", &Executable::HasFunction);
TVM_MODULE_VTABLE_BEGIN("relax.VMExecutable");
TVM_MODULE_VTABLE_ENTRY("stats", &VMExecutable::Stats);
TVM_MODULE_VTABLE_ENTRY("as_text", &VMExecutable::AsText);
TVM_MODULE_VTABLE_ENTRY("as_python", &VMExecutable::AsPython);
TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &VMExecutable::VMLoadExecutable);
TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable", &VMExecutable::VMProfilerLoadExecutable);
TVM_MODULE_VTABLE_ENTRY("has_function", &VMExecutable::HasFunction);
TVM_MODULE_VTABLE_END();

private:
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/relax_vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class VirtualMachine : public runtime::ModuleNode {
* \brief Load the executable for the virtual machine.
* \param exec The executable.
*/
virtual void LoadExecutable(ObjectPtr<Executable> exec) = 0;
virtual void LoadExecutable(ObjectPtr<VMExecutable> exec) = 0;
/*!
* \brief Get global function in the VM.
* \param func_name The name of the function.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from . import te

# tvm.driver
from .driver import build
from .driver import build, compile

# others
from . import arith
Expand Down
38 changes: 20 additions & 18 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from typing import Union

import tvm
from tvm import relax
import tvm.contrib.hexagon as hexagon
from tvm import rpc as _rpc
from tvm import runtime
from tvm.contrib import utils
import tvm.contrib.hexagon as hexagon
from .tools import export_module, HEXAGON_SIMULATOR_NAME

from .tools import HEXAGON_SIMULATOR_NAME, export_module


class Session:
Expand Down Expand Up @@ -202,26 +203,26 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
return self._rpc.get_function("tvm.hexagon.load_module")(str(remote_file_path))

def get_executor_from_factory(
self, module: Union[ExecutorFactoryModule, relax.Executable, str], hexagon_arch: str = "v68"
self, module: Union[runtime.executable, str], hexagon_arch: str = "v68"
):
"""Create a local GraphModule which consumes a remote libmod.

Parameters
----------

module : Union[relax.Executable]
module : Union[runtime.Executable, str]

The module to upload to the remote
session and load.
hexagon_arch : str
The hexagon arch to be used
"""
if isinstance(module, (relax.Executable, str)):
if isinstance(module, (runtime.Executable, str)):
return self._relax_vm_executable_executor(module, hexagon_arch=hexagon_arch)

raise TypeError(f"Unsupported executor type: {type(module)}")

def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactoryModule]):
def _set_device_type(self, module: Union[str, pathlib.Path]):
"""Set session device type(hexagon, cpu) based on target in module.

Parameters
Expand All @@ -244,40 +245,41 @@ def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactory
self._requires_cpu_device = False

def _relax_vm_executable_executor(
self, vm_exec: Union[relax.Executable, str], hexagon_arch: str
self, executable: Union[runtime.Executable, str], hexagon_arch: str
):
"""Create a local TVM module which consumes a remote vm executable.

Paramters
---------
Parameters
----------

vm_exec : relax.Executable
The Relax VM Executable to upload to the remote and load. This will typically be the
output of `relax.build` or the path to an already built and exported shared library
executable : runtime.Executable
The Executable to upload to the remote and load. This will typically be the
output of `tvm.compile` or the path to an already built and exported shared library
hexagon_arch : str
The hexagon arch to be used

Returns
-------
TVMModule :
TVM module object
"""
assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use"

if isinstance(vm_exec, relax.Executable):
if isinstance(executable, runtime.Executable):
temp_dir = utils.tempdir()
path_exec = temp_dir.relpath("exec.so")

vm_exec.mod.export_library(
executable.export_library(
path_exec,
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)

path = self.upload(path_exec, "exec.so")
elif isinstance(vm_exec, str):
path_exec = vm_exec
elif isinstance(executable, str):
path_exec = executable
else:
raise TypeError(f"Unsupported executor type: {type(vm_exec)}")
raise TypeError(f"Unsupported executor type: {type(executable)}")

path = self.upload(path_exec, "exec.so")
return self._rpc.get_function("tvm.hexagon.load_module")(str(path))
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin

"""Namespace for driver APIs"""
from .build_module import build
from .build_module import build, compile
85 changes: 82 additions & 3 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,95 @@

# pylint: disable=invalid-name
"""The build utils in python."""
from typing import Union, Optional
import warnings
from typing import Callable, Optional, Union

import tvm
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.runtime import Executable
from tvm.target import Target
from tvm.tir import PrimFunc


def build(
mod: Union[PrimFunc, IRModule],
target: Optional[Union[str, Target]] = None,
pipeline: Optional[Union[str, tvm.transform.Pass]] = "default_tir",
pipeline: Optional[Union[str, tvm.transform.Pass]] = "default",
):
"""
Build a function with a signature, generating code for devices
coupled with target information.

This function is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.

Parameters
----------
mod : Union[PrimFunc, IRModule]
The input to be built.
target : Optional[Union[str, Target]]
The target for compilation.
pipeline : Optional[Union[str, tvm.transform.Pass]]
The pipeline to use for compilation.

Returns
-------
tvm.runtime.Module
A module combining both host and device code.
"""
warnings.warn(
"build is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.",
DeprecationWarning,
)
return tvm.tir.build(mod, target, pipeline)


def _contains_relax(mod: Union[PrimFunc, IRModule]) -> bool:
if isinstance(mod, PrimFunc):
return False
if isinstance(mod, IRModule):
return any(isinstance(func, tvm.relax.Function) for _, func in mod.functions_items())

raise ValueError(f"Function input must be a PrimFunc or IRModule, but got {type(mod)}")


def compile( # pylint: disable=redefined-builtin
mod: Union[PrimFunc, IRModule],
target: Optional[Target] = None,
*,
relax_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] = "default",
tir_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] = "default",
) -> Executable:
"""
Compile an IRModule to a runtime executable.

This function serves as a unified entry point for compiling both TIR and Relax modules.
It automatically detects the module type and routes to the appropriate build function.

Parameters
----------
mod : Union[PrimFunc, IRModule]
The input module to be compiled. Can be a PrimFunc or an IRModule containing
TIR or Relax functions.
target : Optional[Target]
The target platform to compile for.
relax_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]]
The compilation pipeline to use for Relax functions.
Only used if the module contains Relax functions.
tir_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]]
The compilation pipeline to use for TIR functions.

Returns
-------
Executable
A runtime executable that can be loaded and executed.
"""
# TODO(tvm-team): combine two path into unified one
if _contains_relax(mod):
return tvm.relax.build(
mod,
target,
relax_pipeline=relax_pipeline,
tir_pipeline=tir_pipeline,
)
lib = tvm.tir.build(mod, target, pipeline=tir_pipeline)
return Executable(lib)
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/relax_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def compile_relax(
target: Union[Target, str],
params: Optional[Dict[str, NDArray]],
enable_warning: bool = False,
) -> "relax.Executable":
) -> "relax.VMExecutable":
"""Compile a relax program with a MetaSchedule database.

Parameters
Expand All @@ -401,8 +401,8 @@ def compile_relax(

Returns
-------
lib : relax.Executable
The built runtime module or vm Executable for the given relax workload.
lib : relax.VMExecutable
The built runtime module or vm VMExecutable for the given relax workload.
"""
# pylint: disable=import-outside-toplevel
from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase
Expand Down
17 changes: 7 additions & 10 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,18 @@
"""Customized builder and runner methods"""
# pylint: disable=import-outside-toplevel

from typing import TYPE_CHECKING, Dict, Union, Callable
from typing import Dict, Union, Callable

if TYPE_CHECKING:
import numpy as np # type: ignore
from tvm.ir import IRModule
from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
from tvm.runtime import Device, Module, NDArray
from tvm.target import Target
import numpy as np # type: ignore
from tvm.meta_schedule.runner import RPCConfig
from tvm.runtime import Module, Executable


def run_module_via_rpc(
rpc_config: "RPCConfig",
lib: Union["Module", "Executable"],
rpc_config: RPCConfig,
lib: Union[Module, Executable],
dev_type: str,
args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]],
args: Union[Dict[int, np.ndarray], Dict[str, np.ndarray]],
continuation: Callable,
):
"""Execute a tvm.runtime.Module on RPC remote"""
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/testing/tune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def f_calculator(

Parameters
----------
rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
The runtime module or vm executable.
rt_mod : tvm.runtime.Module
The runtime module.
dev : tvm.device
The device type to run workload.
input_data : Dict[str, np.ndarray]
Expand Down
Loading
Loading