Skip to content

Commit

Permalink
[TVM EP] Improved usability of TVM EP (microsoft#10241)
Browse files Browse the repository at this point in the history
* improved usability of TVM EP
* moved technical import under a condition related to TVM EP only
* Revert "moved technical import under a condition related to TVM EP only"
* add conditional _ld_preload.py file extension for TVM EP
* improve readability of inserted code

(cherry picked from commit a0fe4a7)
  • Loading branch information
KJlaccHoeUM9l authored and Peter Salas committed Nov 7, 2022
1 parent 39269f7 commit aa65c30
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import platform
import subprocess
import sys
import textwrap
import datetime

from pathlib import Path
Expand Down Expand Up @@ -145,6 +146,33 @@ def _rewrite_ld_preload_tensorrt(self, to_preload):
f.write(' import os\n')
f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n')

def _rewrite_ld_preload_tvm(self):
with open('onnxruntime/capi/_ld_preload.py', 'a') as f:
f.write(textwrap.dedent(
"""
import warnings
try:
# This import is necessary in order to delegate the loading of libtvm.so to TVM.
import tvm
except ImportError as e:
warnings.warn(
f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}"
)
try:
# Working between the C++ and Python parts in TVM EP is done using the PackedFunc and
# Registry classes. In order to use a Python function in C++ code, it must be registered in
# the global table of functions. Registration is carried out through the JIT interface,
# so it is necessary to call special functions for registration.
# To do this, we need to make the following import.
import onnxruntime.providers.stvm
except ImportError as e:
warnings.warn(
f"WARNING: Failed to register python functions to work with TVM EP. More details: {e}"
)
"""
))

def run(self):
if is_manylinux:
source = 'onnxruntime/capi/onnxruntime_pybind11_state.so'
Expand Down Expand Up @@ -207,6 +235,8 @@ def run(self):
self._rewrite_ld_preload(to_preload)
self._rewrite_ld_preload_cuda(to_preload_cuda)
self._rewrite_ld_preload_tensorrt(to_preload_tensorrt)
if package_name == 'onnxruntime-tvm':
self._rewrite_ld_preload_tvm()
_bdist_wheel.run(self)
if is_manylinux and not disable_auditwheel_repair:
file = glob(path.join(self.dist_dir, '*linux*.whl'))[0]
Expand Down

0 comments on commit aa65c30

Please sign in to comment.