@@ -578,6 +578,47 @@ struct VectorShuffleOpConvert final
578578 }
579579};
580580
581+ struct VectorInterleaveOpConvert final
582+ : public OpConversionPattern<vector::InterleaveOp> {
583+ using OpConversionPattern::OpConversionPattern;
584+
585+ LogicalResult
586+ matchAndRewrite (vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
587+ ConversionPatternRewriter &rewriter) const override {
588+ // Check the result vector type.
589+ VectorType oldResultType = interleaveOp.getResultVectorType ();
590+ Type newResultType = getTypeConverter ()->convertType (oldResultType);
591+ if (!newResultType)
592+ return rewriter.notifyMatchFailure (interleaveOp,
593+ " unsupported result vector type" );
594+
595+ // Interleave the indices.
596+ VectorType sourceType = interleaveOp.getSourceVectorType ();
597+ int n = sourceType.getNumElements ();
598+
599+ // Input vectors of size 1 are converted to scalars by the type converter.
600+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
601+ // use `spirv::CompositeConstructOp`.
602+ if (n == 1 ) {
603+ Value newOperands[] = {adaptor.getLhs (), adaptor.getRhs ()};
604+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(
605+ interleaveOp, newResultType, newOperands);
606+ return success ();
607+ }
608+
609+ auto seq = llvm::seq<int64_t >(2 * n);
610+ auto indices = llvm::map_to_vector (
611+ seq, [n](int i) { return (i % 2 ? n : 0 ) + i / 2 ; });
612+
613+ // Emit a SPIR-V shuffle.
614+ rewriter.replaceOpWithNewOp <spirv::VectorShuffleOp>(
615+ interleaveOp, newResultType, adaptor.getLhs (), adaptor.getRhs (),
616+ rewriter.getI32ArrayAttr (indices));
617+
618+ return success ();
619+ }
620+ };
621+
581622struct VectorLoadOpConverter final
582623 : public OpConversionPattern<vector::LoadOp> {
583624 using OpConversionPattern::OpConversionPattern;
@@ -822,16 +863,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
822863 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
823864 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
824865 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
825- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
826- typeConverter, patterns.getContext (), PatternBenefit (1 ));
866+ VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
867+ VectorStoreOpConverter>(typeConverter, patterns.getContext (),
868+ PatternBenefit (1 ));
827869
828870 // Make sure that the more specialized dot product pattern has higher benefit
829871 // than the generic one that extracts all elements.
830872 patterns.add <VectorReductionToFPDotProd>(typeConverter, patterns.getContext (),
831873 PatternBenefit (2 ));
832-
833- // Need this until vector.interleave is handled.
834- vector::populateVectorInterleaveToShufflePatterns (patterns);
835874}
836875
837876void mlir::populateVectorReductionToSPIRVDotProductPatterns (
0 commit comments