From 58e2c361b9bd8497330e2d33b0a4305e6c80e52b Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Thu, 21 Nov 2024 11:23:56 +0000 Subject: [PATCH 1/4] dialects: (linalg) linalg.fill attribute positioning --- .../with-mlir/dialects/linalg/ops.mlir | 28 +++++++++++-------- xdsl/dialects/linalg.py | 2 ++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir index 3a62737b2c..e66daaba59 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir @@ -58,6 +58,9 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>) %18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) %20 = "test.op"() : () -> (memref<64x4096xf32>) +%zero = arith.constant 0: f32 + +linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>) linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>) @@ -69,7 +72,7 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou %quant_mat_mul = linalg.quantized_matmul ins(%21, %22, %23, %24 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%25 : tensor<64x4096xi32>) -> tensor<64x4096xi32> -// CHECK-NEXT: #map = affine_map<(d0, d1) -> ()> +// CHECK: #map = affine_map<(d0, d1) -> ()> // CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-NEXT: module { // CHECK-NEXT: %0:2 = "test.op"() : () -> (f32, memref<1x256xf32>) @@ -86,9 +89,9 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %3 = linalg.fill ins(%cst : f32) outs(%1#0 : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: linalg.fill ins(%cst : f32) outs(%0#1 : memref<1x256xf32>) -// CHECK-NEXT: %4 = linalg.mul ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#1 : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %4 = linalg.mul ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#1 : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %5:2 = "test.op"() : () -> (tensor<16x64xf32>, tensor<64x16xf32>) -// CHECK-NEXT: %transposed = linalg.transpose ins(%5#0 : tensor<16x64xf32>) outs(%5#1 : tensor<64x16xf32>) permutation = [1, 0] +// CHECK-NEXT: %transposed = linalg.transpose ins(%5#0 : tensor<16x64xf32>) outs(%5#1 : tensor<64x16xf32>) permutation = [1, 0] // CHECK-NEXT: %6:2 = "test.op"() : () -> (tensor<64x9216xf32>, tensor<9216x4096xf32>) // CHECK-NEXT: %7 = "test.op"() : () -> tensor<64x4096xf32> // CHECK-NEXT: %8 = linalg.matmul ins(%6#0, %6#1 : tensor<64x9216xf32>, tensor<9216x4096xf32>) outs(%7 : tensor<64x4096xf32>) -> tensor<64x4096xf32> @@ -97,19 +100,22 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %11:3 = "test.op"() : () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>) // CHECK-NEXT: %12 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%11#0, %11#1 : tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>) outs(%11#2 : tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> // CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>) -// CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1] -// CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): -// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32 -// CHECK-NEXT: linalg.yield %{{.*}} : f32 +// CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1] +// CHECK-NEXT: %14 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) { +// CHECK-NEXT: ^bb0(%in: f32, %in_2: f32, %out: f32): +// CHECK-NEXT: %21 = arith.addf %in, %in_2 : f32 +// CHECK-NEXT: linalg.yield %21 : f32 // CHECK-NEXT: } -> tensor<2x3xf32> -// CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %15 = linalg.sub ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#1 : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) // CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32> +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill {id} ins(%cst_0 : f32) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>) // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 +// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 // CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32> -// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> +// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> // CHECK-NEXT: } +// CHECK-NEXT: diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 854ed50c86..dd1097d67e 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -642,6 +642,8 @@ class FillOp(NamedOpBase): name = "linalg.fill" + PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True + def __init__( self, inputs: Sequence[SSAValue], From 512bfc97deda96d1719f9f8d4fe7e47af1dd7927 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Thu, 21 Nov 2024 22:53:26 +0000 Subject: [PATCH 2/4] More conservative rewrite --- .../mlir-conversion/with-mlir/dialects/linalg/ops.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir index e66daaba59..01bc06a7e6 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir @@ -72,7 +72,7 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou %quant_mat_mul = linalg.quantized_matmul ins(%21, %22, %23, %24 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%25 : tensor<64x4096xi32>) -> tensor<64x4096xi32> -// CHECK: #map = affine_map<(d0, d1) -> ()> +// CHECK-NEXT: #map = affine_map<(d0, d1) -> ()> // CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-NEXT: module { // CHECK-NEXT: %0:2 = "test.op"() : () -> (f32, memref<1x256xf32>) @@ -89,9 +89,9 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %3 = linalg.fill ins(%cst : f32) outs(%1#0 : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: linalg.fill ins(%cst : f32) outs(%0#1 : memref<1x256xf32>) -// CHECK-NEXT: %4 = linalg.mul ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#1 : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %4 = linalg.mul ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#1 : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %5:2 = "test.op"() : () -> (tensor<16x64xf32>, tensor<64x16xf32>) -// CHECK-NEXT: %transposed = linalg.transpose ins(%5#0 : tensor<16x64xf32>) outs(%5#1 : tensor<64x16xf32>) permutation = [1, 0] +// CHECK-NEXT: %transposed = linalg.transpose ins(%5#0 : tensor<16x64xf32>) outs(%5#1 : tensor<64x16xf32>) permutation = [1, 0] // CHECK-NEXT: %6:2 = "test.op"() : () -> (tensor<64x9216xf32>, tensor<9216x4096xf32>) // CHECK-NEXT: %7 = "test.op"() : () -> tensor<64x4096xf32> // CHECK-NEXT: %8 = linalg.matmul ins(%6#0, %6#1 : tensor<64x9216xf32>, tensor<9216x4096xf32>) outs(%7 : tensor<64x4096xf32>) -> tensor<64x4096xf32> From 10a895724a747cc779284c5d21d6ec00cb1b9ff4 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Thu, 21 Nov 2024 22:56:53 +0000 Subject: [PATCH 3/4] Revert the test rewrite --- .../with-mlir/dialects/linalg/ops.mlir | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir index 01bc06a7e6..3a62737b2c 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir @@ -58,9 +58,6 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>) %18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) %20 = "test.op"() : () -> (memref<64x4096xf32>) -%zero = arith.constant 0: f32 - -linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>) linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>) @@ -100,22 +97,19 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %11:3 = "test.op"() : () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>) // CHECK-NEXT: %12 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%11#0, %11#1 : tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>) outs(%11#2 : tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> // CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>) -// CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1] -// CHECK-NEXT: %14 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_2: f32, %out: f32): -// CHECK-NEXT: %21 = arith.addf %in, %in_2 : f32 -// CHECK-NEXT: linalg.yield %21 : f32 +// CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1] +// CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) { +// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): +// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32 +// CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } -> tensor<2x3xf32> -// CHECK-NEXT: %15 = linalg.sub ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#1 : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) // CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32> -// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: linalg.fill {id} ins(%cst_0 : f32) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>) // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 +// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 // CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32> -// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> +// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> // CHECK-NEXT: } -// CHECK-NEXT: From 0f06b170e310f207454f57bfacefe4194ec08f34 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Thu, 21 Nov 2024 23:05:09 +0000 Subject: [PATCH 4/4] Correction of the test --- .../with-mlir/dialects/linalg/ops.mlir | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir index 3a62737b2c..9d942b95d9 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir @@ -59,6 +59,9 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>) %18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) %20 = "test.op"() : () -> (memref<64x4096xf32>) +%zero = arith.constant 0: f32 +linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>) + linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>) @@ -99,17 +102,19 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>) // CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1] // CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): -// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32 +// CHECK-NEXT: ^bb0(%in: f32, %in_2: f32, %out: f32): +// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_2 : f32 // CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } -> tensor<2x3xf32> // CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) // CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32> +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill {id} ins(%cst_0 : f32) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>) // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 +// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 // CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32> -// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> +// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> // CHECK-NEXT: }