diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir index 809c070012..73da68c87b 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir @@ -185,6 +185,86 @@ func.func @compare_unsigned_arg(%arg0: tensor) // ----- +///////// +// ComplexOp + +// CHECK-LABEL: @complex_collapse_simplify +func.func @complex_collapse_simplify(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %0 = stablehlo.real %arg0 : (tensor<4xcomplex>) -> tensor<4xf32> + %1 = stablehlo.imag %arg0 : (tensor<4xcomplex>) -> tensor<4xf32> + %2 = stablehlo.complex %0, %1 : tensor<4xcomplex> + // CHECK: return %arg0 + return %2 : tensor<4xcomplex> +} + +// CHECK-LABEL: @complex_expand_simplify +func.func @complex_expand_simplify(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex> + %1 = stablehlo.real %0 : (tensor<4xcomplex>) -> tensor<4xf32> + %2 = stablehlo.imag %0 : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK: return %arg0, %arg1 + return %1, %2 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +//////// +// ConcatenateOp + +// CHECK-LABEL: concatenate_noop +func.func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> + %0 = "stablehlo.concatenate"(%arg0) <{ dimension = 0 : i64 }> : (tensor<4xi32>) -> tensor<4xi32> + + // CHECK: return [[ARG]] + func.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_with_empty +func.func @concatenate_with_empty(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<8xi32> { + // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> + // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> + // CHECK: stablehlo.concatenate [[ARG0]], [[ARG0]], dim = 0 + %0 = "stablehlo.concatenate"(%arg0, %arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<4xi32>, tensor<4xi32>, tensor<0xi32>) -> tensor<8xi32> + func.return %0 : tensor<8xi32> +} + + +// CHECK-LABEL: concatenate_empty_bool +func.func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { + // CHECK: stablehlo.constant dense<> + %0 = "stablehlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + func.return %0 : tensor<0xi1> +} + +// CHECK-LABEL: concatenate_forward +func.func @concatenate_forward(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<12xi32> { + // CHECK: [[CST:%.+]] = stablehlo.constant dense<[0, 1, 2, 3]> : tensor<4xi32> + // CHECK: stablehlo.concatenate %arg0, %arg1, [[CST]], dim = 0 : (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<12xi32> + %0 = stablehlo.concatenate %arg0, %arg1, dim = 0 : (tensor<4xi32>, tensor<4xi32>) -> tensor<8xi32> + %c = stablehlo.constant dense<[0, 1, 2, 3]> : tensor<4xi32> + %1 = stablehlo.concatenate %0, %c, dim = 0 : (tensor<8xi32>, tensor<4xi32>) -> tensor<12xi32> + func.return %1 : tensor<12xi32> +} + +// CHECK-LABEL: concatenate_zero_extent +func.func @concatenate_zero_extent(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { + // CHECK: stablehlo.constant dense<> + %0 = "stablehlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> + + func.return %0 : tensor<0xi32> +} + +// CHECK-LABEL: concatenate_empty_float +func.func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + // CHECK: stablehlo.constant dense<> + %0 = "stablehlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + + func.return %0 : tensor<0xf32> +} + +// ----- + ///////// // ConvertOp @@ -270,14 +350,119 @@ func.func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic(%arg0: tensor func.return %0 : tensor<5x4xf32> } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_shape_of +func.func @dynamic_broadcast_in_dim_to_shape_of(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> + %2 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %0) <{ broadcast_dimensions = array }> : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + func.return %2 : tensor +} + +// CHECK-LABEL: @dynamic_broadcast_of_reshape +func.func @dynamic_broadcast_of_reshape(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + // CHECK: [[RESHAPE:%.*]] = stablehlo.dynamic_reshape + // CHECK: return [[RESHAPE]] + %0 = "stablehlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<2xindex>) -> tensor + %1 = "stablehlo.dynamic_broadcast_in_dim"(%0, %shape) { broadcast_dimensions = array } : (tensor, tensor<2xindex>) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: @dynamic_broadcast_in_dim_of_reshape_permuted +func.func @dynamic_broadcast_in_dim_of_reshape_permuted(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + // CHECK: stablehlo.dynamic_reshape + // CHECK: stablehlo.dynamic_broadcast_in_dim + %0 = "stablehlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<2xindex>) -> tensor + %1 = "stablehlo.dynamic_broadcast_in_dim"(%0, %shape) { broadcast_dimensions = array } : (tensor, tensor<2xindex>) -> tensor + func.return %1 : tensor +} + +// ----- + +///////// +// DynamicGatherOp + +// CHECK-LABEL: @simplify_dynamic_gather_i64 +func.func @simplify_dynamic_gather_i64(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { + %c = stablehlo.constant dense<[1, 256]> : tensor<2xi64> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %c) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi64>) -> tensor<16x64x256xf16> + // CHECK: %[[RET:.+]] = "stablehlo.gather"(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> + // CHECK: return %[[RET]] + return %0 : tensor<16x64x256xf16> +} + +// CHECK-LABEL: @simplify_dynamic_gather_i32 +func.func @simplify_dynamic_gather_i32(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { + %c = stablehlo.constant dense<[1, 256]> : tensor<2xi32> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %c) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi32>) -> tensor<16x64x256xf16> + // CHECK: %[[RET:.+]] = "stablehlo.gather"(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> + // CHECK: return %[[RET]] + return %0 : tensor<16x64x256xf16> +} + +// ----- + +///////// +// DynamicIotaOp + +// CHECK-LABEL: @dynamic_iota_broadcast +func.func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { + // CHECK: [[IOTA:%.+]] = stablehlo.iota dim = 0 : tensor<5xi32> + // CHECK: [[BROADCAST:%.+]] = stablehlo.dynamic_broadcast_in_dim [[IOTA]], %arg0, dims = [0] : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "stablehlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<5x?xi32> + + // CHECK: return [[BROADCAST]] + func.return %0 : tensor<5x?xi32> +} + +// CHECK-LABEL: @dynamic_iota_broadcast_second +func.func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { + // CHECK-NEXT: [[CAST1:%.+]] = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi64> + // CHECK-NEXT: [[SLICE:%.+]] = stablehlo.slice [[CAST1]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: [[CAST2:%.+]] = arith.index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex> + // CHECK-NEXT: [[IOTA:%.+]] = stablehlo.dynamic_iota [[CAST2]], dim = 0 : (tensor<1xindex>) -> tensor + // CHECK-NEXT: [[BROADCAST:%.+]] = stablehlo.dynamic_broadcast_in_dim [[IOTA]], %arg0, dims = [1] : (tensor, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "stablehlo.dynamic_iota"(%arg0) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<5x?xi32> + + // CHECK: return [[BROADCAST]] + func.return %0 : tensor<5x?xi32> +} + +// CHECK-LABEL: @dynamic_iota_is_static +func.func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = stablehlo.iota + // CHECK: return [[RESULT]] + %0 = "stablehlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<1xindex>) -> tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// ----- + +///////// +// DynamicPadOp + +// CHECK-LABEL: func.func @dynamic_pad_to_pad +func.func @dynamic_pad_to_pad(%arg0: tensor<2x3xi32>, %arg1: tensor) -> tensor<5x9xi32> { + %low = stablehlo.constant dense<[0, 1]> : tensor<2xi32> + %high = stablehlo.constant dense<[2, 1]> : tensor<2xi32> + %interior = stablehlo.constant dense<[1, 2]> : tensor<2xi32> + %0 = stablehlo.dynamic_pad %arg0, %arg1, %low, %high, %interior + : (tensor<2x3xi32>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x9xi32> + // CHECK: [[PAD:%.*]] = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [1, 2] : (tensor<2x3xi32>, tensor) -> tensor<5x9xi32> + // CHECK: return [[PAD]] + func.return %0 : tensor<5x9xi32> +} + // ----- ///////// // DynamicReshapeOp -// CHECK-LABEL: func.func @dynamic_reshape +// CHECK-LABEL: func.func @dynamic_reshape_is_static // CHECK-SAME: ([[ARG0:%.+]]: tensor<1xf32>, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor<2xi32>) -func.func @dynamic_reshape(%arg0: tensor<1xf32>, %arg1: tensor, %arg2: tensor<2xi32>) +func.func @dynamic_reshape_is_static(%arg0: tensor<1xf32>, %arg1: tensor, %arg2: tensor<2xi32>) -> (tensor<1x1xf32>, tensor<2x1xf32>, tensor<1x2xi32>) { %c0 = stablehlo.constant dense<[2, 1]> : tensor<2xi32> @@ -292,6 +477,107 @@ func.func @dynamic_reshape(%arg0: tensor<1xf32>, %arg1: tensor, %arg2: return %0, %1, %2 : tensor<1x1xf32>, tensor<2x1xf32>, tensor<1x2xi32> } +// CHECK-LABEL: func @dynamic_reshape_shape_of +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] +func.func @dynamic_reshape_shape_of(%arg0: tensor, %shape: tensor<2xindex>) -> tensor<2xindex> { + // CHECK: return [[ARG1]] + %0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<2xindex> + func.return %1 : tensor<2xindex> +} + +// CHECK-LABEL: func @dynamic_reshape_of_same_operand_result +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] +func.func @dynamic_reshape_of_same_operand_result(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { + %0 = stablehlo.dynamic_reshape %arg0, %arg1 : (tensor, tensor<1xindex>) -> tensor + %1 = stablehlo.add %0, %0 : tensor + %2 = stablehlo.dynamic_reshape %1, %arg1 : (tensor, tensor<1xindex>) -> tensor + // CHECK: [[ADD:%.+]] = stablehlo.add + // CHECK: return [[ADD]] + return %2 : tensor +} + +// ----- + +///////// +// DynamicSliceOp + +// CHECK-LABEL: dynamic_slice_variable_start +func.func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // CHECK: stablehlo.dynamic_slice + %0 = stablehlo.dynamic_slice %arg0, %arg1, %arg2, sizes = [1, 4] : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start +func.func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: stablehlo.slice %arg0 [1:3] : (tensor<4xi32>) -> tensor<2xi32> + %c = stablehlo.constant dense<1> : tensor + %0 = stablehlo.dynamic_slice %arg0, %c, sizes = [2] : (tensor<4xi32>, tensor) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape +func.func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: stablehlo.dynamic_slice + // CHECK-NOT: stablehlo.slice + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %0 = stablehlo.dynamic_slice %arg0, %c, %c_0, sizes = [1, 4] : (tensor, tensor, tensor) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start_upper_bound +func.func @dynamic_slice_constant_start_upper_bound(%arg0: tensor<8x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: stablehlo.slice %arg0 [7:8, 0:4] : (tensor<8x4xi32>) -> tensor<1x4xi32> + %c = stablehlo.constant dense<10> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %0 = stablehlo.dynamic_slice %arg0, %c, %c_0, sizes = [1, 4] : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start_lower_bound +func.func @dynamic_slice_constant_start_lower_bound(%arg0: tensor<8x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: stablehlo.slice %arg0 [0:1, 0:4] : (tensor<8x4xi32>) -> tensor<1x4xi32> + %c = stablehlo.constant dense<-1> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %0 = stablehlo.dynamic_slice %arg0, %c, %c_0, sizes = [1, 4] : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// ----- + +//////// +// DynamicSliceOp + +// CHECK-LABEL: dynamic_update_slice_noop +func.func @dynamic_update_slice_noop(%arg0: tensor<3x4xi64>, %arg1: tensor<3x0xi64>) -> tensor<3x4xi64> { + // CHECK: return %arg0 + %c = stablehlo.constant dense<0> : tensor + %0 = stablehlo.dynamic_update_slice %arg0, %arg1, %c, %c : (tensor<3x4xi64>, tensor<3x0xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %0 : tensor<3x4xi64> +} + +// CHECK-LABEL: dynamic_update_slice_noop_dynamic +func.func @dynamic_update_slice_noop_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %c = stablehlo.constant dense<0> : tensor + %0 = stablehlo.dynamic_update_slice %arg0, %arg1, %c, %c : (tensor, tensor, tensor, tensor) -> tensor + func.return %0 : tensor + // CHECK: %[[CST:.*]] = stablehlo.constant dense<0> : tensor + // CHECK: %[[VAL:.*]] = stablehlo.dynamic_update_slice %arg0, %arg1, %[[CST]], %[[CST]] : (tensor, tensor, tensor, tensor) -> tensor + // CHECK: return %[[VAL]] : tensor +} + +// CHECK-LABEL: dynamic_update_slice_identity_update +func.func @dynamic_update_slice_identity_update(%arg0: tensor<3x4xi64>, %arg1: tensor<3x4xi64>) -> tensor<3x4xi64> { + // CHECK: return %arg1 + %c = stablehlo.constant dense<0> : tensor + %0 = stablehlo.dynamic_update_slice %arg0, %arg1, %c, %c : (tensor<3x4xi64>, tensor<3x4xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %0 : tensor<3x4xi64> +} + // ----- ///////// @@ -448,6 +734,54 @@ func.func @complex(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32> return %r, %i : tensor<2xf32>, tensor<2xf32> } +///////// +// IotaOp + +// CHECK-LABEL: @iota_constant +func.func @iota_constant() -> tensor<1xi32> { + // CHECK: [[CONST:%.+]] = stablehlo.constant dense<0> : tensor<1xi32> + %0 = stablehlo.iota dim = 0 : tensor<1xi32> + + // CHECK: return [[CONST]] : tensor<1xi32> + func.return %0 : tensor<1xi32> +} + +// CHECK-LABEL: @iota_constant_multi +func.func @iota_constant_multi() -> tensor<1x4xi32> { + // CHECK: [[CONST:%.+]] = stablehlo.constant dense<0> : tensor<1x4xi32> + %0 = stablehlo.iota dim = 0 : tensor<1x4xi32> + + // CHECK: return [[CONST]] : tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + + +// CHECK-LABEL: @iota_not_lowered_to_constant +func.func @iota_not_lowered_to_constant() -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = stablehlo.iota + // CHECK: return [[RESULT]] + %0 = stablehlo.iota dim = 0 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @iota_broadcast_dim1 +func.func @iota_broadcast_dim1() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = stablehlo.iota dim = 0 : tensor<5xi32> + // CHECK: [[RESULT:%.+]] = stablehlo.broadcast_in_dim [[IOTA]], dims = [0] : (tensor<5xi32>) -> tensor<5x4xi32> + %0 = stablehlo.iota dim = 0 : tensor<5x4xi32> + + func.return %0 : tensor<5x4xi32> +} + +// CHECK-LABEL: @iota_broadcast_dim2 +func.func @iota_broadcast_dim2() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = stablehlo.iota dim = 0 : tensor<4xi32> + // CHECK: [[RESULT:%.+]] = stablehlo.broadcast_in_dim [[IOTA]], dims = [1] : (tensor<4xi32>) -> tensor<5x4xi32> + %0 = stablehlo.iota dim = 1 : tensor<5x4xi32> + + func.return %0 : tensor<5x4xi32> +} + // ----- ///////// @@ -542,6 +876,51 @@ func.func @or_one(%arg0: tensor<2xi1>) -> tensor<2xi1> { // ----- +//////// +// PadOp + +// CHECK-LABEL: @pad_zero_length +func.func @pad_zero_length(%arg0: tensor<5x0xf32>, %arg1: tensor) -> tensor<7x2xf32> { + %0 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0] + : (tensor<5x0xf32>, tensor) -> tensor<7x2xf32> + // CHECK: %[[RES:.+]] = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<7x2xf32> + // CHECK: return %[[RES]] + return %0 : tensor<7x2xf32> +} + +// ----- + +///////// +// RealDynamicSliceOp + +// CHECK-LABEL: @simplify_real_dynamic_slice_to_slice +func.func @simplify_real_dynamic_slice_to_slice(%arg0: tensor) -> tensor<1x4xf32> { + %0 = stablehlo.constant dense<[0, 0]> : tensor<2xi32> + %1 = stablehlo.constant dense<[1, 4]> : tensor<2xi32> + %2 = stablehlo.constant dense<[1, 1]> : tensor<2xi32> + %3 = stablehlo.real_dynamic_slice %arg0, %0, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> + // CHECK: %[[RESULT:.*]] = stablehlo.slice %arg0 [0:1, 0:4] : (tensor) -> tensor<1x4xf32> + // CHECK: return %[[RESULT]] : tensor<1x4xf32> + return %3 : tensor<1x4xf32> +} + +// CHECK-LABEL: @simplify_real_dynamic_slice_to_dynamic_slice +func.func @simplify_real_dynamic_slice_to_dynamic_slice(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x4xf32> { + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi32> + %1 = stablehlo.add %arg1, %0 : tensor<2xi32> + %2 = stablehlo.constant dense<[1, 1]> : tensor<2xi32> + %3 = stablehlo.real_dynamic_slice %arg0, %arg1, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> + return %3 : tensor<1x4xf32> + // CHECK: [[START_INDEX_0_1D:%.*]] = stablehlo.slice %arg1 [0:1] : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: [[START_INDEX_0_0D:%.*]] = stablehlo.reshape [[START_INDEX_0_1D]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: [[START_INDEX_1_1D:%.*]] = stablehlo.slice %arg1 [1:2] : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: [[START_INDEX_1_0D:%.*]] = stablehlo.reshape [[START_INDEX_1_1D]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: [[RESULT:%.*]] = stablehlo.dynamic_slice %arg0, [[START_INDEX_0_0D]], [[START_INDEX_1_0D]], sizes = [1, 4] : (tensor, tensor, tensor) -> tensor<1x4xf32> + // CHECK-NEXT: return [[RESULT]] : tensor<1x4xf32> +} + +// ----- + ///////// // ReduceOp @@ -1189,6 +1568,81 @@ func.func @select_into_minmax2(%arg0: tensor, %arg1: tensor, %arg2: te tensor, tensor, tensor, tensor } +// CHECK-LABEL: func @select_op_not_as_pred( +func.func @select_op_not_as_pred(%arg0: tensor<4xi1>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = stablehlo.not %arg0 : tensor<4xi1> + %1 = stablehlo.select %0, %arg1, %arg2 : tensor<4xi1>, tensor<4xf32> + // CHECK-NOT: stablehlo.not + // CHECK: %[[R:.*]] = stablehlo.select %arg0, %arg2, %arg1 + // CHECK: return %[[R]] + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @select_op_broadcasted_not_as_pred( +func.func @select_op_broadcasted_not_as_pred(%arg0: tensor<1xi1>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = stablehlo.not %arg0 : tensor<1xi1> + %1 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> + %2 = stablehlo.select %1, %arg1, %arg2 : tensor<4xi1>, tensor<4xf32> + + // CHECK-NOT: stablehlo.not + // CHECK: %[[B:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> + // CHECK: %[[R:.*]] = stablehlo.select %[[B]], %arg2, %arg1 + // CHECK: return %[[R]] + return %2 : tensor<4xf32> +} + +// ----- + +///////// +// SliceOp + +// CHECK-LABEL: slice_of_concat +// CHECK-SAME: [[ARG0:%.+]]: tensor<2x5xf32>, [[ARG1:%.+]]: tensor<1x5xf32> +func.func @slice_of_concat(%arg0: tensor<2x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = stablehlo.concatenate %arg0, %arg1, dim = 0 : (tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<3x5xf32> + // CHECK-NOT: stablehlo.concatenate + // CHECK: stablehlo.slice [[ARG0]] + %1 = stablehlo.slice %0 [1:2, 0:5] : (tensor<3x5xf32>) -> tensor<1x5xf32> + return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_2D_noop +// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> +func.func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { + %0 = stablehlo.slice %arg0 [0:2, 0:2] : (tensor<2x2xi64>) -> tensor<2x2xi64> + + // CHECK-NEXT: return [[ARG]] + func.return %0 : tensor<2x2xi64> +} + +// ----- + +///////// +// SortOp + +// CHECK-LABEL: @sort_op_second_arg_unused +// CHECK-SAME: [[ARG0:%.+]]: tensor<3xi32>, [[ARG1:%.+]]: tensor<3xi32> +func.func @sort_op_second_arg_unused(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi32> { + // CHECK: "stablehlo.sort"([[ARG0]]) + %0:2 = "stablehlo.sort"(%arg0, %arg1) <{dimension = 0 : i64, is_stable = false}> ({ + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): + %1 = stablehlo.compare GT, %arg2, %arg3 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + }) : (tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + return %0#0 : tensor<3xi32> +} + +// CHECK-LABEL: @sort_op_set_default_dimension +func.func @sort_op_set_default_dimension(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { + // CHECK: stablehlo.sort{{.*}}dimension = 1 : i64 + %0 = "stablehlo.sort"(%arg0) <{dimension = -1 : i64, is_stable = false}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = stablehlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + }) : (tensor<3x5xi32>) -> tensor<3x5xi32> + return %0 : tensor<3x5xi32> +} + // ----- ///////// @@ -1223,6 +1677,66 @@ func.func @transpose_is_not_reshape(%arg0: tensor<1x4x5x2xf32>) -> tensor<2x4x1x // ----- +//////// +// TupleOp + + +// CHECK-LABEL: unpack_repack_same_tuple +// CHECK-SAME: ([[ARG0:%.*]]: tuple, !stablehlo.token, tensor>) +func.func @unpack_repack_same_tuple(%arg0: tuple, !stablehlo.token, tensor>) -> tuple, !stablehlo.token, tensor> { + %0 = stablehlo.get_tuple_element %arg0[0] : (tuple, !stablehlo.token, tensor>) -> tensor + %1 = stablehlo.get_tuple_element %arg0[1] : (tuple, !stablehlo.token, tensor>) -> !stablehlo.token + %2 = stablehlo.get_tuple_element %arg0[2] : (tuple, !stablehlo.token, tensor>) -> tensor + %3 = stablehlo.tuple %0, %1, %2 : tuple, !stablehlo.token, tensor> + // CHECK: return [[ARG0]] + return %3 : tuple, !stablehlo.token, tensor> +} + +// CHECK-LABEL: unpack_repack_same_tuple_single_element +// CHECK-SAME: ([[ARG0:%.*]]: tuple>) +func.func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tuple> { + %0 = stablehlo.get_tuple_element %arg0[0] : (tuple>) -> tensor + %1 = stablehlo.tuple %0 : tuple> + // CHECK: return [[ARG0]] + return %1 : tuple> +} + +// ----- + +//////// +// WhileOp DCE + +// CHECK-LABEL: while_op_with_outfeed_no_dce +func.func @while_op_with_outfeed_no_dce(%arg0: tensor) -> tensor { + // CHECK: stablehlo.while + %0 = stablehlo.while(%iterArg = %arg0) : tensor + cond { + %1 = stablehlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %1 = stablehlo.create_token : !stablehlo.token + %2 = "stablehlo.outfeed"(%iterArg, %1) <{outfeed_config = ""}> : (tensor, !stablehlo.token) -> !stablehlo.token + stablehlo.return %iterArg : tensor + } + return %arg0 : tensor +} + +// CHECK-LABEL: while_op_dce_no_side_effect +func.func @while_op_dce_no_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: stablehlo.while + %0 = stablehlo.while(%iterArg = %arg0) : tensor + cond { + %1 = stablehlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %1 = stablehlo.create_token : !stablehlo.token + stablehlo.return %iterArg : tensor + } + return %arg0 : tensor +} + +// ----- + ///////// // Generic Zero Extent Ops @@ -1264,7 +1778,7 @@ func.func @xor_cst_on_rhs(%arg0: tensor<2xi1>) -> tensor<2xi1> { // CHECK-LABEL: func.func @add_zero_ext func.func @add_zero_ext(%arg0 : tensor<5x0xi32>, %arg1 : tensor<5x0xi32>) -> tensor<5x0xi32> { - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<5x0xi32> + // CHECK: %[[EMPTY:.+]] = stablehlo.constant dense<> // CHECK: return %[[EMPTY]] %0 = stablehlo.add %arg0, %arg1 : tensor<5x0xi32> func.return %0 : tensor<5x0xi32> @@ -1275,7 +1789,7 @@ func.func @add_zero_ext(%arg0 : tensor<5x0xi32>, %arg1 : tensor<5x0xi32>) -> ten // CHECK-LABEL: func.func @add_zero_ext_dynamic func.func @add_zero_ext_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = stablehlo.add %arg0, %arg1 : tensor - // CHECK-NOT: tensor.empty() + // CHECK-NOT: stablehlo.constant dense<> func.return %0 : tensor } @@ -1297,16 +1811,27 @@ func.func @scatter_zero_ext(%arg0 : tensor, %arg1 : tensor<1x0xi32>, %arg2 indices_are_sorted = true, unique_indices = true } : (tensor, tensor<1x0xi32>, tensor<1xf32>) -> tensor - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x0xi32> - // CHECK: %[[SCATTER:.+]] = "stablehlo.scatter"(%arg0, %0, %arg2) + // CHECK: %[[EMPTY:.+]] = stablehlo.constant dense<> : tensor<1x0xi32> + // CHECK: %[[SCATTER:.+]] = "stablehlo.scatter"(%arg0, %[[EMPTY]], %arg2) // CHECK: return %[[SCATTER]] func.return %0 : tensor } // ----- +// CHECK-LABEL: slice_zero_extent +func.func @slice_zero_extent(%arg0: tensor<1x5xf32>) -> tensor<0x5xf32> { + %0 = stablehlo.slice %arg0 [1:1, 0:5] : (tensor<1x5xf32>) -> tensor<0x5xf32> + // CHECK-NOT: stablehlo.slice + // CHECK: [[CST:%.+]] = stablehlo.constant dense<> : tensor<0x5xf32> + // CHECK: return [[CST]] + return %0 : tensor<0x5xf32> +} + +// ----- + // CHECK-LABEL: @sort_zero_extent -func.func public @sort_zero_extent(%arg0: tensor<0xi16> {jax.arg_info = "a", mhlo.sharding = "{replicated}"}) -> (tensor<0xi32> {jax.result_info = ""}) { +func.func public @sort_zero_extent(%arg0: tensor<0xi16>) -> (tensor<0xi32> {jax.result_info = ""}) { %0 = stablehlo.iota dim = 0 : tensor<0xi32> %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): @@ -1314,7 +1839,7 @@ func.func public @sort_zero_extent(%arg0: tensor<0xi16> {jax.arg_info = "a", mhl stablehlo.return %2 : tensor }) {dimension = 0 : i64, is_stable = true} : (tensor<0xi16>, tensor<0xi32>) -> (tensor<0xi16>, tensor<0xi32>) - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<0xi32> + // CHECK: %[[EMPTY:.+]] = stablehlo.constant dense<> : tensor<0xi32> // CHECK: return %[[EMPTY]] return %1#1 : tensor<0xi32> } @@ -1323,10 +1848,9 @@ func.func public @sort_zero_extent(%arg0: tensor<0xi16> {jax.arg_info = "a", mhl // ----- // CHECK-LABEL: @while_zero_extent -// CHECK: %[[R0:.+]] = tensor.empty() : tensor<75x0xf32> -// CHECK: %[[R1:.+]] = tensor.empty() : tensor<75x0xf32> -// CHECK: %[[R2:.+]]:2 = stablehlo.while -// CHECK: return %[[R2]]#0, %[[R0]] +// CHECK: %[[R0:.+]] = stablehlo.constant dense<> : tensor<75x0xf32> +// CHECK: %[[R2:.+]] = stablehlo.while +// CHECK: return %[[R2]], %[[R0]] func.func public @while_zero_extent(%arg0: tensor, %arg1: tensor<3xf32>, %arg2: tensor<75x0xf32>) -> (tensor, tensor<75x0xf32>) { diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index 3044dd0d45..6f5e5b2c8d 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -54,7 +54,6 @@ def StablehloAggressiveSimplificationPass let summary = "Canonicalizes StableHLO operations"; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", - "mlir::tensor::TensorDialect", ]; } diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index 4ccdd7ded2..0c9849dbaa 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -10,9 +10,7 @@ #include #include #include -#include #include -#include #include #include @@ -23,10 +21,9 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/CommonFolders.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -104,10 +101,7 @@ m_AnyAttrOf(MatcherA, MatcherB) -> m_AnyAttrOf; // CompareOp ///////////////////////////////// -static mlir::stablehlo::ComparisonDirection invertDirection( - mlir::stablehlo::ComparisonDirection direction) { - using mlir::stablehlo::ComparisonDirection; - +static ComparisonDirection invertDirection(ComparisonDirection direction) { switch (direction) { case ComparisonDirection::EQ: case ComparisonDirection::NE: @@ -125,16 +119,15 @@ static mlir::stablehlo::ComparisonDirection invertDirection( llvm::report_fatal_error("Unhandled case"); } -struct CompareOpCanon final : OpRewritePattern { +struct CompareOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::CompareOp op, + LogicalResult matchAndRewrite(CompareOp op, PatternRewriter &rewriter) const override { RankedTensorType type = op.getType(); // Bail out on non-integer comparison. // TODO: Support more comparison types. - using mlir::stablehlo::ComparisonType; std::optional compType = op.getCompareType(); if (!compType || !llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, @@ -142,7 +135,6 @@ struct CompareOpCanon final : OpRewritePattern { return failure(); } - using mlir::stablehlo::ComparisonDirection; ComparisonDirection direction = op.getComparisonDirection(); Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -154,23 +146,22 @@ struct CompareOpCanon final : OpRewritePattern { case ComparisonDirection::EQ: case ComparisonDirection::GE: case ComparisonDirection::LE: { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true))); return success(); } case ComparisonDirection::GT: case ComparisonDirection::LT: case ComparisonDirection::NE: { - rewriter.replaceOpWithNewOp( - op, rewriter.getZeroAttr(type)); + rewriter.replaceOpWithNewOp(op, + rewriter.getZeroAttr(type)); return success(); } } llvm_unreachable("Unhandled case"); } - // Pattern: compare(cst, X, comparator) -> compare(X, cst, - // inverse(comparator)) + // Pattern: compare(cst, X, comparator) -> compare(X, cst, inv(comparator)) TypedAttr lhsAttr, rhsAttr; matchPattern(lhs, m_Constant(&lhsAttr)); matchPattern(rhs, m_Constant(&rhsAttr)); @@ -189,106 +180,83 @@ struct CompareOpCanon final : OpRewritePattern { }; ////////////////////////////////// -// SelectOp +// ConcatenateOp ///////////////////////////////// -struct SelectOpCanon final : OpRewritePattern { +// Pattern: concatenate(X) -> X +class ConcatenateOpNoop : public OpRewritePattern { + public: using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, + LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter &rewriter) const override { - RankedTensorType type = op.getType(); + if (op.getInputs().size() != 1 || + op.getInputs().front().getType() != op.getType()) + return rewriter.notifyMatchFailure(op, "not single operand noop-concat"); - Value trueVal = op.getOnTrue(); - Value falseVal = op.getOnFalse(); - - // Eliminate select with two identical outcomes. - if (trueVal == falseVal) { - rewriter.replaceOp(op, trueVal); - return success(); - } - - // Simplify when the condition is a constant. - Value pred = op.getPred(); - ElementsAttr cond; - if (!matchPattern(pred, m_Constant(&cond))) return failure(); + rewriter.replaceOp(op, op.getInputs().front()); + return success(); + } +}; - // Handle splat predicate and select either `trueVal` or `falseVal`. - if (cond.isSplat()) { - rewriter.replaceOp(op, cond.getSplatValue() ? trueVal : falseVal); +// Pattern: concatenate(X, Y, []) -> concatenate(X, Y) +class ConcatenateOpRemoveEmpty : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter &rewriter) const override { + auto axis = op.getDimension(); + llvm::SmallVector newOperands = llvm::to_vector( + llvm::make_filter_range(op.getOperands(), [&](Value operand) { + return cast(operand.getType()).getDimSize(axis) != 0; + })); + + // Only handle nonempty new operands, empty handled by + // ZeroExtentToEmptyConstant pattern. + if (!newOperands.empty() && newOperands.size() < op.getNumOperands()) { + rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); }); return success(); } - // Handle elementwise selection when both outcomes are also constants. This - // will create a new, likely non-splat constant. - if (cond.getNumElements() > kFoldOpEltLimit) return failure(); - - ElementsAttr trueAttr; - if (!matchPattern(trueVal, m_Constant(&trueAttr))) return failure(); - - ElementsAttr falseAttr; - if (!matchPattern(falseVal, m_Constant(&falseAttr))) return failure(); - - SmallVector newValues; - newValues.reserve(cond.getNumElements()); - for (auto [condElem, trueElem, falseElem] : llvm::zip_equal( - cond.getValues(), trueAttr.getValues(), - falseAttr.getValues())) { - newValues.push_back(condElem ? trueElem : falseElem); - } - - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(type, newValues)); - return success(); + return failure(); } }; -struct CompareSelectIntoMinMax final - : OpRewritePattern { +// Pattern: concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z) +class ConcatenateOpFlatten : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, + LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter &rewriter) const override { - Value pred = op.getPred(); - Value trueVal = op.getOnTrue(); - Value falseVal = op.getOnFalse(); + auto getFlattenedOperands = [&](const Value &val) -> ValueRange { + auto definingOp = dyn_cast_or_null(val.getDefiningOp()); + // To avoid inflate the memory footprint, only flatten the + // ConcatenateOp when it has only one use. + if (definingOp && definingOp->hasOneUse() && + definingOp.getDimension() == op.getDimension()) + return definingOp.getInputs(); + return val; + }; - auto cmpOp = pred.getDefiningOp(); - if (!cmpOp) return failure(); + bool needToFlatten = false; + int operandCount = 0; + llvm::for_each(op.getInputs(), [&](Value val) { + auto result = getFlattenedOperands(val); + if (result.size() != 1 || result[0] != val) needToFlatten = true; + operandCount += result.size(); + }); - using mlir::stablehlo::ComparisonDirection; - ComparisonDirection direction = cmpOp.getComparisonDirection(); - Value cmpLhs = cmpOp.getLhs(); - Value cmpRhs = cmpOp.getRhs(); + if (!needToFlatten) + return rewriter.notifyMatchFailure(op, "no need to flatten"); - // Turn into canonical form: - // b <= a ? a : b ---> a >= b ? a : b - // b < a ? a : b ---> a > b ? a : b - // b >= a ? a : b ---> a <= b ? a : b - // b > a ? a : b ---> a < b ? a : b - if (cmpLhs == falseVal && cmpRhs == trueVal) { - direction = invertDirection(direction); - } else if (!(cmpLhs == trueVal && cmpRhs == falseVal)) { - return failure(); - } + llvm::SmallVector newOperands; + newOperands.reserve(operandCount); - switch (direction) { - case ComparisonDirection::GE: - case ComparisonDirection::GT: { - rewriter.replaceOpWithNewOp(op, trueVal, - falseVal); - return success(); - } - case ComparisonDirection::LE: - case ComparisonDirection::LT: { - rewriter.replaceOpWithNewOp(op, trueVal, - falseVal); - return success(); - } - default: { - return failure(); - } + for (auto operand : op.getInputs()) { + auto flattenedOperands = getFlattenedOperands(operand); + newOperands.append(flattenedOperands.begin(), flattenedOperands.end()); } + + rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); }); + return success(); } }; @@ -333,7 +301,7 @@ static OpTy refineOpWithNewOp(PatternRewriter &rewriter, Operation *op, if (llvm::any_of(opResult.getUsers(), [&](Operation *user) { return user->getDialect() != op->getDialect(); })) - replacementResult = rewriter.create( + replacementResult = rewriter.create( op->getLoc(), opResult.getType(), newOpResult); replacementResults.push_back(replacementResult); } @@ -345,10 +313,10 @@ static OpTy refineOpWithNewOp(PatternRewriter &rewriter, Operation *op, /// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary /// BroadcastInDimOp. struct DynamicBroadcastInDimOpNotActuallyDynamic final - : OpRewritePattern { + : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, PatternRewriter &rewriter) const override { RankedTensorType operandType = op.getOperand().getType(); if (!operandType.hasStaticShape()) @@ -357,7 +325,7 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final RankedTensorType type = op.getType(); // output has static shape, replace with broadcast_in_dim if (type.hasStaticShape()) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, type, op.getOperand(), op.getBroadcastDimensionsAttr()); return success(); } @@ -366,7 +334,7 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final // then replace with broadcast_in_dim if (llvm::SmallVector shape; succeeded(hlo::matchInts(op.getOutputDimensions(), shape))) { - refineOpWithNewOp( + refineOpWithNewOp( rewriter, op, RankedTensorType::get(shape, type.getElementType()), op.getOperand(), op.getBroadcastDimensionsAttr()); return success(); @@ -376,22 +344,258 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final } }; +////////////////////////////////// +// DynamicGatherOp +///////////////////////////////// + +DenseI64ArrayAttr convertToI64Array(OpBuilder &b, Attribute attr) { + auto denseAttr = cast(attr); + SmallVector result; + result.reserve(denseAttr.getNumElements()); + for (auto elem : denseAttr.getValues()) + result.push_back(elem.getSExtValue()); + return b.getDenseI64ArrayAttr(result); +} + +////////////////////////////////// +// DynamicIotaOp +///////////////////////////////// + +struct DynamicIotaIsStatic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter &rewriter) const override { + // Result type has static shape, replace with iota. + auto resultTy = cast(iota.getType()); + if (!resultTy.hasStaticShape()) + return rewriter.notifyMatchFailure(iota, "requires output static shape"); + rewriter.replaceOpWithNewOp(iota, resultTy, + iota.getIotaDimension()); + return success(); + } +}; + +// Dynamic Iota operations across multiple dimensions can be reduced to an iota +// and a ranked broadcast. +// Pattern: dynamic_iota(shape, dim) -> +// dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape) +struct DynamicIotaOpToBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter &rewriter) const override { + auto resultTy = cast(iota.getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(iota, "requires rank >= 2"); + + auto iotaDimension = static_cast(iota.getIotaDimension()); + + // Handle case where iota dimension is index, need to convert to/from i64 + // to interop with slice. These canonicalize away if input is i64. + auto convertedShape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + cast(iota.getOutputShape().getType()).getShape(), + rewriter.getI64Type()), + iota.getOutputShape()); + + auto slicedShape = rewriter.create( + iota.getLoc(), convertedShape, + rewriter.getDenseI64ArrayAttr(iotaDimension), + rewriter.getDenseI64ArrayAttr(iotaDimension + 1), + rewriter.getDenseI64ArrayAttr(1)); + + auto convertedSlicedShape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + {1}, + cast(iota.getOutputShape().getType()).getElementType()), + slicedShape); + + auto iotaType = RankedTensorType::get({resultTy.getDimSize(iotaDimension)}, + resultTy.getElementType()); + + auto newIota = rewriter.create( + iota.getLoc(), iotaType, convertedSlicedShape, + rewriter.getI64IntegerAttr(0)); + + rewriter.replaceOpWithNewOp( + iota, resultTy, newIota, iota.getOutputShape(), + rewriter.getDenseI64ArrayAttr(iotaDimension)); + return success(); + } +}; + ////////////////////////////////// // DynamicReshapeOp ///////////////////////////////// -struct DynamicReshapeOpCanon final - : OpRewritePattern { +struct DynamicReshapeOpIsStatic final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::DynamicReshapeOp op, + LogicalResult matchAndRewrite(DynamicReshapeOp op, PatternRewriter &rewriter) const override { // This is a noop when the output type is already a static shape. RankedTensorType type = op.getType(); - if (!type.hasStaticShape()) return failure(); + if (!type.hasStaticShape()) + return rewriter.notifyMatchFailure(op, "dynamic reshape not static"); + + rewriter.replaceOpWithNewOp(op, type, op.getOperand()); + return success(); + } +}; + +// Pattern: dynamic_reshape(op(dynamic_reshape(X, shape)), shape) +// -> op(dynamic_reshape(X, shape)) +// [if op has same operand and result shape] +class DynamicReshapeOpSameOperandAndResultShape + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter &rewriter) const override { + Operation *defOp = op.getOperand().getDefiningOp(); + if (!defOp || + !defOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "dynamic reshape parent not same operand and result shape"); + } + DynamicReshapeOp reshape = + defOp->getOperand(0).getDefiningOp(); + if (!reshape) + return rewriter.notifyMatchFailure( + op, "dynamic reshape not wrapping same operand and result shape"); + if (reshape.getOutputShape() == op.getOutputShape()) { + rewriter.replaceOp(op, {defOp->getResult(0)}); + return success(); + } + return failure(); + } +}; + +////////////////////////////////// +// DynamicSliceOp +///////////////////////////////// + +// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops. +// This canonicalization is applied the case when the `begin` input values are +// compile time constants and thus can be made into a tensor. +// +// Pattern: dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes) +struct DynamicSliceOpToSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicSliceOp dynamicSlice, + PatternRewriter &rewriter) const override { + Value input = dynamicSlice.getOperand(); + auto inputType = cast(input.getType()); + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(dynamicSlice, + "dynamic slice input not static"); + + auto sliceSizes = dynamicSlice.getSliceSizes(); + SmallVector tempStartIndices; + for (const auto &indexAndSliceStart : + llvm::enumerate(dynamicSlice.getStartIndices())) { + APInt val; + Value start = indexAndSliceStart.value(); + int64_t index = indexAndSliceStart.index(); + if (!matchPattern(start, m_ConstantInt(&val))) + return rewriter.notifyMatchFailure(dynamicSlice, + "dynamic slice input not constant"); + + // Clamp the indices within bounds to faithfully mirror dynamic slice + // semantics. + int64_t clampedStart = + std::clamp(val.getSExtValue(), static_cast(0), + inputType.getDimSize(index) - sliceSizes[index]); + tempStartIndices.push_back(clampedStart); + } + + // At this point we've determined that the start indices are all constants; + // pack them into a single tensor. + auto sliceStartIndices = rewriter.getDenseI64ArrayAttr(tempStartIndices); + SmallVector tempSliceLimits; + for (const auto &[start, size] : llvm::zip(tempStartIndices, sliceSizes)) { + tempSliceLimits.push_back(start + size); + } + auto sliceLimits = rewriter.getDenseI64ArrayAttr(tempSliceLimits); + + auto sliceStrides = rewriter.getDenseI64ArrayAttr( + SmallVector(inputType.getRank(), 1)); - rewriter.replaceOpWithNewOp(op, type, - op.getOperand()); + rewriter.replaceOpWithNewOp(dynamicSlice, input, sliceStartIndices, + sliceLimits, sliceStrides); + return success(); + } +}; + +////////////////////////////////// +// RealDynamicSliceOp +///////////////////////////////// + +// Pattern: real_dynamic_slice(X, start, limit, strides) +// -> dynamic_slice(X, start, limit, strides) +// [if strides, start are constants, limit = start + constant] +struct RealDynamicSliceOpToDynamicSlice + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RealDynamicSliceOp op, + PatternRewriter &rewriter) const override { + // This rewrite only works for unit strides because DynamicSliceOp + // doesn't support strides (i.e. it implicitly has unit strides). + DenseIntElementsAttr stridesAttr; + if (!matchPattern(op.getStrides(), m_Constant(&stridesAttr))) + return rewriter.notifyMatchFailure(op, "requires constant strides"); + if (!llvm::all_of(stridesAttr.getValues(), + [&](APInt stride) { return stride == 1; })) + return rewriter.notifyMatchFailure(op, "requires unit strides"); + + // Check that slice sizes are fully static (DynamicSliceOp style). + // To detect that, we check whether `limit_indices` is defined as + // `start_indices + constant` or `constant + start_indices`. + DenseIntElementsAttr sliceSizesAttr; + auto m_startIndices = matchers::m_Val(op.getStartIndices()); + // Only handle the AddOp case, if all constant we fold to SliceOp. + if (!matchPattern( + op.getLimitIndices(), + m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) && + !matchPattern(op.getLimitIndices(), + m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) + return rewriter.notifyMatchFailure( + op, "requires limit indices equal to start indices plus constant"); + + // RealDynamicSliceOp can take tensors of integer or index element types. + // DynamicSliceOp::slice_sizes only supports i64 element type. + // Adapt accordingly in order to be compatible with DynamicSliceOp. + SmallVector sliceSizes; + for (auto element : sliceSizesAttr.getValues()) { + sliceSizes.push_back(element.getSExtValue()); + } + + // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. + // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. + // Adapt accordingly in order to be compatible with DynamicSliceOp. + SmallVector startIndices; + for (auto i = 0; i < static_cast(sliceSizes.size()); ++i) { + auto startIndex1D = rewriter.create( + op.getLoc(), op.getStartIndices(), rewriter.getDenseI64ArrayAttr(i), + rewriter.getDenseI64ArrayAttr(i + 1), + rewriter.getDenseI64ArrayAttr(1)); + auto startIndex0DType = RankedTensorType::get( + {}, + cast(op.getStartIndices().getType()).getElementType()); + auto startIndex0D = rewriter.create( + op.getLoc(), startIndex0DType, startIndex1D); + startIndices.push_back(startIndex0D); + } + + rewriter.replaceOpWithNewOp( + op, op.getOperand(), startIndices, + rewriter.getDenseI64ArrayAttr(sliceSizes)); return success(); } }; @@ -401,16 +605,14 @@ struct DynamicReshapeOpCanon final ///////////////////////////////// // Pattern: reduce[A](_, _, fn:return A) -> A... -struct ReduceNoopVariableReturn final - : OpRewritePattern { +struct ReduceOpNoopVariableReturn final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, + LogicalResult matchAndRewrite(ReduceOp op, PatternRewriter &rewriter) const override { // If all returned values in the ReduceOp region exists outside the // region, replace the ReduceOp with those values. - if (auto retOp = dyn_cast( - op.getBody().front().getTerminator())) { + if (auto retOp = dyn_cast(op.getBody().front().getTerminator())) { Region *retRegion = retOp->getParentRegion(); if (llvm::any_of(retOp.getResults(), [retRegion](Value result) { return result.getParentRegion() == retRegion; @@ -426,10 +628,10 @@ struct ReduceNoopVariableReturn final }; // Pattern: reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] -struct EmptyReduceOpCanon final : OpRewritePattern { +struct ReduceOpEmptyCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, + LogicalResult matchAndRewrite(ReduceOp op, PatternRewriter &rewriter) const override { // We require all reduce shapes to be the same, up to the element types, so // we can just use the first operand and the first result as @@ -444,8 +646,7 @@ struct EmptyReduceOpCanon final : OpRewritePattern { SmallVector broadcasts(op.getNumResults()); for (auto [bcast, init, outTy] : llvm::zip_equal( broadcasts, op.getInitValues(), op.getResultTypes())) { - bcast = rewriter.create(loc, outTy, - init, empty); + bcast = rewriter.create(loc, outTy, init, empty); } rewriter.replaceOp(op, broadcasts); return success(); @@ -458,8 +659,8 @@ struct EmptyReduceOpCanon final : OpRewritePattern { SmallVector broadcasts(op.getNumResults()); for (auto [bcast, init, shape, outTy] : llvm::zip_equal( broadcasts, op.getInitValues(), shapes, op.getResultTypes())) { - bcast = rewriter.create( - loc, outTy, init, shape, empty); + bcast = rewriter.create(loc, outTy, init, shape, + empty); } rewriter.replaceOp(op, broadcasts); return success(); @@ -467,11 +668,10 @@ struct EmptyReduceOpCanon final : OpRewritePattern { }; // Pattern: reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] -struct UnusedResultReduceOpCanon final - : OpRewritePattern { +struct ReduceOpUnusedResultCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, + LogicalResult matchAndRewrite(ReduceOp op, PatternRewriter &rewriter) const override { SmallVector usedResults; llvm::copy_if(op.getResults(), std::back_inserter(usedResults), @@ -485,7 +685,7 @@ struct UnusedResultReduceOpCanon final const auto numOperandPairs = numOperands / pairSize; Block &reducerBlock = op.getBody().front(); - auto retOp = cast(reducerBlock.getTerminator()); + auto retOp = cast(reducerBlock.getTerminator()); assert(numOperandPairs == op.getNumResults() && numOperandPairs == retOp.getNumOperands()); @@ -564,8 +764,7 @@ struct UnusedResultReduceOpCanon final if (usedReturnOperands[en.index()]) newReturnOperands.push_back(mapper.lookup(en.value())); - rewriter.create(retOp.getLoc(), - newReturnOperands); + rewriter.create(retOp.getLoc(), newReturnOperands); // Build new results list (unused entries will be null). SmallVector newResults(op.getNumResults()); @@ -585,11 +784,10 @@ struct UnusedResultReduceOpCanon final // TODO: This is duplicated with a pattern in shape refinement, consider // consolidating. // Pattern: get_dimension_size(X, i) -> X.shape[i] -struct GetDimensionSizeOpCanon final - : OpRewritePattern { +struct GetDimensionSizeOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::GetDimensionSizeOp op, + LogicalResult matchAndRewrite(GetDimensionSizeOp op, PatternRewriter &rewriter) const override { // Fold get_dimension_size when the queried dim is statically known. RankedTensorType operandTy = op.getOperand().getType(); @@ -599,7 +797,7 @@ struct GetDimensionSizeOpCanon final auto elemTy = cast(op.getType().getElementType()); IntegerAttr elemVal = rewriter.getIntegerAttr(elemTy, dimSize); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, DenseElementsAttr::get(op.getType(), elemVal)); return success(); } @@ -612,17 +810,16 @@ struct GetDimensionSizeOpCanon final /// Converts gather ops to slice ops in case we have a single set of constant /// indices. // Pattern: gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) -struct GatherOpCanon final : OpRewritePattern { +struct GatherOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::GatherOp gather, + LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { DenseIntElementsAttr index; if (!matchPattern(gather.getStartIndices(), m_Constant(&index))) return failure(); - mlir::stablehlo::GatherDimensionNumbersAttr dnums = - gather.getDimensionNumbers(); + GatherDimensionNumbersAttr dnums = gather.getDimensionNumbers(); if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1) return failure(); @@ -657,7 +854,7 @@ struct GatherOpCanon final : OpRewritePattern { Type elementType = gather.getType().getElementType(); auto sliceType = RankedTensorType::get(sliceShape, elementType); - Value result = rewriter.create( + Value result = rewriter.create( gather.getLoc(), sliceType, gather.getOperand(), rewriter.getDenseI64ArrayAttr(sliceStart), rewriter.getDenseI64ArrayAttr(sliceEnd), @@ -671,8 +868,7 @@ struct GatherOpCanon final : OpRewritePattern { reshapeShape.push_back(dim); } auto reshapeType = RankedTensorType::get(reshapeShape, elementType); - result = rewriter.create(gather.getLoc(), - reshapeType, result); + result = rewriter.create(gather.getLoc(), reshapeType, result); } result.setType(gather.getType()); @@ -681,16 +877,366 @@ struct GatherOpCanon final : OpRewritePattern { } }; +////////////////////////////////// +// IotaOp +///////////////////////////////// + +// Iota operations across multiple dimensions can be reduced to an iota and a +// ranked broadcast. +// Pattern: iota(dim) : multi_rank +// -> broadcast_in_dim(iota(dim) : array, multi_rank) +struct IotaOpBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IotaOp iota, + PatternRewriter &rewriter) const override { + auto resultTy = cast(iota.getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(iota, "itoa not broadcastable"); + + auto iotaDim = iota.getIotaDimension(); + auto iotaDimSize = resultTy.getDimSize(iotaDim); + auto iota1D = rewriter.create( + iota.getLoc(), + RankedTensorType::get({iotaDimSize}, resultTy.getElementType()), + rewriter.getI64IntegerAttr(0)); + + auto broadcastAttr = + rewriter.getDenseI64ArrayAttr({static_cast(iotaDim)}); + rewriter.replaceOpWithNewOp(iota, resultTy, iota1D, + broadcastAttr); + return success(); + } +}; + +////////////////////////////////// +// PadOp +///////////////////////////////// + +// If the input tensor has a dimension of length-0, the input tensor is +// irrelevant. Instead we can broadcast the pad value to the output size rather +// than pad the input tensor. + +// If the input tensor has a dimension of length-0, the input tensor is +// irrelevant. Instead we can broadcast the pad value to the output size rather +// than pad the input tensor. + +// Pattern: pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _) +struct PadOpBroadcastEmptyTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp op, + PatternRewriter &rewriter) const override { + auto operand = op.getOperand(); + auto padVal = op.getPaddingValue(); + + auto resultTy = cast(op.getType()); + + if (cast(operand.getType()).getNumElements() != 0) + return rewriter.notifyMatchFailure(op, "operand is not empty tensor"); + + if (resultTy.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, resultTy, padVal, rewriter.getDenseI64ArrayAttr({})); + return success(); + } + + llvm::SmallVector reifiedShapes; + if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), + reifiedShapes))) + return rewriter.notifyMatchFailure(op, "failed to reify return type"); + + rewriter.replaceOpWithNewOp( + op, op.getType(), padVal, reifiedShapes.front(), + rewriter.getDenseI64ArrayAttr({})); + return success(); + } +}; + +////////////////////////////////// +// SelectOp +///////////////////////////////// + +struct SelectOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const override { + RankedTensorType type = op.getType(); + + Value trueVal = op.getOnTrue(); + Value falseVal = op.getOnFalse(); + + // Eliminate select with two identical outcomes. + if (trueVal == falseVal) { + rewriter.replaceOp(op, trueVal); + return success(); + } + + // Simplify when the condition is a constant. + Value pred = op.getPred(); + ElementsAttr cond; + if (!matchPattern(pred, m_Constant(&cond))) return failure(); + + // Handle splat predicate and select either `trueVal` or `falseVal`. + if (cond.isSplat()) { + rewriter.replaceOp(op, cond.getSplatValue() ? trueVal : falseVal); + return success(); + } + + // Handle elementwise selection when both outcomes are also constants. This + // will create a new, likely non-splat constant. + if (cond.getNumElements() > kFoldOpEltLimit) return failure(); + + ElementsAttr trueAttr; + if (!matchPattern(trueVal, m_Constant(&trueAttr))) return failure(); + + ElementsAttr falseAttr; + if (!matchPattern(falseVal, m_Constant(&falseAttr))) return failure(); + + SmallVector newValues; + newValues.reserve(cond.getNumElements()); + for (auto [condElem, trueElem, falseElem] : llvm::zip_equal( + cond.getValues(), trueAttr.getValues(), + falseAttr.getValues())) { + newValues.push_back(condElem ? trueElem : falseElem); + } + + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(type, newValues)); + return success(); + } +}; + +struct CompareSelectIntoMinMax final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const override { + Value pred = op.getPred(); + Value trueVal = op.getOnTrue(); + Value falseVal = op.getOnFalse(); + + auto cmpOp = pred.getDefiningOp(); + if (!cmpOp) return failure(); + + ComparisonDirection direction = cmpOp.getComparisonDirection(); + Value cmpLhs = cmpOp.getLhs(); + Value cmpRhs = cmpOp.getRhs(); + + // Turn into canonical form: + // b <= a ? a : b ---> a >= b ? a : b + // b < a ? a : b ---> a > b ? a : b + // b >= a ? a : b ---> a <= b ? a : b + // b > a ? a : b ---> a < b ? a : b + if (cmpLhs == falseVal && cmpRhs == trueVal) { + direction = invertDirection(direction); + } else if (!(cmpLhs == trueVal && cmpRhs == falseVal)) { + return failure(); + } + + switch (direction) { + case ComparisonDirection::GE: + case ComparisonDirection::GT: { + rewriter.replaceOpWithNewOp(op, trueVal, falseVal); + return success(); + } + case ComparisonDirection::LE: + case ComparisonDirection::LT: { + rewriter.replaceOpWithNewOp(op, trueVal, falseVal); + return success(); + } + default: { + return failure(); + } + } + } +}; + +////////////////////////////////// +// SliceOp +///////////////////////////////// + +// In cases where a concat is fed into a slice, it is possible the concat +// can be simplified or bypassed. This checks which inputs to the concat are +// used by the slice, either reducing the number of concatenated values or +// entirely removes the concat. +// Pattern: slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) +struct SliceOpConcatSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SliceOp slice, + PatternRewriter &rewriter) const override { + auto resultTy = cast(slice.getType()); + if (!resultTy.hasStaticShape()) + return rewriter.notifyMatchFailure(slice, "result shape not static"); + + auto concat = slice.getOperand().getDefiningOp(); + if (!concat) + return rewriter.notifyMatchFailure(slice, "slice input not concat"); + + auto concatType = cast(concat.getType()); + auto dimension = concat.getDimension(); + + auto start = slice.getStartIndices(); + auto limit = slice.getLimitIndices(); + + int64_t sliceStart = start[dimension]; + int64_t sliceLimit = limit[dimension]; + + // We need to determine what inputs from the concat affect the slice, and + // how the bounds of the slice need to be updated for the minimally required + // inputs. + int64_t runningSize = 0; + int64_t frontOffset = concatType.getShape()[dimension]; + + auto subsetStart = concat.operand_end(); + auto subsetEnd = concat.operand_end(); + for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) { + auto input = *it; + ShapedType inputTy = cast(input.getType()); + if (inputTy.isDynamicDim(dimension)) + return rewriter.notifyMatchFailure( + slice, "concat input has dynamic dimension"); + + auto dimSize = inputTy.getShape()[dimension]; + + // If this position is in the slice its the start of the subset and we + // need to update the start and limit values. + if (runningSize + dimSize > sliceStart && + subsetStart == concat.operand_end()) { + subsetStart = it; + frontOffset = runningSize; + } + + // Determine the last required offset. + if (runningSize < sliceLimit) { + subsetEnd = it + 1; + } + + runningSize += dimSize; + } + + auto subsetSize = subsetEnd - subsetStart; + // We need all inputs so no optimization. + if (subsetSize == concat.getNumOperands()) + return rewriter.notifyMatchFailure(slice, + "slice needs all concat inputs"); + + // If there's nothing to slice that means the output is an empty tensor and + // there is dead code. We do nothing here and rely on other passes to clean + // this up. + if (subsetSize == 0) + return rewriter.notifyMatchFailure(slice, "slice is empty"); + + if (subsetSize > 1 && !concat.getResult().hasOneUse()) + return rewriter.notifyMatchFailure(slice, + "slice is not the only concat user"); + + auto concatRange = OperandRange(subsetStart, subsetEnd); + auto newConcat = rewriter.create( + concat.getLoc(), concatRange, concat.getDimension()); + + SmallVector newStart(start); + SmallVector newLimit(limit); + newStart[dimension] -= frontOffset; + newLimit[dimension] -= frontOffset; + + rewriter.replaceOpWithNewOp( + slice, newConcat, rewriter.getDenseI64ArrayAttr(newStart), + rewriter.getDenseI64ArrayAttr(newLimit), slice.getStrides()); + return success(); + } +}; + +////////////////////////////////// +// SortOp +///////////////////////////////// + +/// Drops the operands if the results are not used and they are not used in +/// op.comparator(). + +// Pattern: sort(X,Y) -> sort(X) [if Y unused and unused in comparator] +struct SortOpDropUnusedArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter &rewriter) const override { + DenseSet erasedArgs; + unsigned numOperands = op.getNumOperands(); + for (unsigned i = 0; i < numOperands; ++i) { + if (!op.getResult(i).use_empty()) continue; + Block &block = op.getComparator().front(); + if (!block.getArgument(i * 2).use_empty()) continue; + if (!block.getArgument(i * 2 + 1).use_empty()) continue; + erasedArgs.insert(i); + } + if (erasedArgs.empty()) return failure(); + + SmallVector newOperands; + BitVector erasedBlockArgs(op.getNumOperands() * 2); + for (const auto &en : llvm::enumerate(op.getInputs())) { + if (erasedArgs.contains(en.index())) { + erasedBlockArgs.set(en.index() * 2); + erasedBlockArgs.set(en.index() * 2 + 1); + } else { + newOperands.push_back(en.value()); + } + } + + auto newOp = rewriter.create(op.getLoc(), newOperands, + op.getDimension(), op.getIsStable()); + Region ®ion = newOp.getComparator(); + rewriter.inlineRegionBefore(op.getComparator(), region, region.end()); + region.front().eraseArguments(erasedBlockArgs); + + SmallVector results; + for (unsigned i = 0, j = 0; i < numOperands; ++i) { + if (erasedArgs.contains(i)) { + results.push_back({}); + } else { + results.push_back(newOp.getResult(j++)); + } + } + rewriter.replaceOp(op, results); + + return success(); + } +}; + +/// Set the sorting dimension to the last dimension if it's not set and the rank +/// is known. +// Pattern: sort(X) -> sort(X, dim = N) [when dim can be inferred] +struct SortOpSetDimension : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter &rewriter) const override { + if (op.getResults().size() == 0 || + static_cast(op.getDimension()) != -1) + return rewriter.notifyMatchFailure(op, + "dimension already set or no results"); + + auto type = cast(op.getResultTypes()[0]); + IntegerAttr dim = rewriter.getI64IntegerAttr(type.getRank() - 1); + auto newOp = + rewriter.create(op.getLoc(), op.getResultTypes(), + op.getInputs(), dim, op.getIsStableAttr()); + newOp.getComparator().takeBody(op.getComparator()); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + ////////////////////////////////// // TransposeOp ///////////////////////////////// // Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X) -struct TransposeIsReshape final - : OpRewritePattern { +struct TransposeIsReshape final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, + LogicalResult matchAndRewrite(TransposeOp op, PatternRewriter &rewriter) const override { auto input = op.getOperand(); auto permutation = op.getPermutation(); @@ -713,8 +1259,115 @@ struct TransposeIsReshape final if (nonZeroPerms[i - 1] > nonZeroPerms[i]) return rewriter.notifyMatchFailure(op, "memory layout change"); - rewriter.replaceOpWithNewOp(op, op.getType(), - input); + rewriter.replaceOpWithNewOp(op, op.getType(), input); + return success(); + } +}; + +////////////////////////////////// +// TupleOp +///////////////////////////////// + +// Pattern: tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X +struct TupleIsRepacking : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TupleOp op, + PatternRewriter &rewriter) const override { + if (op.getVal().empty()) + return rewriter.notifyMatchFailure(op, "empty tuple"); + + // Get parent tuple + Value firstElement = op.getVal().front(); + auto firstElementOp = firstElement.getDefiningOp(); + if (!firstElementOp) + return rewriter.notifyMatchFailure(op, "parent not get_tuple_element"); + + Value tuplePredecessor = firstElementOp.getOperand(); + if (tuplePredecessor.getType() != op.getType()) + return rewriter.notifyMatchFailure( + op, "tuple predecessor type does not match"); + + // Check that this is a repacking of the parent tuple. + for (const auto &elementAndIdx : llvm::enumerate(op.getVal())) { + auto elementOp = elementAndIdx.value().getDefiningOp(); + if (!elementOp || + elementOp.getIndexAttr().getInt() != + static_cast(elementAndIdx.index()) || + elementOp.getOperand() != tuplePredecessor) + return rewriter.notifyMatchFailure( + op, "not a repacking of the parent tuple"); + } + + rewriter.replaceOp(op, tuplePredecessor); + return success(); + } +}; + +///////////////////////////////// +// WhileOp +///////////////////////////////// + +// Turn loop invariant values into implicit capture. +// Check if there is at least one value is forwarded from one iteration to +// the next, or one of the yielded value is an implicit capture already. +// Otherwise there is nothing to do here. + +// Pattern: while -> while (loop invariants as implicit captures) +struct WhileOpImplicitCapture : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const override { + Block *cond = whileOp.SingleBlock::getBody(0); + Block *body = whileOp.SingleBlock::getBody(1); + auto bodyReturnOp = cast(body->getTerminator()); + if (!llvm::any_of(llvm::zip(whileOp->getOperands(), body->getArguments(), + bodyReturnOp->getOperands()), + [&](auto zip) { + return (std::get<0>(zip) == std::get<2>(zip) || + std::get<1>(zip) == std::get<2>(zip)); + })) + return rewriter.notifyMatchFailure(whileOp, "no loop invariant found"); + + SmallVector newOperands, resultsToReplace; + SmallVector invariantArgIdxs; + BitVector invariantArgIdxBitVector(cond->getNumArguments()); + for (const auto &enumeratedOperands : llvm::enumerate(llvm::zip( + whileOp.getOperands(), cond->getArguments(), body->getArguments(), + bodyReturnOp->getOperands(), whileOp->getResults()))) { + const auto &operands = enumeratedOperands.value(); + Value whileOperand = std::get<0>(operands); + BlockArgument condBlockArg = std::get<1>(operands); + BlockArgument bodyBlockArg = std::get<2>(operands); + Value bodyReturnOperand = std::get<3>(operands); + Value whileResult = std::get<4>(operands); + + bool forwarded = (whileOperand == bodyReturnOperand || + bodyBlockArg == bodyReturnOperand); + if (forwarded) { + invariantArgIdxs.push_back(enumeratedOperands.index()); + invariantArgIdxBitVector.set(enumeratedOperands.index()); + condBlockArg.replaceAllUsesWith(whileOperand); + bodyBlockArg.replaceAllUsesWith(whileOperand); + whileResult.replaceAllUsesWith(whileOperand); + continue; + } + newOperands.push_back(whileOperand); + resultsToReplace.push_back(whileResult); + } + cond->eraseArguments(invariantArgIdxBitVector); + body->eraseArguments(invariantArgIdxBitVector); + for (int idx : llvm::reverse(invariantArgIdxs)) + bodyReturnOp->eraseOperand(idx); + + WhileOp newWhileOp = rewriter.create( + whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands); + newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0)); + newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1)); + for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults())) + std::get<0>(results).replaceAllUsesWith(std::get<1>(results)); + rewriter.eraseOp(whileOp); return success(); } }; @@ -724,48 +1377,52 @@ struct TransposeIsReshape final ///////////////////////////////// /// Check if a `t` is a tensor with zero extents. -static std::optional isZeroExtent(Type t) { +static std::optional getMaybeZeroExtentType(Type t) { auto type = dyn_cast(t); if (type && type.hasStaticShape() && type.getNumElements() == 0) return type; return std::nullopt; } // Replace instances of zero extent tensors with empty tensors -// Pattern: op(X : zero_extent_tensor) -> tensor.empty() -struct ZeroExtentTensorCanon final : RewritePattern { - ZeroExtentTensorCanon(MLIRContext *context, PatternBenefit benefit) +// Pattern: op(X : zero_extent_tensor) -> constant([]) +struct ZeroExtentToEmptyConstant final : RewritePattern { + ZeroExtentToEmptyConstant(MLIRContext *context, PatternBenefit benefit) : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); - if (!isa_and_present(op->getDialect())) + if (!isa_and_present(op->getDialect())) return rewriter.notifyMatchFailure(op, "not stablehlo"); + if (isa(op)) + return rewriter.notifyMatchFailure(op, "op is empty constant"); // If the result is a zero-extent tensor, replace the whole op with an empty - // tensor. + // constant. bool didUpdate = false; for (auto result : op->getResults()) { - auto resultType = isZeroExtent(result.getType()); + auto resultType = getMaybeZeroExtentType(result.getType()); if (!resultType || result.use_empty()) continue; - rewriter.replaceAllUsesWith(result, rewriter.create( - loc, resultType->getShape(), - resultType->getElementType())); + rewriter.replaceAllUsesWith( + result, rewriter.create( + loc, result.getType(), + DenseElementsAttr::get(resultType.value(), + ArrayRef()))); didUpdate = true; } // If one of the operands is a zero-extent tensor, replace the operand with // an empty tensor. for (OpOperand &operand : op->getOpOperands()) { - auto operandType = isZeroExtent(operand.get().getType()); - if (!operandType || operand.get().getDefiningOp()) - continue; + auto operandType = getMaybeZeroExtentType(operand.get().getType()); + if (!operandType || operand.get().getDefiningOp()) continue; Operation *owner = operand.getOwner(); int operandNum = operand.getOperandNumber(); - auto emptyTensorOp = rewriter.create( - loc, operandType->getShape(), operandType->getElementType()); + auto emptyConstantOp = rewriter.create( + loc, operandType.value(), + DenseElementsAttr::get(operandType.value(), ArrayRef())); rewriter.modifyOpInPlace( - owner, [&]() { owner->setOperand(operandNum, emptyTensorOp); }); + owner, [&]() { owner->setOperand(operandNum, emptyConstantOp); }); didUpdate = true; } return success(didUpdate); @@ -786,9 +1443,7 @@ struct ReorderElementwiseAndShapeOp final return rewriter.notifyMatchFailure( op, "expected to have an op before elementise op"); - if (!isa(definingOp) && - !isa(definingOp) && - !isa(definingOp)) + if (!isa(definingOp)) return rewriter.notifyMatchFailure( op, "defining operation of unexpected type"); @@ -838,19 +1493,26 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context, RewritePatternSet *patterns, PatternBenefit benefit) { populateWithGenerated(*patterns); - patterns->add< - // Arithmetic ops. - CompareOpCanon, SelectOpCanon, CompareSelectIntoMinMax, - // TODO: Dynamism Refinements, consider merging with canonicalize dynamism - GetDimensionSizeOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic, - DynamicReshapeOpCanon, - // Reduce op. - ReduceNoopVariableReturn, EmptyReduceOpCanon, UnusedResultReduceOpCanon, - // Shape manipulation(-ish) ops. - GatherOpCanon, TransposeIsReshape, - // Types. - ZeroExtentTensorCanon>(context, benefit); patterns->add(context); + patterns->add< + CompareOpCanon, CompareSelectIntoMinMax, ConcatenateOpFlatten, + ConcatenateOpNoop, ConcatenateOpRemoveEmpty, DynamicIotaOpToBroadcast, + DynamicReshapeOpSameOperandAndResultShape, DynamicSliceOpToSlice, + GatherOpCanon, IotaOpBroadcast, PadOpBroadcastEmptyTensor, + RealDynamicSliceOpToDynamicSlice, ReduceOpEmptyCanon, + ReduceOpNoopVariableReturn, ReduceOpUnusedResultCanon, SelectOpCanon, + SliceOpConcatSimplify, SortOpDropUnusedArgs, SortOpSetDimension, + TransposeIsReshape, TupleIsRepacking, WhileOpImplicitCapture>(context, + benefit); + + // Generic patterns + patterns->add( + context, benefit); + + // TODO: Dynamism Refinements, consider merging with canonicalize dynamism + patterns + ->add(context); } } // namespace stablehlo diff --git a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td index 31f1f475c0..faea08078e 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td +++ b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td @@ -12,49 +12,87 @@ include "mlir/IR/OpBase.td" include "stablehlo/dialect/StablehloOps.td" +include "mlir/Dialect/Shape/IR/ShapeOps.td" + +/////////// +//// Op & Type Constraints + +class DimSizeEquals : Constraint< + CPred<"llvm::cast($0.getType()).getDimSize($1.getInt()) == " # dimSize>, + "dim size is " # dimSize>; + +def AllDimsNonExpanding : Constraint< + CPred<"$0 && cast($0).size() == llvm::cast($1.getType()).getRank()">, + "all dims are non-expanding">; + +def AllZero : Constraint< + CPred<"llvm::all_of($0, [](Value operand) {return matchPattern(operand, m_Zero()); })">, + "is all zero">; + +def CommutativeOp : Constraint< + CPred<"$0.getDefiningOp()->hasTrait()">, + "op is commutative">; + +def HasOneUse : Constraint>; -//// Utilities def NotConstantOp : Constraint< CPred<"llvm::isa($0) || !llvm::isa($0.getDefiningOp())">, "is not a constant.">; -def OperandsEqual : Constraint, "operands are equal">; - -def TypesEqual : Constraint, "operands are equal">; - def NumberOfElementsEqual : Constraint< CPred<"llvm::cast($0.getType()).getNumElements() == llvm::cast($1.getType()).getNumElements()">, "same number of elements">; +def OperandsEqual : Constraint, "operands are equal">; + def RankEqual : Constraint< CPred<"llvm::cast($0.getType()).getRank() == llvm::cast($1.getType()).getRank()">, "same rank">; -def EmptyI64Array : AttrConstraint< - CPred<"cast($_self).empty()">, "is empty i64 array">; +def TypesEqual : Constraint, "operands are equal">; -def CommutativeOp : Constraint< - CPred<"$0.getDefiningOp()->hasTrait()">, "op is commutative">; +/////////// +//// Attribute Constraints def AnySplat : AttrConstraint, "is any splat">; def AnyZero : AttrConstraint< - CPred<"::mlir::matchPattern($_self, m_AnyAttrOf(m_Zero(), m_AnyZeroFloat()))">, "is int or float zero">; + CPred<"::mlir::matchPattern($_self, m_AnyAttrOf(m_Zero(), m_AnyZeroFloat()))">, + "is int or float zero">; -def IntZero : AttrConstraint< - CPred<"::mlir::matchPattern($_self, m_Zero())">, "is integer zero">; +def DenseIntElementsAttr : AttrConstraint< + CPred<"llvm::isa($_self)">, + "is dense int elements attr">; + +def EmptyI64Array : AttrConstraint< + CPred<"cast($_self).empty()">, + "is empty i64 array">; def IntOne : AttrConstraint< - CPred<"::mlir::matchPattern($_self, m_One())">, "is integer one">; + CPred<"::mlir::matchPattern($_self, m_One())">, + "is integer one">; + +def IntZero : AttrConstraint< + CPred<"::mlir::matchPattern($_self, m_Zero())">,"is integer zero">; def IotaDims : AttrConstraint< - CPred<"isIotaRange(cast($_self).asArrayRef())">, "is iota dimensions">; + CPred<"isIotaRange(cast($_self).asArrayRef())">, + "is iota dimensions">; def SortedDims : AttrConstraint< - CPred<"llvm::is_sorted(cast($_self).asArrayRef())">, "is sorted dimensions">; + CPred<"llvm::is_sorted(cast($_self).asArrayRef())">, + "is sorted dimensions">; -def AllDimsNonExpanding : Constraint< - CPred<"$0 && cast($0).size() == llvm::cast($1.getType()).getRank()">, "all dims are non-expanding">; +def ZeroExtent : AttrConstraint< + CPred<"cast($_self).getNumElements() == 0">, + "is zero extent">; + +/////////// +//// Native Code Call Utilities + +def CastIntElementsAttr : NativeCodeCall<"cast($0)">; + +def ConvertToI64Array : NativeCodeCall<"convertToI64Array($_builder, $0)">; def GetOperandN : NativeCodeCall<"$0.getDefiningOp()->getOperand($1.getInt())">; @@ -145,28 +183,80 @@ def : Pat<(StableHLO_DynamicBroadcastInDimOp $shape, $dims, $expanding, $nonexpanding), (StableHLO_DynamicBroadcastInDimOp $operand, $shape, (MergeBroadcastDims $dims, $dims_p), (GetEmptyI64Array), (GetEmptyI64Array))>; -// Pattern: dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> cast(X) +// Pattern: dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X) // No-op, but wrap in ConvertOp to preserve dynamic output shape, can be // important if this result is returned, where refining type would require // also updating the funciton signature. -def : Pat<(StableHLO_DynamicBroadcastInDimOp:$op $operand, $shape, $dims, $expanding, $nonexpanding), +def : Pat<(StableHLO_DynamicBroadcastInDimOp:$op $operand, $shape, IotaDims:$dims, $expanding, $nonexpanding), (StableHLO_ConvertOpWithShape $op, $operand), [(AllDimsNonExpanding $nonexpanding, $op)]>; // Pattern: dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) // If sharing same shape operand, is dynamic reshape. def : Pat<(StableHLO_DynamicBroadcastInDimOp - (StableHLO_DynamicReshapeOp $operand, $shape), $shape, $dims, $expanding, $nonexpanding), + (StableHLO_DynamicReshapeOp $operand, $shape), $shape, IotaDims:$dims, $expanding, $nonexpanding), (StableHLO_DynamicReshapeOp $operand, $shape)>; +// Pattern: dynamic_broadcast_in_dim(X, shape_of(X)) -> X +def : Pat<(StableHLO_DynamicBroadcastInDimOp + $operand, (Shape_ShapeOfOp $operand), IotaDims:$dims, $expanding, $nonexpanding), + (replaceWithValue $operand)>; + +//////// +// DynamicGatherOp + +// Pattern: dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes) +def : Pat<(StableHLO_DynamicGatherOp $operand, $start_indices, (StableHLO_ConstantOp DenseIntElementsAttr:$slice_sizes), $dimension_numbers, $indices_are_sorted), + (StableHLO_GatherOp $operand, $start_indices, $dimension_numbers, (ConvertToI64Array $slice_sizes), $indices_are_sorted)>; + +//////// +// DynamicPadOp + +// Pattern: dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) +// [if low, high, interior are all constants] +def : Pat<(StableHLO_DynamicPadOp $input, + $padding_value, + (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_low), + (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_high), + (ConstantLikeMatcher AnyIntElementsAttr:$interior_padding)), + (StableHLO_PadOp $input, $padding_value, + (ConvertToI64Array $edge_padding_low), + (ConvertToI64Array $edge_padding_high), + (ConvertToI64Array $interior_padding))>; //////// // DynamicReshapeOp // Pattern: dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) -def : Pat<(StableHLO_DynamicReshapeOp (StableHLO_DynamicReshapeOp $operand, $shape_p), $shape), +def : Pat<(StableHLO_DynamicReshapeOp (StableHLO_DynamicReshapeOp $operand, $shape_p), $shape), (StableHLO_DynamicReshapeOp $operand, $shape)>; +// Pattern: shape_of(dynamic_reshape(X, shape)) -> shape +def : Pat<(Shape_ShapeOfOp:$op (StableHLO_DynamicReshapeOp $x, $shape)), + (replaceWithValue $shape), + [(TypesEqual $shape, $op)]>; + +//////// +// DynamicUpdateSliceOp + +// Pattern: dynamic_update_slice(X, update : zero_extent)) -> X +def : Pat<(StableHLO_DynamicUpdateSliceOp $operand, (ConstantLikeMatcher ZeroExtent:$update), $start_indices), + (replaceWithValue $operand)>; + +// Pattern: dynamic_update_slice(X, update, start_indices : zero)) -> update +def : Pat<(StableHLO_DynamicUpdateSliceOp AnyStaticShapeTensor:$operand, AnyStaticShapeTensor:$update, $start_indices), + (replaceWithValue $update), + [(TypesEqual $operand, $update), (AllZero $start_indices)]>; + + +//////// +// ComplexOp + +// Pattern: complex(real(X), imag(X))) -> X +def : Pat<(StableHLO_ComplexOp (StableHLO_RealOp $operand), (StableHLO_ImagOp $operand)), + (replaceWithValue $operand)>; + + //////// // ImagOp @@ -174,6 +264,15 @@ def : Pat<(StableHLO_DynamicReshapeOp (StableHLO_DynamicReshapeOp $operand, $sh def : Pat<(StableHLO_ImagOp (StableHLO_ComplexOp $lhs, $rhs)), (replaceWithValue $rhs)>; +//////// +// IotaOp + +// Pattern: iota(dim) : type -> constant(0) : type [if type[dim] == 1] +def : Pat<(StableHLO_IotaOp:$iota $dim), + (StableHLO_ConstantLike<"0"> $iota), + [(DimSizeEquals<1> $iota, $dim)]>; + + //////// // MaxOp @@ -215,6 +314,21 @@ def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)), def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), (replaceWithValue $lhs)>; +//////// +// RealDynamicSliceOp + +// Pattern: real_dynamic_slice(X, start, limit, strides) +// -> slice(X, start, limit, strides) +// [if start, limit, strides are all constants] +def : Pat<(StableHLO_RealDynamicSliceOp $operand, + (ConstantLikeMatcher DenseIntElementsAttr:$start_indices), + (ConstantLikeMatcher DenseIntElementsAttr:$limit_indices), + (ConstantLikeMatcher DenseIntElementsAttr:$strides)), + (StableHLO_SliceOp $operand, + (ConvertToI64Array $start_indices), + (ConvertToI64Array $limit_indices), + (ConvertToI64Array $strides))>; + //////// // RealOp @@ -242,6 +356,20 @@ def : Pat<(StableHLO_ReshapeOp:$reshape $operand), (replaceWithValue $operand), [(TypesEqual $reshape, $operand)]>; + +//////// +// SelectOp + +// Pattern: select(not(p), t, f) => select(p, f, t) +def : Pat< + (StableHLO_SelectOp (StableHLO_NotOp $pred), $on_true, $on_false), + (StableHLO_SelectOp $pred, $on_false, $on_true)>; + +// Pattern: select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) +def : Pat<(StableHLO_SelectOp (StableHLO_BroadcastInDimOp:$b (StableHLO_NotOp $pred), $broadcast_dimensions), $on_true, $on_false), + (StableHLO_SelectOp (StableHLO_BroadcastInDimOp $pred, $broadcast_dimensions, (returnType $b)), $on_false, $on_true), + [(HasOneUse $b)]>; + //////// // SubtractOp @@ -254,6 +382,14 @@ def : Pat<(StableHLO_SubtractOp AnyStaticShapeTensor:$operand, $operand), def : Pat<(StableHLO_SubtractOp $lhs, (StableHLO_ConstantOp AnyZero:$value)), (replaceWithValue $lhs)>; +//////// +// SliceOp + +// Pattern: slice(X, [A:A], [B:B], ...) -> X +def : Pat<(StableHLO_SliceOp:$op AnyStaticShapeTensor:$operand, $start_indices, $limit_indices, $strides), + (replaceWithValue $operand), + [(TypesEqual $operand, $op)]>; + //////// // TransposeOp