Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization #84225

Merged
merged 8 commits into from
Mar 12, 2024

Conversation

sahas3
Copy link
Contributor

@sahas3 sahas3 commented Mar 6, 2024

The current canonicalization of memref.dim operating on the result of memref.reshape into memref.load is incorrect as it doesn't check whether the index operand of memref.dim dominates the source memref.reshape op. It always introduces memref.load right after memref.reshape to ensure the memref is not mutated before the memref.load call. As a result, the following error is observed:

$> mlir-opt --canonicalize input.mlir

func.func @reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
    %c4 = arith.constant 4 : index
    %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
    %0 = arith.muli %arg2, %c4 : index
    %dim = memref.dim %reshape, %0 : memref<*xf32>
    return %dim : index
  }

results in:

dominator.mlir:22:12: error: operand #1 does not dominate this use
    %dim = memref.dim %reshape, %0 : memref<*xf32>
           ^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index
dominator.mlir:21:10: note: operand defined here (op in the same block)
    %0 = arith.muli %arg2, %c4 : index

Properly fixing this issue requires a dominator analysis which is expensive to run within a canonicalization pattern. So, this patch moves the canonicalization pattern to tensor.dim. Since tensors are immutable we don't need to worry about where to introduce the tensor.extract call after canonicalization.

Copy link

github-actions bot commented Mar 6, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 6, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Sayan Saha (sahas3)

Changes

The current canonicalization of memref.dim operating on the result of memref.reshape into memref.load is incorrect as it doesn't check whether the index operand of memref.dim dominates the source memref.reshape op. It always introduces memref.load right after memref.reshape to ensure the memref is not mutated before the memref.load call. As a result, the following error is observed:

$&gt; mlir-opt --canonicalize input.mlir

func.func @<!-- -->reshape_dim(%arg0: memref&lt;*xf32&gt;, %arg1: memref&lt;?xindex&gt;, %arg2: index) -&gt; index {
    %c4 = arith.constant 4 : index
    %reshape = memref.reshape %arg0(%arg1) : (memref&lt;*xf32&gt;, memref&lt;?xindex&gt;) -&gt; memref&lt;*xf32&gt;
    %0 = arith.muli %arg2, %c4 : index
    %dim = memref.dim %reshape, %0 : memref&lt;*xf32&gt;
    return %dim : index
  }

results in:

dominator.mlir:22:12: error: operand #<!-- -->1 does not dominate this use
    %dim = memref.dim %reshape, %0 : memref&lt;*xf32&gt;
           ^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) &lt;{nontemporal = false}&gt; : (memref&lt;?xindex&gt;, index) -&gt; index
dominator.mlir:21:10: note: operand defined here (op in the same block)
    %0 = arith.muli %arg2, %c4 : index

Properly fixing this issue requires a dominator analysis which is expensive to run within a canonicalization pattern. So, this patch moves the canonicalization pattern to tensor.dim. Since tensors are immutable we don't need to worry about where to introduce the tensor.extract call after canonicalization.


Full diff: https://github.com/llvm/llvm-project/pull/84225.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Loops.cpp (-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (-33)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (-42)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+80)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     Speculation::Speculatability getSpeculatability();
   }];
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
   MLIRContext *context = enclosingOp->getContext();
   RewritePatternSet patterns(context);
   patterns.add<LinalgRewritePattern<LoopType>>(context);
-  memref::DimOp::getCanonicalizationPatterns(patterns, context);
   tensor::DimOp::getCanonicalizationPatterns(patterns, context);
   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DimOp dim,
-                                PatternRewriter &rewriter) const override {
-    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
-    if (!reshape)
-      return failure();
-
-    // Place the load directly after the reshape to ensure that the shape memref
-    // was not mutated.
-    rewriter.setInsertionPointAfter(reshape);
-    Location loc = dim.getLoc();
-    Value load =
-        rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
-    if (load.getType() != dim.getType())
-      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
-    rewriter.replaceOp(dim, load);
-    return success();
-  }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<DimOfMemRefReshape>(context);
-}
-
 // ---------------------------------------------------------------------------
 // DmaStartOp
 // ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..ce9792f813cbb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+    if (!reshape)
