Skip to content

Commit

Permalink
[BACKEND][CPU] Make it buildable and runnable in a different environm…
Browse files Browse the repository at this point in the history
…ent (triton-lang#8)

* [BACKEND][CPU] Make it buildable and runnable in a different environment

* Revert seemingly inconsistent python code formatting
  • Loading branch information
minjang committed Oct 23, 2024
1 parent 82ef809 commit cb4875e
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 10 deletions.
5 changes: 3 additions & 2 deletions include/triton/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(TritonCPUToLLVM)
# TODO(minjang): I will remove these scratches soon.
# add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonGPUToLLVM)
add_subdirectory(TritonToTritonCPU)
# add_subdirectory(TritonToTritonCPU)
add_subdirectory(TritonToTritonGPU)
5 changes: 3 additions & 2 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#add_subdirectory(TritonToTritonCPU)
# TODO(minjang): I will remove these scratches soon.
# add_subdirectory(TritonToTritonCPU)
add_subdirectory(TritonToTritonGPU)
#add_subdirectory(TritonCPUToLLVM)
# add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonGPUToLLVM)
3 changes: 0 additions & 3 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Membar.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonCPU/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
Expand Down Expand Up @@ -44,8 +43,6 @@ void init_triton_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir",
createConvertTritonToTritonGPUPass, const std::string &,
int, int, int);
// ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir",
// createConvertTritonToTritonCPUPass);
}

void init_triton_passes_ttgpuir(py::module &&m) {
Expand Down
3 changes: 3 additions & 0 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
# CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag.
if src.endswith(".cpp") or src.endswith(".cc"):
cc_cmd += ["-std=c++17"]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand Down
2 changes: 1 addition & 1 deletion third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CPUOptions:
cluster_dims: tuple = (1, 1, 1)
extern_libs: dict = None
debug: bool = False
allowed_dot_input_precisions: Tuple[str] = ("ieee",)
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
allow_fp8e4nv: bool = False
enable_fp_fusion: bool = True

Expand Down
4 changes: 3 additions & 1 deletion third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"LLVMSupport",
"LLVMDemangle",
"stdc++",
"z",
]


Expand Down Expand Up @@ -176,7 +177,8 @@ def format_of(ty):
arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
kernel_fn_args = [i for i in signature.keys() if i not in constants]
kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else ''
kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t"
kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) +
", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t"

# generate glue code
src = f"""
Expand Down
2 changes: 1 addition & 1 deletion third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM)
add_public_tablegen_target(TritonCPUConversionPassIncGen)
add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen)

0 comments on commit cb4875e

Please sign in to comment.