Skip to content

Commit 6ff4718

Browse files
authored
[mlir][linalg] Improve linalg.pack consumer fusion. (#148993)
If a dimension is not tiled, it is always valid to fuse the pack op, even if it has padding semantics. Because it always generates a full slice along the dimension. If a dimension is tiled and it does not need extra padding, the fusion is valid. The revision also formats corresponding tests for consistency. --------- Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 8f4deff commit 6ff4718

File tree

2 files changed

+376
-128
lines changed

2 files changed

+376
-128
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Utils/IndexingUtils.h"
2121
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2222
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2324
#include "mlir/Interfaces/TilingInterface.h"
2425
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2526
#include "llvm/Support/Debug.h"
@@ -887,26 +888,55 @@ struct PackOpTiling
887888

888889
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
889890
ArrayRef<OpFoldResult> sizes(allSizes[0]);
890-
891891
auto packOp = cast<PackOp>(op);
892-
// It is not trivial to infer dest tile from source tile if `packOp` has
893-
// padding semantic.
894-
if (packOp.getPaddingValue())
895-
return failure();
896-
897892
Location loc = packOp.getLoc();
898-
899893
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
900894
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
901895
packOp.getDimAndTileMapping();
902896
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
903897
if (dimAndTileMapping.count(dim)) {
904-
FailureOr<int64_t> cstSize =
898+
FailureOr<int64_t> cstTileSize =
905899
ValueBoundsConstraintSet::computeConstantBound(
906900
presburger::BoundType::UB, sizes[dim],
907901
/*stopCondition=*/nullptr, /*closedUB=*/true);
908902
std::optional<int64_t> cstInnerSize =
909903
getConstantIntValue(dimAndTileMapping[dim]);
904+
905+
// If a dimension is not tiled, it is always valid to fuse the pack op,
906+
// even if the op has padding semantics. Because it always generates a
907+
// full slice along the dimension.
908+
// TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
909+
// hard check to determine if a dimension is tiled or not.
910+
int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
911+
int64_t destDimSize = packOp.getDestType().getDimSize(dim);
912+
bool isTiled = failed(cstTileSize) ||
913+
ShapedType::isDynamic(srcDimSize) ||
914+
cstTileSize.value() != srcDimSize;
915+
if (!isTiled) {
916+
outerDimOffsets.push_back(offsets[dim]);
917+
if (ShapedType::isStatic(destDimSize)) {
918+
outerDimSizes.push_back(b.getIndexAttr(destDimSize));
919+
} else {
920+
outerDimSizes.push_back(
921+
b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
922+
}
923+
continue;
924+
}
925+
926+
// If the dimension needs padding, it is not supported because there are
927+
// iterations that only write padding values to the whole tile. The
928+
// consumer fusion is driven by the source, so it is not possible to map
929+
// an empty slice to the tile.
930+
bool needExtraPadding =
931+
ShapedType::isDynamic(destDimSize) || !cstInnerSize ||
932+
destDimSize * cstInnerSize.value() != srcDimSize;
933+
// Prioritize the case that the op already says that it does not need
934+
// padding.
935+
if (!packOp.getPaddingValue())
936+
needExtraPadding = false;
937+
if (needExtraPadding)
938+
return failure();
939+
910940
// Currently fusing `packOp` as consumer only expects perfect tiling
911941
// scenario because even if without padding semantic, the `packOp` may
912942
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -921,9 +951,9 @@ struct PackOpTiling
921951
// another word, we can only support tiling with consumer if the tile
922952
// size for the producer is a multiple of the inner tile size for the
923953
// packed dimensions at this moment.
924-
if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
954+
if ((failed(cstTileSize) || !cstInnerSize ||
955+
*cstTileSize % *cstInnerSize != 0))
925956
return failure();
926-
}
927957

928958
using AV = affine::AffineValueExpr;
929959
affine::AffineBuilder ab(b, loc);
@@ -988,7 +1018,8 @@ struct PackOpTiling
9881018
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
9891019
tiledOperands.push_back(outSlice);
9901020

991-
assert(!packOp.getPaddingValue() && "Expect no padding semantic");
1021+
if (auto val = packOp.getPaddingValue())
1022+
tiledOperands.push_back(val);
9921023
for (auto tile : packOp.getInnerTiles())
9931024
tiledOperands.push_back(tile);
9941025

0 commit comments

Comments
 (0)