@@ -221,10 +221,21 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
221
221
// / inner_dims_pos = [0]
222
222
// / inner_tiles = [8]
223
223
// / into %init : tensor<?xf32> -> tensor<?x8xf32>
224
- static FailureOr<std::tuple<Value, AffineMap>>
225
- getOrCreatePackedViewOfOperand (OpBuilder &b, Location loc, PackInfo packInfo,
226
- GenericOp genericOp, OpOperand *opOperand,
227
- bool poisonPaddingOk) {
224
+
225
+ struct PackedOperandDetails {
226
+ SmallVector<OpFoldResult> innerTileSizes;
227
+ SmallVector<int64_t > innerDimsPos;
228
+ SmallVector<int64_t > outerDimsPerm;
229
+ AffineMap indexingMap;
230
+ };
231
+
232
+ // / Helper function for getOrCreatePackedViewOfOperand that populates
233
+ // / the details of the packedOperand that needs to be formed and also
234
+ // returns if the packing would require padding.
235
+ static bool getPackedOperandDetails (
236
+ OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
237
+ DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
238
+ PackedOperandDetails currOperandDetails;
228
239
int64_t numOrigLoops = genericOp.getNumLoops ();
229
240
int64_t numInnerLoops = packInfo.getNumTiledLoops ();
230
241
int64_t numLoops = numOrigLoops + numInnerLoops;
@@ -233,9 +244,12 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
233
244
SmallVector<AffineExpr> exprs (origIndexingMap.getResults ());
234
245
235
246
// If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
236
- if (genericOp.isScalar (opOperand) || exprs.empty ())
237
- return std::make_tuple (opOperand->get (),
238
- AffineMap::get (numLoops, 0 , exprs, b.getContext ()));
247
+ if (genericOp.isScalar (opOperand) || exprs.empty ()) {
248
+ currOperandDetails.indexingMap =
249
+ AffineMap::get (numLoops, 0 , exprs, b.getContext ());
250
+ packedOperandMap[opOperand] = currOperandDetails;
251
+ return false ;
252
+ }
239
253
240
254
// Step 1. Construct the information of packing data dimensions; append inner
241
255
// dimensions to the indexing maps for the operand.
@@ -283,32 +297,57 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
283
297
exprs = auxVec;
284
298
}
285
299
}
286
- auto indexingMap = AffineMap::get (numLoops, 0 , exprs, b.getContext ());
300
+ currOperandDetails.indexingMap =
301
+ AffineMap::get (numLoops, 0 , exprs, b.getContext ());
287
302
288
303
// The operand does not have dimensions that relates to pack op.
289
- if (innerDimsPos.empty () && outerDimsPerm.empty ())
290
- return std::make_tuple (opOperand->get (), indexingMap);
304
+ if (innerDimsPos.empty () && outerDimsPerm.empty ()) {
305
+ packedOperandMap[opOperand] = currOperandDetails;
306
+ return false ;
307
+ }
291
308
auto inputType = cast<RankedTensorType>(opOperand->get ().getType ());
292
- auto maybeIntInnerTileSizes = getConstantIntValues (innerTileSizes);
293
- if (!maybeIntInnerTileSizes.has_value ()) {
294
- return failure ();
309
+
310
+ auto maybeIntInnerTileSizes =
311
+ llvm::map_to_vector (innerTileSizes, [](OpFoldResult ofr) -> int64_t {
312
+ std::optional<int64_t > maybeCst = getConstantIntValue (ofr);
313
+ return maybeCst.value_or (ShapedType::kDynamic );
314
+ });
315
+ bool requirePadding = linalg::PackOp::requirePaddingValueStrict (
316
+ inputType.getShape (), innerDimsPos,
317
+ linalg::PackOp::inferPackedType (inputType, maybeIntInnerTileSizes,
318
+ innerDimsPos, outerDimsPerm)
319
+ .getShape (),
320
+ outerDimsPerm, innerTileSizes);
321
+ currOperandDetails.innerDimsPos = innerDimsPos;
322
+ currOperandDetails.innerTileSizes = innerTileSizes;
323
+ currOperandDetails.outerDimsPerm = outerDimsPerm;
324
+ packedOperandMap[opOperand] = currOperandDetails;
325
+
326
+ if (requirePadding)
327
+ return true ;
328
+ return false ;
329
+ }
330
+
331
+ static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand (
332
+ OpBuilder &b, Location loc, OpOperand *opOperand,
333
+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap) {
334
+ assert (packedOperandMap.contains (opOperand) &&
335
+ " packed operand details expected to be populated" );
336
+ auto currOperandDetails = packedOperandMap[opOperand];
337
+ auto innerDimsPos = currOperandDetails.innerDimsPos ;
338
+ auto outerDimsPerm = currOperandDetails.outerDimsPerm ;
339
+ auto innerTileSizes = currOperandDetails.innerTileSizes ;
340
+ if (innerDimsPos.empty () && outerDimsPerm.empty ()) {
341
+ return std::make_tuple (opOperand->get (), currOperandDetails.indexingMap );
295
342
}
296
- if (!poisonPaddingOk &&
297
- linalg::PackOp::requirePaddingValueStrict (
298
- inputType.getShape (), innerDimsPos,
299
- linalg::PackOp::inferPackedType (inputType, *maybeIntInnerTileSizes,
300
- innerDimsPos, outerDimsPerm)
301
- .getShape (),
302
- outerDimsPerm, innerTileSizes))
303
- return failure ();
304
343
auto empty = linalg::PackOp::createDestinationTensor (
305
344
b, loc, opOperand->get (), innerTileSizes, innerDimsPos, outerDimsPerm);
306
345
auto poison = ub::PoisonOp::create (
307
346
b, loc, getElementTypeOrSelf (opOperand->get ().getType ()));
308
347
Value packedOperand =
309
348
linalg::PackOp::create (b, loc, opOperand->get (), empty, innerDimsPos,
310
349
innerTileSizes, poison, outerDimsPerm);
311
- return std::make_tuple (packedOperand, indexingMap);
350
+ return std::make_tuple (packedOperand, currOperandDetails. indexingMap );
312
351
}
313
352
314
353
// / This function is a helper subroutine to pack a genericOp and return it. It
@@ -330,14 +369,18 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
330
369
packOp.getInnerDimsPos () == unPackOp.getInnerDimsPos () &&
331
370
llvm::equal (packOp.getMixedTiles (), unPackOp.getMixedTiles ());
332
371
};
372
+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
373
+ bool requiresPadding = false ;
333
374
for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
334
- auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand (
335
- rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk);
336
- if (failed (mayBepackedOperandAndIndexing)) {
337
- return failure ();
338
- }
339
- auto packedOperand = std::get<0 >(*mayBepackedOperandAndIndexing);
340
- auto packedIndexingMap = std::get<1 >(*mayBepackedOperandAndIndexing);
375
+ requiresPadding |= getPackedOperandDetails (rewriter, packInfo, genericOp,
376
+ inputOperand, packedOperandMap);
377
+ }
378
+ if (requiresPadding && !poisonPaddingOk) {
379
+ return failure ();
380
+ }
381
+ for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
382
+ auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
383
+ rewriter, loc, inputOperand, packedOperandMap);
341
384
auto unpackOp = inputOperand->get ().getDefiningOp <linalg::UnPackOp>();
342
385
auto packOp = packedOperand.getDefiningOp <linalg::PackOp>();
343
386
if (packOp && unpackOp && hasEquivalentTiles (packOp, unpackOp)) {
@@ -492,15 +535,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
492
535
}
493
536
494
537
// Rebuild the indexing map for the corresponding init operand.
495
- auto mayBepackedOperandAndIndexing =
496
- getOrCreatePackedViewOfOperand (rewriter, genericOp. getLoc (), *packInfo,
497
- genericOp, opOperand, poisonPaddingOk );
498
- if (failed (mayBepackedOperandAndIndexing) ) {
538
+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
539
+ bool requiresPadding = getPackedOperandDetails (rewriter, *packInfo, genericOp ,
540
+ opOperand, packedOperandMap );
541
+ if (requiresPadding && !poisonPaddingOk ) {
499
542
return failure ();
500
543
}
501
- auto packedOutOperand = std::get< 0 >(*mayBepackedOperandAndIndexing);
502
- auto packedOutIndexingMap = std::get< 1 >(*mayBepackedOperandAndIndexing);
503
-
544
+ auto [ packedOutOperand, packedOutIndexingMap] =
545
+ getOrCreatePackedViewOfOperand (rewriter, genericOp. getLoc (), opOperand,
546
+ packedOperandMap);
504
547
// Forward the new tensor.empty as a destination if it is one of the following
505
548
// situations:
506
549
// 1) The dps init operand is a tensor.empty.
@@ -1139,14 +1182,17 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1139
1182
return failure ();
1140
1183
1141
1184
// Rebuild the indexing map for the corresponding init operand.
1142
- auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand (
1143
- rewriter, genericOp.getLoc (), *packInfo, genericOp,
1144
- genericOp.getDpsInitOperand (0 ), poisonPaddingOk);
1145
- if (failed (mayBepackedOperandAndIndexing)) {
1185
+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
1186
+ bool requiresPadding =
1187
+ getPackedOperandDetails (rewriter, *packInfo, genericOp,
1188
+ genericOp.getDpsInitOperand (0 ), packedOperandMap);
1189
+ if (requiresPadding && !poisonPaddingOk) {
1146
1190
return failure ();
1147
1191
}
1148
- auto packedOutOperand = std::get<0 >(*mayBepackedOperandAndIndexing);
1149
- auto packedOutIndexingMap = std::get<1 >(*mayBepackedOperandAndIndexing);
1192
+ auto [packedOutOperand, packedOutIndexingMap] =
1193
+ getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (),
1194
+ genericOp.getDpsInitOperand (0 ),
1195
+ packedOperandMap);
1150
1196
auto destPack = packedOutOperand.getDefiningOp <linalg::PackOp>();
1151
1197
1152
1198
// Forward the new tensor.empty as a destination if it is one of the following
0 commit comments