Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/tvm/ffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu
from .ndarray import from_dlpack, NDArray, Shape
from .container import Array, Map
from .module import Module, ModulePropertyMask, system_lib, load_module
from . import serialization
from . import access_path
from . import testing
Expand Down Expand Up @@ -71,4 +72,8 @@
"testing",
"access_path",
"serialization",
"Module",
"ModulePropertyMask",
"system_lib",
"load_module",
]
258 changes: 258 additions & 0 deletions python/tvm/ffi/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Module related objects and functions."""
# pylint: disable=invalid-name

from enum import IntEnum
from . import _ffi_api

from . import core
from .registry import register_object

__all__ = ["Module", "ModulePropertyMask", "system_lib", "load_module"]


class ModulePropertyMask(IntEnum):
"""Runtime Module Property Mask."""

BINARY_SERIALIZABLE = 0b001
RUNNABLE = 0b010
COMPILATION_EXPORTABLE = 0b100


@register_object("ffi.Module")
class Module(core.Object):
"""Runtime Module."""

def __new__(cls):
instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter
instance.entry_name = "__tvm_ffi_main__"
instance._entry = None
return instance

@property
def entry_func(self):
"""Get the entry function

Returns
-------
f : tvm.ffi.Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function("__tvm_ffi_main__")
return self._entry

@property
def kind(self):
"""Get type key of the module."""
return _ffi_api.ModuleGetKind(self)

@property
def imports(self):
"""Get imported modules

Returns
----------
modules : list of Module
The module
"""
return self.imports_

def implements_function(self, name, query_imports=False):
"""Returns True if the module has a definition for the global function with name. Note
that has_function(name) does not imply get_function(name) is non-null since the module
may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function
without further compilation. However, get_function(name) non null should always imply
has_function(name).

Parameters
----------
name : str
The name of the function

query_imports : bool
Whether to also query modules imported by this module.

Returns
-------
b : Bool
True if module (or one of its imports) has a definition for name.
"""
return _ffi_api.ModuleImplementsFunction(self, name, query_imports)

def get_function(self, name, query_imports=False):
"""Get function from the module.

Parameters
----------
name : str
The name of the function

query_imports : bool
Whether also query modules imported by this module.

Returns
-------
f : tvm.ffi.Function
The result function.
"""
func = _ffi_api.ModuleGetFunction(self, name, query_imports)
if func is None:
raise AttributeError(f"Module has no function '{name}'")
return func

def import_module(self, module):
"""Add module to the import list of current one.

Parameters
----------
module : tvm.runtime.Module
The other module.
"""
_ffi_api.ModuleImportModule(self, module)

def __getitem__(self, name):
if not isinstance(name, str):
raise ValueError("Can only take string as function name")
return self.get_function(name)

def __call__(self, *args):
if self._entry:
return self._entry(*args)
# pylint: disable=not-callable
return self.entry_func(*args)

def inspect_source(self, fmt=""):
"""Get source code from module, if available.

Parameters
----------
fmt : str, optional
The specified format.

Returns
-------
source : str
The result source code.
"""
return _ffi_api.ModuleInspectSource(self, fmt)

def get_write_formats(self):
"""Get the format of the module."""
return _ffi_api.ModuleGetWriteFormats(self)

def get_property_mask(self):
"""Get the runtime module property mask. The mapping is stated in ModulePropertyMask.

Returns
-------
mask : int
Bitmask of runtime module property
"""
return _ffi_api.ModuleGetPropertyMask(self)

def is_binary_serializable(self):
"""Module 'binary serializable', save_to_bytes is supported.

Returns
-------
b : Bool
True if the module is binary serializable.
"""
return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0

def is_runnable(self):
"""Module 'runnable', get_function is supported.

Returns
-------
b : Bool
True if the module is runnable.
"""
return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0

def is_compilation_exportable(self):
"""Module 'compilation exportable', write_to_file is supported for object or source.

Returns
-------
b : Bool
True if the module is compilation exportable.
"""
return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0

def clear_imports(self):
"""Remove all imports of the module."""
_ffi_api.ModuleClearImports(self)

