Skip to content

Commit

Permalink
dialects: (linalg) add correct generic printing for fill, and matmul (#…
Browse files Browse the repository at this point in the history
…2971)

I happened to need the fill operator and decided to do the matmul as
well. This continues the work of #2959
I also made sure that the code works on more than just floating points
values.
I moved some duplicate code to determine the type arguments of the
hidden region into a common method of the `NamedOp` base class.
I also used the implicit builder to build the regions as I think this is
much more readable.

Doing the matmul was a bit more tricky:

- for some reason this op prints the attributes in front of the `ins`
and `outs`
- generic from includes the `linalg.memoized_indexing_maps` attribute,
which are the indexing maps that would be generated when converting to a
linalg
  • Loading branch information
jorendumoulin authored Aug 1, 2024
1 parent edeae99 commit 3543d82
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 171 deletions.
18 changes: 12 additions & 6 deletions tests/filecheck/dialects/csl/csl-stencil-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ builtin.module {
// CHECK-GENERIC-NEXT: %16 = "arith.addf"(%15, %6) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %17 = "arith.addf"(%16, %5) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %18 = "tensor.empty"() : () -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %19 = "linalg.fill"(%4, %18) <{"operandSegmentSizes" = array<i32: 1, 1>}> : (f32, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %20 = "arith.mulf"(%17, %19) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: "stencil.return"(%20) : (tensor<510xf32>) -> ()
// CHECK-GENERIC-NEXT: %19 = "linalg.fill"(%4, %18) <{"operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^2(%20 : f32, %21 : f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%20) : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (f32, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %22 = "arith.mulf"(%17, %19) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: "stencil.return"(%22) : (tensor<510xf32>) -> ()
// CHECK-GENERIC-NEXT: }) : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, tensor<4x510xf32>) -> !stencil.temp<[0,1]x[0,1]xtensor<510xf32>>
// CHECK-GENERIC-NEXT: "stencil.store"(%1, %b) {"bounds" = #stencil.bounds<[0, 0], [1, 1]>} : (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
Expand Down Expand Up @@ -195,9 +198,12 @@ builtin.module {
// CHECK-GENERIC-NEXT: %17 = "arith.addf"(%16, %15) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %18 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32
// CHECK-GENERIC-NEXT: %19 = "tensor.empty"() : () -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %20 = "linalg.fill"(%18, %19) <{"operandSegmentSizes" = array<i32: 1, 1>}> : (f32, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %21 = "arith.mulf"(%17, %20) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%21) : (tensor<510xf32>) -> ()
// CHECK-GENERIC-NEXT: %20 = "linalg.fill"(%18, %19) <{"operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^3(%21 : f32, %22 : f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%21) : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (f32, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: %23 = "arith.mulf"(%17, %20) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%23) : (tensor<510xf32>) -> ()
// CHECK-GENERIC-NEXT: }) : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, tensor<510xf32>) -> !stencil.temp<[0,1]x[0,1]xtensor<510xf32>>
// CHECK-GENERIC-NEXT: "stencil.store"(%2, %b) {"bounds" = #stencil.bounds<[0, 0], [1, 1]>} : (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
Expand Down
118 changes: 73 additions & 45 deletions tests/filecheck/dialects/linalg/linalg_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,69 +28,97 @@ linalg.add ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref<
%mul = linalg.mul ins(%t1, %t2 : tensor<4x16xf32>, tensor<4x16xf32>) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32>
linalg.mul ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref<4x16xf32>)


%2, %3 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
%4 = "test.op"() : () -> (memref<64x4096xf32>)
linalg.matmul {id} ins(%2, %3 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%4 : memref<64x4096xf32>)

%fill = linalg.fill ins(%0 : f32) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32>
linalg.fill ins(%0 : f32) outs(%m3 : memref<4x16xf32>)

// CHECK: module {
// CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, memref<1x256xf32>)
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : f32) outs(%1 : memref<1x256xf32>) {
// CHECK-NEXT: ^0(%{{.*}}: f32, %{{.*}}: f32):
// CHECK-NEXT: %{{.*}} %{{.*}} = "test.op"() : () -> (f32, memref<1x256xf32>)
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f32) outs(%{{.*}} : memref<1x256xf32>) {
// CHECK-NEXT: ^0(%{{.*}} f32, %{{.*}} f32):
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: }
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], doc = "a_docstring", library_call = "a_library_call"} ins(%0 : f32) outs(%1 : memref<1x256xf32>) {
// CHECK-NEXT: ^1(%arg3_1 : f32, %arg4_1 : f32):
// CHECK-NEXT: linalg.yield %arg3_1 : f32
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], doc = "a_docstring", library_call = "a_library_call"} ins(%{{.*}} : f32) outs(%{{.*}} : memref<1x256xf32>) {
// CHECK-NEXT: ^1(%{{.*}} : f32, %{{.*}} : f32):
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: }
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : f32) outs(%1 : memref<1x256xf32>) attrs = {"hello" = "world"} {
// CHECK-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f32) outs(%{{.*}} : memref<1x256xf32>) attrs = {"hello" = "world"} {
// CHECK-NEXT: ^{{.*}}(%{{.*}} f32, %{{.*}} f32):
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: }
// CHECK-NEXT: %t1, %t2, %t3 = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>)
// CHECK-NEXT: %m1, %m2, %m3 = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>)
// CHECK-NEXT: %sum = linalg.add ins(%t1, %t2 : tensor<4x16xf32>, tensor<4x16xf32>) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: linalg.add ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref<4x16xf32>)
// CHECK-NEXT: %mul = linalg.mul ins(%t1, %t2 : tensor<4x16xf32>, tensor<4x16xf32>) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: linalg.mul ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref<4x16xf32>)
// CHECK-NEXT: %2, %3 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-NEXT: %4 = "test.op"() : () -> memref<64x4096xf32>
// CHECK-NEXT: linalg.matmul {"id"} ins(%2, %3 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%4 : memref<64x4096xf32>)
// CHECK-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>)
// CHECK-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>)
// CHECK-NEXT: %{{.*}} = linalg.add ins(%{{.*}} %{{.*}} : tensor<4x16xf32>, tensor<4x16xf32>) outs(%{{.*}} : tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: linalg.add ins(%{{.*}} %{{.*}} : memref<4x16xf32>, memref<4x16xf32>) outs(%{{.*}} : memref<4x16xf32>)
// CHECK-NEXT: %{{.*}} = linalg.mul ins(%{{.*}} %{{.*}} : tensor<4x16xf32>, tensor<4x16xf32>) outs(%{{.*}} : tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: linalg.mul ins(%{{.*}} %{{.*}} : memref<4x16xf32>, memref<4x16xf32>) outs(%{{.*}} : memref<4x16xf32>)
// CHECK-NEXT: %{{.*}} %{{.*}} = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-NEXT: %{{.*}} = "test.op"() : () -> memref<64x4096xf32>
// CHECK-NEXT: linalg.matmul {"id"} ins(%{{.*}} %{{.*}} : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%{{.*}} : memref<64x4096xf32>)
// CHECK-NEXT: %{{.*}} = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<4x16xf32>)
// CHECK-NEXT: }

