Skip to content

Commit

Permalink
win32+clang support
Browse files Browse the repository at this point in the history
 * based on Windows support PR triton-lang#2465 by @andreigh
   - triton-lang#2465
 * manually applied, rebased, fix lint errors
 * remove '/A' platform option to use ninja
 * use sysconfig.get_config_var() to get the path of python*.lib
 * clang fix for windows
 * remove '-fPIC' for windows clang
 * fix download_and_copy() to support windows
 * add "exe" extension for windows
 * use "pyd" extension for windows to make importlib work
 * third_party/nvidia: fix for windows
 * win32 fix _path_to_binary()
 * add library_dir, include_dir for win32
 * backend/compiler lazy remove temp files to support windows
 * additional works done by @mantaionut (2024/05/31)
 * rework for latest triton and cleanup (2024/10/14)
 * extract minimal fixes to support win32+clang (2024/10/16)
 * get exe/so extension using sysconfig (suggested by @anmyachev)

see also:
 intel/intel-xpu-backend-for-triton#2478

Original-author-by: Andrei Gheorghe <andrei@dharmaventures.co>
Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
  • Loading branch information
wkpark committed Oct 19, 2024
1 parent ca451d5 commit 4d691d7
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 25 deletions.
32 changes: 21 additions & 11 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_json_package_info():
def get_llvm_package_info():
system = platform.system()
try:
arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
arch = {"x86_64": "x64", "AMD64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
except KeyError:
arch = platform.machine()
if system == "Darwin":
Expand Down Expand Up @@ -196,6 +196,8 @@ def get_llvm_package_info():
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
)
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
elif system == "Windows":
system_suffix = f"windows-{arch}"
else:
print(
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
Expand Down Expand Up @@ -281,18 +283,21 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
base_dir = os.path.dirname(__file__)
system = platform.system()
try:
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
except KeyError:
arch = platform.machine()
supported = {"Linux": "linux", "Darwin": "linux"}
supported = {"Linux": "linux", "Darwin": "linux", "Windows": "win"}
if system not in supported:
return

url = url_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux"
src_path = src_path(platform_name, version) if callable(src_path) else src_path
src_path = os.path.join(tmp_path, src_path)
download = not os.path.exists(src_path)
if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None:
if os.path.exists(dst_path) and system in ("Linux", "Windows") and shutil.which(dst_path) is not None:
curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip()
curr_version = re.search(r"V([.|\d]+)", curr_version).group(1)
download = download or curr_version != version
Expand Down Expand Up @@ -421,6 +426,10 @@ def build_extension(self, ext):
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if platform.system() == "Windows":
installed_base = sysconfig.get_config_var('installed_base')
py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand All @@ -430,9 +439,8 @@ def build_extension(self, ext):
build_args = ["--config", cfg]

if platform.system() == "Windows":
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
cmake_args += ["-A", "x64"]
else:
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
Expand Down Expand Up @@ -499,8 +507,10 @@ def get_platform_dependent_src_path(subdir):
if int(version_major) >= 12 and int(version_minor1) >= 5 else subdir)(*version.split('.')))


exe = sysconfig.get_config_var("EXE")

download_and_copy(
name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH",
name="ptxas", src_path=f"bin/ptxas{exe}", dst_path=f"bin/ptxas{exe}", variable="TRITON_PTXAS_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2"
Expand All @@ -509,17 +519,17 @@ def get_platform_dependent_src_path(subdir):
(*version.split('.'))))
download_and_copy(
name="cuobjdump",
src_path="bin/cuobjdump",
dst_path="bin/cuobjdump",
src_path=f"bin/cuobjdump{exe}",
dst_path=f"bin/cuobjdump{exe}",
variable="TRITON_CUOBJDUMP_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"],
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
name="nvdisasm",
src_path="bin/nvdisasm",
dst_path="bin/nvdisasm",
src_path=f"bin/nvdisasm{exe}",
dst_path=f"bin/nvdisasm{exe}",
variable="TRITON_NVDISASM_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"],
url_func=lambda system, arch, version:
Expand Down
11 changes: 8 additions & 3 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import hashlib
import subprocess
import sysconfig

