|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +"""Module related objects and functions.""" |
| 18 | +# pylint: disable=invalid-name |
| 19 | + |
| 20 | +from enum import IntEnum |
| 21 | +from . import _ffi_api |
| 22 | + |
| 23 | +from . import core |
| 24 | +from .registry import register_object |
| 25 | + |
| 26 | +__all__ = ["Module", "ModulePropertyMask", "system_lib", "load_module"] |
| 27 | + |
| 28 | + |
| 29 | +class ModulePropertyMask(IntEnum): |
| 30 | + """Runtime Module Property Mask.""" |
| 31 | + |
| 32 | + BINARY_SERIALIZABLE = 0b001 |
| 33 | + RUNNABLE = 0b010 |
| 34 | + COMPILATION_EXPORTABLE = 0b100 |
| 35 | + |
| 36 | + |
| 37 | +@register_object("ffi.Module") |
| 38 | +class Module(core.Object): |
| 39 | + """Runtime Module.""" |
| 40 | + |
| 41 | + def __new__(cls): |
| 42 | + instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter |
| 43 | + instance.entry_name = "__tvm_ffi_main__" |
| 44 | + instance._entry = None |
| 45 | + return instance |
| 46 | + |
| 47 | + @property |
| 48 | + def entry_func(self): |
| 49 | + """Get the entry function |
| 50 | +
|
| 51 | + Returns |
| 52 | + ------- |
| 53 | + f : tvm.ffi.Function |
| 54 | + The entry function if exist |
| 55 | + """ |
| 56 | + if self._entry: |
| 57 | + return self._entry |
| 58 | + self._entry = self.get_function("__tvm_ffi_main__") |
| 59 | + return self._entry |
| 60 | + |
| 61 | + @property |
| 62 | + def kind(self): |
| 63 | + """Get type key of the module.""" |
| 64 | + return _ffi_api.ModuleGetKind(self) |
| 65 | + |
| 66 | + @property |
| 67 | + def imports(self): |
| 68 | + """Get imported modules |
| 69 | +
|
| 70 | + Returns |
| 71 | + ---------- |
| 72 | + modules : list of Module |
| 73 | + The module |
| 74 | + """ |
| 75 | + return self.imports_ |
| 76 | + |
| 77 | + def implements_function(self, name, query_imports=False): |
| 78 | + """Returns True if the module has a definition for the global function with name. Note |
| 79 | + that has_function(name) does not imply get_function(name) is non-null since the module |
| 80 | + may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function |
| 81 | + without further compilation. However, get_function(name) non null should always imply |
| 82 | + has_function(name). |
| 83 | +
|
| 84 | + Parameters |
| 85 | + ---------- |
| 86 | + name : str |
| 87 | + The name of the function |
| 88 | +
|
| 89 | + query_imports : bool |
| 90 | + Whether to also query modules imported by this module. |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + b : Bool |
| 95 | + True if module (or one of its imports) has a definition for name. |
| 96 | + """ |
| 97 | + return _ffi_api.ModuleImplementsFunction(self, name, query_imports) |
| 98 | + |
| 99 | + def get_function(self, name, query_imports=False): |
| 100 | + """Get function from the module. |
| 101 | +
|
| 102 | + Parameters |
| 103 | + ---------- |
| 104 | + name : str |
| 105 | + The name of the function |
| 106 | +
|
| 107 | + query_imports : bool |
| 108 | + Whether also query modules imported by this module. |
| 109 | +
|
| 110 | + Returns |
| 111 | + ------- |
| 112 | + f : tvm.ffi.Function |
| 113 | + The result function. |
| 114 | + """ |
| 115 | + func = _ffi_api.ModuleGetFunction(self, name, query_imports) |
| 116 | + if func is None: |
| 117 | + raise AttributeError(f"Module has no function '{name}'") |
| 118 | + return func |
| 119 | + |
| 120 | + def import_module(self, module): |
| 121 | + """Add module to the import list of current one. |
| 122 | +
|
| 123 | + Parameters |
| 124 | + ---------- |
| 125 | + module : tvm.runtime.Module |
| 126 | + The other module. |
| 127 | + """ |
| 128 | + _ffi_api.ModuleImportModule(self, module) |
| 129 | + |
| 130 | + def __getitem__(self, name): |
| 131 | + if not isinstance(name, str): |
| 132 | + raise ValueError("Can only take string as function name") |
| 133 | + return self.get_function(name) |
| 134 | + |
| 135 | + def __call__(self, *args): |
| 136 | + if self._entry: |
| 137 | + return self._entry(*args) |
| 138 | + # pylint: disable=not-callable |
| 139 | + return self.entry_func(*args) |
| 140 | + |
| 141 | + def inspect_source(self, fmt=""): |
| 142 | + """Get source code from module, if available. |
| 143 | +
|
| 144 | + Parameters |
| 145 | + ---------- |
| 146 | + fmt : str, optional |
| 147 | + The specified format. |
| 148 | +
|
| 149 | + Returns |
| 150 | + ------- |
| 151 | + source : str |
| 152 | + The result source code. |
| 153 | + """ |
| 154 | + return _ffi_api.ModuleInspectSource(self, fmt) |
| 155 | + |
| 156 | + def get_write_formats(self): |
| 157 | + """Get the format of the module.""" |
| 158 | + return _ffi_api.ModuleGetWriteFormats(self) |
| 159 | + |
| 160 | + def get_property_mask(self): |
| 161 | + """Get the runtime module property mask. The mapping is stated in ModulePropertyMask. |
| 162 | +
|
| 163 | + Returns |
| 164 | + ------- |
| 165 | + mask : int |
| 166 | + Bitmask of runtime module property |
| 167 | + """ |
| 168 | + return _ffi_api.ModuleGetPropertyMask(self) |
| 169 | + |
| 170 | + def is_binary_serializable(self): |
| 171 | + """Module 'binary serializable', save_to_bytes is supported. |
| 172 | +
|
| 173 | + Returns |
| 174 | + ------- |
| 175 | + b : Bool |
| 176 | + True if the module is binary serializable. |
| 177 | + """ |
| 178 | + return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 |
| 179 | + |
| 180 | + def is_runnable(self): |
| 181 | + """Module 'runnable', get_function is supported. |
| 182 | +
|
| 183 | + Returns |
| 184 | + ------- |
| 185 | + b : Bool |
| 186 | + True if the module is runnable. |
| 187 | + """ |
| 188 | + return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 |
| 189 | + |
| 190 | + def is_compilation_exportable(self): |
| 191 | + """Module 'compilation exportable', write_to_file is supported for object or source. |
| 192 | +
|
| 193 | + Returns |
| 194 | + ------- |
| 195 | + b : Bool |
| 196 | + True if the module is compilation exportable. |
| 197 | + """ |
| 198 | + return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0 |
| 199 | + |
| 200 | + def clear_imports(self): |
| 201 | + """Remove all imports of the module.""" |
| 202 | + _ffi_api.ModuleClearImports(self) |
| 203 | + |
| 204 | + def write_to_file(self, file_name, fmt=""): |
| 205 | + """Write the current module to file. |
| 206 | +
|
| 207 | + Parameters |
| 208 | + ---------- |
| 209 | + file_name : str |
| 210 | + The name of the file. |
| 211 | + fmt : str |
| 212 | + The format of the file. |
| 213 | +
|
| 214 | + See Also |
| 215 | + -------- |
| 216 | + runtime.Module.export_library : export the module to shared library. |
| 217 | + """ |
| 218 | + _ffi_api.ModuleWriteToFile(self, file_name, fmt) |
| 219 | + |
| 220 | + |
| 221 | +def system_lib(symbol_prefix=""): |
| 222 | + """Get system-wide library module singleton. |
| 223 | +
|
| 224 | + System lib is a global module that contains self register functions in startup. |
| 225 | + Unlike normal dso modules which need to be loaded explicitly. |
| 226 | + It is useful in environments where dynamic loading api like dlopen is banned. |
| 227 | +
|
| 228 | + The system lib is intended to be linked and loaded during the entire life-cyle of the program. |
| 229 | + If you want dynamic loading features, use dso modules instead. |
| 230 | +
|
| 231 | + Parameters |
| 232 | + ---------- |
| 233 | + symbol_prefix: Optional[str] |
| 234 | + Optional symbol prefix that can be used for search. When we lookup a symbol |
| 235 | + symbol_prefix + name will first be searched, then the name without symbol_prefix. |
| 236 | +
|
| 237 | + Returns |
| 238 | + ------- |
| 239 | + module : runtime.Module |
| 240 | + The system-wide library module. |
| 241 | + """ |
| 242 | + return _ffi_api.SystemLib(symbol_prefix) |
| 243 | + |
| 244 | + |
| 245 | +def load_module(path): |
| 246 | + """Load module from file. |
| 247 | +
|
| 248 | + Parameters |
| 249 | + ---------- |
| 250 | + path : str |
| 251 | + The path to the module file. |
| 252 | +
|
| 253 | + Returns |
| 254 | + ------- |
| 255 | + module : ffi.Module |
| 256 | + The loaded module |
| 257 | + """ |
| 258 | + return _ffi_api.ModuleLoadFromFile(path) |
0 commit comments