From 50a001f3c2af33691490e136f071654f32d65c33 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:30:01 -0800 Subject: [PATCH] [DispatchCreation] Add CSE before canonicalization of `flow.dispatch.workgroups` (#19178) The real change in this PR is to add CSE before canonicalization so that operands of `flow.dispatch.workgroups` can be better de-duped. Since this needed a pipeline test, noticed the pipeline test existed in Flow and called the global optimization, dispatch creation and flow pipelines. This PR moves the test to dispatch creation and only tests that pipeline. Signed-off-by: MaheshRavishankar --- .../Dialect/Flow/Transforms/test/BUILD.bazel | 1 - .../Flow/Transforms/test/CMakeLists.txt | 1 - .../Flow/Transforms/test/pipeline_tests.mlir | 99 ------------ .../iree/compiler/DispatchCreation/Passes.cpp | 1 + .../DispatchCreation/test/BUILD.bazel | 1 + .../DispatchCreation/test/CMakeLists.txt | 1 + .../DispatchCreation/test/pipeline_tests.mlir | 151 ++++++++++++++++++ 7 files changed, 154 insertions(+), 101 deletions(-) delete mode 100644 compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir create mode 100644 compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel index 3ce7528fcb35..ad9c6f310413 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel @@ -30,7 +30,6 @@ iree_lit_test_suite( "outline_constants.mlir", "outline_dispatch_externs.mlir", "outline_dispatch_regions.mlir", - "pipeline_tests.mlir", "top_level_scf_to_cfg.mlir", "verify_input_ir.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index 53280e7805b9..e9c4a14901fc 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -28,7 +28,6 @@ iree_lit_test_suite( "outline_constants.mlir" "outline_dispatch_externs.mlir" "outline_dispatch_regions.mlir" - "pipeline_tests.mlir" "top_level_scf_to_cfg.mlir" "verify_input_ir.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir deleted file mode 100644 index 8973ba5a0278..000000000000 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir +++ /dev/null @@ -1,99 +0,0 @@ -// TODO(hanchung): Split the transformation pipeline tests into two mlir files. -// RUN: iree-opt --iree-global-optimization-transformation-pipeline --iree-dispatch-creation-pipeline --iree-flow-transformation-pipeline --split-input-file %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0)> -#map1 = affine_map<(d0, d1) -> (d1)> -#map2 = affine_map<(d0, d1) -> (d0, d1)> -#map3 = affine_map<(d0, d1) -> ()> -util.func public @main(%arg0: tensor<833xi32>, %arg1: tensor<833x833xf32>, %arg2: tensor) -> tensor { - %cst = arith.constant 5.66893432E-4 : f32 - %0 = tensor.empty() : tensor<833x833xf32> - %1 = linalg.generic { - indexing_maps = [#map2, #map3, #map2], iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg2 : tensor<833x833xf32>, tensor) - outs(%0 : tensor<833x833xf32>) { - ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): - %2 = arith.divf %b0, %b1 : f32 - linalg.yield %2 : f32 - } -> tensor<833x833xf32> - %4 = linalg.generic { - indexing_maps = [#map, #map1, #map2, #map2], iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg0, %1 : tensor<833xi32>, tensor<833xi32>, tensor<833x833xf32>) - outs(%0 : tensor<833x833xf32>) { - ^bb0(%b0 : i32, %b1 : i32, %b2 : f32, %b3 : f32): - %5 = arith.cmpi eq, %b0, %b1 : i32 - %6 = arith.select %5, %b2, %cst : f32 - linalg.yield %6 : f32 - } -> tensor<833x833xf32> - %7 = tensor.empty() : tensor - %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor) -> tensor - %9 = linalg.generic { - indexing_maps = [#map2, #map3], iterator_types = ["reduction", "reduction"]} - ins(%4 : tensor<833x833xf32>) outs(%7 : tensor) { - ^bb0(%b0 : f32, %b1 : f32): - %10 = arith.addf %b1, %b0 : f32 - linalg.yield %10 : f32 - } -> tensor - util.return %9 : tensor -} -// Check that the linalg op with two reduction loops get folded into a single -// reduction which then prevents the parallel ops to be folded into it. -// See https://github.com/iree-org/iree/issues/13285 -// CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]] -// CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]] -// CHECK: linalg.generic -// CHECK-SAME: ["reduction", "reduction"] -// CHECK-NOT: linalg.generic -// CHECK: util.func public @main( -// CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]] -// CHECK: util.return %[[T0]] - -// ----- - -#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> -#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> -util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<1x1x4096xf32> - %1 = tensor.empty() : tensor<4096x32x128xf32> - %2 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32> - %3 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi4>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%1 : tensor<4096x32x128xf32>) { - ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32): - %5 = arith.extui %in : i4 to i32 - %6 = arith.uitofp %5 : i32 to f32 - %7 = arith.subf %6, %in_1 : f32 - %8 = arith.mulf %7, %in_0 : f32 - linalg.yield %8 : f32 - } -> tensor<4096x32x128xf32> - %4 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %3 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%2 : tensor<1x1x4096xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %5 = arith.mulf %in, %in_0 : f32 - %6 = arith.addf %5, %out : f32 - linalg.yield %6 : f32 - } -> tensor<1x1x4096xf32> - util.return %4 : tensor<1x1x4096xf32> -} -// Check that the two linalg.generic ops are fused into the same dispatch. -// CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]] -// CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]] -// CHECK: %[[GEN0:.+]] = linalg.generic -// CHECK-SAME: ["parallel", "parallel", "parallel"] -// CHECK: arith.extui -// CHECK: arith.uitofp -// CHECK: arith.subf -// CHECK: arith.mulf -// CHECK: %[[GEN1:.+]] = linalg.generic -// CHECK-SAME: ["parallel", "reduction", "reduction"] -// CHECK-SAME: ins( -// CHECK-SAME: %[[GEN0]] -// CHECK-SAME: outs( -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: flow.dispatch.tensor.store %[[GEN1]] -// CHECK: util.func public @grouped_quantized_matmul( -// CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]] -// CHECK: %[[RS:.+]] = flow.tensor.reshape %[[T0]] : tensor<4096xf32> -> tensor<1x1x4096xf32> -// CHECK: util.return %[[RS]] diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index 3fc56829d86b..abd3850ea3b9 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -312,6 +312,7 @@ void buildDispatchCreationPassPipeline( // acts as a contiguous view of the tensor // - Apply tensor -> flow patterns .addPass(DispatchCreation::createConvertTensorToFlowPass) + .addPass(createCSEPass) .addPass(IREE::Flow::createCanonicalizerPass) /// Creates the workgroup count region where the materialized computation /// is derived as a program slice of the body of the dispatch. This method diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel index 880a55f029c7..f1a6c4b4bdbf 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel +++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel @@ -41,6 +41,7 @@ iree_lit_test_suite( "fusion_preprocessing.mlir", "pad_fusion_with_consumer.mlir", "pad_fusion_with_producer.mlir", + "pipeline_tests.mlir", "set_encoding.mlir", "sink_reshapes.mlir", "split_reduction.mlir", diff --git a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt index 7de76f95a2a1..769ca7620613 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt @@ -39,6 +39,7 @@ iree_lit_test_suite( "hoist_encoding_ops.mlir" "pad_fusion_with_consumer.mlir" "pad_fusion_with_producer.mlir" + "pipeline_tests.mlir" "set_encoding.mlir" "sink_reshapes.mlir" "split_reduction.mlir" diff --git a/compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir b/compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir new file mode 100644 index 000000000000..03a9e885965d --- /dev/null +++ b/compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir @@ -0,0 +1,151 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(iree-dispatch-creation-fold-unit-extent-dims, iree-dispatch-creation-pipeline)" --split-input-file --mlir-print-local-scope %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0)> +#map1 = affine_map<(d0, d1) -> (d1)> +#map2 = affine_map<(d0, d1) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> ()> +util.func public @main(%arg0: tensor<833xi32>, %arg1: tensor<833x833xf32>, %arg2: tensor) -> tensor { + %cst = arith.constant 5.66893432E-4 : f32 + %0 = tensor.empty() : tensor<833x833xf32> + %1 = linalg.generic { + indexing_maps = [#map2, #map3, #map2], iterator_types = ["parallel", "parallel"]} + ins(%arg1, %arg2 : tensor<833x833xf32>, tensor) + outs(%0 : tensor<833x833xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %2 = arith.divf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<833x833xf32> + %4 = linalg.generic { + indexing_maps = [#map, #map1, #map2, #map2], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg0, %1 : tensor<833xi32>, tensor<833xi32>, tensor<833x833xf32>) + outs(%0 : tensor<833x833xf32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : f32, %b3 : f32): + %5 = arith.cmpi eq, %b0, %b1 : i32 + %6 = arith.select %5, %b2, %cst : f32 + linalg.yield %6 : f32 + } -> tensor<833x833xf32> + %7 = tensor.empty() : tensor + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor) -> tensor + %9 = linalg.generic { + indexing_maps = [#map2, #map3], iterator_types = ["reduction", "reduction"]} + ins(%4 : tensor<833x833xf32>) outs(%7 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %10 = arith.addf %b1, %b0 : f32 + linalg.yield %10 : f32 + } -> tensor + util.return %9 : tensor +} +// Check that the linalg op with two reduction loops get folded into a single +// reduction which then prevents the parallel ops to be folded into it. +// See https://github.com/iree-org/iree/issues/13285 +// CHECK-LABEL: func public @main +// CHECK-SAME: %[[ARG0:.+]]: tensor<833xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<833x833xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.tensor> +// CHECK-SAME: %[[ARG4:.+]]: !flow.dispatch.tensor> +// CHECK-SAME: %[[ARG5:.+]]: !flow.dispatch.tensor> +// CHECK-SAME: %[[ARG6:.+]]: !flow.dispatch.tensor> +// CHECK-DAG: %[[L0:.+]] = flow.dispatch.tensor.load %[[ARG3]] +// CHECK-DAG: %[[L1:.+]] = flow.dispatch.tensor.load %[[ARG4]] +// CHECK-DAG: %[[L2:.+]] = flow.dispatch.tensor.load %[[ARG5]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[L0]], %[[L0]], %[[L1]], %[[L2]] : +// CHECK: flow.dispatch.tensor.store %[[GENERIC]], %[[ARG6]] +// CHECK: return %[[DISPATCH]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +util.func public @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi4>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<1x1x4096xf32> + %1 = tensor.empty() : tensor<4096x32x128xf32> + %2 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi4>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%1 : tensor<4096x32x128xf32>) { + ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32): + %5 = arith.extui %in : i4 to i32 + %6 = arith.uitofp %5 : i32 to f32 + %7 = arith.subf %6, %in_1 : f32 + %8 = arith.mulf %7, %in_0 : f32 + linalg.yield %8 : f32 + } -> tensor<4096x32x128xf32> + %4 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %3 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%2 : tensor<1x1x4096xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %5 = arith.mulf %in, %in_0 : f32 + %6 = arith.addf %5, %out : f32 + linalg.yield %6 : f32 + } -> tensor<1x1x4096xf32> + util.return %4 : tensor<1x1x4096xf32> +} +// Check that the two linalg.generic ops are fused into the same dispatch. +// CHECK-LABEL: func public @grouped_quantized_matmul +// CHECK-SAME: %[[ARG0:.+]]: tensor<4096x32x128xi4>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x1x32x128xf32>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<4096x32x1xf32>, +// CHECK-SAME: %[[ARG3:.+]]: tensor<4096x32x1xf32>) +// CHECK-DAG: %[[RESHAPED_ARG2:.+]] = flow.tensor.reshape %[[ARG2]] : tensor<4096x32x1xf32> -> tensor<4096x32xf32> +// CHECK-DAG: %[[RESHAPED_ARG3:.+]] = flow.tensor.reshape %[[ARG3]] : tensor<4096x32x1xf32> -> tensor<4096x32xf32> +// CHECK-DAG: %[[RESHAPED_ARG1:.+]] = flow.tensor.reshape %[[ARG1]] : tensor<1x1x32x128xf32> -> tensor<32x128xf32> +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups(%[[ARG0]], %[[RESHAPED_ARG2]], %[[RESHAPED_ARG3]], %[[RESHAPED_ARG1]]) +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%{{.+}}, %[[GENERIC1]] : +// CHECK: flow.dispatch.tensor.store %[[GENERIC2]] +// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %[[DISPATCH]] +// CHECK: return %[[RESHAPE]] + +// ----- + +util.func public @verify_operand_cse(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view { + %c12 = arith.constant 12 : index + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index + %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[2] : index + %2 = hal.tensor.import wait(%arg2) => %arg0 : !hal.buffer_view -> tensor{%0, %1} + %3 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index + %4 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[3] : index + %5 = hal.tensor.import wait(%arg2) => %arg1 : !hal.buffer_view -> tensor{%3, %4} + %6 = arith.maxui %0, %3 : index + %collapsed = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor into tensor + %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2], [3]] : tensor into tensor + %7 = arith.muli %6, %c12 : index + %8 = tensor.empty(%7, %1, %4) : tensor + %9 = linalg.fill ins(%cst : f32) outs(%8 : tensor) -> tensor + %10 = linalg.batch_matmul ins(%collapsed, %collapsed_0 : tensor, tensor) outs(%9 : tensor) -> tensor + %11 = arith.divui %7, %c12 : index + %expanded = tensor.expand_shape %10 [[0, 1], [2], [3]] output_shape [%11, 12, %1, %4] : tensor into tensor + %12 = hal.tensor.barrier join(%expanded : tensor) => %arg3 : !hal.fence + %dim = tensor.dim %12, %c0 : tensor + %dim_1 = tensor.dim %12, %c2 : tensor + %dim_2 = tensor.dim %12, %c3 : tensor + %13 = hal.tensor.export %12 : tensor{%dim, %dim_1, %dim_2} -> !hal.buffer_view + util.return %13 : !hal.buffer_view +} +// Check that after forming dispatch.workgroup op the size of the +// `flow.tensor.load` and the dynamic dimension match. This is allows +// checking that the slice is a full slice. Running CSE before +// canonicalization makes this happen for this case. + +// CHECK-LABEL: func public @verify_operand_cse +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups +// CHECK-DAG: %[[DIM1:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 0 +// CHECK-DAG: %[[DIM2:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 1 +// CHECK-DAG: %[[DIM3:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 2 +// CHECK-DAG: %[[DIM4:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 3 +// CHECK: flow.dispatch.tensor.load +// CHECK-SAME: sizes = [%[[DIM1]], %[[DIM2]], 64] +// CHECK-SAME: !flow.dispatch.tensor>{%[[DIM1]], %[[DIM2]]} +// CHECK: flow.dispatch.tensor.load +// CHECK-SAME: sizes = [%[[DIM3]], 64, %[[DIM4]]] +// CHECK-SAME: !flow.dispatch.tensor>{%[[DIM3]], %[[DIM4]]}