from abc import ABCMeta, abstractmethod, abstractclassmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -228,20 +229,24 @@ def __init__(self, target: GPUTarget) -> None:

@staticmethod
def _path_to_binary(binary: str):
exe = sysconfig.get_config_var("EXE")
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
os.path.join(base_dir, "third_party", "cuda", "bin", f"{binary}{exe}"),
]
for p in paths:
bin = p.split(" ")[0]
if os.name != "nt":
bin = p.split(" ")[0]
else:
bin = p
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return p, version.group(1)
raise RuntimeError(f"Cannot find {binary}")
raise RuntimeError(f"Cannot find {binary}{exe}")

@abstractclassmethod
def supports_target(target: GPUTarget):
Expand Down
4 changes: 3 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import re
import functools
import os
import sysconfig

# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
Expand Down Expand Up @@ -151,7 +152,8 @@ def triton_key():

# backend
libtriton_hash = hashlib.sha256()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
so_ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
with open(os.path.join(TRITON_PATH, "_C", "libtriton." + so_ext), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
Expand Down
1 change: 1 addition & 0 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC"))
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand Down
17 changes: 9 additions & 8 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import signal
import os
import subprocess
import sysconfig
from pathlib import Path


Expand All @@ -20,9 +21,10 @@ def min_dot_size(target: GPUTarget):

@functools.lru_cache()
def _path_to_binary(binary: str):
exe = sysconfig.get_config_var("EXE")
paths = [
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(os.path.dirname(__file__), "bin", binary),
os.path.join(os.path.dirname(__file__), "bin", f"{binary}{exe}"),
]

for bin in paths:
Expand All @@ -32,7 +34,7 @@ def _path_to_binary(binary: str):
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return bin, version.group(1)
raise RuntimeError(f"Cannot find {binary}")
raise RuntimeError(f"Cannot find {binary}{exe}")


@functools.lru_cache()
Expand Down Expand Up @@ -340,15 +342,9 @@ def make_cubin(src, metadata, opt, capability):
]
try:
subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog)
if os.path.exists(fsrc.name):
os.remove(fsrc.name)
if os.path.exists(flog.name):
os.remove(flog.name)
except subprocess.CalledProcessError as e:
with open(flog.name) as log_file:
log = log_file.read()
if os.path.exists(flog.name):
os.remove(flog.name)

if e.returncode == 255:
error = 'Internal Triton PTX codegen error'
Expand All @@ -365,6 +361,11 @@ def make_cubin(src, metadata, opt, capability):
cubin = f.read()
if os.path.exists(fbin):
os.remove(fbin)

if os.path.exists(fsrc.name):
os.remove(fsrc.name)
if os.path.exists(flog.name):
os.remove(flog.name)
return cubin

def add_stages(self, stages, options):
Expand Down
15 changes: 13 additions & 2 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import os
import hashlib
import sysconfig
import subprocess
import tempfile
from pathlib import Path
Expand All @@ -14,12 +15,20 @@
libdevice_dir = os.path.join(dirname, "lib")
libraries = ['cuda']

if os.name == "nt":
include_dir += [os.path.join(os.environ.get("CUDA_PATH"), "include")]


@functools.lru_cache()
def libcuda_dirs():
env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH")
if env_libcuda_path:
return [env_libcuda_path]
if os.name == "nt":
installed_base = sysconfig.get_config_var('installed_base')
dirs = [os.path.join(os.environ.get("CUDA_PATH"), "lib", "x64")]
dirs += [os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))]
return dirs

libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
# each line looks like the following:
Expand Down Expand Up @@ -48,15 +57,17 @@ def library_dirs():
def compile_module_from_src(src, name):
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
cache_path = cache.get_file(f"{name}.so")
so_ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
so_name = f'{name}.{so_ext}'
cache_path = cache.get_file(so_name)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
cache_path = cache.put(f.read(), so_name, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location(name, cache_path)
mod = importlib.util.module_from_spec(spec)
Expand Down

0 comments on commit 4d691d7

Please sign in to comment.