From 19d1831f3a6c6d86fbeb81ee1a1e65edff9476d2 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 21 Dec 2022 00:32:10 -0800 Subject: [PATCH] [OPTIMIZER] Adjusted heuristics (#1001) --- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 26c269f8f0fb..7dcdc0162b10 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -461,14 +461,18 @@ class FoldConvertAndReduce : public mlir::RewritePattern { public: explicit FoldConvertAndReduce(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 2, context) {} + 1, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *cvtOp, mlir::PatternRewriter &rewriter) const override { auto cvt = dyn_cast(*cvtOp); auto srcEncoding = + cvt.getOperand().getType().cast().getEncoding(); + auto dstEncoding = cvt.getResult().getType().cast().getEncoding(); + if (srcEncoding.isa()) + return failure(); SetVector cvtSlices; auto filter = [&](Operation *op) { return op->getBlock() == cvt->getBlock() && @@ -499,7 +503,7 @@ class FoldConvertAndReduce : public mlir::RewritePattern { llvm::MapVector toConvert; if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 && failed(simulateBackwardRematerialization(argOp, processed, layout, - toConvert, srcEncoding))) { + toConvert, dstEncoding))) { return failure(); } }