@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
5555 return filledVector;
5656}
5757
58+ // / Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59+ template <typename SrcOpTy>
60+ static SmallVector<Operation *> getAsOperations (ArrayRef<SrcOpTy> ops) {
61+ return llvm::to_vector (
62+ llvm::map_range (ops, [](auto op) -> Operation * { return op; }));
63+ }
64+ template <typename SrcOpTy>
65+ static SmallVector<Operation *>
66+ getAsOperations (const SmallVector<SrcOpTy> &ops) {
67+ return getAsOperations (ArrayRef<SrcOpTy>(ops));
68+ }
69+
70+ // / Convert a list of `Operation *` to a list of `DstOpTy.
71+ template <typename DstOpTy>
72+ static SmallVector<DstOpTy> castToTypedOperations (ArrayRef<Operation *> ops) {
73+ return llvm::to_vector (
74+ llvm::map_range (ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75+ }
76+ template <typename DstOpTy>
77+ static SmallVector<DstOpTy>
78+ castToTypedOperations (const SmallVector<Operation *> &ops) {
79+ return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80+ }
81+
5882// ===----------------------------------------------------------------------===//
5983// tileUsingSCFForOp implementation.
6084// ===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
77101// / `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
78102static OpFoldResult getBoundedTileSize (OpBuilder &b, Location loc,
79103 Range loopRange, Value iv,
80- Value tileSize) {
81- std::optional<int64_t > ts = getConstantIntValue (tileSize);
82- if (ts && ts.value () == 1 )
83- return getAsOpFoldResult (tileSize);
104+ OpFoldResult tileSize) {
105+ if (isConstantIntValue (tileSize, 1 ))
106+ return tileSize;
84107
85108 if (tileDividesIterationDomain (
86109 Range{loopRange.offset , loopRange.size , tileSize}))
@@ -296,8 +319,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
296319 tileSizeVector.append (numLoops - tileSizeVector.size (), zero);
297320 }
298321
299- scf::SCFTilingResult tilingResult;
300322 SmallVector<OpFoldResult> offsets, sizes;
323+ SmallVector<scf::ForOp> forLoops;
301324 {
302325 // If there is an interchange specified, permute the iteration domain and
303326 // the tile sizes.
@@ -320,8 +343,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
320343 // 3. Materialize an empty loop nest that iterates over the tiles. These
321344 // loops for now do not return any values even if the original operation has
322345 // results.
323- tilingResult. loops = generateTileLoopNest (
324- rewriter, op. getLoc (), iterationDomain, tileSizeVector, offsets, sizes);
346+ forLoops = generateTileLoopNest (rewriter, op. getLoc (), iterationDomain,
347+ tileSizeVector, offsets, sizes);
325348
326349 if (!interchangeVector.empty ()) {
327350 auto inversePermutation = invertPermutationVector (interchangeVector);
@@ -331,30 +354,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
331354 }
332355
333356 LLVM_DEBUG ({
334- if (!tilingResult. loops .empty ()) {
357+ if (!forLoops .empty ()) {
335358 llvm::dbgs () << " LoopNest shell :\n " ;
336- tilingResult. loops .front ().dump ();
359+ forLoops .front ().dump ();
337360 llvm::dbgs () << " \n " ;
338361 }
339362 });
340363
341364 // 4. Generate the tiled implementation within the inner most loop.
342- if (!tilingResult.loops .empty ())
343- rewriter.setInsertionPoint (
344- tilingResult.loops .back ().getBody ()->getTerminator ());
365+ if (!forLoops.empty ())
366+ rewriter.setInsertionPoint (forLoops.back ().getBody ()->getTerminator ());
345367 FailureOr<TilingResult> tiledImplementation =
346368 op.getTiledImplementation (rewriter, offsets, sizes);
347- tilingResult. tiledOps . append (tiledImplementation-> tiledOps );
369+
348370 if (op->getNumResults () == 0 ) {
349- // nothing more to do.
350- return tilingResult ;
371+ return scf::SCFTilingResult{
372+ tiledImplementation-> tiledOps , getAsOperations (forLoops), {}} ;
351373 }
352374
353375 // If loops are empty, the tiled op is used as the replacement for the untiled
354376 // op.
355- if (tilingResult.loops .empty ()) {
356- tilingResult.replacements = tiledImplementation->tiledValues ;
357- return tilingResult;
377+ if (forLoops.empty ()) {
378+ return scf::SCFTilingResult{tiledImplementation->tiledOps ,
379+ getAsOperations (forLoops),
380+ tiledImplementation->tiledValues };
358381 }
359382
360383 // 5. Yield all the results of the tiled operation. The surrounding loop
@@ -378,18 +401,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
378401 destinationTensors)))
379402 return rewriter.notifyMatchFailure (op, " failed to get destinations" );
380403
381- tilingResult. replacements = yieldTiledValues (
404+ SmallVector<Value> replacements = yieldTiledValues (
382405 rewriter, destinationTensors, tiledImplementation.value (),
383- resultOffsetsList, resultSizesList, tilingResult.loops );
384-
406+ resultOffsetsList, resultSizesList, forLoops);
385407 LLVM_DEBUG ({
386- if (!tilingResult. loops .empty ()) {
408+ if (!forLoops .empty ()) {
387409 llvm::dbgs () << " After tiled implementation :\n " ;
388- tilingResult. loops .front ().dump ();
410+ forLoops .front ().dump ();
389411 llvm::dbgs () << " \n " ;
390412 }
391413 });
392- return tilingResult;
414+ return scf::SCFTilingResult{tiledImplementation->tiledOps ,
415+ getAsOperations (forLoops), replacements};
393416}
394417
395418FailureOr<scf::SCFReductionTilingResult>
@@ -467,6 +490,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
467490 results.mergeOp = mergeOp;
468491 return results;
469492}
493+
470494// ===----------------------------------------------------------------------===//
471495// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
472496// ===----------------------------------------------------------------------===//
@@ -637,28 +661,31 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
637661 }
638662
639663 // 1. First tile the consumer.
640- scf::SCFTileAndFuseResult tileAndFuseResult;
664+ SmallVector<scf::ForOp> forLoops;
665+ SetVector<Operation *> fusedProducers, tiledAndFusedOps;
666+ DenseMap<Value, Value> replacements;
641667 llvm::SmallDenseMap<Value, int64_t > yieldedValueToResultNumber;
642668 {
643669 FailureOr<scf::SCFTilingResult> tilingResult =
644670 tileUsingSCFForOp (rewriter, consumer, options.tilingOptions );
645671 if (failed (tilingResult))
646672 return rewriter.notifyMatchFailure (consumer, " failed to tile consumer" );
647673 for (auto *tiledOp : tilingResult->tiledOps )
648- tileAndFuseResult.tiledAndFusedOps .insert (tiledOp);
649- tileAndFuseResult.loops = std::move (tilingResult->loops );
650- for (const auto &result : llvm::enumerate (
651- llvm::zip (consumer->getResults (), tilingResult->replacements ))) {
652- tileAndFuseResult.replacements [std::get<0 >(result.value ())] =
653- std::get<1 >(result.value ());
674+ tiledAndFusedOps.insert (tiledOp);
675+ forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops );
676+ for (auto [index, origValue, replacement] :
677+ llvm::enumerate (consumer->getResults (), tilingResult->replacements )) {
678+ replacements[origValue] = replacement;
654679 yieldedValueToResultNumber[tilingResult->tiledOps .back ()->getResult (
655- result. index ()) ] = result. index () ;
680+ index) ] = index;
656681 }
657682 }
658683
659684 // If there are no loops generated, fusion is immaterial.
660- if (tileAndFuseResult.loops .empty ())
661- return tileAndFuseResult;
685+ if (forLoops.empty ()) {
686+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
687+ getAsOperations (forLoops), replacements};
688+ }
662689
663690 // 2. Typically, the operands of the tiled operation are slices of the
664691 // operands of the untiled operation. These are expressed in IR using
@@ -675,7 +702,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
675702 };
676703
677704 std::deque<tensor::ExtractSliceOp> candidates;
678- addCandidateSlices (tileAndFuseResult. tiledAndFusedOps .back (), candidates);
705+ addCandidateSlices (tiledAndFusedOps.back (), candidates);
679706 OpBuilder::InsertionGuard g (rewriter);
680707 while (!candidates.empty ()) {
681708 // Traverse the slices in BFS fashion.
@@ -685,19 +712,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
685712 // The operands of the fused producer might themselved be slices of
686713 // values produced by operations that implement the `TilingInterface`.
687714 // Add these operations to the worklist.
688- std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
689- tileAndFuseProducerOfSlice (rewriter, candidateSliceOp,
690- tileAndFuseResult.loops );
691- if (!fusedProducer)
715+ std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
716+ tileAndFuseProducerOfSlice (rewriter, candidateSliceOp, forLoops);
717+ if (!fusedResult)
692718 continue ;
693719
694720 if (Operation *tiledAndFusedOp =
695- fusedProducer->tiledAndFusedProducer .getDefiningOp ()) {
696- tileAndFuseResult.tiledAndFusedOps .insert (tiledAndFusedOp);
721+ fusedResult->tiledAndFusedProducer .getDefiningOp ()) {
722+ fusedProducers.insert (fusedResult->origProducer .getDefiningOp ());
723+ tiledAndFusedOps.insert (tiledAndFusedOp);
697724 addCandidateSlices (tiledAndFusedOp, candidates);
698725 }
699726 }
700- return tileAndFuseResult;
727+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
728+ getAsOperations (forLoops), replacements};
701729}
702730
703731// ===----------------------------------------------------------------------===//
0 commit comments