Skip to content

Commit

Permalink
[OPTIMIZER] Adjusted heuristics (#1001)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Dec 21, 2022
1 parent 655b71a commit 19d1831
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,18 @@ class FoldConvertAndReduce : public mlir::RewritePattern {
public:
explicit FoldConvertAndReduce(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
1, context) {}

mlir::LogicalResult
matchAndRewrite(mlir::Operation *cvtOp,
mlir::PatternRewriter &rewriter) const override {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(*cvtOp);
auto srcEncoding =
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
return failure();
SetVector<Operation *> cvtSlices;
auto filter = [&](Operation *op) {
return op->getBlock() == cvt->getBlock() &&
Expand Down Expand Up @@ -499,7 +503,7 @@ class FoldConvertAndReduce : public mlir::RewritePattern {
llvm::MapVector<Value, Attribute> toConvert;
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
failed(simulateBackwardRematerialization(argOp, processed, layout,
toConvert, srcEncoding))) {
toConvert, dstEncoding))) {
return failure();
}
}
Expand Down

0 comments on commit 19d1831

Please sign in to comment.