Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows native port #2478

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
1b87208
First commit, triton builds on windows
gshimansky Oct 4, 2024
62cbe2c
Fixed loading of nvidia driver.py module
gshimansky Oct 5, 2024
022402d
Fixed compile command line for windows compilers
gshimansky Oct 9, 2024
69dee26
Make loader buildable on windows
gshimansky Oct 10, 2024
9382773
Fixed loading spirv_utils dynamic library module
gshimansky Oct 11, 2024
dbc280a
Fixed getting device context on windows
gshimansky Oct 11, 2024
bcf7d94
Formatter corrections
gshimansky Oct 11, 2024
f43486d
Disable compilation warning psapi
gshimansky Oct 15, 2024
e3e8068
Specify different levels of C++ for Linux and Windows
gshimansky Oct 15, 2024
af38761
Fixed constants of different size on linux and windows
gshimansky Oct 15, 2024
33fc428
Fixed -fPIC
gshimansky Oct 15, 2024
f1c0970
Fixed print_helper cuda windows addition
gshimansky Oct 16, 2024
fb33d3d
Removed superfluous C++ standard options
gshimansky Oct 16, 2024
c992c68
Restored comment
gshimansky Oct 16, 2024
9a68984
Removed torch.cuda.synchronize and delete=False for temp files
gshimansky Oct 17, 2024
3b2ebf6
Restored original test code
gshimansky Oct 17, 2024
6c47d44
Restored delete=False when creating a tempfile
gshimansky Oct 17, 2024
37a7e8e
Use sysconfig instead of condition
gshimansky Oct 17, 2024
fe2116b
Take commit 188370325395c79a4ba3de0bc47e39a19fc83224 from upstream tr…
gshimansky Oct 18, 2024
726dab6
Revert hashing algorythm to what it was
gshimansky Oct 18, 2024
7c5c234
Changed condition to sysconfig call
gshimansky Oct 18, 2024
a3268ea
Removed redundant PARTIAL_SOURCES_INTENDED
gshimansky Oct 21, 2024
62774dc
Manually delete tempfile after it is not needed any more
gshimansky Oct 22, 2024
ae77f77
Merge branch 'main' into gregory/windows-support
gshimansky Oct 22, 2024
34589f1
Use /Zc:preprocessor instead of MSVC workarounds
gshimansky Oct 22, 2024
14bc5c2
Use command list instead of string
gshimansky Oct 22, 2024
97d2441
Merge branch 'main' into gregory/windows-support
gshimansky Oct 23, 2024
835497a
Merge branch 'main' into gregory/windows-support
gshimansky Oct 25, 2024
6b5a1d4
Remove change that is already implemented in 9d424e02ed4db695cc58baf9…
gshimansky Oct 25, 2024
39d252a
Remove change because __builtin_prefetch is no longer called on Windows
gshimansky Oct 25, 2024
14f01c8
Remove ifdef because AMD headers build on windows successfully
gshimansky Oct 25, 2024
ec7dff7
Merge branch 'main' into gregory/windows-support
gshimansky Oct 28, 2024
f57b7bf
Merge branch 'main' into gregory/windows-support
gshimansky Oct 28, 2024
42a6307
Removed windows llvm URL arch because we don't have llvm windows bina…
gshimansky Oct 29, 2024
79937ce
Removed ifdef around AMD calls because they successfully compile on w…
gshimansky Oct 29, 2024
34f2dce
Use 'long long' type for int64_t
gshimansky Oct 29, 2024
506ee92
Removed debug print
gshimansky Oct 30, 2024
dc8a628
Merge branch 'main' into gregory/windows-support
gshimansky Oct 30, 2024
81b009a
Removed cuobjdump.exe, nvdisasm.exe and ptxas.exe from ignore
gshimansky Oct 30, 2024
1c985de
Merge branch 'main' into gregory/windows-support
gshimansky Oct 31, 2024
cb3bbf9
Merge branch 'main' into gregory/windows-support
gshimansky Oct 31, 2024
fad8e24
Merge branch 'main' into gregory/windows-support
gshimansky Nov 1, 2024
e9db820
Merge branch 'main' into gregory/windows-support
gshimansky Nov 6, 2024
289b1fe
Removed redundant code
gshimansky Nov 6, 2024
74a1c93
Merge branch 'main' into gregory/windows-support
gshimansky Nov 7, 2024
a8cd2fa
Removed redundant code
gshimansky Nov 7, 2024
12a097d
Merge branch 'main' into gregory/windows-support
gshimansky Nov 7, 2024
d15586c
Removed redundant code
gshimansky Nov 8, 2024
36a1977
Merge branch 'main' into gregory/windows-support
gshimansky Nov 8, 2024
916613a
Merge branch 'main' into gregory/windows-support
gshimansky Nov 8, 2024
f017395
Remove empty line
gshimansky Nov 8, 2024
fdb63be
Merge branch 'main' into gregory/windows-support
gshimansky Nov 12, 2024
d940e61
Merge branch 'main' into gregory/windows-support
gshimansky Nov 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ python/*.whl
python/triton/_C/*.pyd
python/triton/_C/*.so
python/triton/_C/*.dylib
python/triton/_C/*.pdb
python/triton/_C/*.exe
python/triton/_C/*.ilk

benchmarks/dist
benchmarks/*.egg-info/
Expand Down
65 changes: 49 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,46 @@ endif()

include(ExternalProject)

set(CMAKE_CXX_STANDARD 17)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

project(triton)
include(CTest)

if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()


list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")

# Options
if(WIN32)
set(DEFAULT_BUILD_PROTON OFF)
else()
set(DEFAULT_BUILD_PROTON ON)
endif()

# Define the option with the determined default value
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON})
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")

# Ensure Python3 vars are set correctly
# used conditionally in this file and by lit tests

# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
if(NOT MSVC)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
else()
set(CMAKE_CXX_STANDARD 20)
victor-eds marked this conversation as resolved.
Show resolved Hide resolved
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
endif()

# Default build type
if(NOT CMAKE_BUILD_TYPE)
Expand All @@ -53,7 +65,15 @@ endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
if(NOT MSVC)
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why -std=gnu++17 is removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a redundant option as was discussed earlier in this PR. C++ level is defined by CMAKE_CXX_STANDARD and we had it defined here https://github.com/intel/intel-xpu-backend-for-triton/pull/2478/files#diff-1e7de1ae2d059d21e1dd75d5812d5a34b0222cef273b7c3a2af62eb747f9d20aL11 but now it was moved into condition for Linux/Windows because MSVC is unable to parse some of the templates on level 17 so it needs 20, while gcc refuses to compile code on level 20 https://github.com/intel/intel-xpu-backend-for-triton/pull/2478/files#diff-1e7de1ae2d059d21e1dd75d5812d5a34b0222cef273b7c3a2af62eb747f9d20aR37 .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this change upstream?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't set(CMAKE_CXX_STANDARD 17) correspond to -std=c++17, not -std=gnu++17?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -std=gnu++17 necessary? Looks like all tests in this PR pass without it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -std=gnu++17 necessary? Looks like all tests in this PR pass without it.

The Triton project has been using this extension for a very long time, for example I found a mention of gnu++11 in triton-lang/triton@50587bb. Even if we assume that all code, not just the one being tested, does not use this extension, it is unlikely that they will decide to remove something that they have been using for a long time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it might be some legacy like Windows build support that Triton inherited from previous projects and by now nobody knows why it was added.

else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -Wno-deprecated")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
endif()


# #########
Expand Down Expand Up @@ -107,7 +127,11 @@ endfunction()


# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
endif()

include_directories(".")
include_directories(${MLIR_INCLUDE_DIRS})
Expand All @@ -117,7 +141,8 @@ include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
include_directories(${PROJECT_SOURCE_DIR}/third_party)
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files

# link_directories(${LLVM_LIBRARY_DIR})
link_directories(${LLVM_LIBRARY_DIR})

add_subdirectory(include)
add_subdirectory(lib)

Expand Down Expand Up @@ -146,6 +171,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
# using pip install.
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${PYBIND11_INCLUDE_DIR})
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
link_directories(${PYTHON_LIB_DIRS})
else()
# Otherwise, we might be building from top CMakeLists.txt directly.
# Try to find Python and pybind11 packages.
Expand Down Expand Up @@ -228,7 +255,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
LLVMAArch64CodeGen
LLVMAArch64AsmParser
)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
list(APPEND TRITON_LIBRARIES
LLVMX86CodeGen
LLVMX86AsmParser
Expand Down Expand Up @@ -263,6 +290,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
if(WIN32)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
target_link_libraries(triton PRIVATE z)
endif()
Expand All @@ -289,6 +318,10 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
endif()
if(WIN32)
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
endif()

add_subdirectory(third_party/f2reduce)
add_subdirectory(bin)
Expand Down
57 changes: 55 additions & 2 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,49 @@ def copy_externals():
]


def find_vswhere():
program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")
vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe"
if vswhere_path.exists():
return vswhere_path
return None
victor-eds marked this conversation as resolved.
Show resolved Hide resolved


def find_visual_studio(version_ranges):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse code from CLFinder.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can and it should work in theory, but it doesn't. For some reason when I import these functions from CLFinder I am getting an error LINK : fatal error LNK1168: cannot open C:\b\tr\python\triton\_C\libtriton.pyd for writing. I have no idea how they are related but this link error is stably reproducible for me. I tried to debug the problem and found that environment after calling set_env_vars is identical when functions are reused from CLFinder.py or setup.py has its own copies, so now I am out of ideas how libtryton.pyd may end up locked.

vswhere = find_vswhere()
if not vswhere:
raise FileNotFoundError("vswhere.exe not found.")

for version_range in version_ranges:
command = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it works only if I specify -products:
"C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe" -version "[17.0,18.0)" -products Microsoft.VisualStudio.Product.BuildTools -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath -prerelease

@gshimansky do you know why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it works just fine if I don't specify -products or if I specify -products "*". Specifying -products Microsoft.VisualStudio.Product.BuildTools doesn't find anything for me but specifying -products Microsoft.VisualStudio.Product.Professional does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-products "*" works for me too. It seems that this is because I did not install the entire studio, but only the build tools. Should we add -products "*" to allow this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether it may result in multiple different products found as the result, but we can add it for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether it may result in multiple different products found as the result, but we can add it for now.

I thought about that too. But couldn't it be the same default behavior when -products parameter is not specified at all?

str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
"-property", "installationPath", "-prerelease"
]

try:
output = subprocess.check_output(command, text=True).strip()
if output:
return output
except subprocess.CalledProcessError:
continue

return None


def set_env_vars(vs_path, arch="x64"):
vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
if not vcvarsall_path.exists():
raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}")

command = ["call", vcvarsall_path, arch, "&&", "set"]
output = subprocess.check_output(command, shell=True, text=True)

for line in output.splitlines():
if '=' in line:
var, value = line.split('=', 1)
os.environ[var] = value


# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
Expand Down Expand Up @@ -281,10 +324,10 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
base_dir = os.path.dirname(__file__)
system = platform.system()
try:
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
except KeyError:
arch = platform.machine()
supported = {"Linux": "linux", "Darwin": "linux"}
supported = {"Linux": "linux", "Darwin": "linux", "Windows": "win"}
url = url_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
Expand Down Expand Up @@ -401,6 +444,11 @@ def get_proton_cmake_args(self):
def build_extension(self, ext):
lit_dir = shutil.which('lit')
ninja_dir = shutil.which('ninja')
if platform.system() == "Windows":
vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"])
env = set_env_vars(vs_path)
if not vs_path:
raise EnvironmentError("Visual Studio 2019 or 2022 not found.")
# lit is used by the test suite
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
thirdparty_cmake_args += self.get_pybind11_cmake_args()
Expand All @@ -421,6 +469,10 @@ def build_extension(self, ext):
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if platform.system() == "Windows":
installed_base = sysconfig.get_config_var('installed_base')
py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand All @@ -431,6 +483,7 @@ def build_extension(self, ext):

cmake_args += [f"-DCMAKE_BUILD_TYPE={cfg}"]
if platform.system() == "Windows":
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
else:
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
Expand Down
35 changes: 21 additions & 14 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,10 +2623,11 @@ def test_scan_layouts(M, N, src_layout, axis, device):
}}
"""

with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f:
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
f.write(ir)
f.flush()
f.close()
kernel = triton.compile(f.name)
os.remove(f.name)
rs = RandomState(17)
x = rs.randint(-100, 100, (M, N)).astype('int32')

Expand Down Expand Up @@ -2757,10 +2758,11 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>
""" + epilogue

with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f:
f.write(ir)
f.flush()
f.close()
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
kernel = triton.compile(f.name)
os.remove(f.name)

rs = RandomState(17)
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
Expand Down Expand Up @@ -2811,10 +2813,11 @@ def test_store_op(M, src_layout, device):
}}
"""

with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f:
f.write(ir)
f.flush()
f.close()
store_kernel = triton.compile(f.name)
os.remove(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, 1)).astype('float32')
Expand Down Expand Up @@ -2861,10 +2864,11 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
}}
}}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f:
f.write(ir)
f.flush()
f.close()
kernel = triton.compile(f.name)
os.remove(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, )).astype('int32')
Expand Down Expand Up @@ -2943,10 +2947,11 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
}}
}}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f:
f.write(ir)
f.flush()
f.close()
kernel = triton.compile(f.name)
os.remove(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
Expand Down Expand Up @@ -5344,10 +5349,11 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
z = torch.empty_like(x, device=device)

with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f:
f.write(ir)
f.flush()
f.close()
kernel = triton.compile(f.name)
os.remove(f.name)
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())

assert torch.equal(z, x)
Expand Down Expand Up @@ -5457,10 +5463,11 @@ def do_test(src_layout, dst_layout):
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
z = torch.empty_like(x)

with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f:
f.write(ir)
f.flush()
f.close()
kernel = triton.compile(f.name)
os.remove(f.name)
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())

assert torch.equal(z, x)
Expand Down
8 changes: 6 additions & 2 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import multiprocessing
import shutil
import os

import triton
import triton.language as tl
Expand Down Expand Up @@ -27,7 +28,10 @@ def kernel_sub(a, b, o, N: tl.constexpr):

def test_compile_in_subproc() -> None:
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
multiprocessing.set_start_method('fork')
if os.name == "nt":
multiprocessing.set_start_method('spawn')
else:
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
proc.start()
proc.join()
Expand Down Expand Up @@ -92,7 +96,7 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:

# stage 2.p
shutil.rmtree(fresh_triton_cache)
assert multiprocessing.get_start_method() == 'fork'
assert multiprocessing.get_start_method() in ['fork', 'spawn']
proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, ))

# stage 3.c
Expand Down
7 changes: 6 additions & 1 deletion python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import hashlib
import subprocess
import sysconfig

from abc import ABCMeta, abstractmethod, abstractclassmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -231,13 +232,17 @@ def __init__(self, target: GPUTarget) -> None:

@staticmethod
def _path_to_binary(binary: str):
binary += sysconfig.get_config_var("EXE")
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
]
for p in paths:
bin = p.split(" ")[0]
if os.name != "nt":
bin = p.split(" ")[0]
else:
bin = p
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
Expand Down
Loading