@@ -224,10 +224,10 @@ struct VectorizationState {
224224 // / Masks an operation with the canonical vector mask if the operation needs
225225 // / masking. Returns the masked operation or the original operation if masking
226226 // / is not needed. If provided, the canonical mask for this operation is
227- // / permuted using `maybeMaskingMap `.
227+ // / permuted using `maybeIndexingMap `.
228228 Operation *
229229 maskOperation (RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230- std::optional<AffineMap> maybeMaskingMap = std::nullopt );
230+ std::optional<AffineMap> maybeIndexingMap = std::nullopt );
231231
232232private:
233233 // / Initializes the iteration space static sizes using the Linalg op
@@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor(
422422 return mask;
423423}
424424
425- // / Masks an operation with the canonical vector mask if the operation needs
426- // / masking. Returns the masked operation or the original operation if masking
427- // / is not needed. If provided, the canonical mask for this operation is
428- // / permuted using `maybeMaskingMap`.
429425Operation *
430426VectorizationState::maskOperation (RewriterBase &rewriter, Operation *opToMask,
431427 LinalgOp linalgOp,
432- std::optional<AffineMap> maybeMaskingMap ) {
428+ std::optional<AffineMap> maybeIndexingMap ) {
433429 LDBG (" Trying to mask: " << *opToMask << " \n " );
434430
431+ std::optional<AffineMap> maybeMaskingMap = std::nullopt ;
432+ // The Operand indexing map may contain "zero" results, e.g.:
433+ // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434+ // When applied to canonical vector shapes like these:
435+ // (1, 16, 16, 4)
436+ // we would get:
437+ // (1, 16, 16, 0)
438+ // Instead, we should extract the following map permutation map for masking:
439+ // (d0, d1, d2, d3) -> (d0, d1, d2)
440+ // This way, the corresponding vector/mask type will be:
441+ // vector<1x16x16xty>
442+ // rather than:
443+ // vector<1x16x16x0xty>
444+ if (maybeIndexingMap)
445+ maybeMaskingMap = maybeIndexingMap->dropZeroResults ();
446+
435447 // Create or retrieve mask for this operation.
436448 Value mask =
437449 getOrCreateMaskFor (rewriter, opToMask, linalgOp, maybeMaskingMap);
@@ -630,21 +642,8 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
630642 loc, value, outputOperand->get (), ValueRange{});
631643 }
632644
633- // The operand map may contain "zero" results, e.g.:
634- // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
635- // When applied to canonical vector shapes like these:
636- // (1, 16, 16, 4)
637- // we would get:
638- // (1, 16, 16, 0)
639- // Instead, we should extract the following map:
640- // (d0, d1, d2, d3) -> (d0, d1, d2)
641- // This way, the corresponding vector/mask type will be:
642- // vector<1x16x16xty>
643- // rather than:
644- // vector<1x16x16x0xty>
645- AffineMap opOperantMapWithoutZeros = opOperandMap.dropZeroResults ();
646645 write =
647- state.maskOperation (rewriter, write, linalgOp, opOperantMapWithoutZeros );
646+ state.maskOperation (rewriter, write, linalgOp, opOperandMap );
648647
649648 // If masked, set in-bounds to true. Masking guarantees that the access will
650649 // be in-bounds.
@@ -1330,16 +1329,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13301329 // permutation map and masking map.
13311330 AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
13321331
1333- // Remove zeros from indexing map to use it as masking map.
1334- SmallVector<int64_t > zeroPos;
1335- auto results = indexingMap.getResults ();
1336- for (const auto &result : llvm::enumerate (results)) {
1337- if (isa<AffineConstantExpr>(result.value ())) {
1338- zeroPos.push_back (result.index ());
1339- }
1340- }
1341- AffineMap maskingMap = indexingMap.dropResults (zeroPos);
1342-
13431332 AffineMap readMap;
13441333 VectorType readType;
13451334 Type elemType = getElementTypeOrSelf (opOperand->get ());
@@ -1369,7 +1358,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13691358 Operation *read = rewriter.create <vector::TransferReadOp>(
13701359 loc, readType, opOperand->get (), indices, readMap,
13711360 ArrayRef<bool >(inBounds));
1372- read = state.maskOperation (rewriter, read, linalgOp, maskingMap );
1361+ read = state.maskOperation (rewriter, read, linalgOp, indexingMap );
13731362 Value readValue = read->getResult (0 );
13741363
13751364 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
0 commit comments