Skip to content

Conversation

obtuseangleAZ
Copy link
Contributor

@obtuseangleAZ obtuseangleAZ commented May 23, 2024

  • Add vector.interleave to spirv.VectorShuffle conversion,
  • Remove the vector.interleave to vector.shuffle conversion from populateVectorToSPIRVPatterns and CMake/Bazel dependencies

@llvmbot
Copy link
Member

llvmbot commented May 23, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

Changes

Add vector.interleave to spirv.VectorShuffle conversion, and remove the vector.interleave to vector.shuffle conversion in populateVectorToSPIRVPatterns.


Full diff: https://github.com/llvm/llvm-project/pull/93240.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+41-5)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorInterleaveOpConvert final
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+     // Check the source vector type
+    auto sourceType = interleaveOp.getSourceVectorType();
+    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported source vector type");
+    }  
+
+    // Check the result vector type
+    auto oldResultType = interleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported result vector type");
+
+    // Interleave the indices
+    int n = sourceType.getNumElements();
+    auto seq = llvm::seq<int64_t>(2 * n);
+    auto indices = llvm::to_vector(llvm::map_range(
+        seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+    // Emit a SPIR-V shuffle.
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+        rewriter.getI32ArrayAttr(indices));
+
+    llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+    
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
-      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 
+      VectorInterleaveOpConvert, VectorSplatPattern, 
+      VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
-
-  // Need this until vector.interleave is handled.
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

@llvmbot
Copy link
Member

llvmbot commented May 23, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

Changes

Add vector.interleave to spirv.VectorShuffle conversion, and remove the vector.interleave to vector.shuffle conversion in populateVectorToSPIRVPatterns.


Full diff: https://github.com/llvm/llvm-project/pull/93240.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+41-5)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorInterleaveOpConvert final
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+     // Check the source vector type
+    auto sourceType = interleaveOp.getSourceVectorType();
+    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported source vector type");
+    }  
+
+    // Check the result vector type
+    auto oldResultType = interleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported result vector type");
+
+    // Interleave the indices
+    int n = sourceType.getNumElements();
+    auto seq = llvm::seq<int64_t>(2 * n);
+    auto indices = llvm::to_vector(llvm::map_range(
+        seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+    // Emit a SPIR-V shuffle.
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+        rewriter.getI32ArrayAttr(indices));
+
+    llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+    
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
-      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 
+      VectorInterleaveOpConvert, VectorSplatPattern, 
+      VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
-
-  // Need this until vector.interleave is handled.
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

Copy link

github-actions bot commented May 23, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@obtuseangleAZ obtuseangleAZ force-pushed the vector-interleave-to-spirv-shuffle branch 2 times, most recently from 5a54b60 to cdc9def Compare May 23, 2024 21:37
@llvmbot llvmbot added mlir:vectorops bazel "Peripheral" support tier build system: utils/bazel mlir:vector labels May 27, 2024
@obtuseangleAZ obtuseangleAZ force-pushed the vector-interleave-to-spirv-shuffle branch 2 times, most recently from c006ee5 to cf47613 Compare May 27, 2024 16:08
@obtuseangleAZ obtuseangleAZ force-pushed the vector-interleave-to-spirv-shuffle branch 2 times, most recently from 24c6d24 to aed3118 Compare May 27, 2024 20:37
@obtuseangleAZ obtuseangleAZ force-pushed the vector-interleave-to-spirv-shuffle branch from aed3118 to 9a688a7 Compare May 27, 2024 20:53
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar kuhar merged commit 57c10fa into llvm:main May 27, 2024
@obtuseangleAZ obtuseangleAZ deleted the vector-interleave-to-spirv-shuffle branch June 6, 2024 19:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:spirv mlir:vector mlir:vectorops mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants