Skip to content

Commit 15c9016

Browse files
reviewer comments
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent 5a5c53e commit 15c9016

File tree

2 files changed

+92
-46
lines changed

2 files changed

+92
-46
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
242242
// Same as above function but here dynamic dimensions are assumed
243243
// to require padding.
244244
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
245-
ArrayRef<int64_t> innerDimsPos,
246-
ArrayRef<int64_t> outputShape,
247-
ArrayRef<int64_t> outerDimsPerm,
248-
ArrayRef<OpFoldResult> innerTiles);
245+
ArrayRef<int64_t> innerDimsPos,
246+
ArrayRef<int64_t> outputShape,
247+
ArrayRef<int64_t> outerDimsPerm,
248+
ArrayRef<OpFoldResult> innerTiles);
249249

250250
static Value createDestinationTensor(OpBuilder &b, Location loc,
251251
Value source, ArrayRef<OpFoldResult> innerTileSizes,

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,21 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
221221
/// inner_dims_pos = [0]
222222
/// inner_tiles = [8]
223223
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
224-
static FailureOr<std::tuple<Value, AffineMap>>
225-
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
226-
GenericOp genericOp, OpOperand *opOperand,
227-
bool poisonPaddingOk) {
224+
225+
struct PackedOperandDetails {
226+
SmallVector<OpFoldResult> innerTileSizes;
227+
SmallVector<int64_t> innerDimsPos;
228+
SmallVector<int64_t> outerDimsPerm;
229+
AffineMap indexingMap;
230+
};
231+
232+
/// Helper function for getOrCreatePackedViewOfOperand that populates
233+
/// the details of the packedOperand that needs to be formed and also
234+
// returns if the packing would require padding.
235+
static bool getPackedOperandDetails(
236+
OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
237+
DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
238+
PackedOperandDetails currOperandDetails;
228239
int64_t numOrigLoops = genericOp.getNumLoops();
229240
int64_t numInnerLoops = packInfo.getNumTiledLoops();
230241
int64_t numLoops = numOrigLoops + numInnerLoops;
@@ -233,9 +244,12 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
233244
SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
234245

235246
// If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
236-
if (genericOp.isScalar(opOperand) || exprs.empty())
237-
return std::make_tuple(opOperand->get(),
238-
AffineMap::get(numLoops, 0, exprs, b.getContext()));
247+
if (genericOp.isScalar(opOperand) || exprs.empty()) {
248+
currOperandDetails.indexingMap =
249+
AffineMap::get(numLoops, 0, exprs, b.getContext());
250+
packedOperandMap[opOperand] = currOperandDetails;
251+
return false;
252+
}
239253

240254
// Step 1. Construct the information of packing data dimensions; append inner
241255
// dimensions to the indexing maps for the operand.
@@ -283,32 +297,57 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
283297
exprs = auxVec;
284298
}
285299
}
286-
auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
300+
currOperandDetails.indexingMap =
301+
AffineMap::get(numLoops, 0, exprs, b.getContext());
287302

