@@ -266,3 +266,131 @@ func.func @unpack_16x1x1x2_to_32x1(%arg0 : tensor<16x1x1x2xf32>) -> tensor<32x1x
266266 : tensor <16 x1 x1 x2 xf32 > -> tensor <32 x1 xf32 >
267267 return %unpack : tensor <32 x1 xf32 >
268268}
269+
270+ // -----
271+
272+ // CHECK-LABEL: func.func @pad_like_pack(
273+ // CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
274+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
275+ // CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32>
276+ func.func @pad_like_pack (%arg0: tensor <32 x64 xf32 >) -> tensor <1 x1 x32 x64 xf32 > {
277+ %empty = tensor.empty () : tensor <1 x1 x32 x64 xf32 >
278+ %0 = tensor.pack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 64 ] into %empty : tensor <32 x64 xf32 > -> tensor <1 x1 x32 x64 xf32 >
279+ return %0 : tensor <1 x1 x32 x64 xf32 >
280+ }
281+
282+ // -----
283+
284+ // CHECK-LABEL: func.func @pad_like_pack_with_outer_dims_perm(
285+ // CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
286+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 1, 32, 64] : tensor<32x64xf32> into tensor<1x1x32x64xf32>
287+ // CHECK: return %[[EXPANDED]] : tensor<1x1x32x64xf32>
288+ func.func @pad_like_pack_with_outer_dims_perm (%arg0: tensor <32 x64 xf32 >) -> tensor <1 x1 x32 x64 xf32 > {
289+ %empty = tensor.empty () : tensor <1 x1 x32 x64 xf32 >
290+ %0 = tensor.pack %arg0 outer_dims_perm = [1 , 0 ] inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 64 ] into %empty : tensor <32 x64 xf32 > -> tensor <1 x1 x32 x64 xf32 >
291+ return %0 : tensor <1 x1 x32 x64 xf32 >
292+ }
293+
294+ // -----
295+
296+ // CHECK-LABEL: func.func @inner_pad_like_pack(
297+ // CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
298+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [32, 1, 64] : tensor<32x64xf32> into tensor<32x1x64xf32>
299+ // CHECK: return %[[EXPANDED]] : tensor<32x1x64xf32>
300+ func.func @inner_pad_like_pack (%arg0: tensor <32 x64 xf32 >) -> tensor <32 x1 x64 xf32 > {
301+ %empty = tensor.empty () : tensor <32 x1 x64 xf32 >
302+ %0 = tensor.pack %arg0 inner_dims_pos = [1 ] inner_tiles = [64 ] into %empty : tensor <32 x64 xf32 > -> tensor <32 x1 x64 xf32 >
303+ return %0 : tensor <32 x1 x64 xf32 >
304+ }
305+
306+ // -----
307+
308+ // Do not simplify pack with inner dimension shuffling.
309+ // CHECK-LABEL: func.func @pad_and_inner_dim_shuffle_pack(
310+ // CHECK-SAME: %[[ARG0:.+]]: tensor<32x64xf32>)
311+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x64x32xf32>
312+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [64, 32] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<1x1x64x32xf32>
313+ // CHECK: return %[[PACK]] : tensor<1x1x64x32xf32>
314+ func.func @pad_and_inner_dim_shuffle_pack (%arg0: tensor <32 x64 xf32 >) -> tensor <1 x1 x64 x32 xf32 > {
315+ %empty = tensor.empty () : tensor <1 x1 x64 x32 xf32 >
316+ %0 = tensor.pack %arg0 inner_dims_pos = [1 , 0 ] inner_tiles = [64 , 32 ] into %empty : tensor <32 x64 xf32 > -> tensor <1 x1 x64 x32 xf32 >
317+ return %0 : tensor <1 x1 x64 x32 xf32 >
318+ }
319+
320+ // -----
321+
322+ // Do not simplify pack with inner dimension transpose.
323+ // CHECK-LABEL: func.func @pad_like_pack_with_transpose(
324+ // CHECK-SAME: %[[ARG0:.+]]: tensor<32x64x16xf32>)
325+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x1x16x64xf32>
326+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [64] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<32x1x16x64xf32>
327+ // CHECK: return %[[PACK]] : tensor<32x1x16x64xf32>
328+ func.func @pad_like_pack_with_transpose (%arg0: tensor <32 x64 x16 xf32 >) -> tensor <32 x1 x16 x64 xf32 > {
329+ %empty = tensor.empty () : tensor <32 x1 x16 x64 xf32 >
330+ %0 = tensor.pack %arg0 inner_dims_pos = [1 ] inner_tiles = [64 ] into %empty : tensor <32 x64 x16 xf32 > -> tensor <32 x1 x16 x64 xf32 >
331+ return %0 : tensor <32 x1 x16 x64 xf32 >
332+ }
333+
334+ // -----
335+
336+ // CHECK-LABEL: func.func @unpad_like_unpack(
337+ // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
338+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
339+ // CHECK: return %[[COLLAPSED]] : tensor<32x64xf32>
340+ func.func @unpad_like_unpack (%arg0: tensor <1 x1 x32 x64 xf32 >) -> tensor <32 x64 xf32 > {
341+ %empty = tensor.empty () : tensor <32 x64 xf32 >
342+ %0 = tensor.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 64 ] into %empty : tensor <1 x1 x32 x64 xf32 > -> tensor <32 x64 xf32 >
343+ return %0 : tensor <32 x64 xf32 >
344+ }
345+
346+ // -----
347+
348+ // CHECK-LABEL: func.func @unpad_like_unpack_with_outer_dims_perm(
349+ // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
350+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x32x64xf32> into tensor<32x64xf32>
351+ // CHECK: return %[[COLLAPSED]] : tensor<32x64xf32>
352+ func.func @unpad_like_unpack_with_outer_dims_perm (%arg0: tensor <1 x1 x32 x64 xf32 >) -> tensor <32 x64 xf32 > {
353+ %empty = tensor.empty () : tensor <32 x64 xf32 >
354+ %0 = tensor.unpack %arg0 outer_dims_perm = [1 , 0 ] inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 64 ] into %empty : tensor <1 x1 x32 x64 xf32 > -> tensor <32 x64 xf32 >
355+ return %0 : tensor <32 x64 xf32 >
356+ }
357+
358+ // -----
359+
360+ // CHECK-LABEL: func.func @inner_unpad_like_unpack(
361+ // CHECK-SAME: %[[ARG0:.+]]: tensor<32x1x64xf32>)
362+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<32x1x64xf32> into tensor<32x64xf32>
363+ // CHECK: return %[[COLLAPSED]] : tensor<32x64xf32>
364+ func.func @inner_unpad_like_unpack (%arg0: tensor <32 x1 x64 xf32 >) -> tensor <32 x64 xf32 > {
365+ %empty = tensor.empty () : tensor <32 x64 xf32 >
366+ %0 = tensor.unpack %arg0 inner_dims_pos = [1 ] inner_tiles = [64 ] into %empty : tensor <32 x1 x64 xf32 > -> tensor <32 x64 xf32 >
367+ return %0 : tensor <32 x64 xf32 >
368+ }
369+
370+ // -----
371+
372+ // Do not simplify unpack with inner dimension shuffling.
373+ // CHECK-LABEL: func.func @unpad_and_inner_dim_shuffle_pack(
374+ // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x32x64xf32>)
375+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64x32xf32>
376+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [32, 64] into %[[EMPTY]] : tensor<1x1x32x64xf32> -> tensor<64x32xf32>
377+ // CHECK: return %[[UNPACK]] : tensor<64x32xf32>
378+ func.func @unpad_and_inner_dim_shuffle_pack (%arg0: tensor <1 x1 x32 x64 xf32 >) -> tensor <64 x32 xf32 > {
379+ %empty = tensor.empty () : tensor <64 x32 xf32 >
380+ %0 = tensor.unpack %arg0 inner_dims_pos = [1 , 0 ] inner_tiles = [32 , 64 ] into %empty : tensor <1 x1 x32 x64 xf32 > -> tensor <64 x32 xf32 >
381+ return %0 : tensor <64 x32 xf32 >
382+ }
383+
384+ // -----
385+
386+ // Do not simplify unpack with inner dimension transpose.
387+ // CHECK-LABEL: func.func @unpad_like_unpack_with_transpose(
388+ // CHECK-SAME: %[[ARG0:.+]]: tensor<32x1x16x64xf32>)
389+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x64x16xf32>
390+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [64] into %[[EMPTY]] : tensor<32x1x16x64xf32> -> tensor<32x64x16xf32>
391+ // CHECK: return %[[UNPACK]] : tensor<32x64x16xf32>
392+ func.func @unpad_like_unpack_with_transpose (%arg0: tensor <32 x1 x16 x64 xf32 >) -> tensor <32 x64 x16 xf32 > {
393+ %empty = tensor.empty () : tensor <32 x64 x16 xf32 >
394+ %0 = tensor.unpack %arg0 inner_dims_pos = [1 ] inner_tiles = [64 ] into %empty : tensor <32 x1 x16 x64 xf32 > -> tensor <32 x64 x16 xf32 >
395+ return %0 : tensor <32 x64 x16 xf32 >
396+ }
0 commit comments