Skip to content

[mlir][scf] Return replacements explicitly in SCFTilingResult. #143217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ struct SCFTilingResult {
SmallVector<Value> initialValues;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
/// The result generated by the loop nest in tiling, may hold partial results,
/// which need to be merged to match the computation of the untiled operation.
/// `mergeResult` contains the operations used to perform this merge from
/// partial results and the values that can be used as replacements of
/// the untiled operation.
MergeResult mergeResult;
/// Values to use as replacements for the untiled op. Is the same size as the
/// number of results of the untiled op.
SmallVector<Value> replacements;
/// Slices generated after tiling that can be used for fusing with the tiled
/// producer.
SmallVector<Operation *> generatedSlices;
/// In cases where there as an additional merge step after tiling
/// return the merged ops after tiling. This list is empty when reduction
/// tiling strategy is
/// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
SmallVector<Operation *> mergeOps;
};

/// Method to tile an op that implements the `TilingInterface` using
Expand Down Expand Up @@ -362,7 +364,7 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
/// ```
FailureOr<scf::SCFTilingResult>
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSize);
ArrayRef<OpFoldResult> tileSizes);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice haha


} // namespace scf
} // namespace mlir
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
];
}

def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
def PartialReductionOpInterface :
OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! I didn't know we could do this

let description = [{
Interface for allowing operations to expose information needed to
tile reductions using partial reduction followed by merge. This is
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2381,7 +2381,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultDefiniteFailure(target);

if (target->getNumResults())
rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
rewriter.replaceOp(target, maybeTilingResult->replacements);
else
rewriter.eraseOp(target);

Expand Down Expand Up @@ -2800,12 +2800,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(

if (failed(result))
return emitDefaultSilenceableFailure(target);
rewriter.replaceOp(target, result->mergeResult.replacements);
rewriter.replaceOp(target, result->replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
for (auto mergeOp : result->mergeResult.mergeOps)
for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -3229,7 +3229,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure::definiteFailure();

rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
rewriter.replaceOp(op, maybeTilingResult->replacements);

tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
Expand Down Expand Up @@ -3465,7 +3465,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);

rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);

tilingResult = *maybeTilingResult;

Expand Down
58 changes: 30 additions & 28 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,48 +1058,50 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");

SmallVector<Value> partialResults;
if (loops.empty()) {
// If loops are empty, the tiled op is used as the replacement for the
// untiled op.
partialResults = tilingResult->tiledValues;
} else {
partialResults = llvm::map_to_vector(loops.front()->getResults(),
return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
tilingResult->tiledValues,
tilingResult->generatedSlices};
}

auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
[](OpResult r) -> Value { return r; });

// For the full reduction case, there is nothing more to do.
if (options.reductionStrategy ==
scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
loopResults, tilingResult->generatedSlices};
}

// The results of the loop needs to be merged.
FailureOr<MergeResult> mergeResult =
mergeTilingResults(rewriter, op, partialResults, options);
mergeTilingResults(rewriter, op, loopResults, options);
if (failed(mergeResult)) {
return rewriter.notifyMatchFailure(
op, "Failed to merge partial results from tiling");
}

return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
mergeResult.value(),
tilingResult->generatedSlices};
return scf::SCFTilingResult{tilingResult->tiledOps,
initTensors,
loops,
mergeResult->replacements,
tilingResult->generatedSlices,
mergeResult->mergeOps};
}

FailureOr<scf::SCFTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSizes) {
SCFTilingOptions options;
options.setLoopType(SCFTilingOptions::LoopType::ForOp);
options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction);
options.setTileSizes(tileSizes);

TilingInterface tilingInterfaceOp =
dyn_cast<TilingInterface>(op.getOperation());
if (!tilingInterfaceOp) {
return b.notifyMatchFailure(
op,
"Operation implementing PartialReductionOpInterface should implement "
"TilingInterface");
}

return tileUsingSCF(b, tilingInterfaceOp, options);
ArrayRef<OpFoldResult> tileSize) {
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
options.setReductionTilingStrategy(
scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction);
options.setTileSizes(tileSize);
return tileUsingSCF(b, op, options);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1539,8 +1541,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
tiledAndFusedOps.insert_range(tilingResult->tiledOps);

DenseMap<Value, Value> replacements;
for (auto [origVal, replacement] : llvm::zip_equal(
consumer->getResults(), tilingResult->mergeResult.replacements)) {
for (auto [origVal, replacement] :
llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
replacements[origVal] = replacement;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
return failure();

// Perform the replacement of tiled and fused values.
rewriter.replaceOp(tilingInterfaceOp,
tiledResults->mergeResult.replacements);
rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);

// Report back the relevant handles to the transform op.
tiledOps.push_back(tiledResults->tiledOps.front());
Expand Down