+      return failure();
+
+    // Since tensors are immutable we don't need to worry about where to place
+    // the load call
+    rewriter.setInsertionPointAfter(dim);
+    Location loc = dim.getLoc();
+    Value load =
+        rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+    if (load.getType() != dim.getType())
+      load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+    rewriter.replaceOp(dim, load);
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
 
 // -----
 
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   memref.store
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
-  // Update the shape to test that he load ends up in the right place.
-  memref.store %c3, %arg1[%c3] : memref<?xindex>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
-//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   %[[CAST:.*]] = arith.index_cast %[[DIM]]
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-    -> index {
-  %c3 = arith.constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
 // CHECK-LABEL: func @alloc_const_fold
 func.func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..45d37c553a0025 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK:         return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+//   CHECK-NOT:   tensor.store
+//   CHECK-NOT:   tensor.dim
+//   CHECK-NOT: tensor.reshape
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+    -> index {
+  %c3 = arith.constant 3 : index
+  %0 = tensor.reshape %arg0(%arg1)
+      : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // Update the shape to test that the load ends up in the right place.
+  tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+  %1 = tensor.dim %0, %c3 : tensor<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+//       CHECK:  tensor.extract
+//  CHECK-NEXT:  %[[CAST:.*]] = arith.index_cast
+//   CHECK-NOT:  tensor.dim
+//   CHECK-NOT:  tensor.reshape
+//       CHECK:  return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+    -> index {
+    %c3 = arith.constant 3 : index
+    %0 = tensor.reshape %arg0(%arg1)
+        : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+    %1 = tensor.dim %0, %c3 : tensor<*xf32>
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+//       CHECK: scf.for
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+      %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+      %3 = arith.muli %arg3, %2 : index
+      scf.yield %3 : index
+    }
+    return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+//       CHECK: arith.muli
+//  CHECK-NEXT: tensor.extract
+//   CHECK-NOT: tensor.dim
+//   CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+    %c4 = arith.constant 4 : index
+    %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+    %0 = arith.muli %arg2, %c4 : index
+    %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+    return %dim : index
+  }

@joker-eph joker-eph changed the title [BugFix] : Move DimOp canonicalization from memref to tensor. [MLIR] Fix incorrect memref::DimOp canonicalization, move it to tensor::DimOp Mar 6, 2024
@joker-eph
Copy link
Collaborator

LGTM overall, but should we also have the negative tests you have on the memref::DimOp?

@MaheshRavishankar
Copy link
Contributor

I am not sure that removing the memref.dim canonicalization here is the right answer. The canonicalization itself seems fine apart from it inserting operation after the op.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the memrefs canonicalization needs to be removed

@MaheshRavishankar MaheshRavishankar dismissed their stale review March 7, 2024 05:48

Actually I understand why this is a problem. This op is just bonkers. The semantics are pretty poor. So at least it makes sense to remove the canonicalization.

@sahas3 sahas3 changed the title [MLIR] Fix incorrect memref::DimOp canonicalization, move it to tensor::DimOp [MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization Mar 8, 2024
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor dimOp canonicalization seemed valuable, are you bringing this another PR?

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp Show resolved Hide resolved
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp Show resolved Hide resolved
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp Outdated Show resolved Hide resolved
@sahas3
Copy link
Contributor Author

sahas3 commented Mar 8, 2024

The tensor dimOp canonicalization seemed valuable, are you bringing this another PR?

@joker-eph tensor.dim canonicalization is still part of this PR: https://github.com/llvm/llvm-project/pull/84225/files/e50232affc5a111db92b63df30031031b6ce7bc4#diff-025c342b2a76f2edaae5557ff3e81b0b9ef9f47107359e78d85654581318d154

Are you suggesting it should be it's own PR?

@joker-eph
Copy link
Collaborator

The tensor dimOp canonicalization seemed valuable, are you bringing this another PR?

@joker-eph tensor.dim canonicalization is still part of this PR

Nevermind: GitHub was showing me the diff since the last review (or just for the most recent push) so I was only seeing the Memref one.

Copy link

github-actions bot commented Mar 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@joker-eph
Copy link
Collaborator

@MaheshRavishankar any concern left here?

@MaheshRavishankar
Copy link
Contributor

@MaheshRavishankar any concern left here?

Nope. I dismissed my review here. Thanks for checking

@joker-eph
Copy link
Collaborator

Thanks! Merging then

@joker-eph joker-eph merged commit 26722f5 into llvm:main Mar 12, 2024
4 checks passed
Copy link

@sahas3 Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may recieve a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants