Skip to content

Commit

Permalink
[stream] Dedeplicate stream.yield values from stream.async.execute
Browse files Browse the repository at this point in the history
If during cleanup values are replicated in `stream.async.execute` cases
it can result in double allocation of those blocks (this occurs as
ScheduleAllocation assumes all returns are unique). This double
allocation also results in bad numerical behavior as the separate
returns may be used by different dispatches and fetch incorrect values.

Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
  • Loading branch information
rsuderman committed Feb 21, 2025
1 parent 284ec9c commit a80c055
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
81 changes: 81 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2275,10 +2275,91 @@ struct ElideNoOpAsyncExecuteOp : public OpRewritePattern<AsyncExecuteOp> {
}
};

struct DeduplicateYieldCmdExecuteOp : public OpRewritePattern<AsyncExecuteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsyncExecuteOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<bool> keepYield;
llvm::SmallVector<Value> yieldOperands;
llvm::SmallVector<int> remapping;

auto yield =
cast<IREE::Stream::YieldOp>(op.getBody().front().getTerminator());
int64_t oldYieldCount = yield.getResourceOperands().size();
for (int i = 0, s = oldYieldCount; i < s; ++i) {
auto operand = yield.getResourceOperands()[i];

auto find =
std::find(yieldOperands.begin(), yieldOperands.end(), operand);
if (find != yieldOperands.end()) {
keepYield.push_back(false);
remapping.push_back(find - yieldOperands.begin());
continue;
}

remapping.push_back(yieldOperands.size());
keepYield.push_back(true);
yieldOperands.push_back(operand);
}

if (oldYieldCount == yieldOperands.size()) {
return failure();
}

llvm::SmallVector<Type> newTypes;
llvm::SmallVector<Value> newResultSizes;

for (int i = 0; i < oldYieldCount; ++i) {
if (!keepYield[i])
continue;
newTypes.push_back(op.getResults()[i].getType());
newResultSizes.push_back(op.getResultSizes()[i]);
}

auto newExecuteOp = rewriter.create<IREE::Stream::AsyncExecuteOp>(
op.getLoc(), newTypes, newResultSizes, op.getAwaitTimepoint(),
op.getResourceOperands(), op.getResourceOperandSizes(),
llvm::map_to_vector(op.getTiedOperandsAttr(), [](Attribute intAttr) {
return llvm::cast<IntegerAttr>(intAttr).getInt();
}));

newExecuteOp.setAffinityAttr(op.getAffinityAttr());

rewriter.inlineRegionBefore(op.getRegion(), newExecuteOp.getRegion(),
newExecuteOp.getRegion().end());

llvm::SmallVector<Value> newYieldVals;
llvm::SmallVector<Value> newYieldSizes;
yield = cast<IREE::Stream::YieldOp>(
newExecuteOp.getBody().front().getTerminator());
for (int i = 0; i < oldYieldCount; ++i) {
if (!keepYield[i])
continue;
newYieldVals.push_back(yield.getResourceOperands()[i]);
newYieldSizes.push_back(yield.getResourceOperandSizes()[i]);
}

rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<IREE::Stream::YieldOp>(yield, newYieldVals,
newYieldSizes);

llvm::SmallVector<Value> replace;
for (auto i : remapping) {
replace.push_back(newExecuteOp.getResult(i));
}

replace.push_back(newExecuteOp.getResultTimepoint());
rewriter.replaceOp(op, replace);

return success();
}
};

} // namespace

void AsyncExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<DeduplicateYieldCmdExecuteOp>(context);
results.insert<ElideImmediateTimepointWait<AsyncExecuteOp>>(context);
results.insert<ChainDependentAwaits<AsyncExecuteOp>>(context);
results.insert<CloneCapturedAsyncExecuteSubviewOps>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,23 @@ util.func private @ElideImmediateAsyncExecuteWaits(%arg0: !stream.resource<*>, %
util.return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint
}

// -----

// CHECK-LABEL: @DedeuplicateASyncExecuteReturns
util.func private @DedeuplicateASyncExecuteReturns(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: %[[VAL:.+]], %[[TP:.+]] = stream.async.execute
%0:4 = stream.async.execute with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> %arg0 as !stream.resource<*>{%arg1}, %arg0 as !stream.resource<*>{%arg1}, %arg0 as !stream.resource<*>{%arg1} {
// CHECK: %[[VAL:.+]]:2 = stream.async.dispatch
%1, %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg2[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}, !stream.resource<*>{%arg1}
// CHECK: stream.yield %[[VAL]]#0, %[[VAL]]#1
stream.yield %1, %2, %1 : !stream.resource<*>{%arg1}, !stream.resource<*>{%arg1}, !stream.resource<*>{%arg1}
} => !stream.timepoint
util.return %0#0, %0#1, %0#2, %0#3 : !stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint
}


// -----

// CHECK-LABEL: @ChainAsyncExecuteWaits
Expand Down

0 comments on commit a80c055

Please sign in to comment.