diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index b76ff8c291cc..749ee17458ea 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -16,8 +16,6 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" -#include "cpu/include/TritonCPUToLLVM/Passes.h" -#include "cpu/include/TritonToTritonCPU/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -66,12 +64,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUStreamPipelineV2(); mlir::registerTritonAMDGPUCanonicalizePointers(); - // CPU passes - mlir::triton::cpu::registerTritonToTritonCPUPasses(); - mlir::triton::cpu::registerTritonToTritonCPUPipeline(); - mlir::triton::cpu::registerTritonCPUToLLVMPasses(); - mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); - // TODO: register Triton & TritonGPU passes registry.insert traits = [], string baseCppClass = "::mlir::Attribute"> : AttrDef { - let description = [{TritonCPU attr.}]; - let attrName = "triton.cpu." # attrMnemonic; + let description = [{ + WIP... + }]; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td index 260db2743046..9ccac13f0b58 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td @@ -17,7 +17,6 @@ def TritonCPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "tensor::TensorDialect", - "mlir::memref::MemRefDialect", ]; let extraClassDeclaration = [{ @@ -25,8 +24,6 @@ def TritonCPU_Dialect : Dialect { }]; let useDefaultTypePrinterParser = 1; - let useDefaultAttributePrinterParser = 1; - let usePropertiesForAttributes = 1; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index bb7417ebd03e..16d9e433e899 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -7,57 +7,6 @@ include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" -include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -include "mlir/Interfaces/DestinationStyleOpInterface.td" -include "mlir/Interfaces/ViewLikeInterface.td" - -class TTC_Op traits = []> : - Op { -} - -def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> { - let summary = "Extract base memref from a block pointer"; - - let description = [{ - Extract base memref from a block pointer. It covers whole base tensor memory, - not only the block referenced. Base pointer, shape, and strides are used - in the resulting memref. Offsets and block shape are ignored. - - }]; - - let arguments = (ins TT_TensorPtr:$src); - - let results = (outs AnyRankedOrUnrankedMemRef:$result); - - let hasCanonicalizer = 1; - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} - -def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> { - let summary = "Extract indices from a block pointer."; - - let description = [{ - Extract indices that can be used to access the block using its base memref. - Indices are supposed to be used for vector loads/stores with the base - memref extracted from the same block pointer. - }]; - - let arguments = (ins TT_TensorPtr:$src); - - let results = (outs Variadic:$result); - - let builders = [ - OpBuilder<(ins "Value":$src)> - ]; - - let hasCanonicalizer = 1; - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} #endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 83db4ae41607..5c3aa2c1a827 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,4 @@ -#add_subdirectory(TritonToTritonCPU) +add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -#add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index e5eb53caf686..e28a65358dca 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -2,19 +2,16 @@ #include -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" - using namespace mlir; -using namespace mlir::triton; using namespace mlir::triton::cpu; //===----------------------------------------------------------------------===// @@ -23,35 +20,6 @@ using namespace mlir::triton::cpu; #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" -void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) {} - -void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) {} - -/// Parse an attribute registered to this dialect. -::mlir::Attribute -TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser, - ::mlir::Type type) const { - llvm_unreachable("parse stub called"); -} - -/// Print an attribute registered to this dialect. -void TritonCPUDialect::printAttribute(::mlir::Attribute attr, - ::mlir::DialectAsmPrinter &os) const { - llvm_unreachable("print stub called"); -} - -void ExtractIndicesOp::build(::mlir::OpBuilder &builder, - ::mlir::OperationState &state, Value src) { - assert(triton::isTensorPointerType(src.getType()) && - "Unexecpeted source type"); - auto tensorTy = dyn_cast( - dyn_cast(src.getType()).getPointeeType()); - SmallVector resTypes(tensorTy.getRank(), builder.getIndexType()); - build(builder, state, resTypes, src); -} - void TritonCPUDialect::initialize() { registerTypes(); @@ -66,9 +34,6 @@ void TritonCPUDialect::initialize() { >(); } -#define GET_OP_CLASSES -#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" - // verify TritonCPU ops LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/python/src/llvm.cc b/python/src/llvm.cc index e6ef24b44345..f9b98a2540a2 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -3,8 +3,6 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" -#include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -23,7 +21,6 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" -#include "llvm/TargetParser/Host.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include @@ -393,70 +390,6 @@ void init_triton_llvm(py::module &&m) { py::arg("flags") = std::vector{}, py::arg("enable_fp_fusion") = false); - m.def("set_host_target", [](llvm::Module *mod) { - mod->setTargetTriple(llvm::sys::getDefaultTargetTriple()); - std::string error; - auto target = - llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); - std::unique_ptr machine{target->createTargetMachine( - mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, - llvm::Reloc::PIC_)}; - mod->setDataLayout(machine->createDataLayout()); - }); - - m.def( - "translate_to_host_asm", - [](std::string llvmIR) -> py::object { - std::string res; - { - // when allow_threads goes out of scope, gil will be released - py::gil_scoped_release allow_threads; - // create LLVM module from C++ - llvm::LLVMContext context; - std::unique_ptr buffer = - llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); - llvm::SMDiagnostic error; - std::unique_ptr module = - llvm::parseIR(buffer->getMemBufferRef(), error, context); - if (!module) { - llvm::report_fatal_error( - "failed to parse IR: " + error.getMessage() + - "lineno: " + std::to_string(error.getLineNo())); - } - res = translateLLVMIRToASM( - *module, llvm::sys::getDefaultTargetTriple(), - llvm::sys::getHostCPUName().str(), "", {}, false, false); - } - return py::str(res); - }, - ret::take_ownership); - - m.def( - "translate_to_bc", - [](const std::string llvmIR) -> py::object { - py::gil_scoped_release allow_threads; - // create LLVM module - llvm::LLVMContext context; - std::unique_ptr buffer = - llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); - llvm::SMDiagnostic error; - std::unique_ptr module = - llvm::parseIR(buffer->getMemBufferRef(), error, context); - if (!module) { - llvm::report_fatal_error( - "failed to parse IR: " + error.getMessage() + - "lineno: " + std::to_string(error.getLineNo())); - } - // Write bitcode to a buffer. - llvm::SmallVector buf; - llvm::BitcodeWriter writer(buf); - writer.writeModule(*module); - writer.writeStrtab(); - std::string bitcode(buf.begin(), buf.end()); - return py::bytes(bitcode); - }, - ret::take_ownership); - m.def( "translate_to_asm", [](std::string llvmIR, std::string triple, std::string proc, diff --git a/python/src/passes.cc b/python/src/passes.cc index a2694f0e703e..f31542af5d5e 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -44,8 +44,8 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); - // ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", - // createConvertTritonToTritonCPUPass); + ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", + createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { @@ -81,7 +81,6 @@ void init_triton_passes_convert(py::module &&m) { ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); - ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass); } void init_triton_passes_llvmir(py::module &&m) { diff --git a/python/test/unit/language/conftest.py b/python/test/unit/language/conftest.py index 44615b8b883b..091f9ea41e7f 100644 --- a/python/test/unit/language/conftest.py +++ b/python/test/unit/language/conftest.py @@ -3,4 +3,3 @@ def pytest_configure(config): config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") - config.addinivalue_line("markers", "cpu: indicate whether test is supported on cpu") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 07fe8e2e8d59..3013bbf53177 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -193,7 +193,6 @@ def filter_layouts(layouts): return [l for l in layouts if is_layout_applicable(l)] -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) def test_empty_kernel(dtype_x, device): @@ -416,7 +415,6 @@ def test_dtype_codegen(): # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -471,7 +469,6 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): test_broadcast=(op != "%"), filter_y=filter_y, test_scalar=not skip_scalar_test) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) def test_addptr(dtype, order, device): @@ -498,7 +495,6 @@ def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): np.testing.assert_allclose(y, to_numpy(y_tri)) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y", [ # (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes @@ -519,7 +515,6 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) -@pytest.mark.cpu def test_unsigned_name_mangling(device): # Test that uint32 and int32 are mangled differently by the compiler SIZE = 128 @@ -556,7 +551,6 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # test bitwise ops # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -581,7 +575,6 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes @@ -604,7 +597,6 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "dtype_x, dtype_y, op, mode_x, mode_y", @@ -629,7 +621,6 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- # test broadcast # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): @@ -664,7 +655,6 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con # ---------- -@pytest.mark.cpu @pytest.mark.interpreter def test_slice(device): @@ -696,7 +686,6 @@ def slice_kernel(XBLOCK: tl.constexpr): # ------------------ -@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_slice(device): dst = torch.empty(128, device=device) @@ -712,7 +701,6 @@ def _kernel(dst): # ---------------- # test expand_dims # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims(device): @@ -761,7 +749,6 @@ def expand_dims_kernel(dummy, N: tl.constexpr): expand_dims_kernel[(1, )](dummy_tensor, N) -@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims_error_cases(device): @@ -825,7 +812,6 @@ def duplicate_dim2(dummy, N: tl.constexpr): # ---------------------------- # test invalid program id axis # ---------------------------- -@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_pid_axis(device): dst = torch.empty(128, device=device) @@ -842,7 +828,6 @@ def _kernel(dst): # --------------- # test where # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -895,7 +880,6 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl. assert (z == to_numpy(z_tri)).all() -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): @@ -940,7 +924,6 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr", [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') @@ -955,7 +938,6 @@ def test_unary_op(dtype_x, expr, num_ctas, device): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) @@ -966,7 +948,6 @@ def test_math_op(dtype_x, expr, x, device): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_erf_op(dtype, device): @@ -988,7 +969,6 @@ def kernel(Z, X, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_fma_op(dtype, device): @@ -1014,7 +994,6 @@ def kernel(Z, X, Y, W, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1027,7 +1006,6 @@ def test_math_divide_op(expr, num_ctas, device): # ------------- # test precise math # ------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr_prec, expr_ref", [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), @@ -1068,7 +1046,6 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): @@ -1114,7 +1091,6 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter def test_shapes_as_params(device): @@ -1184,7 +1160,6 @@ def make_ptr_str(name, shape): return f"{name} + {' + '.join(offsets)}" -@pytest.mark.cpu # TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` @pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] @@ -1254,7 +1229,6 @@ def tuples_fn(a, b): a * b -@pytest.mark.cpu @pytest.mark.interpreter def test_tuples(device): @@ -1347,7 +1321,6 @@ def noinline_multi_values_fn(x, y, Z): tl.store(Z, z) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) def test_noinline(mode, device): @@ -1601,7 +1574,6 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ @@ -4306,7 +4278,6 @@ def kernel(VALUE, X): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @@ -4344,7 +4315,6 @@ def kernel(Z, X, Y): np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) -@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_shape(device): @@ -4358,7 +4328,6 @@ def kernel(X): np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) -@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_scalar_shape(device): @@ -4376,7 +4345,6 @@ def kernel(X, s): reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("formats", reshape_list) def test_reshape(formats, device): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 45178a40bb29..8f47d046ac27 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -605,6 +605,13 @@ def run(self, *args, grid, warmup, **kwargs): sigvals = sig_and_spec[:len(sigkeys)] signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + # The CPU launcher will provide the grid ids directly to the kernel. + # Note that this design is interim and subject to change. + if target.backend == 'cpu': + signature["__grid0"] = 'i32' + signature["__grid1"] = 'i32' + signature["__grid2"] = 'i32' + configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) constant_params = configs[0].get_constants() constants = { diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index d8be71ad6c11..683889547b0a 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,8 +1,3 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) -add_subdirectory(include) -add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) - target_link_libraries(TritonCPU PUBLIC MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 357b5f448fe9..3c293cdf468f 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -4,7 +4,7 @@ import re from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget @@ -20,8 +20,6 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False - allowed_dot_input_precisions: Tuple[str] = ("ieee",) - allow_fp8e4nv: bool = False # TODO: We may introduce CPU-specific options like # of cores. @@ -42,7 +40,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "bc" + self.binary_ext = "exe" def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -64,6 +62,7 @@ def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) @@ -78,34 +77,33 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - cpu.passes.ttcpuir.add_triton_to_triton_cpu_pipeline(pm) + passes.ttir.add_convert_to_ttcpuir(pm) + + # + # TODO: + # + passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - passes.common.add_canonicalizer(pm) pm.run(mod) - metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) return mod @staticmethod def make_llir(src, metadata, options): - # warp-specialization mutates num_warps - num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") - if num_warp_groups is not None: - metadata["num_warps"] *= num_warp_groups - metadata["threads_per_warp"] = 1 mod = src # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) - cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) - passes.convert.add_math_to_llvmir(pm) - cpu.passes.ttcpuir.add_math_to_libm(pm) - cpu.passes.ttcpuir.add_vector_to_llvmir(pm) - cpu.passes.ttcpuir.add_memref_to_llvmir(pm) + + cpu.passes.ttcpuir.add_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + + passes.convert.add_scf_to_cf(pm) + passes.convert.add_cf_to_llvmir(pm) passes.convert.add_arith_to_llvmir(pm) - cpu.passes.ttcpuir.add_func_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) @@ -113,40 +111,45 @@ def make_llir(src, metadata, options): passes.llvmir.add_di_scope(pm) pm.run(mod) - # Find kernel fn - kernel_names = cpu.find_kernel_names(mod) - assert len(kernel_names) == 1, f"expected exactly 1 kernel in a module, got {kernel_names}" - # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) - llvm.set_host_target(llvm_mod) - #if options.extern_libs: - # paths = [path for (name, path) in options.extern_libs] - # llvm.link_extern_libs(llvm_mod, paths) + + # TODO: + if not llvm_mod: + metadata["shared"] = 0 + return src + + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llvm.link_extern_libs(llvm_mod, paths) llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - # Get some metadata + + # CPU doesn't have SMEM, but just to make it work for now. metadata["shared"] = 0 - metadata["name"] = kernel_names[0] + + # Cleanup ret = str(llvm_mod) del llvm_mod del context return ret @staticmethod - def make_bc(src, metadata, options): - if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": - print("********** Module ASM **********") - print(llvm.translate_to_host_asm(src)) - ret = llvm.translate_to_bc(src) - return ret + def make_exe(src, metadata, options): + # Just a quick hack while developing the backend. + names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) + assert len(names) == 1 + metadata["name"] = names[0] + + # TODO: Call llc to create an executable. + return src def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options) + stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.cpp b/third_party/cpu/backend/driver.cpp deleted file mode 100644 index babff3dfdebe..000000000000 --- a/third_party/cpu/backend/driver.cpp +++ /dev/null @@ -1,224 +0,0 @@ -//===- driver.cpp ---------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "llvm/Bitcode/BitcodeReader.h" -#include "llvm/ExecutionEngine/Orc/CompileUtils.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" -#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/TargetSelect.h" - -#include -#include -#include -#include -#include -#include -#include - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - - return Py_BuildValue("{s:i}", "max_shared_mem", 0); -} - -bool getBoolEnv(const std::string &env) { - const char *s = std::getenv(env.c_str()); - std::string str(s ? s : ""); - std::transform(str.begin(), str.end(), str.begin(), - [](unsigned char c) { return std::tolower(c); }); - return (str == "on" || str == "true" || str == "1"); -} - -llvm::orc::ThreadSafeContext &getThreadSafeContext() { - static llvm::orc::ThreadSafeContext tsc; - static std::once_flag init_flag; - std::call_once(init_flag, []() { - auto context = std::make_unique(); - tsc = llvm::orc::ThreadSafeContext(std::move(context)); - }); - return tsc; -} - -std::string llvmErrToString(const llvm::Error &err) { - std::string res; - llvm::raw_string_ostream os(res); - os << err; - return res; -}; - -struct CompiledKernel { - std::unique_ptr execution_session; - std::unique_ptr data_layout; - std::unique_ptr mangle; - std::unique_ptr object_layer; - std::unique_ptr compiler_layer; - llvm::orc::JITDylib *dylib = nullptr; - - CompiledKernel() = default; - CompiledKernel(CompiledKernel &&) = default; - - ~CompiledKernel() { - if (execution_session) - llvm::cantFail(execution_session->endSession()); - } -}; - -std::vector> compiled_kernels; - -static PyObject *loadBitcode(PyObject *self, PyObject *args) { - const char *name; - int shared; - PyObject *py_bytes; - int devId; - - if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) { - std::cerr << "loadBitcode arg parse failed" << std::endl; - return NULL; - } - - std::string kernel_name = name; - size_t binary_size = PyBytes_Size(py_bytes); - const char *binary_ptr = PyBytes_AsString(py_bytes); - - llvm::LLVMContext context; - auto buf = llvm::MemoryBuffer::getMemBuffer( - llvm::StringRef(binary_ptr, binary_size)); - auto mod = llvm::parseBitcodeFile(*buf, context); - if (!mod) { - std::cerr << "Failed to parse LLVM bitcode module" << std::endl; - return NULL; - } - - if (getBoolEnv("MLIR_ENABLE_DUMP")) { - llvm::errs() << "********** Loaded Module (kernel_name=" << name - << ") **********\n" - << **mod << "\n"; - } - - auto init_err = llvm::InitializeNativeTarget(); - if (init_err) { - std::cerr << "Failed to initialize native target." << std::endl; - return NULL; - } - - llvm::InitializeNativeTargetAsmPrinter(); - llvm::InitializeNativeTargetAsmParser(); - - auto self_epc = - llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); - - auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!detect_host_res) { - std::cerr << "Failed to initialize JITTargetMachineBuilder: " - << llvmErrToString(detect_host_res.takeError()); - return NULL; - } - llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); - - auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); - if (!data_layout_res) { - std::cerr << "Failed to initialize data layout: " - << llvmErrToString(data_layout_res.takeError()); - return NULL; - } - - CompiledKernel kernel; - kernel.execution_session = - std::make_unique(std::move(self_epc)); - kernel.data_layout = - std::make_unique(std::move(*data_layout_res)); - kernel.mangle = std::make_unique( - *kernel.execution_session, *kernel.data_layout); - kernel.object_layer = std::make_unique( - *kernel.execution_session, - []() { return std::make_unique(); }); - kernel.compiler_layer = std::make_unique( - *kernel.execution_session, *kernel.object_layer, - std::make_unique(std::move(tmb))); - - auto dylib_res = kernel.execution_session->createJITDylib("
"); - if (!dylib_res) { - std::cerr << "Failed to create initialize JITDylib: " - << llvmErrToString(dylib_res.takeError()); - return NULL; - } - - kernel.dylib = &(*dylib_res); - kernel.dylib->addGenerator(llvm::cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - kernel.data_layout->getGlobalPrefix()))); - - // Compile module. - (**mod).setDataLayout(*kernel.data_layout); - llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); - auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); - if (err) { - std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); - return NULL; - } - - // Find kernel function pointer. - auto lookup_res = - kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); - if (!lookup_res) { - std::cerr << "Failed to find function " << std::string(name) - << "\nError: " << llvmErrToString(lookup_res.takeError()); - return NULL; - } - uint64_t fn_ptr = lookup_res->getAddress().getValue(); - - compiled_kernels.push_back( - std::make_unique(std::move(kernel))); - auto *kernel_ptr = compiled_kernels.back().get(); - - return Py_BuildValue("(KKii)", reinterpret_cast(kernel_ptr), - reinterpret_cast(fn_ptr), 0, 0); -} - -static PyObject *initContext(PyObject *self, PyObject *args) { - return Py_BuildValue("(K)", (uint64_t)0); -} - -static PyObject *initDevices(PyObject *self, PyObject *args) { - return Py_BuildValue("(i)", 1); -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBitcode, METH_VARARGS, - "Load provided SPV into ZE driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cpu_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_cpu_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - PyModule_AddFunctions(m, ModuleMethods); - return m; -} diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 743684d2640f..3f3816a99b9f 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,100 +1,5 @@ -import os -import hashlib -import tempfile -from pathlib import Path -from triton.runtime.build import _build -from triton.runtime.cache import get_cache_manager -from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget - -dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") -llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") -llvm_root = os.path.expanduser(llvm_root) -llvm_dirs = os.listdir(llvm_root) -if len(llvm_dirs) == 1: - llvm_root = os.path.join(llvm_root, llvm_dirs[0]) -include_dir = [ - os.path.join(dirname, "include"), - os.path.join(llvm_root, "include"), -] -library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] -libraries = [ - "LLVMOrcJIT", - "LLVMPasses", - "LLVMX86CodeGen", - "LLVMX86AsmParser", - "LLVMX86Desc", - "LLVMX86Info", - "LLVMGlobalISel", - "LLVMSelectionDAG", - "LLVMHipStdPar", - "LLVMCoroutines", - "LLVMipo", - "LLVMFrontendOpenMP", - "LLVMInstrumentation", - "LLVMAsmPrinter", - "LLVMCodeGen", - "LLVMObjCARCOpts", - "LLVMLinker", - "LLVMVectorize", - "LLVMScalarOpts", - "LLVMInstCombine", - "LLVMFrontendOffloading", - "LLVMExecutionEngine", - "LLVMAggressiveInstCombine", - "LLVMTransformUtils", - "LLVMTarget", - "LLVMRuntimeDyld", - "LLVMJITLink", - "LLVMIRPrinter", - "LLVMBitWriter", - "LLVMAnalysis", - "LLVMProfileData", - "LLVMSymbolize", - "LLVMDebugInfoDWARF", - "LLVMObject", - "LLVMTextAPI", - "LLVMMCParser", - "LLVMMCDisassembler", - "LLVMMC", - "LLVMIRReader", - "LLVMCFGuard", - "LLVMBitReader", - "LLVMAsmParser", - "LLVMCore", - "LLVMBinaryFormat", - "LLVMOrcTargetProcess", - "LLVMTargetParser", - "LLVMRemarks", - "LLVMOrcShared", - "LLVMOption", - "LLVMDebugInfoCodeView", - "LLVMCodeGenTypes", - "LLVMBitstreamReader", - "LLVMSupport", - "LLVMDemangle", - "stdc++", -] - - -def compile_module_from_src(src, name): - key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.cpp") - with open(src_path, "w") as f: - f.write(src) - so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - +from triton.backends.driver import CPUDriverBase # ------------------------ # Utils @@ -110,12 +15,22 @@ def __new__(cls): def __init__(self): pass - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") - self.load_binary = mod.load_binary - def get_device_properties(self, *args): - return {"max_shared_mem": 0} + @staticmethod + def get_device_properties(device): + # This is just dummy for now. We will need to implement driver.c. + return { + "max_shared_mem": 0, + "multiprocessor_count": 0, + "sm_clock_rate": 0, + "mem_clock_rate": 0, + "mem_bus_width": 0, + } + + @staticmethod + def load_binary(name, kernel_asm, shared, device): + # This is just dummy for now. We will need to implement driver.c. + return (None, kernel_asm, 0, 0) # ------------------------ @@ -123,228 +38,27 @@ def get_device_properties(self, *args): # ------------------------ -def ty_to_cpp(ty): - if ty[0] == '*': - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u1": "uint32_t", - "u8": "uint8_t", - "u16": "uint16_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - def make_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - arg_types = (', '.join(f"{ty_to_cpp(ty)}" for i, ty in signature.items()) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" - - def _extracted_type(ty): - if ty[0] == '*': - return "PyObject*" - return ty_to_cpp(ty) - - def format_of(ty): - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "l", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty] - - args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) - format = "iiiOKOOOO" + args_format - args_list = ', '.join(f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - - # generate glue code - src = f""" -#include -#include -#include -#include - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include -#include - -using kernel_ptr_t = void(*)({arg_types}); - -typedef struct _DevicePtrInfo {{ - void* dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); - if(!ptr_info.dev_ptr) {{ - return ptr_info; - }} - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; -}} - -static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - // TODO: add OMP pragmas to run in parallel - for (uint32_t z = 0; z < gridZ; ++z) {{ - for (uint32_t y = 0; y < gridY; ++y) {{ - for (uint32_t x = 0; x < gridX; ++x) {{ - (*kernel_ptr)({args_list + ', ' if len(arg_decls) > 0 else ''} x, y, z); - }} - }} - }} -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - - - int gridX, gridY, gridZ; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *py_obj_stream; - void* pKrnl; - - {' '.join([f"{_extracted_type(ty)} arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook {', ' + arg_ptrs_list if len(signature) > 0 else ''})) {{ - return NULL; - }} - - void *pStream = PyLong_AsVoidPtr(py_obj_stream); - kernel_ptr_t kernel_ptr = reinterpret_cast(pKrnl); - - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - run_omp_kernels(gridX, gridY, gridZ, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - if (PyErr_Occurred()) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_cpu_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_cpu_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" - return src + pass class CPULauncher(object): def __init__(self, src, metadata): - ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} - constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} - src = make_launcher(constants, signature, ids) - mod = compile_module_from_src(src, "__triton_cpu_launcher") - self.launch = mod.launch + # TODO: + self.launch = lambda *args, **kwargs: None def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) -class CPUDriver(DriverBase): +class CPUDriver(CPUDriverBase): def __init__(self): self.utils = CPUUtils() self.launcher_cls = CPULauncher super().__init__() - def get_current_device(self): - return 0 - - def get_current_stream(self, device): - return 0 - def get_current_target(self): # Capability and warp size are zeros for CPU. # TODO: GPUTarget naming isn't obviously good. diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt deleted file mode 100644 index fc9a19e52b0d..000000000000 --- a/third_party/cpu/include/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(TritonCPUToLLVM) -add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt deleted file mode 100644 index 64b36523d35d..000000000000 --- a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) -add_public_tablegen_target(TritonCPUConversionPassIncGen) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h deleted file mode 100644 index 74f74b00870c..000000000000 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H -#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include - -namespace mlir { - -class ModuleOp; -template class OperationPass; - -namespace triton { -namespace cpu { - -#define GEN_PASS_DECL -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" - -std::unique_ptr> createFuncOpToLLVMPass(); -std::unique_ptr> createMemoryOpToLLVMPass(); -std::unique_ptr> createGetProgramIdOpToLLVMPass(); - -void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); -void registerTritonCPUToLLVMPipeline(); - -#define GEN_PASS_REGISTRATION -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" - -} // namespace cpu -} // namespace triton - -} // namespace mlir - -#endif diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td deleted file mode 100644 index c75b58b572f1..000000000000 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef TRITONCPU_CONVERSION_PASSES -#define TRITONCPU_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert FuncOp to LLVM for CPU."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createFuncOpToLLVMPass()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::scf::SCFDialect", - "mlir::LLVM::LLVMDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert Triton memory operations to LLVM for CPU."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createMemoryOpToLLVMPass()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::scf::SCFDialect", - "mlir::memref::MemRefDialect", - "mlir::LLVM::LLVMDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert Triton GetProgramId to LLVM for CPU."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createGetProgramIdOpToLLVMPass()"; - - let dependentDialects = ["mlir::LLVM::LLVMDialect", - "mlir::triton::TritonDialect"]; -} - -#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt deleted file mode 100644 index 56e231273ed6..000000000000 --- a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) -add_public_tablegen_target(TritonToTritonCPUPassIncGen) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h deleted file mode 100644 index ab98a8741a16..000000000000 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES_H -#define TRITONTOTRITONCPU_CONVERSION_PASSES_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include - -namespace mlir { - -class ModuleOp; -template class OperationPass; - -namespace triton { -namespace cpu { - -#define GEN_PASS_DECL -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" - -std::unique_ptr> createConvertElementwiseOps(); -std::unique_ptr> createConvertMemoryOps(); -std::unique_ptr> createConvertPtrOps(); -std::unique_ptr> createConvertDotOp(); - -void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); -void registerTritonToTritonCPUPipeline(); - -#define GEN_PASS_REGISTRATION -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" - -} // namespace cpu -} // namespace triton - -} // namespace mlir - -#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td deleted file mode 100644 index 77e6528c6943..000000000000 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES -#define TRITONTOTRITONCPU_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { - let summary = "Convert Triton memory ops."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::ModuleOp"> { - let summary = "Convert elementwise ops."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertElementwiseOps()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { - let summary = "Convert Triton ops related to pointer arithmetics."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertPtrOps()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def ConvertDotOp : Pass<"triton-cpu-convert-dot-op", "mlir::ModuleOp"> { - let summary = "Convert Triton DotOp."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertDotOp()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -#endif diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt deleted file mode 100644 index fc9a19e52b0d..000000000000 --- a/third_party/cpu/lib/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(TritonCPUToLLVM) -add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt deleted file mode 100644 index 884c9352ef1b..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_triton_library(TritonCPUToLLVM - FuncOpToLLVM.cpp - GetProgramIdOpToLLVM.cpp - MemoryOpToLLVM.cpp - Pipeline.cpp - TypeConverter.cpp - - DEPENDS - TritonCPUToLLVMConversionPassIncGen - - LINK_LIBS PUBLIC - MLIRVectorToLLVMPass -) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp deleted file mode 100644 index 5895341fc34b..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Pass/Pass.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_FUNCOPTOLLVM -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -namespace mlir { -FailureOr -convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter); -} - -using namespace mlir; -using namespace mlir::triton; - -namespace { - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalOp(); - } -}; - -struct FuncOpConversion : public ConvertOpToLLVMPattern { - FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - - /// Only retain those attributes that are not constructed by - /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument - /// attributes. - static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, - SmallVectorImpl &result) { - - for (const auto &attr : op->getAttrs()) { - if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == op.getFunctionTypeAttrName() || - attr.getName() == "std.varargs" || - (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) - continue; - result.push_back(attr); - } - } - - triton::FuncOp amendProgramIdArgs(triton::FuncOp funcOp, - ConversionPatternRewriter &rewriter) const { - // Push back a variable that indicates the current stack pointer of shared - // memory to the function arguments. - auto loc = funcOp.getLoc(); - auto ctx = funcOp->getContext(); - // 1. Modify the function type to add new arguments. - auto funcTy = funcOp.getFunctionType(); - auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); - amendedInputTy.push_back(i32_ty); - amendedInputTy.push_back(i32_ty); - amendedInputTy.push_back(i32_ty); - auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, - funcTy.getResults()); - // 2. Modify the argument attributes to add new arguments. - SmallVector amendedAttrs; - filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); - SmallVector amendedArgAttrs; - if (funcOp.getAllArgAttrs()) - amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedAttrs.push_back(rewriter.getNamedAttr( - funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); - // 3. Add a new arguments to the region - auto amendedFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); - auto ®ion = funcOp.getBody(); - region.addArgument(funcTy, loc); - rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), - amendedFuncOp.end()); - return amendedFuncOp; - } - - LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Prevent LLVM's inliner to inline this function - auto modifiedFuncOp = funcOp; - if (LLVM::isKernel(funcOp)) - modifiedFuncOp = amendProgramIdArgs(modifiedFuncOp, rewriter); - - LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( - modifiedFuncOp, rewriter, *getTypeConverter()); - if (!newFuncOp) - return failure(); - - // required by AxisInfoAnalysis - if (LLVM::isKernel(funcOp)) - rewriter.eraseOp(modifiedFuncOp); - rewriter.eraseOp(funcOp); - return success(); - } -}; - -struct ReturnOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - LLVM::ReturnOp newOp; - if (adaptor.getOperands().size() < 2) { - // Single or no return value. - newOp = - rewriter.create(op.getLoc(), adaptor.getOperands()); - } else { - // Pack the results into a struct. - auto funcOp = op->getParentOfType(); - auto packedResultsTy = this->getTypeConverter()->packFunctionResults( - funcOp.getResultTypes()); - Value packedResults = - rewriter.create(op.getLoc(), packedResultsTy); - auto loc = op.getLoc(); - for (auto it : llvm::enumerate(adaptor.getOperands())) { - packedResults = - insert_val(packedResultsTy, packedResults, it.value(), it.index()); - } - newOp = rewriter.create(op.getLoc(), packedResults); - } - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } -}; - -// CallOpInterfaceLowering is adapted from -// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 -struct CallOpConversion : public ConvertOpToLLVMPattern { - CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - - LogicalResult - matchAndRewrite(triton::CallOp callOp, - typename triton::CallOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); - auto newCallOp = - convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); - if (!newCallOp) - return failure(); - auto results = getCallOpResults(callOp, newCallOp, rewriter); - rewriter.replaceOp(callOp, results); - return success(); - } - -private: - SmallVector - promoteOperands(triton::CallOp callOp, - typename triton::CallOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = callOp.getLoc(); - auto caller = callOp->getParentOfType(); - auto promotedOperands = this->getTypeConverter()->promoteOperands( - callOp.getLoc(), /*opOperands=*/callOp->getOperands(), - adaptor.getOperands(), rewriter); - return promotedOperands; - } - - LLVM::CallOp - convertCallOpToLLVMCallOp(triton::CallOp callOp, - ArrayRef promotedOperands, - ConversionPatternRewriter &rewriter) const { - // Pack the result types into a struct. - Type packedResult = nullptr; - unsigned numResults = callOp.getNumResults(); - auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - - if (numResults != 0) { - if (!(packedResult = - this->getTypeConverter()->packFunctionResults(resultTypes))) - return nullptr; - } - auto newCallOp = rewriter.create( - callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), - promotedOperands, callOp->getAttrs()); - return newCallOp; - } - - SmallVector - getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, - ConversionPatternRewriter &rewriter) const { - auto numResults = callOp.getNumResults(); - SmallVector results; - if (numResults < 2) { - // If < 2 results, packing did not do anything and we can just return. - results.append(newCallOp.result_begin(), newCallOp.result_end()); - } else { - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create( - callOp.getLoc(), newCallOp->getResult(0), i)); - } - } - return results; - } -}; - -struct FuncOpToLLVM : public triton::impl::FuncOpToLLVMBase { - using FuncOpToLLVMBase::FuncOpToLLVMBase; - - FuncOpToLLVM() : FuncOpToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - // Lower tt.func - RewritePatternSet funcPatterns(context); - funcPatterns.add(typeConverter, - /*benefit=*/1); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - funcPatterns); - if (failed( - applyPartialConversion(mod, convTarget, std::move(funcPatterns)))) - return signalPassFailure(); - - // Lower tt.call, tt.return - int benefit = 10; - RewritePatternSet patterns(context); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createFuncOpToLLVMPass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp deleted file mode 100644 index 4c593f1ff7aa..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_GETPROGRAMIDOPTOLLVM -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalOp(); - } -}; - -// TODO: use enums to access struct fields. -struct GetProgramIdOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = op->getParentOfType(); - assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); - auto args = funcOp.getArguments(); - // Last three args are x, y, z program ids. - auto argIdx = args.size() - 3 + op.getAxisAsInt(); - assert(argIdx < args.size() && "out-of-bounds arg index"); - assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); - rewriter.replaceOp(op, args[argIdx]); - return success(); - } -}; - -struct GetProgramIdOpToLLVM - : public triton::impl::GetProgramIdOpToLLVMBase { - using GetProgramIdOpToLLVMBase::GetProgramIdOpToLLVMBase; - - GetProgramIdOpToLLVM() : GetProgramIdOpToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createGetProgramIdOpToLLVMPass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp deleted file mode 100644 index 594495c4ab9d..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp +++ /dev/null @@ -1,277 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_MEMORYOPTOLLVM -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalOp(); - } -}; - -// TODO: use enums to access struct fields. -struct ExtractMemRefOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ExtractMemRefOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); - auto memRefTy = cast(op.getType()); - auto rank = memRefTy.getRank(); - auto memRefStructTy = getTypeConverter()->convertType(op.getType()); - auto memRefStructFields = - cast(memRefStructTy).getBody(); - auto i64Ty = IntegerType::get(getContext(), 64); - - auto copyValue = [&](Value to, int64_t idxFrom, int64_t idxTo) { - auto valueTy = memRefStructFields[idxTo]; - Value val = rewriter.create( - loc, valueTy, tensorPtrStruct, idxFrom); - return rewriter.create(loc, memRefStructTy, to, val, - idxTo); - }; - - Value res = undef(memRefStructTy); - // Copy base. - res = copyValue(res, 0, 1); - // Use 0 offset. - res = rewriter.create(loc, memRefStructTy, res, - i64_val(0), 2); - // Copy shape. - res = copyValue(res, 2, 3); - // Copy strides. - res = copyValue(res, 3, 4); - - rewriter.replaceOp(op, res); - - return success(); - } -}; - -struct ExtractIndicesOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ExtractIndicesOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - auto loc = op.getLoc(); - Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); - auto rank = op.getNumResults(); - auto i64Ty = IntegerType::get(getContext(), 64); - SmallVector indices; - - for (int64_t i = 0; i < rank; i++) { - Value offs = rewriter.create( - loc, i64Ty, tensorPtrStruct, SmallVector{1, i}); - Value stride = rewriter.create( - loc, i64Ty, tensorPtrStruct, SmallVector{3, i}); - indices.push_back(rewriter.create(loc, offs, stride)); - } - - rewriter.replaceOp(op, indices); - - return success(); - } -}; - -struct MakeTensorPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto structTy = getTypeConverter()->convertType(op.getType()); - auto i64Ty = IntegerType::get(getContext(), 64); - - auto insertArray = [&](Value structVal, auto values, int64_t idx, - Type zextTo = nullptr) { - for (int64_t i = 0; i < static_cast(values.size()); ++i) { - Value val = values[i]; - if (zextTo) - val = rewriter.create(loc, zextTo, val); - structVal = rewriter.create( - loc, structTy, structVal, val, SmallVector{idx, i}); - } - return structVal; - }; - - Value res = undef(structTy); - // 0 - base pointer. - auto base = rewriter.getRemappedValue(op.getBase()); - res = rewriter.create(loc, structTy, res, base, 0); - // 1 - array for offsets. Promote values to i64. - res = insertArray(res, op.getOffsets(), 1, i64Ty); - // 2 - array for shape. - res = insertArray(res, op.getShape(), 2); - // 3 - array for strides. - res = insertArray(res, op.getStrides(), 3); - - rewriter.replaceOp(op, res); - - return success(); - } -}; - -struct AdvanceOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto i64Ty = IntegerType::get(getContext(), 64); - Value res = rewriter.getRemappedValue(op.getPtr()); - Type structTy = res.getType(); - auto offsets = op.getOffsets(); - - for (int64_t i = 0; i < offsets.size(); ++i) { - auto oldOffset = rewriter.create( - loc, i64Ty, res, SmallVector{1, i}); - auto step = rewriter.create(loc, i64Ty, offsets[i]); - auto newOffset = rewriter.create(loc, oldOffset, step); - res = rewriter.create(loc, structTy, res, newOffset, - SmallVector{1, i}); - } - - rewriter.replaceOp(op, res); - - return success(); - } -}; - -struct LoadOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type ptrTy = LLVM::LLVMPointerType::get(getContext()); - Value ptr = rewriter.getRemappedValue(op.getPtr()); - Type resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, ptr, 0, - op.getIsVolatile()); - return success(); - } -}; - -struct StoreOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value ptr = rewriter.getRemappedValue(op.getPtr()); - Value val = rewriter.getRemappedValue(op.getValue()); - rewriter.replaceOpWithNewOp(op, val, ptr); - return success(); - } -}; - -struct PtrToIntOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = rewriter.getRemappedValue(op.getSrc()); - Type resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, src); - return success(); - } -}; - -struct IntToPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = rewriter.getRemappedValue(op.getSrc()); - Type resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, src); - return success(); - } -}; - -struct MemoryOpToLLVM - : public triton::impl::MemoryOpToLLVMBase { - using MemoryOpToLLVMBase::MemoryOpToLLVMBase; - - MemoryOpToLLVM() : MemoryOpToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createMemoryOpToLLVMPass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp deleted file mode 100644 index 914f56e668f8..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Conversion/Passes.h" -#include "mlir/Pass/PassManager.h" - -namespace mlir { -namespace triton { -namespace cpu { - -void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { - pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); - pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); - pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); - // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); -} - -void registerTritonCPUToLLVMPipeline() { - PassPipelineRegistration<>("triton-cpu-to-llvmir", - "TritonCPU to LLVM conversion pipeline.", - tritonCPUToLLVMPipelineBuilder); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp deleted file mode 100644 index 144cb57b1115..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "TypeConverter.h" - -using namespace mlir; -using namespace mlir::triton; - -TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( - MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis) - : LLVMTypeConverter(ctx, option, analysis) { - addConversion([&](triton::PointerType type) -> std::optional { - return convertTritonPointerType(type); - }); - addConversion([this](RankedTensorType tensorTy) -> std::optional { - if (isa(tensorTy.getElementType())) - return VectorType::get(tensorTy.getShape(), - IntegerType::get(tensorTy.getContext(), 64)); - return std::nullopt; - }); -} - -Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( - triton::PointerType type) { - auto ctx = type.getContext(); - auto pointeeType = type.getPointeeType(); - if (isa(pointeeType)) { - // struct { - // ptr base_ptr; - // array offsets; - // array shape; - // array strides; - // } - auto tensorTy = cast(pointeeType); - auto rank = tensorTy.getShape().size(); - auto i64Ty = IntegerType::get(ctx, 64); - SmallVector types; - types.push_back(LLVM::LLVMPointerType::get(ctx)); - types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); - types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); - types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); - return LLVM::LLVMStructType::getLiteral(ctx, types); - } - return LLVM::LLVMPointerType::get(ctx); -} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h deleted file mode 100644 index 35d74a9ec430..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H -#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/Triton/IR/Types.h" - -using namespace mlir; -using namespace mlir::triton; - -class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { -public: - using TypeConverter::convertType; - - TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis = nullptr); - - Type convertTritonPointerType(triton::PointerType type); -}; - -#endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt deleted file mode 100644 index 9fa892b449ac..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_triton_library(TritonToTritonCPU - ConvertDotOp.cpp - ConvertElementwiseOps.cpp - ConvertMemoryOps.cpp - ConvertPtrOps.cpp - Pipeline.cpp - TypeConverter.cpp - - DEPENDS - TritonToTritonCPUPassIncGen - - LINK_LIBS PUBLIC - TritonCPUIR - MLIRVectorDialect -) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp deleted file mode 100644 index b6fbb1893202..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTDOTOP -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class PtrConversionTarget : public ConversionTarget { -public: - explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - addIllegalOp(); - } -}; - -struct DotOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = op.getContext(); - Value a = rewriter.getRemappedValue(op.getA()); - Value b = rewriter.getRemappedValue(op.getB()); - Value c = rewriter.getRemappedValue(op.getC()); - auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); - auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); - auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); - auto iteratorTypes = rewriter.getArrayAttr( - {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)}); - rewriter.replaceOpWithNewOp( - op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), - iteratorTypes); - return success(); - } -}; - -struct ConvertDotOp : public triton::impl::ConvertDotOpBase { - using ConvertDotOpBase::ConvertDotOpBase; - - ConvertDotOp() : ConvertDotOpBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - TritonToTritonCPUTypeConverter typeConverter; - PtrConversionTarget convTarget(*context, typeConverter); - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertDotOp() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp deleted file mode 100644 index 70e8c4ed3c66..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ /dev/null @@ -1,300 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTELEMENTWISEOPS -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class ElementwiseOpConversionTarget : public ConversionTarget { -public: - explicit ElementwiseOpConversionTarget(MLIRContext &ctx, - TypeConverter &converter) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - addDynamicallyLegalDialect( - [&](Operation *op) -> std::optional { - return converter.isLegal(op); - }); - addDynamicallyLegalDialect( - [&](Operation *op) -> std::optional { - return converter.isLegal(op); - }); - - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - } -}; - -template -struct ElementwiseOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpConversionPattern::getTypeConverter; - using typename OpConversionPattern::OpAdaptor; - - LogicalResult - matchAndRewrite(OpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - OperationState newState(op.getLoc(), ResOpT::getOperationName()); - // Convert operands. - for (auto operand : op->getOperands()) { - Value newOperand = rewriter.getRemappedValue(operand); - newState.operands.push_back(newOperand); - } - // Convert result types. - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - newState.types))) { - return failure(); - } - newState.attributes = op->getAttrs(); - - auto newOp = rewriter.create(newState); - rewriter.replaceOp(op, newOp); - - return success(); - } -}; - -template <> -struct ElementwiseOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(isa(op.getType())); - auto resTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); - assert(resTy); - if (auto denseAttr = dyn_cast(op.getValueAttr())) { - rewriter.replaceOpWithNewOp(op, resTy, - denseAttr.reshape(resTy)); - } else { - llvm_unreachable("Unexpected constant attribute"); - } - return success(); - } -}; - -template <> -struct ElementwiseOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(isa(op.getType())); - auto loc = op.getLoc(); - auto src = rewriter.getRemappedValue(op.getSrc()); - auto srcShape = dyn_cast(src.getType()).getShape(); - auto resTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); - auto dstShape = resTy.getShape(); - auto elemTy = resTy.getElementType(); - - // There are restrictions on how shape can be modified by ShapeCastOp - // when rank is changed. For now, we simply detect it and handle through - // a cast to 1D vector. Better solution may be required later. - if (canCastShape(srcShape, dstShape)) { - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), src); - } else { - SmallVector tmpShape({resTy.getNumElements()}); - auto tmp = rewriter.create( - loc, VectorType::get(tmpShape, elemTy), src); - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), tmp); - } - return success(); - } - -private: - bool canCastShape(ArrayRef src, ArrayRef dst) const { - if (src.size() == dst.size()) - return true; - if (src.size() > dst.size()) - return canCastShape(dst, src); - - size_t srcIdx = 0; - size_t dstIdx = 0; - while (srcIdx < src.size() && dstIdx < dst.size()) { - if (src[srcIdx] == 1) { - ++srcIdx; - } else { - // Source dim size should be a product of continuous dest dim sizes. - int64_t srcSize = src[srcIdx++]; - int64_t dstSize = dst[dstIdx++]; - while (dstSize < srcSize && dstIdx < dst.size()) - dstSize *= dst[dstIdx++]; - if (dstSize != srcSize) - return false; - } - } - - // Skip trailing 1s. - while (srcIdx < src.size() && src[srcIdx] == 1) - ++srcIdx; - while (dstIdx < dst.size() && dst[dstIdx] == 1) - ++dstIdx; - - return srcIdx == src.size() && dstIdx == dst.size(); - } -}; - -struct ConvertElementwiseOps - : public triton::impl::ConvertElementwiseOpsBase { - using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; - - ConvertElementwiseOps() : ConvertElementwiseOpsBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - TritonToTritonCPUTypeConverter typeConverter; - ElementwiseOpConversionTarget convTarget(*context, typeConverter); - RewritePatternSet patterns(context); - - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - - patterns.add>( - typeConverter, context); - patterns - .add>( - typeConverter, context); - patterns.add< - ElementwiseOpConversion>( - typeConverter, context); - patterns.add>( - typeConverter, context); - patterns.add>( - typeConverter, context); - patterns.add>(typeConverter, - context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertElementwiseOps() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp deleted file mode 100644 index 1679ecc7af90..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ /dev/null @@ -1,277 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTMEMORYOPS -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -struct LoadOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = loadOp.getLoc(); - auto mask = loadOp.getMask(); - auto ptr = loadOp.getPtr(); - auto boundaryChecks = loadOp.getBoundaryCheck(); - - if (!triton::isTensorPointerType(ptr.getType())) { - return lowerToScalarLoads(loadOp, rewriter); - } - - // TODO: support masks. - if (mask) { - llvm_unreachable("unsupported load op"); - } - - auto memRef = rewriter.getRemappedValue(ptr); - auto rank = dyn_cast(memRef.getType()).getRank(); - auto resTy = dyn_cast( - getTypeConverter()->convertType(loadOp.getResult().getType())); - auto indices = rewriter.create(loc, ptr).getResults(); - SmallVector inBounds(rank, true); - for (auto dim : boundaryChecks) { - inBounds[dim] = false; - } - auto vecRead = rewriter.create(loc, resTy, memRef, - indices, inBounds); - rewriter.replaceOp(loadOp, vecRead); - return success(); - } - - LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, - ConversionPatternRewriter &rewriter) const { - // Scalar loads and boundary checks are not expected. - assert(loadOp.getBoundaryCheck().empty()); - assert(isa(loadOp.getType())); - - auto loc = loadOp.getLoc(); - auto vecTy = - dyn_cast(getTypeConverter()->convertType(loadOp.getType())); - auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); - auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) - : nullptr; - auto ptrTy = - dyn_cast(loadOp.getPtr().getType()).getElementType(); - auto cache = loadOp.getCache(); - auto evict = loadOp.getEvict(); - auto isVolatile = loadOp.getIsVolatile(); - - Value defaultVal = loadOp.getOther(); - if (!defaultVal) - defaultVal = rewriter.create( - loc, rewriter.getZeroAttr(vecTy.getElementType())); - Value dst = rewriter.create(loc, vecTy, defaultVal); - - int64_t numElems = vecTy.getNumElements(); - auto strides = computeStrides(vecTy.getShape()); - for (auto idx = 0; idx < numElems; ++idx) { - auto indices = delinearize(idx, strides); - Block *headerBlock = rewriter.getBlock(); - Block *condBlock = nullptr; - Value origDst = dst; - // Create a conditional block for load if there is a mask. - if (mask) { - condBlock = - rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(condBlock); - } - - Value ptr = rewriter.create(loc, ptrs, indices); - ptr = rewriter.create(loc, ptrTy, ptr); - Value val = - rewriter.create(loc, ptr, cache, evict, isVolatile); - dst = rewriter.create(loc, val, dst, indices); - - // Add predicate and branches. - if (mask) { - Block *footerBlock = - rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); - Value resDst = dst; - dst = footerBlock->addArgument(dst.getType(), dst.getLoc()); - rewriter.setInsertionPointToEnd(headerBlock); - auto predicate = rewriter.create(loc, mask, indices); - rewriter.create(loc, predicate, condBlock, - footerBlock, origDst); - rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, footerBlock, resDst); - rewriter.setInsertionPointToStart(footerBlock); - } - } - - rewriter.replaceOp(loadOp, dst); - - return success(); - } -}; - -struct StoreOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = storeOp.getLoc(); - auto mask = storeOp.getMask(); - auto ptr = storeOp.getPtr(); - auto boundaryChecks = storeOp.getBoundaryCheck(); - - if (!triton::isTensorPointerType(ptr.getType())) { - return lowerToScalarStores(storeOp, rewriter); - } - - // TODO: support masks. - if (mask) { - llvm_unreachable("unsupported store op"); - } - - auto value = rewriter.getRemappedValue(storeOp.getValue()); - auto memRef = rewriter.getRemappedValue(ptr); - auto rank = dyn_cast(memRef.getType()).getRank(); - auto indices = rewriter.create(loc, ptr).getResults(); - SmallVector inBounds(rank, true); - for (auto dim : boundaryChecks) { - inBounds[dim] = false; - } - auto vecWrite = rewriter.create(loc, value, memRef, - indices, inBounds); - rewriter.replaceOp(storeOp, vecWrite); - return success(); - } - - LogicalResult lowerToScalarStores(triton::StoreOp storeOp, - ConversionPatternRewriter &rewriter) const { - // Scalar stores and boundary checks are not expected. - assert(storeOp.getBoundaryCheck().empty()); - assert(isa(storeOp.getValue().getType())); - - auto loc = storeOp.getLoc(); - auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); - auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) - : nullptr; - auto vals = rewriter.getRemappedValue(storeOp.getValue()); - auto tensorTy = dyn_cast(storeOp.getPtr().getType()); - auto ptrTy = tensorTy.getElementType(); - auto cache = storeOp.getCache(); - auto evict = storeOp.getEvict(); - - int64_t numElems = tensorTy.getNumElements(); - auto strides = computeStrides(tensorTy.getShape()); - for (auto idx = 0; idx < numElems; ++idx) { - auto indices = delinearize(idx, strides); - Block *headerBlock = rewriter.getBlock(); - Block *condBlock = nullptr; - // Create a conditional block for store if there is a mask. - if (mask) { - condBlock = - rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(condBlock); - } - - Value ptr = rewriter.create(loc, ptrs, indices); - ptr = rewriter.create(loc, ptrTy, ptr); - Value val = rewriter.create(loc, vals, indices); - rewriter.create(loc, ptr, val, cache, evict); - - // Add predicate and branches. - if (mask) { - Block *footerBlock = - rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToEnd(headerBlock); - auto predicate = rewriter.create(loc, mask, indices); - rewriter.create(loc, predicate, condBlock, - footerBlock); - rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, footerBlock); - rewriter.setInsertionPointToStart(footerBlock); - } - } - - rewriter.eraseOp(storeOp); - - return success(); - } -}; - -class MemoryOpConversionTarget : public ConversionTarget { -public: - explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - // Allow only scalar loads and stores. - addDynamicallyLegalOp([](triton::LoadOp loadOp) { - return loadOp.getType().isIntOrIndexOrFloat(); - }); - addDynamicallyLegalOp([](triton::StoreOp storeOp) { - return storeOp.getValue().getType().isIntOrIndexOrFloat(); - }); - } -}; - -struct ConvertMemoryOps - : public triton::impl::ConvertMemoryOpsBase { - using ConvertMemoryOpsBase::ConvertMemoryOpsBase; - - ConvertMemoryOps() : ConvertMemoryOpsBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - MemoryOpConversionTarget convTarget(*context); - TritonToTritonCPUTypeConverter pointerConverter; - RewritePatternSet patterns(context); - patterns.add(pointerConverter, context); - patterns.add(pointerConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertMemoryOps() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp deleted file mode 100644 index ade8b858bbfb..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTPTROPS -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -unsigned getElemBitWidth(Type type) { - if (auto tensorTy = dyn_cast(type)) - return tensorTy.getElementType().getIntOrFloatBitWidth(); - if (auto vectorTy = dyn_cast(type)) - return vectorTy.getElementType().getIntOrFloatBitWidth(); - return type.getIntOrFloatBitWidth(); -} - -class PtrConversionTarget : public ConversionTarget { -public: - explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - // Allow only scalar pointer conversion. - addDynamicallyLegalOp( - [](triton::PtrToIntOp op) { return op.getType().isInteger(); }); - addDynamicallyLegalOp([](triton::IntToPtrOp op) { - return op.getSrc().getType().isInteger(); - }); - } -}; - -struct MakeRangeOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - int32_t start = static_cast(op.getStart()); - int32_t end = static_cast(op.getEnd()); - assert(end >= start); - - llvm::SmallVector values; - values.reserve(end - start); - for (int32_t v = start; v < end; ++v) { - values.push_back(v); - } - - Type resTy = getTypeConverter()->convertType(op.getType()); - auto newOp = rewriter.create( - op.getLoc(), resTy, rewriter.getI32VectorAttr(values)); - - rewriter.replaceOp(op, newOp); - return success(); - } -}; - -struct SplatOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value val = op.getSrc(); - Type dstValType = getTypeConverter()->convertType(val.getType()); - // Cast pointer - if (isa(val.getType())) - val = rewriter - .create( - loc, getTypeConverter()->convertType(val.getType()), val) - .getResult(); - Type resType = getTypeConverter()->convertType(op.getType()); - auto cast = rewriter.create(loc, resType, val); - - rewriter.replaceOp(op, cast); - return success(); - } -}; - -struct AddPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value ptr = rewriter.getRemappedValue(op.getPtr()); - Value offset = rewriter.getRemappedValue(op.getOffset()); - unsigned offsetBitWidth = getElemBitWidth(offset.getType()); - unsigned elemBitWidth = getPointeeBitWidth(op.getPtr().getType()); - // Compute scale. i1 elements take 1 byte. - Value scale = rewriter.create( - loc, (elemBitWidth + 7) / 8, offsetBitWidth); - if (isa(offset.getType())) - scale = rewriter.create(loc, offset.getType(), scale); - offset = rewriter.create(loc, offset, scale); - offset = rewriter.create(loc, ptr.getType(), offset); - rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); - return success(); - } -}; - -struct PtrToIntOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value val = rewriter.getRemappedValue(op.getSrc()); - auto resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, val); - return success(); - } -}; - -struct IntToPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value val = rewriter.getRemappedValue(op.getSrc()); - auto resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, val); - return success(); - } -}; - -struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { - using ConvertPtrOpsBase::ConvertPtrOpsBase; - - ConvertPtrOps() : ConvertPtrOpsBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - TritonToTritonCPUTypeConverter typeConverter; - PtrConversionTarget convTarget(*context, typeConverter); - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertPtrOps() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp deleted file mode 100644 index 16bff114ed81..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Conversion/Passes.h" -#include "mlir/Pass/PassManager.h" - -namespace mlir { -namespace triton { -namespace cpu { - -void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); - pm.addPass(mlir::triton::cpu::createConvertPtrOps()); - pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); - pm.addPass(mlir::triton::cpu::createConvertDotOp()); - // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); -} - -void registerTritonToTritonCPUPipeline() { - PassPipelineRegistration<>("triton-to-triton-cpu", - "Triton to TritonCPU conversion pipeline.", - tritonToTritonCPUPipelineBuilder); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp deleted file mode 100644 index 07b2da0468ba..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "TypeConverter.h" - -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { - addConversion([](Type type) { return type; }); - addConversion([](triton::PointerType ptrTy) -> Type { - if (triton::isTensorPointerType(ptrTy)) { - // Tensor pointer is translated into a memref - auto tensorTy = dyn_cast(ptrTy.getPointeeType()); - auto elemTy = tensorTy.getElementType(); - // TODO: use dynamic strides - SmallVector shape(tensorTy.getRank(), ShapedType::kDynamic); - return MemRefType::get(shape, elemTy); - } - return IntegerType::get(ptrTy.getContext(), 64); - }); - addConversion([this](RankedTensorType tensorTy) -> Type { - Type elemTy = convertType(tensorTy.getElementType()); - return VectorType::get(tensorTy.getShape(), elemTy); - }); - - // Converted ops produce vectors instead of tensors. Provide conversion - // here for users. Also, convert pointers when required. - addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> std::optional { - if (isa(type)) - return builder.create(loc, type, inputs); - return builder.create(loc, type, inputs) - .getResult(0); - }); - - // Converted loads and stores consume memrefs instead of pointers, use extract - // op to get them. Also, provide conversion for vector users and pointer - // casts. - addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> std::optional { - if (type.isInteger() && isa(inputs.front().getType())) - return builder.create(loc, type, inputs); - if (isa(type)) - return builder.create(loc, type, inputs) - .getResult(0); - if (isa(type)) - return builder.create(loc, type, inputs); - llvm_unreachable("Unexpected target materizalization"); - }); -} diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h deleted file mode 100644 index cb89f0886c60..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H -#define TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H - -#include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/Triton/IR/Types.h" - -using namespace mlir; -using namespace mlir::triton; - -class TritonToTritonCPUTypeConverter : public TypeConverter { -public: - using TypeConverter::convertType; - - TritonToTritonCPUTypeConverter(); - - Type convertTritonPointerType(triton::PointerType type); -}; - -#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index efc949d6f4a1..302951d04d59 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,20 +1,9 @@ -#include "TritonCPUToLLVM/Passes.h" -#include "TritonToTritonCPU/Passes.h" - -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/Passes.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "triton/Conversion/TritonCPUToLLVM/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" - #include #include #include @@ -25,26 +14,8 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - // m.def("add_to_llvmir", [](mlir::PassManager &pm) { - // pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); - // }); - m.def("add_triton_to_triton_cpu_pipeline", [](mlir::PassManager &pm) { - mlir::triton::cpu::tritonToTritonCPUPipelineBuilder(pm); - }); - m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { - mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); - }); - m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::createConvertVectorToLLVMPass()); - }); - m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); - }); - m.def("add_math_to_libm", [](mlir::PassManager &pm) { - pm.addPass(mlir::createConvertMathToLibmPass()); - }); - m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::createConvertFuncToLLVMPass()); + m.def("add_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); }); } @@ -54,18 +25,8 @@ void init_triton_cpu(py::module &&m) { m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); - - m.def("find_kernel_names", [](mlir::ModuleOp &mod) { - std::vector res; - mod.walk([&](mlir::FunctionOpInterface funcOp) { - if (funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) - res.push_back(funcOp.getName().str()); - }); - return res; - }); }