Skip to content

Commit bce951c

Browse files
authored
[mlir][linalg] Update vectorization logic for linalg.unpack (#149156)
This PR makes sure that we don't generate unnecessary `tensor.empty` when vectorizing `linalg.unpack`. To better visualize the changes implemented here, consider this IR: ```mlir func.func @example( %source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> { %res = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32> return %res : tensor<64x127xf32> } ``` Below is the output after vectorization, BEFORE and AFTER this PR. BEFORE (note `tensor.empty` and the fact that `%arg1` is not used): ```mlir func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32> %1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32> %2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32> %3 = tensor.empty() : tensor<64x127xf32> %c0_0 = arith.constant 0 : index %4 = vector.transfer_write %2, %3[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32> return %4 : tensor<64x127xf32> } ``` AFTER (note that `%arg1` is correctly used): ```mlir func.func @example(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<64x127xf32>) -> tensor<64x127xf32> { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32> %1 = vector.transpose %0, [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32> %2 = vector.shape_cast %1 : vector<4x16x8x16xf32> to vector<64x128xf32> %c0_0 = arith.constant 0 : index %3 = vector.transfer_write %2, %arg1[%c0_0, %c0_0] {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32> return %3 : tensor<64x127xf32> } ```
1 parent ace6e20 commit bce951c

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,11 +1928,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19281928
unpackOp.getDestType().hasStaticShape()
19291929
? vectorSizes
19301930
: shapeCastOp.getResultVectorType().getShape());
1931-
Value dest = rewriter.create<tensor::EmptyOp>(
1932-
loc, reifiedRetShapes[0],
1933-
shapeCastOp.getResult().getType().getElementType());
19341931
Operation *write = createWriteOrMaskedWrite(
1935-
rewriter, loc, shapeCastOp.getResult(), dest,
1932+
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
19361933
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
19371934
newResults.push_back(write->getResult(0));
19381935
return success();

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,7 @@ module attributes {transform.with_named_sequence} {
11581158
// -----
11591159

11601160
// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
1161+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
11611162
func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
11621163
// CHECK: %[[C0:.*]] = arith.constant 0
11631164
// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
@@ -1175,9 +1176,8 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
11751176
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
11761177
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
11771178
// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
1178-
// CHECK: %[[empt0:.*]] = tensor.empty
11791179
// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
1180-
// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
1180+
// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[ARG_0]]
11811181
// CHECK: return %[[write0]]
11821182
%ret = linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
11831183
return %ret : tensor<?x?xf32>
@@ -1193,6 +1193,8 @@ module attributes {transform.with_named_sequence} {
11931193
// -----
11941194

11951195
// CHECK-LABEL: func @test_vectorize_unpack
1196+
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
1197+
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
11961198
func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
11971199
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
11981200
// CHECK: %[[C0:.*]]= arith.constant 0 : index
@@ -1201,15 +1203,14 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
12011203
// CHECK: %[[C32:.*]] = arith.constant 32 : index
12021204
// CHECK: %[[C16:.*]] = arith.constant 16 : index
12031205
// CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
1204-
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
1206+
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] { vector.transfer_read %[[SRC]]{{.*}}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
12051207
// CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
12061208
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
1207-
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
12081209
// CHECK: %[[C01:.*]] = arith.constant 0 : index
12091210
// CHECK: %[[C256:.*]] = arith.constant 256 : index
12101211
// CHECK: %[[C128:.*]] = arith.constant 128 : index
12111212
// CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
1212-
// CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
1213+
// CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] { vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<512x128xi1> -> tensor<256x128xf32>
12131214
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
12141215
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
12151216
return %0 : tensor<256x128xf32>
@@ -1225,15 +1226,16 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
12251226
// -----
12261227

12271228
// CHECK-LABEL: func @test_vectorize_unpack_no_masks
1229+
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
1230+
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
12281231
func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
12291232
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
12301233
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1231-
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
1234+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
12321235
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
12331236
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
1234-
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
12351237
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1236-
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
1238+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32>
12371239
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
12381240
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
12391241
return %0 : tensor<256x128xf32>
@@ -1248,16 +1250,17 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
12481250

12491251
// -----
12501252

1251-
// CHECK-LABEL: test_vectorize_unpack_with_outer_perm
1253+
// CHECK-LABEL: test_vectorize_unpack_with_outer_perm
1254+
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
1255+
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
12521256
func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
12531257
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
12541258
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1255-
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
1259+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
12561260
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
12571261
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
1258-
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
12591262
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1260-
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
1263+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32>
12611264
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
12621265
%0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
12631266
return %0 : tensor<256x128xf32>
@@ -1327,15 +1330,17 @@ module attributes {transform.with_named_sequence} {
13271330

13281331
// -----
13291332

1333+
// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes
1334+
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
1335+
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
13301336
func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
13311337
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
13321338
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1333-
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
1339+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
13341340
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
13351341
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
1336-
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
13371342
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1338-
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
1343+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32>
13391344
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
13401345
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
13411346
return %0 : tensor<256x128xf32>
@@ -1350,15 +1355,17 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>,
13501355

13511356
// -----
13521357

1358+
// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_slice_output
1359+
// CHECK-SAME: %[[SRC:.*]]: tensor<8x4x16x16xf32>
1360+
// CHECK-SAME: %[[DEST:.*]]: tensor<64x127xf32>
13531361
func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> {
13541362
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
13551363
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1356-
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
1364+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
13571365
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
13581366
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32>
1359-
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x127xf32>
13601367
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1361-
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[EMPT]]{{\[}}%[[C00]], %[[C00]]]
1368+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]
13621369
// CHECK-SAME: {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
13631370
// CHECK: return %[[WRIT]] : tensor<64x127xf32>
13641371
%0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32>
@@ -1374,18 +1381,20 @@ func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x
13741381

13751382
// -----
13761383

1384+
// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_permute
1385+
// CHECK-SAME: %[[SRC:.*]]: tensor<4x7x4xf32>
1386+
// CHECK-SAME: %[[DEST:.*]]: tensor<7x16xf32>
13771387
func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> {
13781388
%0 = linalg.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
13791389
return %0 : tensor<7x16xf32>
13801390
}
13811391
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
13821392
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1383-
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
1393+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
13841394
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32>
13851395
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32>
1386-
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32>
13871396
// CHECK: %[[C00:.*]] = arith.constant 0 : index
1388-
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32>
1397+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<7x16xf32>, tensor<7x16xf32>
13891398
// CHECK: return %[[WRIT]] : tensor<7x16xf32>
13901399
module attributes {transform.with_named_sequence} {
13911400
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)