Skip to content

Commit

Permalink
[mlir][tensor] Improve FoldTensorCastProducerOp (dynamic shapes) (#…
Browse files Browse the repository at this point in the history
…114559)

Currently, `FoldTensorCastProducerOp` incorrectly folds the following:
```mlir
    %pack = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
    %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
```
as (note the static trailing dim in the result and dynamic tile
dimension that corresponds to that):
```mlir
    %res = tensor.pack %src
      padding_value(%pad : i32)
      inner_dims_pos = [0, 1]
      inner_tiles = [%c8, 1]
      into %cast : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
```

This triggers an Op verification failure and is due to the fact that the
folder does not update the inner tile sizes in the pack Op. This PR
addresses that.

Note, supporting other Ops with size-like attributes is left as a TODO.
  • Loading branch information
banach-space authored Nov 5, 2024
1 parent a993dfc commit 9b9369e
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 33 deletions.
143 changes: 112 additions & 31 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4698,6 +4698,111 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
// 1. InsertSliceOp has its own logic about folding tensor.cast ops.
// 2. Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
if (isa<InsertSliceOp>(op.getOperation()) ||
isa<LoopLikeOpInterface>(op.getOperation()))
return false;

// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});

return hasTensorCastOperand;
}

static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
SmallVector<Type> &newResTy) {
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());

// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResTy[dpsInitIdx++] = newOperands.back().getType();
}
return newOperands;
}

/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
/// `tensor.cast` has source that is more static than the consuming op.
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
/// ```
struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(PackOp op,
PatternRewriter &rewriter) const override {
if (!foldTensorCastPrecondition(op))
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);

// Get the updated mixed-tile-sizes attribute.
SmallVector<OpFoldResult> newMixedTileSizes;
for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
.getShape()
.take_back(op.getMixedTiles().size()),
op.getMixedTiles())) {
int64_t shape = std::get<0>(it);
if (shape == ShapedType::kDynamic) {
newMixedTileSizes.push_back(std::get<1>(it));
continue;
}

if (Attribute attr =
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
// Already a constant
newMixedTileSizes.push_back(std::get<1>(it));
} else {
int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
assert(tileSize == shape && "tile size and dim size don't match!");
newMixedTileSizes.push_back(
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
}
}

// Clone op.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());

// Replace op.
Value oldResult = op.getResult();
Value newResult = newOp.getResult();
Value replacement = (newResult.getType() != oldResult.getType())
? rewriter.create<tensor::CastOp>(
op->getLoc(), oldResult.getType(), newResult)
: newResult;

rewriter.replaceOp(op, {replacement});

return success();
}
};

/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
Expand All @@ -4722,42 +4827,17 @@ struct FoldTensorCastProducerOp

LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const override {
// InsertSliceOp has its own logic about folding tensor.cast ops.
if (isa<InsertSliceOp>(op.getOperation()))
return failure();

// Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
if (isa<LoopLikeOpInterface>(op.getOperation()))
// Reject tensor::PackOp - there's dedicated pattern for that instead.
if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
return failure();

// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);

SmallVector<Type, 4> newResultTypes(op->getResultTypes());
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResultTypes[dpsInitIdx++] = newOperands.back().getType();
}
// Clone op
auto newOp = clone(rewriter, op, newResultTypes, newOperands);

// Clone op.
Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
SmallVector<Value, 4> replacements;
replacements.reserve(newOp->getNumResults());
for (auto [oldResult, newResult] :
Expand All @@ -4781,6 +4861,7 @@ struct FoldTensorCastProducerOp

void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}

Expand Down
23 changes: 21 additions & 2 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {

// -----

// CHECK-LABEL: func.func @test_destination_multiple_result(
// CHECK-LABEL: func.func @fold_cast_multiple_results(
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
// CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
// CHECK: return %[[RES]]#1 : index
func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
%cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
%cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
return %0#1 : index
}
// -----

// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>,
// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>,
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
func.func @fold_cast_pack_dynamic_tile_size(
%dest: tensor<1x1x8x1xi32>,
%src: tensor<7x?xi32>,
%pad: i32) -> tensor<1x1x8x1xi32> {

%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
%c8 = arith.constant 8 : index
%pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
return %res : tensor<1x1x8x1xi32>
}

// -----

Expand Down

0 comments on commit 9b9369e

Please sign in to comment.