Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce triton-arith-to-linalg pass #85

Merged
1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonArithToLinalg)
add_public_tablegen_target(TritonArithToLinalgConversionPassIncGen)
15 changes: 15 additions & 0 deletions include/triton-shared/Conversion/TritonArithToLinalg/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES_H
#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES_H

#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
20 changes: 20 additions & 0 deletions include/triton-shared/Conversion/TritonArithToLinalg/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef TRITON_ARITH_TO_LINALG_CONVERSION_PASSES
#define TRITON_ARITH_TO_LINALG_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def TritonArithToLinalg : Pass<"triton-arith-to-linalg", "mlir::ModuleOp"> {
let summary = "Convert Triton arithmetic operations into linalg";
let options = [
Option<"pidsToFuncArgs", "pids-to-func-args", "bool", /*default*/"true",
"Convert tt.get_program_id and tt.get_num_programs to reference to function arguments">,
Option<"ttToFuncFunc", "tt-to-func-func", "bool", /*default*/"true",
"Convert tt.func to func.func">,
Option<"addptrToLinalg", "addptr-to-linalg", "bool", /*default*/"true",
"Convert tt.addptr on tensors to linalg">,
Option<"assertToCf", "assert-to-cf", "bool", /*default*/"true",
"Convert tt.assert to cf.assert">,
];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H
#define TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir {
namespace triton {

#define GEN_PASS_DECL
#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h.inc"

void populateTritonArithToLinalgCanonicalizationPatterns(
RewritePatternSet &patterns);

void populateTritonArithToLinalgConversionPatterns(bool pidsToFuncArgs,
bool addptrToLinalg,
bool assertToCf,
RewritePatternSet &patterns);

} // namespace triton
} // namespace mlir

#endif // TRITON_CONVERSION_TRITONARITHTOLINALG_TRITONARITHTOLINALG_H
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
24 changes: 24 additions & 0 deletions lib/Conversion/TritonArithToLinalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
add_mlir_conversion_library(TritonArithToLinalg
TritonArithToLinalg.cpp
TritonArithToLinalgPass.cpp

DEPENDS
TritonArithToLinalgConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
MLIRSupport
TritonIR
TritonTransforms
TritonTilingExtIR
TritonStructuredIR
)
Loading