@@ -224,10 +224,10 @@ struct VectorizationState {
224224 // / Masks an operation with the canonical vector mask if the operation needs
225225 // / masking. Returns the masked operation or the original operation if masking
226226 // / is not needed. If provided, the canonical mask for this operation is
227- // / permuted using `maybeMaskingMap `.
227+ // / permuted using `maybeIndexingMap `.
228228 Operation *
229229 maskOperation (RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230- std::optional<AffineMap> maybeMaskingMap = std::nullopt );
230+ std::optional<AffineMap> maybeIndexingMap = std::nullopt );
231231
232232private:
233233 // / Initializes the iteration space static sizes using the Linalg op
@@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor(
422422 return mask;
423423}
424424
425- // / Masks an operation with the canonical vector mask if the operation needs
426- // / masking. Returns the masked operation or the original operation if masking
427- // / is not needed. If provided, the canonical mask for this operation is
428- // / permuted using `maybeMaskingMap`.
429425Operation *
430426VectorizationState::maskOperation (RewriterBase &rewriter, Operation *opToMask,
431427 LinalgOp linalgOp,
432- std::optional<AffineMap> maybeMaskingMap ) {
428+ std::optional<AffineMap> maybeIndexingMap ) {
433429 LDBG (" Trying to mask: " << *opToMask << " \n " );
434430
431+ std::optional<AffineMap> maybeMaskingMap = std::nullopt ;
432+ // The Operand indexing map may contain "zero" results, e.g.:
433+ // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434+ // When applied to canonical vector shapes like these:
435+ // (1, 16, 16, 4)
436+ // we would get:
437+ // (1, 16, 16, 0)
438+ // Instead, we should extract the following map permutation map for masking:
439+ // (d0, d1, d2, d3) -> (d0, d1, d2)
440+ // This way, the corresponding vector/mask type will be:
441+ // vector<1x16x16xty>
442+ // rather than:
443+ // vector<1x16x16x0xty>
444+ if (maybeIndexingMap)
445+ maybeMaskingMap = maybeIndexingMap->dropZeroResults ();
446+
435447 // Create or retrieve mask for this operation.
436448 Value mask =
437449 getOrCreateMaskFor (rewriter, opToMask, linalgOp, maybeMaskingMap);
@@ -640,21 +652,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
640652 loc, value, outputOperand->get (), ValueRange{});
641653 }
642654
643- // The operand map may contain "zero" results, e.g.:
644- // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
645- // When applied to canonical vector shapes like these:
646- // (1, 16, 16, 4)
647- // we would get:
648- // (1, 16, 16, 0)
649- // Instead, we should extract the following map:
650- // (d0, d1, d2, d3) -> (d0, d1, d2)
651- // This way, the corresponding vector/mask type will be:
652- // vector<1x16x16xty>
653- // rather than:
654- // vector<1x16x16x0xty>
655- AffineMap opOperantMapWithoutZeros = opOperandMap.dropZeroResults ();
656- write =
657- state.maskOperation (rewriter, write, linalgOp, opOperantMapWithoutZeros);
655+ write = state.maskOperation (rewriter, write, linalgOp, opOperandMap);
658656
659657 // If masked, set in-bounds to true. Masking guarantees that the access will
660658 // be in-bounds.
@@ -1332,16 +1330,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13321330 // permutation map and masking map.
13331331 AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
13341332
1335- // Remove zeros from indexing map to use it as masking map.
1336- SmallVector<int64_t > zeroPos;
1337- auto results = indexingMap.getResults ();
1338- for (const auto &result : llvm::enumerate (results)) {
1339- if (isa<AffineConstantExpr>(result.value ())) {
1340- zeroPos.push_back (result.index ());
1341- }
1342- }
1343- AffineMap maskingMap = indexingMap.dropResults (zeroPos);
1344-
13451333 AffineMap readMap;
13461334 VectorType readType;
13471335 Type elemType = getElementTypeOrSelf (opOperand->get ());
@@ -1371,7 +1359,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13711359 Operation *read = rewriter.create <vector::TransferReadOp>(
13721360 loc, readType, opOperand->get (), indices, readMap,
13731361 ArrayRef<bool >(inBounds));
1374- read = state.maskOperation (rewriter, read, linalgOp, maskingMap );
1362+ read = state.maskOperation (rewriter, read, linalgOp, indexingMap );
13751363 Value readValue = read->getResult (0 );
13761364
13771365 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
0 commit comments