@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
911911
912912// -----
913913
914+ // CHECK-LABEL: func @reinterpret_constant_fold
915+ // CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
916+ // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
917+ // CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
918+ // CHECK: return %[[CAST]]
919+ func.func @reinterpret_constant_fold (%arg0: memref <f32 >) -> memref <?x?xf32 , strided <[?, ?], offset : ?>> {
920+ %c0 = arith.constant 0 : index
921+ %c1 = arith.constant 1 : index
922+ %c100 = arith.constant 100 : index
923+ %reinterpret_cast = memref.reinterpret_cast %arg0 to offset : [%c0 ], sizes : [%c100 , %c100 ], strides : [%c100 , %c1 ] : memref <f32 > to memref <?x?xf32 , strided <[?, ?], offset : ?>>
924+ return %reinterpret_cast : memref <?x?xf32 , strided <[?, ?], offset : ?>>
925+ }
926+
927+ // -----
928+
914929// CHECK-LABEL: func @reinterpret_of_reinterpret
915930// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
916931// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
9961011// when the strides don't match.
9971012// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
9981013// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
999- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1000- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
1001- // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
1002- // CHECK: return %[[RES]]
1014+ // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
1015+ // CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
1016+ // CHECK: return %[[CAST]]
10031017func.func @reinterpret_of_extract_strided_metadata_w_different_stride (%arg0 : memref <8 x2 xf32 >) -> memref <?x?x?xf32 , strided <[?, ?, ?], offset : ?>> {
10041018 %base , %offset , %sizes:2 , %strides:2 = memref.extract_strided_metadata %arg0 : memref <8 x2 xf32 > -> memref <f32 >, index , index , index , index , index
10051019 %m2 = memref.reinterpret_cast %base to offset : [%offset ], sizes : [4 , 2 , 2 ], strides : [1 , 1 , %strides#1 ] : memref <f32 > to memref <?x?x?xf32 , strided <[?, ?, ?], offset : ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
10111025// when the offset doesn't match.
10121026// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
10131027// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
1014- // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
1015- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
1016- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
1017- // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
1018- // CHECK: return %[[RES]]
1028+ // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
1029+ // CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
1030+ // CHECK: return %[[CAST]]
10191031func.func @reinterpret_of_extract_strided_metadata_w_different_offset (%arg0 : memref <8 x2 xf32 >) -> memref <?x?xf32 , strided <[?, ?], offset : ?>> {
10201032 %base , %offset , %sizes:2 , %strides:2 = memref.extract_strided_metadata %arg0 : memref <8 x2 xf32 > -> memref <f32 >, index , index , index , index , index
10211033 %m2 = memref.reinterpret_cast %base to offset : [1 ], sizes : [%sizes#0 , %sizes#1 ], strides : [%strides#0 , %strides#1 ] : memref <f32 > to memref <?x?xf32 , strided <[?, ?], offset : ?>>
0 commit comments