From fd48251127299ed9fc0b55d97e21fea94954b110 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 21 Sep 2023 16:24:11 -0700 Subject: [PATCH 1/3] [mlir][TilingInterface] Add `scf::tileUsingSCFForallOp` method to tile using the interface to generate `scf::forall`. Similar to `scf::tileUsingSCFForOp` that is a method that tiles operations that implement the `TilingInterface`, using `scf.for` operations, this method introduces tiling of operations using `scf.forall`. Most of this implementation is derived from `linalg::tileToForallOp` method. Eventually that method will either be deprecated or moved to use the method introduced here. --- .../SCF/Transforms/TileUsingInterface.h | 17 +++ .../SCF/Transforms/TileUsingInterface.cpp | 133 ++++++++++++++++++ .../TilingInterface/tile-using-scfforall.mlir | 37 +++++ .../TilingInterface/TestTilingInterface.cpp | 69 +++++++++ 4 files changed, 256 insertions(+) create mode 100644 mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 9f49d97e141e0..06cce19894e9f 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -51,6 +51,17 @@ struct SCFTilingOptions { interchangeVector = llvm::to_vector(interchange); return *this; } + + /// Specify mapping of loops to devices. This is only respected when the loop + /// constructs support such a mapping (like `scf.forall`). Will be ignored + /// when using loop constructs that dont support such a mapping (like + /// `scf.for`) + SmallVector mappingVector = {}; + SCFTilingOptions &setMapping(ArrayRef mapping) { + mappingVector = llvm::to_vector( + llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; })); + return *this; + } }; /// Transformation information returned after tiling. @@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions { } }; +/// Method to tile and op that implements the `TilingInterface` using +/// `scf.forall`. +FailureOr +tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, + const SCFTilingOptions &options); + /// Fuse the producer of the source of `candidateSliceOp` by computing the /// required slice of the producer in-place. Note that the method /// replaces the uses of `candidateSliceOp` with the tiled and fused producer diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 96d6169111b38..a58cd7a7541a5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -122,6 +122,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, b, loc, minMap, SmallVector{iv, tileSize, size}); } +/// Clones the operation and updates the destination if the operation +/// implements the `DestinationStyleOpInterface`. +static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, + Operation *op, + ValueRange newDestArgs) { + Operation *clonedOp = rewriter.clone(*op); + if (auto destinationStyleOp = + dyn_cast(clonedOp)) { + // Note that this is assuming that + auto [start, end] = destinationStyleOp.getDpsInitsPositionRange(); + assert((end - start == newDestArgs.size()) && + "expected as many new destination args as number of inits of the " + "operation"); + clonedOp->setOperands(start, end - start, newDestArgs); + } + return clonedOp; +} + /// Generate an empty loop nest that represents the tiled loop nest shell. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. @@ -728,6 +746,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( getAsOperations(forLoops), replacements}; } +//===----------------------------------------------------------------------===// +// tileUsingSCFForAllOp implementation. +//===----------------------------------------------------------------------===// + +FailureOr +mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, + const scf::SCFTilingOptions &options) { + Location loc = op->getLoc(); + OpBuilder::InsertionGuard g(rewriter); + + // 1. Get the range of loops that are represented by the operation. + SmallVector loopRanges = op.getIterationDomain(rewriter); + if (loopRanges.empty()) + return op->emitOpError("expected non-empty loop ranges"); + auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; + if (llvm::any_of(loopRanges, hasStrideOne)) + return op->emitOpError("only stride-1 supported atm"); + + // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed. + // To make it easier, pad the tile sizes to loopRanges.size with value 0. + SmallVector tileSizeVector = + options.tileSizeComputationFunction(rewriter, op); + tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0)); + + // 3. Build the offsets, sizes and steps for the tile and distributed loops. + SmallVector lbs, ubs, steps; + for (auto [index, tileSize, loopRange] : + llvm::enumerate(tileSizeVector, loopRanges)) { + if (isConstantIntValue(tileSize, 0)) + continue; + lbs.push_back(loopRange.offset); + ubs.push_back(loopRange.size); + steps.push_back(tileSize); + } + + // 4. Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest))) + return op->emitOpError("failed to get destination tensors"); + + // 5. Build the device mapping attribute; + std::optional mappingAttr; + if (!options.mappingVector.empty()) { + mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector)); + } + + // 6. Create the ForallOp. We don't use the lambda body-builder + // version because we require the use of RewriterBase in the body, so we + // manually move the insertion point to the body below. + auto forallOp = + rewriter.create(loc, lbs, ubs, steps, dest, mappingAttr); + + // 7. Get the tile offset and sizes. + rewriter.setInsertionPoint(forallOp.getTerminator()); + SmallVector tiledOffsets, tiledSizes; + tiledOffsets.reserve(loopRanges.size()); + tiledSizes.reserve(loopRanges.size()); + ValueRange ivs = forallOp.getInductionVars(); + { + int materializedLoopNum = 0; + for (auto [index, tileSize, loopRange] : + llvm::enumerate(tileSizeVector, loopRanges)) { + if (isConstantIntValue(tileSize, 0)) { + tiledOffsets.push_back(loopRange.offset); + tiledSizes.push_back(loopRange.size); + continue; + } + Value iv = ivs[materializedLoopNum++]; + tiledOffsets.push_back(iv); + tiledSizes.push_back( + getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); + } + } + + // 8. Tile the operation. Clone the operation to allow fix up of destination + // operands + ArrayRef destBbArgs = forallOp.getOutputBlockArguments(); + Operation *clonedOp = + cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs); + FailureOr tilingResult = + cast(clonedOp).getTiledImplementation( + rewriter, tiledOffsets, tiledSizes); + if (failed(tilingResult)) + return clonedOp->emitError("Failed to tile op: "); + rewriter.eraseOp(clonedOp); + + // 9. Parallel insert back into the result tensor. + for (auto [index, tiledValue, destBBArg] : + llvm::enumerate(tilingResult->tiledValues, destBbArgs)) { + // 9.a. Partial subset information is inserted just before the terminator. + rewriter.setInsertionPoint(forallOp.getTerminator()); + + SmallVector resultOffsets, resultSizes; + if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets, + tiledSizes, resultOffsets, + resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector strides(resultSizes.size(), + rewriter.getIndexAttr(1)); + + // 5.b. Parallel insertions are inserted at the end of the combining + // terminator. + rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); + rewriter.create( + loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides); + } + + // 10. Return the tiling result; + return scf::SCFTilingResult{ + tilingResult->tiledOps, + {forallOp.getOperation()}, + llvm::to_vector(llvm::map_range(forallOp.getResults(), + [](auto val) -> Value { return val; }))}; +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir new file mode 100644 index 0000000000000..bfc352c764ad1 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s + +func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul {__internal_transform__ = "simple_gemm"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> +// CHECK: func.func @simple_matmul( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]] +// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1] +// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1] +// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 2573e11979dbc..2bec859b50f26 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -186,6 +186,51 @@ struct TestTileUsingSCFForOp TransformationFilter filter; }; +/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using +/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles) +/// while using a `filter` to avoid recursive application. +struct TestTileUsingSCFForallOp + : public OpInterfaceRewritePattern { + TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} + + /// Construct a generic pattern applied to `opName`. + TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context, + scf::SCFTilingOptions options, + TransformationFilter filter = TransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} + + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + FailureOr tilingResult = + scf::tileUsingSCFForallOp(rewriter, op, options); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(op, "failed to tile operation"); + + if (op->getNumResults()) { + rewriter.replaceOp(op, tilingResult->replacements); + } else { + rewriter.eraseOp(op); + } + + for (auto *tiledOp : tilingResult->tiledOps) + filter.replaceTransformationFilter(rewriter, tiledOp); + return success(); + } + +private: + scf::SCFTilingOptions options; + TransformationFilter filter; +}; + /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive @@ -415,6 +460,12 @@ struct TestTilingInterfacePass "Test tiling using TilingInterface with scf.for operations"), llvm::cl::init(false)}; + Option testTilingForAll{ + *this, "tile-using-scf-forall", + llvm::cl::desc( + "Test tiling using TilingInterface with scf.forall operations"), + llvm::cl::init(false)}; + Option testTileConsumerFuseAndYieldProducer{ *this, "tile-consumer-fuse-and-yield-producer-using-scf-for", llvm::cl::desc( @@ -455,6 +506,20 @@ static void addPatternForTiling(MLIRContext *context, patterns.add(context, tilingOptions, filter); } +static void addPatternForTilingUsingForall(MLIRContext *context, + RewritePatternSet &patterns, + StringRef filterName, + ArrayRef tileSizes, + ArrayRef interchange = {}) { + scf::SCFTilingOptions tilingOptions; + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(context, tileSizes); + tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); + TransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); + patterns.add(context, tilingOptions, filter); +} + static void addPatternForTileFuseAndYield(MLIRContext *context, RewritePatternSet &patterns, StringRef filterName, @@ -514,6 +579,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20}); return; } + if (testTilingForAll) { + addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20}); + return; + } if (testTileConsumerAndFuseProducer) { // 1. Tile and fuse of gemm with fill producer and bias-add consumer. addPatternForTileAndFuse(context, patterns, "fusion", {10, 20}); From d1f1103792c206c65b756bb3b935c528a5485681 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 18 Oct 2023 16:20:39 -0700 Subject: [PATCH 2/3] Add lit tests. --- .../SCF/Transforms/TileUsingInterface.cpp | 11 +- .../TilingInterface/tile-using-scfforall.mlir | 133 +++++++++++++++++- .../TilingInterface/TestTilingInterface.cpp | 10 ++ 3 files changed, 144 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index a58cd7a7541a5..a45918eb062ee 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -101,10 +101,10 @@ static bool tileDividesIterationDomain(Range loopRange) { /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, - Value tileSize) { + OpFoldResult tileSize) { std::optional ts = getConstantIntValue(tileSize); if (ts && ts.value() == 1) - return getAsOpFoldResult(tileSize); + return tileSize; if (tileDividesIterationDomain( Range{loopRange.offset, loopRange.size, tileSize})) @@ -130,12 +130,7 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *clonedOp = rewriter.clone(*op); if (auto destinationStyleOp = dyn_cast(clonedOp)) { - // Note that this is assuming that - auto [start, end] = destinationStyleOp.getDpsInitsPositionRange(); - assert((end - start == newDestArgs.size()) && - "expected as many new destination args as number of inits of the " - "operation"); - clonedOp->setOperands(start, end - start, newDestArgs); + destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); } return clonedOp; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir index bfc352c764ad1..709ecb6a97e3c 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir @@ -1,9 +1,9 @@ -// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %0 = linalg.matmul {__internal_transform__ = "simple_gemm"} - ins(%arg0, %arg1 : tensor, tensor) + ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } @@ -35,3 +35,132 @@ func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, // CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]] // CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] // CHECK: return %[[RESULT]] + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) { + %init0 = tensor.empty() : tensor<128x300x200xf32> + %init1 = tensor.empty() : tensor<300x128x200xf32> + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel"]} + {__internal_transform__ = "parallel_generic_transpose"} + ins(%arg0 : tensor<128x200x300xf32>) + outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + linalg.yield %b0, %b0 : f32, f32 + } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) + return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32> +} +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)> +// CHECK-LABEL: func.func @multi_result( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() +// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = (0, 0) to (128, 300) step (10, 20) +// CHECK-SAME: shared_outs(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]]) +// CHECK: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG2]] +// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1] +// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG_TILE]] : +// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice %[[RESULT_TILE]]#0 into %[[ARG1]][%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1] +// CHECK-DAG: tensor.parallel_insert_slice %[[RESULT_TILE]]#1 into %[[ARG2]][%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1] +// CHECK: } +// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1 + +// ----- + +func.func @conv2D(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf { + strides = dense<[2, 3]> : tensor<2xi64>, + dilation = dense<[4, 5]> : tensor<2xi64>, + __internal_transform__ = "simple_conv"} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)> +// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)> +// CHECK-LABEL: func.func @conv2D( +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]] +// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]] +// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]] +// CHECK-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]] +// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]] +// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]] +// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]] +// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]], %[[IV2:[a-zA-Z0-9]+]]) = +// CHECK-SAME: (0, 0, 0) to (%[[P]], %[[Q]], %[[C]]) step (10, 20, 30) shared_outs(%[[INIT0:.+]] = %[[INIT]]) +// CHECK-DAG: %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]] +// CHECK-DAG: %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]] +// CHECK-DAG: %[[TS_C:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[C]]] +// CHECK-DAG: %[[TS_H:.+]] = affine.apply #[[$MAP3]](%[[TS_P]])[%[[R]]] +// CHECK-DAG: %[[TS_W:.+]] = affine.apply #[[$MAP4]](%[[TS_Q]])[%[[S]]] +// CHECK-DAG: %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]] +// CHECK-SAME: [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]] +// CHECK-DAG: %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]] +// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT0]] +// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] +// CHECK: %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64> +// CHECK-SAME: ins(%[[INPUT_TILE]], %[[FILTER_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice %[[CONV_TILE]] into %[[INIT0]] +// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] [1, 1, 1, 1] +// CHECK: return %[[RESULT]] + +// ----- + +// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)> + +func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> tensor { + // Check that we correctly amend "linalg.index" results. + + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + {__internal_transform__ = "indexed_semantics"} + ins(%arg0: tensor) + outs(%arg1: tensor) { + ^bb0(%arg2: f32, %arg3: f32): + %1 = linalg.index 0 : index + %2 = linalg.index 1 : index + %3 = arith.addi %1, %2 : index + %4 = arith.index_cast %3 : index to i64 + %5 = arith.uitofp %4 : i64 to f32 + %6 = arith.addf %5, %arg2 : f32 + linalg.yield %6 : f32 + } -> (tensor) + return %0 : tensor +} +// CHECK-LABEL: @indexed_semantics +// CHECK: scf.forall (%[[I0:.+]], %[[I1:.+]]) = +// CHECK: %[[INDEX0:.+]] = linalg.index 0 +// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]]) +// CHECK: %[[INDEX1:.+]] = linalg.index 1 +// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]]) +// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 2bec859b50f26..04632567ee2a7 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -580,7 +580,17 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, return; } if (testTilingForAll) { + // 1. Tiling M and N dims of `linalg.matmul` on tensors. addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20}); + // 2. Tiling 3D parallel generic op which implements a transpose. + addPatternForTilingUsingForall(context, patterns, + "parallel_generic_transpose", {10, 0, 20}); + // 3. Tiling 2D conv op. + addPatternForTilingUsingForall(context, patterns, "simple_conv", + {0, 0, 0, 0, 10, 20, 30}); + // 4. Tiling a simple op with `linalg.index` inside. + addPatternForTilingUsingForall(context, patterns, "indexed_semantics", + {10, 20}); return; } if (testTileConsumerAndFuseProducer) { From 55f9518c728af5be1cef142e592ad612438e0edc Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 19 Oct 2023 23:00:35 -0700 Subject: [PATCH 3/3] Address comments. --- .../SCF/Transforms/TileUsingInterface.h | 6 ++--- .../SCF/Transforms/TileUsingInterface.cpp | 27 +++++++++---------- .../TilingInterface/tile-using-scfforall.mlir | 1 + .../TilingInterface/TestTilingInterface.cpp | 23 +++++++++------- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 06cce19894e9f..81325b62791c4 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -58,8 +58,8 @@ struct SCFTilingOptions { /// `scf.for`) SmallVector mappingVector = {}; SCFTilingOptions &setMapping(ArrayRef mapping) { - mappingVector = llvm::to_vector( - llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; })); + mappingVector = llvm::map_to_vector( + mapping, [](auto attr) -> Attribute { return attr; }); return *this; } }; @@ -93,7 +93,7 @@ struct SCFTileAndFuseOptions { } }; -/// Method to tile and op that implements the `TilingInterface` using +/// Method to tile an op that implements the `TilingInterface` using /// `scf.forall`. FailureOr tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index a45918eb062ee..2c6e66de6dc60 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -767,8 +767,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, // 3. Build the offsets, sizes and steps for the tile and distributed loops. SmallVector lbs, ubs, steps; - for (auto [index, tileSize, loopRange] : - llvm::enumerate(tileSizeVector, loopRanges)) { + for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) { if (isConstantIntValue(tileSize, 0)) continue; lbs.push_back(loopRange.offset); @@ -781,7 +780,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest))) return op->emitOpError("failed to get destination tensors"); - // 5. Build the device mapping attribute; + // 5. Build the device mapping attribute. std::optional mappingAttr; if (!options.mappingVector.empty()) { mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector)); @@ -796,13 +795,10 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, // 7. Get the tile offset and sizes. rewriter.setInsertionPoint(forallOp.getTerminator()); SmallVector tiledOffsets, tiledSizes; - tiledOffsets.reserve(loopRanges.size()); - tiledSizes.reserve(loopRanges.size()); ValueRange ivs = forallOp.getInductionVars(); { int materializedLoopNum = 0; - for (auto [index, tileSize, loopRange] : - llvm::enumerate(tileSizeVector, loopRanges)) { + for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) { if (isConstantIntValue(tileSize, 0)) { tiledOffsets.push_back(loopRange.offset); tiledSizes.push_back(loopRange.size); @@ -816,7 +812,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, } // 8. Tile the operation. Clone the operation to allow fix up of destination - // operands + // operands. ArrayRef destBbArgs = forallOp.getOutputBlockArguments(); Operation *clonedOp = cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs); @@ -824,7 +820,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, cast(clonedOp).getTiledImplementation( rewriter, tiledOffsets, tiledSizes); if (failed(tilingResult)) - return clonedOp->emitError("Failed to tile op: "); + return clonedOp->emitError("failed to tile op: "); rewriter.eraseOp(clonedOp); // 9. Parallel insert back into the result tensor. @@ -836,24 +832,25 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op, SmallVector resultOffsets, resultSizes; if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets, tiledSizes, resultOffsets, - resultSizes))) + resultSizes))) { return op->emitOpError("output offsets couldn't be calculated"); + } + SmallVector strides(resultSizes.size(), rewriter.getIndexAttr(1)); - - // 5.b. Parallel insertions are inserted at the end of the combining + // 9.b. Parallel insertions are inserted at the end of the combining // terminator. rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); rewriter.create( loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides); } - // 10. Return the tiling result; + // 10. Return the tiling result. return scf::SCFTilingResult{ tilingResult->tiledOps, {forallOp.getOperation()}, - llvm::to_vector(llvm::map_range(forallOp.getResults(), - [](auto val) -> Value { return val; }))}; + llvm::map_to_vector(forallOp.getResults(), + [](auto val) -> Value { return val; })}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir index 709ecb6a97e3c..314efde45720a 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir @@ -34,6 +34,7 @@ func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, // CHECK: scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]] // CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1] +// CHECK: mapping = [#gpu.block, #gpu.block] // CHECK: return %[[RESULT]] // ----- diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 04632567ee2a7..e5d7dc54409e4 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -443,9 +444,9 @@ struct TestTilingInterfacePass TestTilingInterfacePass(const TestTilingInterfacePass &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); linalg::registerTilingInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); } @@ -506,15 +507,16 @@ static void addPatternForTiling(MLIRContext *context, patterns.add(context, tilingOptions, filter); } -static void addPatternForTilingUsingForall(MLIRContext *context, - RewritePatternSet &patterns, - StringRef filterName, - ArrayRef tileSizes, - ArrayRef interchange = {}) { +static void addPatternForTilingUsingForall( + MLIRContext *context, RewritePatternSet &patterns, StringRef filterName, + ArrayRef tileSizes, + ArrayRef mapping = {}, + ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; SmallVector tileSizesOfr = getAsIndexOpFoldResult(context, tileSizes); tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); + tilingOptions.setMapping(mapping); TransformationFilter filter(StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); patterns.add(context, tilingOptions, filter); @@ -581,7 +583,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, } if (testTilingForAll) { // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20}); + addPatternForTilingUsingForall( + context, patterns, "simple_gemm", {10, 20}, + {gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimY), + gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimX)}); // 2. Tiling 3D parallel generic op which implements a transpose. addPatternForTilingUsingForall(context, patterns, "parallel_generic_transpose", {10, 0, 20});