Skip to content

Commit 6f58c16

Browse files
[mlir][linalg] Use ub.poison in data layout propagation if a packed operand requires padding. (#159467)
In the past, it was hard to set padding values because we did not have ub.poison. It is not always correct if we set zeros as padding values. Now we can use `ub.poison` in this case. The revision adds the support for setting padding value using `ub.poison` when padding is required in the propagation. Otherwise, it creates an invalid pack op. Additionally the revision adds a control option for allowing padding in the pattern which is false by default. To correctly do this, a new `requirePaddingValueStrict` method is added which assumes dynamic dims would mean padding is required. The revision also removes trailing white space in the lit test file. Co-authored-by : Nirvedh Meshram <nirvedh@gmail.com> --------- Signed-off-by: hanhanW <hanhan0912@gmail.com> Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com> Co-authored-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent 876296e commit 6f58c16

File tree

6 files changed

+231
-78
lines changed

6 files changed

+231
-78
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
239239
ArrayRef<int64_t> outerDimsPerm,
240240
ArrayRef<OpFoldResult> innerTiles);
241241

242+
// Same as above function but here dynamic dimensions are assumed
243+
// to require padding.
244+
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
245+
ArrayRef<int64_t> innerDimsPos,
246+
ArrayRef<int64_t> outputShape,
247+
ArrayRef<int64_t> outerDimsPerm,
248+
ArrayRef<OpFoldResult> innerTiles);
249+
242250
static Value createDestinationTensor(OpBuilder &b, Location loc,
243251
Value source, ArrayRef<OpFoldResult> innerTileSizes,
244252
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns(
19141914
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
19151915

19161916
/// Patterns to bubble up or down data layout ops across other operations.
1917+
/// The function also has an option to allow the patterns to propagate with
1918+
/// poison padding if requested by the caller.
19171919
void populateDataLayoutPropagationPatterns(
19181920
RewritePatternSet &patterns,
1919-
const ControlPropagationFn &controlPackUnPackPropagation);
1921+
const ControlPropagationFn &controlPackUnPackPropagation,
1922+
bool PoisonPaddingOk = false);
19201923

19211924
/// Patterns to sink extract slice across other operations.
19221925
void populateExtractSliceSinkingPatterns(

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5310,6 +5310,32 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
53105310
return false;
53115311
}
53125312

5313+
bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5314+
ArrayRef<int64_t> innerDimsPos,
5315+
ArrayRef<int64_t> outputShape,
5316+
ArrayRef<int64_t> outerDimsPerm,
5317+
ArrayRef<OpFoldResult> innerTiles) {
5318+
SmallVector<int64_t> outputTileSizes(
5319+
outputShape.take_front(inputShape.size()));
5320+
if (!outerDimsPerm.empty()) {
5321+
assert(outerDimsPerm.size() == outputTileSizes.size() &&
5322+
"expected output and outer_dims_perm to have same size");
5323+
applyPermutationToVector(outputTileSizes,
5324+
invertPermutationVector(outerDimsPerm));
5325+
}
5326+
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5327+
if (ShapedType::isDynamic(inputShape[pos]) ||
5328+
ShapedType::isDynamic(outputTileSizes[pos]))
5329+
return true;
5330+
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5331+
if (!constantTile)
5332+
return true;
5333+
if (inputShape[pos] % (*constantTile) != 0)
5334+
return true;
5335+
}
5336+
return false;
5337+
}
5338+
53135339
LogicalResult PackOp::verify() {
53145340
if (failed(commonVerifierPackAndUnPackOp(*this)))
53155341
return failure();

0 commit comments

Comments
 (0)