Skip to content

Commit 9e8b46c

Browse files
committed
[FFI][REFACTOR] Establish ffi.Module in python
This PR refactors and establishes ffi.Module under the python tvm ffi api. Also moves export_library method to executable so it aligns more with compiled artifact.
1 parent b8eb80b commit 9e8b46c

File tree

6 files changed

+402
-343
lines changed

6 files changed

+402
-343
lines changed

python/tvm/ffi/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu
3131
from .ndarray import from_dlpack, NDArray, Shape
3232
from .container import Array, Map
33+
from .module import Module, ModulePropertyMask, system_lib, load_module
3334
from . import serialization
3435
from . import access_path
3536
from . import testing
@@ -71,4 +72,8 @@
7172
"testing",
7273
"access_path",
7374
"serialization",
75+
"Module",
76+
"ModulePropertyMask",
77+
"system_lib",
78+
"load_module",
7479
]

python/tvm/ffi/module.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Module related objects and functions."""
18+
# pylint: disable=invalid-name
19+
20+
from enum import IntEnum
21+
from . import _ffi_api
22+
23+
from . import core
24+
from .registry import register_object
25+
26+
__all__ = ["Module", "ModulePropertyMask", "system_lib", "load_module"]
27+
28+
29+
class ModulePropertyMask(IntEnum):
30+
"""Runtime Module Property Mask."""
31+
32+
BINARY_SERIALIZABLE = 0b001
33+
RUNNABLE = 0b010
34+
COMPILATION_EXPORTABLE = 0b100
35+
36+
37+
@register_object("ffi.Module")
38+
class Module(core.Object):
39+
"""Runtime Module."""
40+
41+
def __new__(cls):
42+
instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter
43+
instance.entry_name = "__tvm_ffi_main__"
44+
instance._entry = None
45+
return instance
46+
47+
@property
48+
def entry_func(self):
49+
"""Get the entry function
50+
51+
Returns
52+
-------
53+
f : tvm.ffi.Function
54+
The entry function if exist
55+
"""
56+
if self._entry:
57+
return self._entry
58+
self._entry = self.get_function("__tvm_ffi_main__")
59+
return self._entry
60+
61+
@property
62+
def kind(self):
63+
"""Get type key of the module."""
64+
return _ffi_api.ModuleGetKind(self)
65+
66+
@property
67+
def imports(self):
68+
"""Get imported modules
69+
70+
Returns
71+
----------
72+
modules : list of Module
73+
The module
74+
"""
75+
return self.imports_
76+
77+
def implements_function(self, name, query_imports=False):
78+
"""Returns True if the module has a definition for the global function with name. Note
79+
that has_function(name) does not imply get_function(name) is non-null since the module
80+
may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function
81+
without further compilation. However, get_function(name) non null should always imply
82+
has_function(name).
83+
84+
Parameters
85+
----------
86+
name : str
87+
The name of the function
88+
89+
query_imports : bool
90+
Whether to also query modules imported by this module.
91+
92+
Returns
93+
-------
94+
b : Bool
95+
True if module (or one of its imports) has a definition for name.
96+
"""
97+
return _ffi_api.ModuleImplementsFunction(self, name, query_imports)
98+
99+
def get_function(self, name, query_imports=False):
100+
"""Get function from the module.
101+
102+
Parameters
103+
----------
104+
name : str
105+
The name of the function
106+
107+
query_imports : bool
108+
Whether also query modules imported by this module.
109+
110+
Returns
111+
-------
112+
f : tvm.ffi.Function
113+
The result function.
114+
"""
115+
func = _ffi_api.ModuleGetFunction(self, name, query_imports)
116+
if func is None:
117+
raise AttributeError(f"Module has no function '{name}'")
118+
return func
119+
120+
def import_module(self, module):
121+
"""Add module to the import list of current one.
122+
123+
Parameters
124+
----------
125+
module : tvm.runtime.Module
126+
The other module.
127+
"""
128+
_ffi_api.ModuleImportModule(self, module)
129+
130+
def __getitem__(self, name):
131+
if not isinstance(name, str):
132+
raise ValueError("Can only take string as function name")
133+
return self.get_function(name)
134+
135+
def __call__(self, *args):
136+
if self._entry:
137+
return self._entry(*args)
138+
# pylint: disable=not-callable
139+
return self.entry_func(*args)
140+
141+
def inspect_source(self, fmt=""):
142+
"""Get source code from module, if available.
143+
144+
Parameters
145+
----------
146+
fmt : str, optional
147+
The specified format.
148+
149+
Returns
150+
-------
151+
source : str
152+
The result source code.
153+
"""
154+
return _ffi_api.ModuleInspectSource(self, fmt)
155+
156+
def get_write_formats(self):
157+
"""Get the format of the module."""
158+
return _ffi_api.ModuleGetWriteFormats(self)
159+
160+
def get_property_mask(self):
161+
"""Get the runtime module property mask. The mapping is stated in ModulePropertyMask.
162+
163+
Returns
164+
-------
165+
mask : int
166+
Bitmask of runtime module property
167+
"""
168+
return _ffi_api.ModuleGetPropertyMask(self)
169+
170+
def is_binary_serializable(self):
171+
"""Module 'binary serializable', save_to_bytes is supported.
172+
173+
Returns
174+
-------
175+
b : Bool
176+
True if the module is binary serializable.
177+
"""
178+
return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0
179+
180+
def is_runnable(self):
181+
"""Module 'runnable', get_function is supported.
182+
183+
Returns
184+
-------
185+
b : Bool
186+
True if the module is runnable.
187+
"""
188+
return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0
189+
190+
def is_compilation_exportable(self):
191+
"""Module 'compilation exportable', write_to_file is supported for object or source.
192+
193+
Returns
194+
-------
195+
b : Bool
196+
True if the module is compilation exportable.
197+
"""
198+
return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0
199+
200+
def clear_imports(self):
201+
"""Remove all imports of the module."""
202+
_ffi_api.ModuleClearImports(self)
203+
204+
def write_to_file(self, file_name, fmt=""):
205+
"""Write the current module to file.
206+
207+
Parameters
208+
----------
209+
file_name : str
210+
The name of the file.
211+
fmt : str
212+
The format of the file.
213+
214+
See Also
215+
--------
216+
runtime.Module.export_library : export the module to shared library.
217+
"""
218+
_ffi_api.ModuleWriteToFile(self, file_name, fmt)
219+
220+
221+
def system_lib(symbol_prefix=""):
222+
"""Get system-wide library module singleton.
223+
224+
System lib is a global module that contains self register functions in startup.
225+
Unlike normal dso modules which need to be loaded explicitly.
226+
It is useful in environments where dynamic loading api like dlopen is banned.
227+
228+
The system lib is intended to be linked and loaded during the entire life-cyle of the program.
229+
If you want dynamic loading features, use dso modules instead.
230+
231+
Parameters
232+
----------
233+
symbol_prefix: Optional[str]
234+
Optional symbol prefix that can be used for search. When we lookup a symbol
235+
symbol_prefix + name will first be searched, then the name without symbol_prefix.
236+
237+
Returns
238+
-------
239+
module : runtime.Module
240+
The system-wide library module.
241+
"""
242+
return _ffi_api.SystemLib(symbol_prefix)
243+
244+
245+
def load_module(path):
246+
"""Load module from file.
247+
248+
Parameters
249+
----------
250+
path : str
251+
The path to the module file.
252+
253+
Returns
254+
-------
255+
module : ffi.Module
256+
The loaded module
257+
"""
258+
return _ffi_api.ModuleLoadFromFile(path)

python/tvm/relax/vm_build.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def _auto_attach_system_lib_prefix(
9999
return tir_mod
100100

101101

102+
def _is_device_module(mod: tvm.runtime.Module) -> bool:
103+
return mod.kind in ["cuda", "opencl", "metal", "hip", "vulkan", "webgpu"]
104+
105+
102106
def _vmlink(
103107
builder: "relax.ExecBuilder",
104108
target: Optional[Union[str, tvm.target.Target]],
@@ -153,7 +157,7 @@ def _vmlink(
153157
tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib)
154158
lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
155159
for ext_mod in ext_libs:
156-
if ext_mod.is_device_module():
160+
if _is_device_module(ext_mod):
157161
tir_ext_libs.append(ext_mod)
158162
else:
159163
relax_ext_libs.append(ext_mod)

python/tvm/runtime/executable.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
# pylint: disable=invalid-name, no-member
1818

1919
"""Executable object for TVM Runtime"""
20-
from typing import Any, Callable, Dict, List, Optional, Union
20+
from typing import Any, Callable, Dict, List, Optional
2121

2222
import tvm
23+
2324
from tvm.contrib import utils as _utils
2425
from . import PackedFunc, Module
2526

@@ -94,7 +95,7 @@ def _not_runnable(x):
9495
return x.kind in ("c", "static_library")
9596

9697
# pylint:disable = protected-access
97-
not_runnable_list = self.mod._collect_from_import_tree(_not_runnable)
98+
not_runnable_list = self._collect_from_import_tree(_not_runnable)
9899

99100
# everything is runnable, directly return mod.
100101
if len(not_runnable_list) == 0:
@@ -105,20 +106,27 @@ def _not_runnable(x):
105106
# by collecting the link and allow export_library skip those modules.
106107
workspace_dir = _utils.tempdir()
107108
dso_path = workspace_dir.relpath("exported.so")
108-
self.mod.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs)
109+
self.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs)
109110
self._jitted_mod = tvm.runtime.load_module(dso_path)
110111
return self._jitted_mod
111112

112113
def export_library(
113114
self,
114-
file_name: str,
115+
file_name,
115116
*,
116-
fcompile: Optional[Union[str, Callable[[str, List[str], Dict[str, Any]], None]]] = None,
117-
addons: Optional[List[str]] = None,
118-
workspace_dir: Optional[str] = None,
117+
fcompile=None,
118+
addons=None,
119+
workspace_dir=None,
119120
**kwargs,
120-
) -> Any:
121-
"""Export the executable to a library which can then be loaded back.
121+
):
122+
"""
123+
Export the module and all imported modules into a single device library.
124+
125+
This function only works on host LLVM modules, other runtime::Module
126+
subclasses will work with this API but they must support implement
127+
the save and load mechanisms of modules completely including saving
128+
from streams and files. This will pack your non-shared library module
129+
into a single shared library which can later be loaded by TVM.
122130
123131
Parameters
124132
----------
@@ -127,6 +135,15 @@ def export_library(
127135
128136
fcompile : function(target, file_list, kwargs), optional
129137
The compilation function to use create the final library object during
138+
export.
139+
140+
For example, when fcompile=_cc.create_shared, or when it is not supplied but
141+
module is "llvm," this is used to link all produced artifacts
142+
into a final dynamic library.
143+
144+
This behavior is controlled by the type of object exported.
145+
If fcompile has attribute object_format, will compile host library
146+
to that format. Otherwise, will use default format "o".
130147
131148
addons : list of str, optional
132149
Additional object files to link against.
@@ -144,20 +161,9 @@ def export_library(
144161
result of fcompile() : unknown, optional
145162
If the compilation function returns an artifact it would be returned via
146163
export_library, if any.
147-
148-
Examples
149-
--------
150-
.. code:: python
151-
152-
ex = tvm.compile(mod, target)
153-
# export the library
154-
ex.export_library("exported.so")
155-
156-
# load it back for future uses.
157-
rt_mod = tvm.runtime.load_module("exported.so")
158164
"""
159165
return self.mod.export_library(
160-
file_name=file_name,
166+
file_name,
161167
fcompile=fcompile,
162168
addons=addons,
163169
workspace_dir=workspace_dir,

0 commit comments

Comments
 (0)