Skip to content

Commit

Permalink
[WIP] support remsi canonicalizer
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanfz98 committed Nov 22, 2023
1 parent de797bb commit 257ed48
Show file tree
Hide file tree
Showing 9 changed files with 709 additions and 0 deletions.
104 changes: 104 additions & 0 deletions include/triton-shared/Analysis/MultiBufferAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#ifndef TRITON_ANALYSIS_MASKANALYSIS_H
#define TRITON_ANALYSIS_MASKANALYSIS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

#include <utility>

namespace mlir {

class PatternRewriter;

namespace triton {

struct MultiBufferState {
SmallVector<int64_t> values;
int64_t scalar = 0;

int64_t getShape() const { return values.size(); }

bool hasValue() const { return !values.empty(); }

bool hasScalar() const { return scalar != 0; }

bool isEmpty() const { return !hasValue() && !scalar; }

bool hasSameShape(MultiBufferState other) const {
return hasValue() && other.hasValue() && getShape() == other.getShape();
}

// Recursively parse a Value; call the coresponding function based on the
// defining operation and Value type
LogicalResult parse(Value operand, const Location loc,
PatternRewriter &rewriter);

private:
// -------
// Utility functions to operate on MultiBufferState
// -------
LogicalResult addStates(const MultiBufferState &lhsState,
const MultiBufferState &rhsState, Location loc,
PatternRewriter &rewriter);

LogicalResult mulStates(const MultiBufferState &lhsState,
const MultiBufferState &rhsState, Location loc,
PatternRewriter &rewriter);

LogicalResult divStates(const MultiBufferState &lhsState,
const MultiBufferState &rhsState, Location loc,
PatternRewriter &rewriter);

LogicalResult remStates(const MultiBufferState &lhsState,
const MultiBufferState &rhsState, Location loc,
PatternRewriter &rewriter);
// -------
// Helper functions to parse values to populate MultiBufferState
// -------

LogicalResult parseConstant(arith::ConstantOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseAdd(arith::AddIOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseMul(arith::MulIOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseRem(arith::RemSIOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseDiv(arith::DivSIOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseMakeRange(triton::MakeRangeOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseSplat(triton::SplatOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseGetProgramId(triton::GetProgramIdOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseStore(triton::StoreOp op, const Location loc,
PatternRewriter &rewriter);

LogicalResult parseAddPtr(triton::AddPtrOp op, const Location loc,
PatternRewriter &rewriter);
};

} // namespace triton

} // namespace mlir

#endif
10 changes: 10 additions & 0 deletions include/triton-shared/Conversion/TritonToLinalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,14 @@ def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> {
let constructor = "triton::createTritonToLinalgPass()";
}

def MultiBufferRewrite : Pass<"multibuffer-rewrite", "mlir::ModuleOp"> {
let summary = "MultiBuffer Rewrite";
let constructor = "triton::createMultiBufferRewritePass()";
}

def TritonToLinalgCanonicalize : Pass<"triton-to-linalg-canonicalize", "mlir::ModuleOp"> {
let summary = "Triton to linalg canonicalizer";
let constructor = "triton::createTritonToLinalgCanonicalizePass()";
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ namespace triton {

std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinalgPass();

std::unique_ptr<OperationPass<ModuleOp>> createMultiBufferRewritePass();

std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinalgCanonicalizePass();

void populateTritonToLinalgCanonicalizationPatterns(
RewritePatternSet &patterns);

Expand Down
1 change: 1 addition & 0 deletions lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_library(TritonSharedAnalysis
OpFoldResultUtils.cpp
PtrAnalysis.cpp
UseAnalysis.cpp
MultiBufferAnalysis.cpp

DEPENDS
TritonAnalysis
Expand Down
Loading

0 comments on commit 257ed48

Please sign in to comment.