288303
// The operand does not have dimensions that relates to pack op.
289-
if (innerDimsPos.empty() && outerDimsPerm.empty())
290-
return std::make_tuple(opOperand->get(), indexingMap);
304+
if (innerDimsPos.empty() && outerDimsPerm.empty()) {
305+
packedOperandMap[opOperand] = currOperandDetails;
306+
return false;
307+
}
291308
auto inputType = cast<RankedTensorType>(opOperand->get().getType());
292-
auto maybeIntInnerTileSizes = getConstantIntValues(innerTileSizes);
293-
if (!maybeIntInnerTileSizes.has_value()) {
294-
return failure();
309+
310+
auto maybeIntInnerTileSizes =
311+
llvm::map_to_vector(innerTileSizes, [](OpFoldResult ofr) -> int64_t {
312+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
313+
return maybeCst.value_or(ShapedType::kDynamic);
314+
});
315+
bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
316+
inputType.getShape(), innerDimsPos,
317+
linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes,
318+
innerDimsPos, outerDimsPerm)
319+
.getShape(),
320+
outerDimsPerm, innerTileSizes);
321+
currOperandDetails.innerDimsPos = innerDimsPos;
322+
currOperandDetails.innerTileSizes = innerTileSizes;
323+
currOperandDetails.outerDimsPerm = outerDimsPerm;
324+
packedOperandMap[opOperand] = currOperandDetails;
325+
326+
if (requirePadding)
327+
return true;
328+
return false;
329+
}
330+
331+
static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
332+
OpBuilder &b, Location loc, OpOperand *opOperand,
333+
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap) {
334+
assert(packedOperandMap.contains(opOperand) &&
335+
"packed operand details expected to be populated");
336+
auto currOperandDetails = packedOperandMap[opOperand];
337+
auto innerDimsPos = currOperandDetails.innerDimsPos;
338+
auto outerDimsPerm = currOperandDetails.outerDimsPerm;
339+
auto innerTileSizes = currOperandDetails.innerTileSizes;
340+
if (innerDimsPos.empty() && outerDimsPerm.empty()) {
341+
return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap);
295342
}
296-
if (!poisonPaddingOk &&
297-
linalg::PackOp::requirePaddingValueStrict(
298-
inputType.getShape(), innerDimsPos,
299-
linalg::PackOp::inferPackedType(inputType, *maybeIntInnerTileSizes,
300-
innerDimsPos, outerDimsPerm)
301-
.getShape(),
302-
outerDimsPerm, innerTileSizes))
303-
return failure();
304343
auto empty = linalg::PackOp::createDestinationTensor(
305344
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
306345
auto poison = ub::PoisonOp::create(
307346
b, loc, getElementTypeOrSelf(opOperand->get().getType()));
308347
Value packedOperand =
309348
linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
310349
innerTileSizes, poison, outerDimsPerm);
311-
return std::make_tuple(packedOperand, indexingMap);
350+
return std::make_tuple(packedOperand, currOperandDetails.indexingMap);
312351
}
313352

314353
/// This function is a helper subroutine to pack a genericOp and return it. It
@@ -330,14 +369,18 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
330369
packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
331370
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
332371
};
372+
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
373+
bool requiresPadding = false;
333374
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
334-
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
335-
rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk);
336-
if (failed(mayBepackedOperandAndIndexing)) {
337-
return failure();
338-
}
339-
auto packedOperand = std::get<0>(*mayBepackedOperandAndIndexing);
340-
auto packedIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
375+
requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
376+
inputOperand, packedOperandMap);
377+
}
378+
if (requiresPadding && !poisonPaddingOk) {
379+
return failure();
380+
}
381+
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
382+
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
383+
rewriter, loc, inputOperand, packedOperandMap);
341384
auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
342385
auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
343386
if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
@@ -492,15 +535,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
492535
}
493536

494537
// Rebuild the indexing map for the corresponding init operand.
495-
auto mayBepackedOperandAndIndexing =
496-
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
497-
genericOp, opOperand, poisonPaddingOk);
498-
if (failed(mayBepackedOperandAndIndexing)) {
538+
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
539+
bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
540+
opOperand, packedOperandMap);
541+
if (requiresPadding && !poisonPaddingOk) {
499542
return failure();
500543
}
501-
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
502-
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
503-
544+
auto [packedOutOperand, packedOutIndexingMap] =
545+
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
546+
packedOperandMap);
504547
// Forward the new tensor.empty as a destination if it is one of the following
505548
// situations:
506549
// 1) The dps init operand is a tensor.empty.
@@ -1139,14 +1182,17 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11391182
return failure();
11401183

11411184
// Rebuild the indexing map for the corresponding init operand.
1142-
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
1143-
rewriter, genericOp.getLoc(), *packInfo, genericOp,
1144-
genericOp.getDpsInitOperand(0), poisonPaddingOk);
1145-
if (failed(mayBepackedOperandAndIndexing)) {
1185+
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
1186+
bool requiresPadding =
1187+
getPackedOperandDetails(rewriter, *packInfo, genericOp,
1188+
genericOp.getDpsInitOperand(0), packedOperandMap);
1189+
if (requiresPadding && !poisonPaddingOk) {
11461190
return failure();
11471191
}
1148-
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
1149-
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
1192+
auto [packedOutOperand, packedOutIndexingMap] =
1193+
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
1194+
genericOp.getDpsInitOperand(0),
1195+
packedOperandMap);
11501196
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
11511197

11521198
// Forward the new tensor.empty as a destination if it is one of the following

0 commit comments

Comments
 (0)