20
20
#include " mlir/Dialect/Utils/IndexingUtils.h"
21
21
#include " mlir/Dialect/Utils/StaticValueUtils.h"
22
22
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
23
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
23
24
#include " mlir/Interfaces/TilingInterface.h"
24
25
#include " mlir/Interfaces/ValueBoundsOpInterface.h"
25
26
#include " llvm/Support/Debug.h"
@@ -887,26 +888,55 @@ struct PackOpTiling
887
888
888
889
ArrayRef<OpFoldResult> offsets (allOffsets[0 ]);
889
890
ArrayRef<OpFoldResult> sizes (allSizes[0 ]);
890
-
891
891
auto packOp = cast<PackOp>(op);
892
- // It is not trivial to infer dest tile from source tile if `packOp` has
893
- // padding semantic.
894
- if (packOp.getPaddingValue ())
895
- return failure ();
896
-
897
892
Location loc = packOp.getLoc ();
898
-
899
893
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
900
894
DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
901
895
packOp.getDimAndTileMapping ();
902
896
for (auto dim : llvm::seq<int64_t >(packOp.getSourceRank ())) {
903
897
if (dimAndTileMapping.count (dim)) {
904
- FailureOr<int64_t > cstSize =
898
+ FailureOr<int64_t > cstTileSize =
905
899
ValueBoundsConstraintSet::computeConstantBound (
906
900
presburger::BoundType::UB, sizes[dim],
907
901
/* stopCondition=*/ nullptr , /* closedUB=*/ true );
908
902
std::optional<int64_t > cstInnerSize =
909
903
getConstantIntValue (dimAndTileMapping[dim]);
904
+
905
+ // If a dimension is not tiled, it is always valid to fuse the pack op,
906
+ // even if the op has padding semantics. Because it always generates a
907
+ // full slice along the dimension.
908
+ // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
909
+ // hard check to determine if a dimension is tiled or not.
910
+ int64_t srcDimSize = packOp.getSourceType ().getDimSize (dim);
911
+ int64_t destDimSize = packOp.getDestType ().getDimSize (dim);
912
+ bool isTiled = failed (cstTileSize) ||
913
+ ShapedType::isDynamic (srcDimSize) ||
914
+ cstTileSize.value () != srcDimSize;
915
+ if (!isTiled) {
916
+ outerDimOffsets.push_back (offsets[dim]);
917
+ if (ShapedType::isStatic (destDimSize)) {
918
+ outerDimSizes.push_back (b.getIndexAttr (destDimSize));
919
+ } else {
920
+ outerDimSizes.push_back (
921
+ b.createOrFold <tensor::DimOp>(loc, packOp.getDest (), dim));
922
+ }
923
+ continue ;
924
+ }
925
+
926
+ // If the dimension needs padding, it is not supported because there are
927
+ // iterations that only write padding values to the whole tile. The
928
+ // consumer fusion is driven by the source, so it is not possible to map
929
+ // an empty slice to the tile.
930
+ bool needExtraPadding =
931
+ ShapedType::isDynamic (destDimSize) || !cstInnerSize ||
932
+ destDimSize * cstInnerSize.value () != srcDimSize;
933
+ // Prioritize the case that the op already says that it does not need
934
+ // padding.
935
+ if (!packOp.getPaddingValue ())
936
+ needExtraPadding = false ;
937
+ if (needExtraPadding)
938
+ return failure ();
939
+
910
940
// Currently fusing `packOp` as consumer only expects perfect tiling
911
941
// scenario because even if without padding semantic, the `packOp` may
912
942
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -921,9 +951,9 @@ struct PackOpTiling
921
951
// another word, we can only support tiling with consumer if the tile
922
952
// size for the producer is a multiple of the inner tile size for the
923
953
// packed dimensions at this moment.
924
- if (failed (cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0 ) {
954
+ if ((failed (cstTileSize) || !cstInnerSize ||
955
+ *cstTileSize % *cstInnerSize != 0 ))
925
956
return failure ();
926
- }
927
957
928
958
using AV = affine::AffineValueExpr;
929
959
affine::AffineBuilder ab (b, loc);
@@ -988,7 +1018,8 @@ struct PackOpTiling
988
1018
loc, packOp.getDest (), outputOffsets, outputSizes, strides);
989
1019
tiledOperands.push_back (outSlice);
990
1020
991
- assert (!packOp.getPaddingValue () && " Expect no padding semantic" );
1021
+ if (auto val = packOp.getPaddingValue ())
1022
+ tiledOperands.push_back (val);
992
1023
for (auto tile : packOp.getInnerTiles ())
993
1024
tiledOperands.push_back (tile);
994
1025
0 commit comments