def write_to_file(self, file_name, fmt=""):
"""Write the current module to file.

Parameters
----------
file_name : str
The name of the file.
fmt : str
The format of the file.

See Also
--------
runtime.Module.export_library : export the module to shared library.
"""
_ffi_api.ModuleWriteToFile(self, file_name, fmt)


def system_lib(symbol_prefix=""):
"""Get system-wide library module singleton.

System lib is a global module that contains self register functions in startup.
Unlike normal dso modules which need to be loaded explicitly.
It is useful in environments where dynamic loading api like dlopen is banned.

The system lib is intended to be linked and loaded during the entire life-cyle of the program.
If you want dynamic loading features, use dso modules instead.

Parameters
----------
symbol_prefix: Optional[str]
Optional symbol prefix that can be used for search. When we lookup a symbol
symbol_prefix + name will first be searched, then the name without symbol_prefix.

Returns
-------
module : runtime.Module
The system-wide library module.
"""
return _ffi_api.SystemLib(symbol_prefix)


def load_module(path):
"""Load module from file.

Parameters
----------
path : str
The path to the module file.

Returns
-------
module : ffi.Module
The loaded module
"""
return _ffi_api.ModuleLoadFromFile(path)
6 changes: 5 additions & 1 deletion python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def _auto_attach_system_lib_prefix(
return tir_mod


def _is_device_module(mod: tvm.runtime.Module) -> bool:
return mod.kind in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"]


def _vmlink(
builder: "relax.ExecBuilder",
target: Optional[Union[str, tvm.target.Target]],
Expand Down Expand Up @@ -153,7 +157,7 @@ def _vmlink(
tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib)
lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
for ext_mod in ext_libs:
if ext_mod.is_device_module():
if _is_device_module(ext_mod):
tir_ext_libs.append(ext_mod)
else:
relax_ext_libs.append(ext_mod)
Expand Down
46 changes: 26 additions & 20 deletions python/tvm/runtime/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
# pylint: disable=invalid-name, no-member

"""Executable object for TVM Runtime"""
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

import tvm

from tvm.contrib import utils as _utils
from . import PackedFunc, Module

Expand Down Expand Up @@ -105,20 +106,27 @@ def _not_runnable(x):
# by collecting the link and allow export_library skip those modules.
workspace_dir = _utils.tempdir()
dso_path = workspace_dir.relpath("exported.so")
self.mod.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs)
self.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs)
self._jitted_mod = tvm.runtime.load_module(dso_path)
return self._jitted_mod

def export_library(
self,
file_name: str,
file_name,
*,
fcompile: Optional[Union[str, Callable[[str, List[str], Dict[str, Any]], None]]] = None,
addons: Optional[List[str]] = None,
workspace_dir: Optional[str] = None,
fcompile=None,
addons=None,
workspace_dir=None,
**kwargs,
) -> Any:
"""Export the executable to a library which can then be loaded back.
):
"""
Export the module and all imported modules into a single device library.

This function only works on host LLVM modules, other runtime::Module
subclasses will work with this API but they must support implement
the save and load mechanisms of modules completely including saving
from streams and files. This will pack your non-shared library module
into a single shared library which can later be loaded by TVM.

Parameters
----------
Expand All @@ -127,6 +135,15 @@ def export_library(

fcompile : function(target, file_list, kwargs), optional
The compilation function to use create the final library object during
export.

For example, when fcompile=_cc.create_shared, or when it is not supplied but
module is "llvm," this is used to link all produced artifacts
into a final dynamic library.

This behavior is controlled by the type of object exported.
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".

addons : list of str, optional
Additional object files to link against.
Expand All @@ -144,20 +161,9 @@ def export_library(
result of fcompile() : unknown, optional
If the compilation function returns an artifact it would be returned via
export_library, if any.

Examples
--------
.. code:: python

ex = tvm.compile(mod, target)
# export the library
ex.export_library("exported.so")

# load it back for future uses.
rt_mod = tvm.runtime.load_module("exported.so")
"""
return self.mod.export_library(
file_name=file_name,
file_name,
fcompile=fcompile,
addons=addons,
workspace_dir=workspace_dir,
Expand Down
Loading
Loading