// CHECK-GENERIC: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^0(%{{.*}}: f32, %{{.*}}: f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f32) -> ()
// CHECK-GENERIC: "linalg.generic"(%{{.*}} %{{.*}} <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^0(%{{.*}} f32, %{{.*}} f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (f32, memref<1x256xf32>) -> ()
// CHECK-GENERIC-NEXT: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "doc" = "a_docstring", "library_call" = "a_library_call", "operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^1(%arg3_1 : f32, %arg4_1 : f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%arg3_1) : (f32) -> ()
// CHECK-GENERIC-NEXT: "linalg.generic"(%{{.*}} %{{.*}} <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "doc" = "a_docstring", "library_call" = "a_library_call", "operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^1(%{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (f32, memref<1x256xf32>) -> ()

// CHECK-GENERIC: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f32) -> ()
// CHECK-GENERIC: "linalg.generic"(%{{.*}} %{{.*}} <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}} f32, %{{.*}} f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) {"hello" = "world"} : (f32, memref<1x256xf32>) -> ()

// CHECK-GENERIC-NEXT: %t1, %t2, %t3 = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>)
// CHECK-GENERIC-NEXT: %m1, %m2, %m3 = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>)
// CHECK-GENERIC-NEXT: %sum = "linalg.add"(%t1, %t2, %t3) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^3(%2 : f32, %3 : f32, %4 : f32):
// CHECK-GENERIC-NEXT: %5 = "arith.addf"(%2, %3) : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%5) : (f32) -> ()
// CHECK-GENERIC-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>)
// CHECK-GENERIC-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>)

