Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] introduce a pass to stage complex sparse operations in… #68436

Merged
merged 2 commits into from
Oct 6, 2023

Conversation

PeimingLiu
Copy link
Member

…to simple steps

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Oct 6, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Changes

…to simple steps


Full diff: https://github.com/llvm/llvm-project/pull/68436.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+9)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+12)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+17)
  • (added) mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp (+4)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c1e217675020f08..c537e92a51d5333 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -87,6 +87,15 @@ std::unique_ptr<Pass> createSparsificationPass();
 std::unique_ptr<Pass>
 createSparsificationPass(const SparsificationOptions &options);
 
+//===----------------------------------------------------------------------===//
+// The StageSparseOperations pass.
+//===----------------------------------------------------------------------===//
+
+/// Sets up StageSparseOperation rewriting rules.
+void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createStageSparseOperationsPass();
+
 //===----------------------------------------------------------------------===//
 // The PostSparsificationRewriting pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d8d5dbb5ad3ce75..7071c3091d33f3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
   ];
 }
 
+def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
+  let summary = "Decompose a complex sparse operations into multiple stages";
+  let description = [{
+    A pass that decomposes a complex sparse operations into multiple stages.
+    E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC.
+  }];
+  let constructor = "mlir::createStageSparseOperationsPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
   let summary = "Applies sparse tensor rewriting rules after sparsification";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 5ef9d906f0e8b7c..0ca6668c8c74745 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseVectorization.cpp
   Sparsification.cpp
   SparsificationAndBufferizationPass.cpp
+  StageSparseOperations.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f50d3d4606554a1..e1f88ad9c0e1140 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -30,6 +30,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
 #define GEN_PASS_DEF_SPARSEVECTORIZATION
 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
+#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
@@ -92,6 +93,18 @@ struct SparsificationPass
   }
 };
 
+struct StageSparseOperationsPass
+    : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
+  StageSparseOperationsPass() = default;
+  StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateStageSparseOperationsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct PostSparsificationRewritePass
     : public impl::PostSparsificationRewriteBase<
           PostSparsificationRewritePass> {
@@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
   return std::make_unique<SparsificationPass>(options);
 }
 
+std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
+  return std::make_unique<StageSparseOperationsPass>();
+}
+
 std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
   return std::make_unique<PostSparsificationRewritePass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
new file mode 100644
index 000000000000000..4adc4d131198cc7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -0,0 +1,4 @@
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+void mlir::populateStageSparseOperationsPatterns(
+    RewritePatternSet & /*patterns*/) {}

@PeimingLiu PeimingLiu merged commit 0637440 into llvm:main Oct 6, 2023
2 checks passed
@PeimingLiu PeimingLiu deleted the stage-sparse branch October 6, 2023 21:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants