Skip to content

Commit 0f127ab

Browse files
committed
Add canonicalizer removing rank 0 affine.parallel
1 parent 3e31569 commit 0f127ab

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,8 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
587587
static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; }
588588
static StringRef getStepsAttrName() { return "steps"; }
589589
}];
590+
591+
let hasCanonicalizer = 1;
590592
}
591593

592594
def AffinePrefetchOp : Affine_Op<"prefetch"> {

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,6 +2614,33 @@ static LogicalResult verify(AffineVectorStoreOp op) {
26142614
return success();
26152615
}
26162616

2617+
namespace {
2618+
/// This pattern removes affine.parallel ops with no induction variables
2619+
struct AffineParallelRank0LoopRemover
2620+
: public OpRewritePattern<AffineParallelOp> {
2621+
using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
2622+
2623+
LogicalResult matchAndRewrite(AffineParallelOp op,
2624+
PatternRewriter &rewriter) const override {
2625+
// Check that there are no induction variables
2626+
if (op.lowerBoundsMap().getNumResults() != 0)
2627+
return failure();
2628+
// Remove the affine.parallel wrapper, retain the body in the same location
2629+
auto &parentOps = rewriter.getInsertionBlock()->getOperations();
2630+
auto &parallelBodyOps = op.region().front().getOperations();
2631+
parentOps.splice(mlir::Block::iterator(op), parallelBodyOps,
2632+
parallelBodyOps.begin(), std::prev(parallelBodyOps.end()));
2633+
rewriter.eraseOp(op);
2634+
return success();
2635+
}
2636+
};
2637+
} // end anonymous namespace
2638+
2639+
void AffineParallelOp::getCanonicalizationPatterns(
2640+
OwningRewritePatternList &results, MLIRContext *context) {
2641+
results.insert<AffineParallelRank0LoopRemover>(context);
2642+
}
2643+
26172644
//===----------------------------------------------------------------------===//
26182645
// TableGen'd op method definitions
26192646
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,18 @@ func @drop_duplicate_bounds(%N : index) {
604604
}
605605
return
606606
}
607+
608+
// -----
609+
610+
// CHECK: func @remove_rank0_affine_parallel(%[[OUT:.*]]: memref<f32>)
611+
func @remove_rank0_affine_parallel(%out: memref<f32>) {
612+
// CHECK-NEXT: %[[CST:.*]] = constant
613+
%cst = constant 0.0 : f32
614+
// CHECK-NEXT: affine.store %[[CST]], %[[OUT]][] : memref<f32>
615+
affine.parallel () = () to () {
616+
affine.parallel () = () to () {
617+
affine.store %cst, %out[] : memref<f32>
618+
}
619+
}
620+
return
621+
}

0 commit comments

Comments
 (0)