@@ -429,20 +429,24 @@ namespace {
429429// / result type.
430430// / - The permutation map doesn't perform permutation (broadcasting is allowed).
431431struct TransferReadToVectorLoadLowering
432- : public OpRewritePattern <vector::TransferReadOp> {
432+ : public MaskableOpRewritePattern <vector::TransferReadOp> {
433433 TransferReadToVectorLoadLowering (MLIRContext *context,
434434 std::optional<unsigned > maxRank,
435435 PatternBenefit benefit = 1 )
436- : OpRewritePattern <vector::TransferReadOp>(context, benefit),
436+ : MaskableOpRewritePattern <vector::TransferReadOp>(context, benefit),
437437 maxTransferRank (maxRank) {}
438438
439- LogicalResult matchAndRewrite (vector::TransferReadOp read,
440- PatternRewriter &rewriter) const override {
439+ FailureOr<mlir::Value>
440+ matchAndRewriteMaskableOp (vector::TransferReadOp read,
441+ MaskingOpInterface maskOp,
442+ PatternRewriter &rewriter) const override {
441443 if (maxTransferRank && read.getVectorType ().getRank () > *maxTransferRank) {
442444 return rewriter.notifyMatchFailure (
443445 read, " vector type is greater than max transfer rank" );
444446 }
445447
448+ if (maskOp)
449+ return rewriter.notifyMatchFailure (read, " Masked case not supported" );
446450 SmallVector<unsigned > broadcastedDims;
447451 // Permutations are handled by VectorToSCF or
448452 // populateVectorTransferPermutationMapLoweringPatterns.
@@ -485,7 +489,7 @@ struct TransferReadToVectorLoadLowering
485489 return rewriter.notifyMatchFailure (read, " out-of-bounds needs mask" );
486490
487491 // Create vector load op.
488- Operation *loadOp ;
492+ Operation *res ;
489493 if (read.getMask ()) {
490494 if (read.getVectorType ().getRank () != 1 )
491495 // vector.maskedload operates on 1-D vectors.
@@ -495,24 +499,20 @@ struct TransferReadToVectorLoadLowering
495499
496500 Value fill = rewriter.create <vector::SplatOp>(
497501 read.getLoc (), unbroadcastedVectorType, read.getPadding ());
498- loadOp = rewriter.create <vector::MaskedLoadOp>(
502+ res = rewriter.create <vector::MaskedLoadOp>(
499503 read.getLoc (), unbroadcastedVectorType, read.getSource (),
500504 read.getIndices (), read.getMask (), fill);
501505 } else {
502- loadOp = rewriter.create <vector::LoadOp>(
506+ res = rewriter.create <vector::LoadOp>(
503507 read.getLoc (), unbroadcastedVectorType, read.getSource (),
504508 read.getIndices ());
505509 }
506510
507511 // Insert a broadcasting op if required.
508- if (!broadcastedDims.empty ()) {
509- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
510- read, read.getVectorType (), loadOp->getResult (0 ));
511- } else {
512- rewriter.replaceOp (read, loadOp->getResult (0 ));
513- }
514-
515- return success ();
512+ if (!broadcastedDims.empty ())
513+ res = rewriter.create <vector::BroadcastOp>(
514+ read.getLoc (), read.getVectorType (), res->getResult (0 ));
515+ return res->getResult (0 );
516516 }
517517
518518 std::optional<unsigned > maxTransferRank;
@@ -581,19 +581,23 @@ struct VectorStoreToMemrefStoreLowering
581581// / - The permutation map is the minor identity map (neither permutation nor
582582// / broadcasting is allowed).
583583struct TransferWriteToVectorStoreLowering
584- : public OpRewritePattern <vector::TransferWriteOp> {
584+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
585585 TransferWriteToVectorStoreLowering (MLIRContext *context,
586586 std::optional<unsigned > maxRank,
587587 PatternBenefit benefit = 1 )
588- : OpRewritePattern <vector::TransferWriteOp>(context, benefit),
588+ : MaskableOpRewritePattern <vector::TransferWriteOp>(context, benefit),
589589 maxTransferRank (maxRank) {}
590590
591- LogicalResult matchAndRewrite (vector::TransferWriteOp write,
592- PatternRewriter &rewriter) const override {
591+ FailureOr<mlir::Value>
592+ matchAndRewriteMaskableOp (vector::TransferWriteOp write,
593+ MaskingOpInterface maskOp,
594+ PatternRewriter &rewriter) const override {
593595 if (maxTransferRank && write.getVectorType ().getRank () > *maxTransferRank) {
594596 return rewriter.notifyMatchFailure (
595597 write, " vector type is greater than max transfer rank" );
596598 }
599+ if (maskOp)
600+ return rewriter.notifyMatchFailure (write, " Masked case not supported" );
597601
598602 // Permutations are handled by VectorToSCF or
599603 // populateVectorTransferPermutationMapLoweringPatterns.
@@ -645,14 +649,16 @@ struct TransferWriteToVectorStoreLowering
645649 << write;
646650 });
647651
648- rewriter.replaceOpWithNewOp <vector::MaskedStoreOp>(
649- write, write. getSource (), write.getIndices (), write.getMask (),
650- write.getVector ());
652+ rewriter.create <vector::MaskedStoreOp>(
653+ write. getLoc (), write.getSource (), write.getIndices (),
654+ write.getMask (), write. getVector ());
651655 } else {
652- rewriter.replaceOpWithNewOp <vector::StoreOp>(
653- write, write. getVector (), write.getSource (), write.getIndices ());
656+ rewriter.create <vector::StoreOp>(write. getLoc (), write. getVector (),
657+ write.getSource (), write.getIndices ());
654658 }
655- return success ();
659+ // There's no return value for StoreOps. Use Value() to signal success to
660+ // matchAndRewrite.
661+ return Value ();
656662 }
657663
658664 std::optional<unsigned > maxTransferRank;
0 commit comments