From aa65c3077c306816a637c565e57ae798b6d238e9 Mon Sep 17 00:00:00 2001 From: Alexey Gladyshev Date: Tue, 25 Jan 2022 20:48:08 +0300 Subject: [PATCH] [TVM EP] Improved usability of TVM EP (#10241) * improved usability of TVM EP * moved technical import under a condition related to TVM EP only * Revert "moved technical import under a condition related to TVM EP only" * add conditional _ld_preload.py file extension for TVM EP * improve readability of inserted code (cherry picked from commit a0fe4a7c1c156d43840af42111a559dbbd9753ed) --- setup.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/setup.py b/setup.py index 2a1ec7f25d9b5..6130fcdbba2fb 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ import platform import subprocess import sys +import textwrap import datetime from pathlib import Path @@ -145,6 +146,33 @@ def _rewrite_ld_preload_tensorrt(self, to_preload): f.write(' import os\n') f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n') + def _rewrite_ld_preload_tvm(self): + with open('onnxruntime/capi/_ld_preload.py', 'a') as f: + f.write(textwrap.dedent( + """ + import warnings + + try: + # This import is necessary in order to delegate the loading of libtvm.so to TVM. + import tvm + except ImportError as e: + warnings.warn( + f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}" + ) + try: + # Working between the C++ and Python parts in TVM EP is done using the PackedFunc and + # Registry classes. In order to use a Python function in C++ code, it must be registered in + # the global table of functions. Registration is carried out through the JIT interface, + # so it is necessary to call special functions for registration. + # To do this, we need to make the following import. + import onnxruntime.providers.stvm + except ImportError as e: + warnings.warn( + f"WARNING: Failed to register python functions to work with TVM EP. More details: {e}" + ) + """ + )) + def run(self): if is_manylinux: source = 'onnxruntime/capi/onnxruntime_pybind11_state.so' @@ -207,6 +235,8 @@ def run(self): self._rewrite_ld_preload(to_preload) self._rewrite_ld_preload_cuda(to_preload_cuda) self._rewrite_ld_preload_tensorrt(to_preload_tensorrt) + if package_name == 'onnxruntime-tvm': + self._rewrite_ld_preload_tvm() _bdist_wheel.run(self) if is_manylinux and not disable_auditwheel_repair: file = glob(path.join(self.dist_dir, '*linux*.whl'))[0]