diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index c2d6bc610cd92a..1847066b2d1e36 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4698,6 +4698,111 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// // Common Canonicalizers and Folders. //===----------------------------------------------------------------------===// +bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { + // 1. InsertSliceOp has its own logic about folding tensor.cast ops. + // 2. Exclude DPS ops that are also LoopLike from this interface as they + // might need special handling of attached regions. + if (isa(op.getOperation()) || + isa(op.getOperation())) + return false; + + // If no operand comes from a tensor::CastOp and can be folded then fail. + bool hasTensorCastOperand = + llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { + if (llvm::isa(opOperand.get())) + return false; + auto castOp = opOperand.get().getDefiningOp(); + return castOp && canFoldIntoConsumerOp(castOp); + }); + + return hasTensorCastOperand; +} + +static SmallVector getNewOperands(DestinationStyleOpInterface op, + SmallVector &newResTy) { + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + + // Assumes that the result has dpsInits followed by nonDpsInits. + int64_t dpsInitIdx = 0; + for (OpOperand &opOperand : op->getOpOperands()) { + auto tensorCastOp = opOperand.get().getDefiningOp(); + bool fold = canFoldIntoConsumerOp(tensorCastOp); + newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); + if (op.isDpsInit(&opOperand) && + !llvm::isa(newOperands.back().getType())) + newResTy[dpsInitIdx++] = newOperands.back().getType(); + } + return newOperands; +} + +/// Folds a tensor.cast op into a consuming tensor::PackOp op if the +/// `tensor.cast` has source that is more static than the consuming op. +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor +/// %2 = tensor.pack %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ... +/// ``` +struct FoldTensorCastPackOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOp op, + PatternRewriter &rewriter) const override { + if (!foldTensorCastPrecondition(op)) + return failure(); + + SmallVector newResultTypes(op->getResultTypes()); + SmallVector newOperands = getNewOperands(op, newResultTypes); + + // Get the updated mixed-tile-sizes attribute. + SmallVector newMixedTileSizes; + for (auto it : llvm::zip(cast(newResultTypes[0]) + .getShape() + .take_back(op.getMixedTiles().size()), + op.getMixedTiles())) { + int64_t shape = std::get<0>(it); + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + if (Attribute attr = + llvm::dyn_cast_if_present(std::get<1>(it))) { + // Already a constant + newMixedTileSizes.push_back(std::get<1>(it)); + } else { + int64_t tileSize = getConstantIntValue(std::get<1>(it)).value(); + assert(tileSize == shape && "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + + // Clone op. + PackOp newOp = rewriter.create( + op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), + newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); + + // Replace op. + Value oldResult = op.getResult(); + Value newResult = newOp.getResult(); + Value replacement = (newResult.getType() != oldResult.getType()) + ? rewriter.create( + op->getLoc(), oldResult.getType(), newResult) + : newResult; + + rewriter.replaceOp(op, {replacement}); + + return success(); + } +}; /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if /// the `tensor.cast` has source that is more static than the consuming op. @@ -4722,42 +4827,17 @@ struct FoldTensorCastProducerOp LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override { - // InsertSliceOp has its own logic about folding tensor.cast ops. - if (isa(op.getOperation())) - return failure(); - // Exclude DPS ops that are also LoopLike from this interface as they - // might need special handling of attached regions. - if (isa(op.getOperation())) + // Reject tensor::PackOp - there's dedicated pattern for that instead. + if (!foldTensorCastPrecondition(op) || dyn_cast(*op)) return failure(); - // If no operand comes from a tensor::CastOp and can be folded then fail. - bool hasTensorCastOperand = - llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { - if (llvm::isa(opOperand.get())) - return false; - auto castOp = opOperand.get().getDefiningOp(); - return castOp && canFoldIntoConsumerOp(castOp); - }); - if (!hasTensorCastOperand) - return failure(); + SmallVector newResultTypes(op->getResultTypes()); + SmallVector newOperands = getNewOperands(op, newResultTypes); - SmallVector newResultTypes(op->getResultTypes()); - SmallVector newOperands; - newOperands.reserve(op->getNumOperands()); - // Assumes that the result has dpsInits followed by nonDpsInits. - int64_t dpsInitIdx = 0; - for (OpOperand &opOperand : op->getOpOperands()) { - auto tensorCastOp = opOperand.get().getDefiningOp(); - bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); - if (op.isDpsInit(&opOperand) && - !llvm::isa(newOperands.back().getType())) - newResultTypes[dpsInitIdx++] = newOperands.back().getType(); - } + // Clone op + auto newOp = clone(rewriter, op, newResultTypes, newOperands); - // Clone op. - Operation *newOp = clone(rewriter, op, newResultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto [oldResult, newResult] : @@ -4781,6 +4861,7 @@ struct FoldTensorCastProducerOp void TensorDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { + results.add(getContext()); results.add(getContext()); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 236d2a3e60eb2c..2186aab9a527cc 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> { // ----- -// CHECK-LABEL: func.func @test_destination_multiple_result( +// CHECK-LABEL: func.func @fold_cast_multiple_results( // CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>, // CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index { // CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>) // CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index // CHECK: return %[[RES]]#1 : index -func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index { +func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index { %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor %0:2 = test.destination_style_op ins(%cast : tensor) outs(%cast_0 : tensor) -> tensor, index return %0#1 : index } +// ----- + +// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size +// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>, +// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { +// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32> +// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> +func.func @fold_cast_pack_dynamic_tile_size( + %dest: tensor<1x1x8x1xi32>, + %src: tensor<7x?xi32>, + %pad: i32) -> tensor<1x1x8x1xi32> { + + %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32> + %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> + return %res : tensor<1x1x8x1xi32> +} // -----