@@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
6161 // / `idx` of `key` in the epilogue.
6262 void setValueMapping (Value key, Value el, int64_t idx);
6363
64+ // / Return the defining op of the given value, if the Value is an argument of
65+ // / the loop return the associated defining op in the loop and its distance to
66+ // / the Value.
67+ std::pair<Operation *, int64_t > getDefiningOpAndDistance (Value value);
68+
6469public:
6570 // / Initalize the information for the given `op`, return true if it
6671 // / satisfies the pre-condition to apply pipelining.
@@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
240245 unsigned stage = stages[op];
241246
242247 auto analyzeOperand = [&](OpOperand &operand) {
243- Operation * def = operand.get (). getDefiningOp ( );
248+ auto [ def, distance] = getDefiningOpAndDistance ( operand.get ());
244249 if (!def)
245250 return ;
246251 auto defStage = stages.find (def);
247- if (defStage == stages.end () || defStage->second == stage)
252+ if (defStage == stages.end () || defStage->second == stage ||
253+ defStage->second == stage + distance)
248254 return ;
249255 assert (stage > defStage->second );
250256 LiverangeInfo &info = crossStageValues[operand.get ()];
@@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
261267 return crossStageValues;
262268}
263269
270+ std::pair<Operation *, int64_t >
271+ LoopPipelinerInternal::getDefiningOpAndDistance (Value value) {
272+ int64_t distance = 0 ;
273+ if (auto arg = dyn_cast<BlockArgument>(value)) {
274+ if (arg.getOwner () != forOp.getBody ())
275+ return {nullptr , 0 };
276+ // Ignore induction variable.
277+ if (arg.getArgNumber () == 0 )
278+ return {nullptr , 0 };
279+ distance++;
280+ value =
281+ forOp.getBody ()->getTerminator ()->getOperand (arg.getArgNumber () - 1 );
282+ }
283+ Operation *def = value.getDefiningOp ();
284+ if (!def)
285+ return {nullptr , 0 };
286+ return {def, distance};
287+ }
288+
264289scf::ForOp LoopPipelinerInternal::createKernelLoop (
265290 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
266291 &crossStageValues,
@@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
366391 rewriter.setInsertionPointAfter (newOp);
367392 continue ;
368393 }
369- auto arg = dyn_cast<BlockArgument>(operand->get ());
394+ Value source = operand->get ();
395+ auto arg = dyn_cast<BlockArgument>(source);
370396 if (arg && arg.getOwner () == forOp.getBody ()) {
371- // If the value is a loop carried value coming from stage N + 1 remap,
372- // it will become a direct use.
373397 Value ret = forOp.getBody ()->getTerminator ()->getOperand (
374398 arg.getArgNumber () - 1 );
375399 Operation *dep = ret.getDefiningOp ();
@@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
378402 auto stageDep = stages.find (dep);
379403 if (stageDep == stages.end () || stageDep->second == useStage)
380404 continue ;
381- assert (stageDep->second == useStage + 1 );
382- nestedNewOp->setOperand (operand->getOperandNumber (),
383- mapping.lookupOrDefault (ret));
384- continue ;
405+ // If the value is a loop carried value coming from stage N + 1 remap,
406+ // it will become a direct use.
407+ if (stageDep->second == useStage + 1 ) {
408+ nestedNewOp->setOperand (operand->getOperandNumber (),
409+ mapping.lookupOrDefault (ret));
410+ continue ;
411+ }
412+ source = ret;
385413 }
386414 // For operands defined in a previous stage we need to remap it to use
387415 // the correct region argument. We look for the right version of the
388416 // Value based on the stage where it is used.
389- Operation *def = operand-> get () .getDefiningOp ();
417+ Operation *def = source .getDefiningOp ();
390418 if (!def)
391419 continue ;
392420 auto stageDef = stages.find (def);
@@ -418,9 +446,29 @@ LogicalResult LoopPipelinerInternal::createKernel(
418446 // We create a mapping between original values and the associated loop
419447 // returned values that will be needed by the epilogue.
420448 llvm::SmallVector<Value> yieldOperands;
421- for (Value retVal : forOp.getBody ()->getTerminator ()->getOperands ()) {
422- yieldOperands.push_back (mapping.lookupOrDefault (retVal));
449+ for (OpOperand &yieldOperand :
450+ forOp.getBody ()->getTerminator ()->getOpOperands ()) {
451+ Value source = mapping.lookupOrDefault (yieldOperand.get ());
452+ // When we don't peel the epilogue and the yield value is used outside the
453+ // loop we need to make sure we return the version from numStages -
454+ // defStage.
455+ if (!peelEpilogue &&
456+ !forOp.getResult (yieldOperand.getOperandNumber ()).use_empty ()) {
457+ Operation *def = getDefiningOpAndDistance (yieldOperand.get ()).first ;
458+ if (def) {
459+ auto defStage = stages.find (def);
460+ if (defStage != stages.end () && defStage->second < maxStage) {
461+ Value pred = predicates[defStage->second ];
462+ source = rewriter.create <arith::SelectOp>(
463+ pred.getLoc (), pred, source,
464+ newForOp.getBody ()
465+ ->getArguments ()[yieldOperand.getOperandNumber () + 1 ]);
466+ }
467+ }
468+ }
469+ yieldOperands.push_back (source);
423470 }
471+
424472 for (auto &it : crossStageValues) {
425473 int64_t version = maxStage - it.second .lastUseStage + 1 ;
426474 unsigned numVersionReturned = it.second .lastUseStage - it.second .defStage ;
@@ -444,9 +492,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
444492 Operation *def = retVal.value ().getDefiningOp ();
445493 assert (def && " Only support loop carried dependencies of distance 1" );
446494 unsigned defStage = stages[def];
447- setValueMapping (forOp.getRegionIterArgs ()[retVal.index ()],
448- newForOp->getResult (retVal.index ()),
449- maxStage - defStage + 1 );
495+ if (defStage > 0 ) {
496+ setValueMapping (forOp.getRegionIterArgs ()[retVal.index ()],
497+ newForOp->getResult (retVal.index ()),
498+ maxStage - defStage + 1 );
499+ }
450500 }
451501 rewriter.create <scf::YieldOp>(forOp.getLoc (), yieldOperands);
452502 return success ();
0 commit comments