diff --git a/python/tvm/libinfo.py b/python/tvm/libinfo.py index f9f28b6853e2..69429179fc69 100644 --- a/python/tvm/libinfo.py +++ b/python/tvm/libinfo.py @@ -195,7 +195,9 @@ def find_include_path(name=None, search_path=None, optional=False): include_path : list(string) List of all found paths to header files. """ - if os.environ.get("TVM_HOME", None): + if os.environ.get("TVM_SOURCE_DIR", None): + source_dir = os.environ["TVM_SOURCE_DIR"] + elif os.environ.get("TVM_HOME", None): source_dir = os.environ["TVM_HOME"] else: ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -204,7 +206,7 @@ def find_include_path(name=None, search_path=None, optional=False): if os.path.isdir(os.path.join(source_dir, "include")): break else: - raise AssertionError("Cannot find the source directory given ffi_dir: {ffi_dir}") + raise AssertionError(f"Cannot find the source directory given ffi_dir: {ffi_dir}") third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 0f81675a8fb9..1fea39e9a221 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -24,6 +24,8 @@ from pathlib import Path from typing import List +import tvm_ffi + import tvm from tvm.target import Target @@ -124,17 +126,51 @@ def get_object_file_path(src: Path) -> Path: # ------------------------------------------------------------------------ # 2) Include paths # ------------------------------------------------------------------------ - tvm_home = os.environ["TVM_SOURCE_DIR"] include_paths = [ FLASHINFER_INCLUDE_DIR, FLASHINFER_CSRC_DIR, FLASHINFER_TVM_BINDING_DIR, - Path(tvm_home).resolve() / "include", - Path(tvm_home).resolve() / "ffi" / "include", - Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include", - Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", ] + CUTLASS_INCLUDE_DIRS + if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME", None): + # Respect TVM_SOURCE_DIR and TVM_HOME if they are set + tvm_home = ( + os.environ["TVM_SOURCE_DIR"] + if os.environ.get("TVM_SOURCE_DIR", None) + else os.environ["TVM_HOME"] + ) + include_paths += [ + Path(tvm_home).resolve() / "include", + Path(tvm_home).resolve() / "ffi" / "include", + Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include", + Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", + ] + else: + # If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM package path + tvm_package_path = Path(tvm.__file__).resolve().parent + if (tvm_package_path / "include").exists(): + # The package is installed from pip. + tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent + include_paths += [ + tvm_package_path / "include", + tvm_package_path / "3rdparty" / "dmlc-core" / "include", + tvm_ffi_package_path / "include", + ] + elif (tvm_package_path.parent.parent / "include").exists(): + # The package is installed from source. + include_paths += [ + tvm_package_path.parent.parent / "include", + tvm_package_path.parent.parent / "ffi" / "include", + tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack" / "include", + tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include", + ] + else: + # warning: TVM is not installed in the system. + print( + "Warning: Include path for TVM cannot be found. " + "FlashInfer kernel compilation may fail due to missing headers." + ) + # ------------------------------------------------------------------------ # 3) Function to compile a single source file # ------------------------------------------------------------------------