Skip to content

Commit

Permalink
[mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (
Browse files Browse the repository at this point in the history
#113571)

This fixes a bug in the tiling implementation of tensor.unpack that was
causing an infinite loop when certain unpack ops get tiled and fused as
a producer. The tiled implementation of tensor.unpack sometimes needs to
create an additional tensor.extract_slice on the result of the tiled
unpack op, but this slice was getting added to the `generatedSlices` of
the tiling result. The `generatedSlices` are used to find the next
producers to fuse, so it caused an infinite loop of fusing the same
unpack op after it was already in the loop. This fixes the bug by adding
the slice of the source instead of the result.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Oct 25, 2024
1 parent 6ee5ff9 commit f1595ec
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,14 @@ struct UnPackOpTiling
sliceSrcIndices.append(numInnerTiles, zeroAttr);
sliceSrcSizes.append(unpackOp.getMixedTiles());
sliceSrcStrides.append(numInnerTiles, oneAttr);
Value sliceSource =
SmallVector<Operation *> generatedSlices;
ExtractSliceOp sliceSource =
b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
sliceSrcSizes, sliceSrcStrides);
generatedSlices.push_back(sliceSource);

SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
Value sliceDest;
SmallVector<Operation *> generatedSlices;
if (isPerfectTilingCase) {
auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
offsets, sizes, destStrides);
Expand All @@ -571,7 +572,7 @@ struct UnPackOpTiling
unpackOp.getDestType().getElementType());
}

SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
for (auto tile : unpackOp.getInnerTiles())
tiledOperands.push_back(tile);

Expand All @@ -586,7 +587,6 @@ struct UnPackOpTiling
auto extractSlice =
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
resultOffsetsFromDest, sizes, destStrides);
generatedSlices.push_back(extractSlice);
return TilingResult{
{tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
// CHECK: scf.yield %[[INSERT_SLICE]]
// CHECK: return %[[FOR_RESULT]]

// -----

func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
%0 = tensor.unpack %source
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [8, 4] into %dest
: tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
%1 = tensor.empty() : tensor<1x2x1152xf32>
%cst = arith.constant 1.0 : f32
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0 : tensor<1x2x1152xf32>)
outs(%1 : tensor<1x2x1152xf32>) {
^bb0(%in: f32, %out: f32):
%7 = arith.addf %in, %cst : f32
linalg.yield %7 : f32
} -> tensor<1x2x1152xf32>
return %2 : tensor<1x2x1152xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.structured.fuse %matmul [0, 1, 0]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// CHECK-LABEL: func @imperfect_unpack_producer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2x1152xf32>
// CHECK: %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
// CHECK-DAG: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[UNPACK_SLICE]]
// CHECK-SAME: outs(%[[INIT_SLICE]]
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
// CHECK: scf.yield %[[INSERT_SLICE]]
// CHECK: return %[[FOR_RESULT]]

0 comments on commit f1595ec

Please sign in to comment.