// CHECK-GENERIC-NEXT: %{{.*}} = "linalg.add"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^3(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}} %{{.*}} : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-GENERIC-NEXT: "linalg.add"(%m1, %m2, %m3) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^4(%6 : f32, %7 : f32, %8 : f32):
// CHECK-GENERIC-NEXT: %9 = "arith.addf"(%6, %7) : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%9) : (f32) -> ()

// CHECK-GENERIC-NEXT: "linalg.add"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^4(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}} %{{.*}} : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) -> ()
// CHECK-GENERIC-NEXT: %mul = "linalg.mul"(%t1, %t2, %t3) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^5(%10 : f32, %11 : f32, %12 : f32):
// CHECK-GENERIC-NEXT: %13 = "arith.mulf"(%10, %11) : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%13) : (f32) -> ()

// CHECK-GENERIC-NEXT: %{{.*}} = "linalg.mul"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^5(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}} %{{.*}} : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-GENERIC-NEXT: "linalg.mul"(%m1, %m2, %m3) <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^6(%14 : f32, %15 : f32, %16 : f32):
// CHECK-GENERIC-NEXT: %17 = "arith.mulf"(%14, %15) : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%17) : (f32) -> ()

// CHECK-GENERIC-NEXT: "linalg.mul"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^6(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}} %{{.*}} : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) -> ()

// CHECK-GENERIC-NEXT: %{{.*}} %{{.*}} = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-GENERIC-NEXT: %{{.*}} = "test.op"() : () -> memref<64x4096xf32>

// CHECK-GENERIC-NEXT: "linalg.matmul"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array<i32: 2, 1>}> ({
// CHECK-GENERIC-NEXT: ^7(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}} : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}} : (f32, f32) -> f32
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) {"id", "linalg.memoized_indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]} : (memref<64x9216xf32>, memref<9216x4096xf32>, memref<64x4096xf32>) -> ()

// CHECK-GENERIC-NEXT: %{{.*}} = "linalg.fill"(%{{.*}}, %{{.*}} <{"operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^8(%{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (f32, tensor<4x16xf32>) -> tensor<4x16xf32>

// CHECK-GENERIC-NEXT: "linalg.fill"(%{{.*}}, %{{.*}} <{"operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^9(%{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (f32, memref<4x16xf32>) -> ()
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ builtin.module {
%0, %1 = "test.op"() : () -> (tensor<2x3xf32>, tensor<2x3xf32>)

// CHECK: Input type is tensor<2x3xf32> but must be an instance of AnyFloat or IntegerType.
%res_fill = "linalg.fill"(%0, %1) <{"operandSegmentSizes" = array<i32: 1, 1>}> : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%res_fill = linalg.fill ins (%0: tensor<2x3xf32>) outs (%1: tensor<2x3xf32>) -> tensor<2x3xf32>

}

Expand Down
7 changes: 2 additions & 5 deletions tests/interpreters/test_linalg_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import cast

import pytest

from xdsl.builder import ImplicitBuilder
Expand All @@ -22,7 +20,7 @@
from xdsl.interpreters.linalg import LinalgFunctions
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.ir import Attribute, Block, Region
from xdsl.ir import Block, Region
from xdsl.ir.affine import AffineExpr, AffineMap
from xdsl.utils.test_value import TestSSAValue

Expand Down Expand Up @@ -199,9 +197,8 @@ def test_fill_op():
interpreter.register_implementations(ArithFunctions())
interpreter.register_implementations(LinalgFunctions())
constant = arith.Constant(FloatAttr(1.0, f32))
constant = cast(Attribute, constant)
op = linalg.FillOp(
(TestSSAValue(constant),),
(TestSSAValue(constant.result.type),),
(TestSSAValue(TensorType(f32, [2, 3])),),
(TensorType(f32, [2, 3]),),
)
Expand Down
Loading

0 comments on commit 3543d82

Please sign in to comment.