Skip to content

Commit

Permalink
[mlir][vector] Add tests for populateSinkVectorOpsPatterns (2/N) (llv…
Browse files Browse the repository at this point in the history
…m#122338)

Adds tests for scalable vectors in:

  * "vector-sink.mlir".

This test file exercises patterns included in
`populateSinkVectorOpsPatterns`:

  * `ReorderElementwiseOpsOnBroadcast`,
  * `ReorderCastOpsOnBroadcast`,
  * `ReorderElementwiseOpsOnTranspose`.

This PR focuses on adding tests for the latter two patterns
(`ReorderCastOpsOnBroadcast` and `ReorderElementwiseOpsOnTranspose`).

Tests for `ReorderElementwiseOpsOnBroadcast` were added in llvm#102286. Please
note that in PR llvm#102856, I renamed:

  * `populateSinkVectorBroadcastPatterns`, to
  * `populateSinkVectorOpsPatterns`.
  • Loading branch information
banach-space authored Jan 15, 2025
1 parent c82a6a0 commit e9504c5
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions mlir/test/Dialect/Vector/vector-sink.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {

// -----

func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
// CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
%b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
return %r : vector<2x[4]xi32>
}

// -----

func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
Expand All @@ -236,6 +246,16 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
return %r : vector<2x4xi32>
}

// -----

func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
%b = vector.broadcast %a : i8 to vector<2x[4]xi8>
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
return %r : vector<2x[4]xi32>
}

//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
Expand All @@ -250,6 +270,16 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {

// -----

func.func @transpose_extsi_scalable(%a : vector<[4]x2xi8>) -> vector<2x[4]xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]x2xi8> to vector<[4]x2xi32>
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<[4]x2xi32> to vector<2x[4]xi32>
%b = vector.transpose %a, [1, 0]: vector<[4]x2xi8> to vector<2x[4]xi8>
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
return %r : vector<2x[4]xi32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_same_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
Expand All @@ -265,6 +295,21 @@ func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2

// -----

// CHECK-LABEL: func @transpose_elementwise_same_type_scalable
// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
// CHECK: return %[[T]]

func.func @transpose_elementwise_same_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
%at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
%bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
%r = arith.addf %at, %bt : vector<2x[4]xf32>
return %r : vector<2x[4]xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
Expand All @@ -280,6 +325,21 @@ func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a :

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_types_scalable
// CHECK-SAME: (%[[COND:.+]]: vector<[4]x2xi1>, %[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<[4]x2xi1>, vector<[4]x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<[4]x2xf32> to vector<2x[4]xf32>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_types_scalable(%cond: vector<[4]x2xi1>, %a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
%condt = vector.transpose %cond, [1, 0]: vector<[4]x2xi1> to vector<2x[4]xi1>
%at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
%bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
%r = arith.select %condt, %at, %bt : vector<2x[4]xi1>, vector<2x[4]xf32>
return %r : vector<2x[4]xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
Expand All @@ -294,6 +354,20 @@ func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>,

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type_scalable
// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<[4]x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<[4]x2xi1> to vector<2x[4]xi1>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_result_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xi1> {
%at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
%bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
%r = arith.cmpf olt, %at, %bt : vector<2x[4]xf32>
return %r : vector<2x[4]xi1>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_splat_constant
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
Expand All @@ -310,6 +384,22 @@ func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vec

// -----

// CHECK-LABEL: func @transpose_elementwise_splat_constant_scalable
// CHECK-SAME: (%[[A:.+]]: vector<[4]x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<[4]x6x3x2xf32>
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x6x3x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
// CHECK: return %[[T:.+]] : vector<6x[4]x2x3xf32>

func.func @transpose_elementwise_splat_constant_scalable(%a : vector<[4]x6x3x2xf32>) -> vector<6x[4]x2x3xf32> {
%b = arith.constant dense<5.0> : vector<6x[4]x2x3xf32>
%at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
%r = arith.addf %at, %b : vector<6x[4]x2x3xf32>
return %r : vector<6x[4]x2x3xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_map
// CHECK: vector.transpose
// CHECK: vector.transpose
Expand All @@ -320,3 +410,16 @@ func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}

// -----

// CHECK-LABEL: func @transpose_elementwise_diff_map_scalable
// CHECK: vector.transpose
// CHECK: vector.transpose
// CHECK: arith.addf
func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %b: vector<6x2x[4]x3xf32>) -> vector<6x[4]x2x3xf32> {
%at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x[4]x3xf32> to vector<6x[4]x2x3xf32>
%r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
return %r : vector<6x[4]x2x3xf32>
}

0 comments on commit e9504c5

Please sign in to comment.