From b12a2d730e46f1de62b2b6e9e858c38e8f175885 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Tue, 27 Feb 2024 23:07:12 +0000 Subject: [PATCH] [mlir][LinAlg] Vectorize reverse-like ops using vector.gather ops. The reverse op is treated as a VectorMemoryAccessKind::Contiguous load. It is contiguous slice, but we'll need to compute indices differently and apply a reverse at vector level. It takes non-trivial efforts for the approach. The revision flips the case to use vector.gather. Otherwise there are functionality issues. E.g., the below example loaded `2, 3, 4` (which is a bug), but what we want is `2, 1, 0`. Before vectorization: ```mlir func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) { ^bb0(%out: f32): %1 = linalg.index 1 : index %2 = linalg.index 0 : index %3 = affine.apply #map1(%1, %2, %arg2) %4 = linalg.index 2 : index %5 = arith.subi %c2, %4 : index %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32> linalg.yield %extracted : f32 } -> tensor<1x1x3xf32> return %0 : tensor<1x1x3xf32> } ``` Partial IR after vectorization: ``` %5 = vector.constant_mask [1, 1, 3] : vector<1x1x4xi1> %6 = vector.broadcast %arg0 : index to vector<1x1x4xindex> %7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex> %8 = vector.extractelement %7[%c0_i32 : i32] : vector<4xindex> %9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x1x4xf32> ``` --- .../Linalg/Transforms/Vectorization.cpp | 3 +- .../Linalg/vectorize-tensor-extract.mlir | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ac043e87223df..1e703dacfd0c7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -891,8 +891,7 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, // Conservatively reject Ops that could lead to indices with stride other // than 1. - if (!isa( - ancestor)) + if (!isa(ancestor)) return false; bool result = false; diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index 96953c234a087..85e1c56dd45a0 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -550,3 +550,48 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)> +func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) { + ^bb0(%out: f32): + %1 = linalg.index 1 : index + %2 = linalg.index 0 : index + %3 = affine.apply #map1(%1, %2, %arg2) + %4 = linalg.index 2 : index + %5 = arith.subi %c2, %4 : index + %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32> + linalg.yield %extracted : f32 + } -> tensor<1x1x3xf32> + return %0 : tensor<1x1x3xf32> +} +// CHECK-LABEL: func.func @vectorize_reverse_like_tensor_extract +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]] +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]] +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]] +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> +// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> +// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex> +// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex> +// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex> +// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]] +// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]] +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]] +// CHECK: vector.transfer_write %[[GATHER]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op + transform.yield + } +}