Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more StableHLO Simplification Patterns #2607

Closed
GleasonK opened this issue Oct 29, 2024 · 2 comments · Fixed by #2608
Closed

Add more StableHLO Simplification Patterns #2607

GleasonK opened this issue Oct 29, 2024 · 2 comments · Fixed by #2608

Comments

@GleasonK
Copy link
Member

Request description

There a loads of simplification patterns in MHLO, most of which would be very useful to have in StableHLO. This ticket will track the porting of these simplification patterns into the StablehloAggressiveSimplification pass.

@GleasonK
Copy link
Member Author

Ported everything from xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir except for the following tests, since they aren't pure StableHLO->StableHLO simplifications. Some of these probably have precedence to be added to the StableHLO canonicalizations, but others should likely exist at a later phase of compilation.

////////
// DynamicBroadcastInDimOp

// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_2
func.func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
  %0 = shape.shape_of %arg0 : tensor<?xf32> -> !shape.shape
  %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<1xindex>
  %2 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = array<i64: 0> }> : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
  // CHECK: return %[[ARG]] : tensor<?xf32>
  func.return %2 : tensor<?xf32>
}

// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_3
func.func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
  %1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
  %2 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = array<i64: 0> }> : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
  // CHECK: return %[[ARG]] : tensor<?xf32>
  func.return %2 : tensor<?xf32>
}

// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_4
func.func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
  %0 = shape.shape_of %arg0 : tensor<?xf32> -> !shape.shape
  %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<?xindex>
  %2 = tensor.cast %1 : tensor<?xindex> to tensor<1xindex>
  %3 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %2) <{ broadcast_dimensions = array<i64: 0> }> : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
  // CHECK: return %[[ARG]] : tensor<?xf32>
  func.return %3 : tensor<?xf32>
}

////////
// (Dynamic)PadOp

// CHECK-LABEL: @pad_zero_length_dyn
func.func @pad_zero_length_dyn(%arg0: tensor<?x0xf32>, %arg1: tensor<f32>) -> tensor<?x2xf32> {
  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
  // CHECK-DAG: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x0xf32>
  // CHECK-DAG: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C1]]
  // CHECK-DAG: %[[MAX:.+]] = arith.maxsi %[[SUB]], %[[C0]]
  // CHECK-DAG: %[[MUL:.+]] = arith.muli %[[MAX]], %[[C2]]
  // CHECK-DAG: %[[ADD1:.+]] = arith.addi %[[DIM]], %[[MUL]]
  // CHECK-DAG: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[C2]]
  // CHECK-DAG: %[[SHAPE:.+]] = tensor.from_elements %[[ADD2]], %[[C2]] : tensor<2xindex>
  // CHECK-DAG: %[[BROAD:.+]] = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>, tensor<2xindex>) -> tensor<?x2xf32>
  %0 = "stablehlo.pad"(%arg0, %arg1) {
    edge_padding_low = array<i64: 1, 1>,
    edge_padding_high = array<i64: 1, 1>,
    interior_padding = array<i64: 2, 2>
  } : (tensor<?x0xf32>, tensor<f32>) -> tensor<?x2xf32>
  // CHECK: return %[[BROAD]]
  func.return %0 : tensor<?x2xf32>
}

// CHECK-LABEL: @dynamic_pad_length_dyn
func.func @dynamic_pad_length_dyn(
  %arg0: tensor<?x0xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>,
  %arg3: tensor<2xi32>) -> tensor<?x?xf32> {
  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
  // CHECK-DAG: %[[CI0:.+]] = arith.constant 0 : index
  // CHECK-DAG: %[[CI1:.+]] = arith.constant 1 : index
  // CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : tensor<f32>
  // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[CI0]]
  // CHECK: %[[CAST:.+]] = arith.index_cast %[[DIM0]] : index to i32
  // CHECK: %[[EX0:.+]] = tensor.extract %arg1[%[[CI0]]]
  // CHECK: %[[EX1:.+]] = tensor.extract %arg2[%[[CI0]]]
  // CHECK: %[[EX2:.+]] = tensor.extract %arg3[%[[CI0]]]
  // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[CAST]], %[[C1]] : i32
  // CHECK: %[[SUB:.+]] = arith.subi %[[CAST]], %[[C1]] : i32
  // CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[C0]], %[[SUB]] : i32
  // CHECK: %[[MUL:.+]] = arith.muli %[[EX2]], %[[SEL]] : i32
  // CHECK: %[[ADD0:.+]] = arith.addi %[[MUL]], %[[CAST]] : i32
  // CHECK: %[[ADD1:.+]] = arith.addi %[[ADD0]], %[[EX0]] : i32
  // CHECK: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[EX1]] : i32
  // CHECK: %[[EX3:.+]] = tensor.extract %arg1[%[[CI1]]]
  // CHECK: %[[EX4:.+]] = tensor.extract %arg2[%[[CI1]]]
  // CHECK: %[[ADD3:.+]] = arith.addi %[[EX3]], %[[EX4]] : i32
  // CHECK: %[[SHAPE:.+]] = tensor.from_elements %[[ADD2]], %[[ADD3]] : tensor<2xi32>
  // CHECK: %[[BROAD:.+]] = "stablehlo.dynamic_broadcast_in_dim"(%[[CST]], %[[SHAPE]]) <{broadcast_dimensions = dense<> : tensor<0xi64>}>
  %0 = arith.constant dense<0.0> : tensor<f32>
  %1 = "stablehlo.dynamic_pad"(%arg0, %0, %arg1, %arg2, %arg3) {
  } : (tensor<?x0xf32>, tensor<f32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
  // CHECK: return %[[BROAD]]
  func.return %1 : tensor<?x?xf32>
}



////////
// ConcatenateOp

// CHECK-LABEL: concatenate_const_1D
func.func @concatenate_const_1D() -> tensor<4xi32> {
  // CHECK: [[VAL:%.+]]= stablehlo.constant dense<[0, 1, 2, 3]>
  %0 = stablehlo.constant dense<[0, 1]> : tensor<2xi32>
  %1 = stablehlo.constant dense<[2, 3]> : tensor<2xi32>
  %2 = "stablehlo.concatenate"(%0, %1) <{ dimension = 0 : i64 }> : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>

  // CHECK: return [[VAL]]
  func.return %2 : tensor<4xi32>
}

// CHECK-LABEL: concatenate_const_1D_float
func.func @concatenate_const_1D_float() -> tensor<4xf32> {
  // CHECK: [[VAL:%.+]] = stablehlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>

  %0 = stablehlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
  %1 = stablehlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
  %2 = "stablehlo.concatenate"(%0, %1) <{ dimension = 0 : i64 }> : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>

  // CHECK: return [[VAL]]
  func.return %2 : tensor<4xf32>
}

// CHECK-LABEL: concatenate_const_2D_vertical
func.func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
  // CHECK: [[VAL:%.+]]= stablehlo.constant dense<[
  // CHECK-SAME: [0, 1], [2, 3]
  // CHECK-SAME: ]>
  %0 = stablehlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
  %1 = stablehlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
  %2 = "stablehlo.concatenate"(%0, %1) <{ dimension = 0 : i64 }> : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>

  // CHECK: return [[VAL]]
  func.return %2 : tensor<2x2xi32>
}

// CHECK-LABEL: concatenate_const_2D_horizontal
func.func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
  // CHECK: [[VAL:%.+]]= stablehlo.constant dense<[
  // CHECK-SAME: [0, 2], [1, 3]
  // CHECK-SAME: ]>
  %0 = stablehlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
  %1 = stablehlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
  %2 = "stablehlo.concatenate"(%0, %1) <{ dimension = 1 : i64 }> : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>

  // CHECK: return [[VAL]]
  func.return %2 : tensor<2x2xi32>
}

////////
// Tensor/Shape canonicalize

// CHECK-LABEL: concatenate_noop_typecast
func.func @concatenate_noop_typecast(%arg0: tensor<?xi32>) -> tensor<4xi32> {
  // CHECK-SAME: [[ARG:%.+]]: tensor<?xi32>
  // CHECK-NEXT: [[RES:%.+]] = tensor.cast [[ARG]] : tensor<?xi32> to tensor<4xi32>
  %0 = "stablehlo.concatenate"(%arg0) <{ dimension = 0 : i64 }> : (tensor<?xi32>) -> tensor<4xi32>

  // CHECK: return [[RES]]
  func.return %0 : tensor<4xi32>
}

@GleasonK
Copy link
Member Author

Current list of simplifications in #2608:

add(X, 0) -> X
add(cst, X) -> add(X, cst)
add(cst,cst) -> cst
and(X, 0) -> 0
and(X, 1) -> X
and(cst, X) -> and(X, cst)
broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank]
broadcast_in_dim(X, [iota...]) -> X
broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel]
broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB))
broadcast_in_dim(splat, _) -> constant(splat)
compare(X, X, [EQ,GE,LE]) -> true
compare(X, X, [NE,GT,LT]) -> false
compare(cst, X, comparator) -> compare(X, cst, inv(comparator))
complex(real(X), imag(X))) -> X
concatenate(X) -> X
concatenate(X, Y, []) -> concatenate(X, Y)
concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z)
convert(X, [X.type]) -> X
dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X)
dynamic_broadcast_in_dim(X, shape_of(X)) -> X
dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB))
dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape)
dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes)
dynamic_iota(shape, dim) -> dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape)
dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) [if low, high, interior are all constants]
dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape)
dynamic_reshape(op(dynamic_reshape(X, shape)), shape) -> op(dynamic_reshape(X, shape)) [if op has same operand and result shape]
dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes)
dynamic_update_slice(X, update : zero_extent)) -> X
dynamic_update_slice(X, update, start_indices : zero)) -> update
gather(X, cst_start_indices) -> slice(X, slice_start, slice_end)
get_dimension_size(X, i) -> X.shape[i]
get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
imag(complex(R,I)) -> I
iota(dim) : multi_rank -> broadcast_in_dim(iota(dim) : array, multi_rank)
iota(dim) : type -> constant(0) : type [if type[dim] == 1]
max(cst, X) -> max(X, cst)
minimum(cst, X) -> minimum(X, cst)
multiply(X, 0i) -> 0i
multiply(X, 1i) -> X
multiply(cst, X) -> multiply(X, cst)
or(X, 0) -> X
or(X, 1) -> 1
or(cst, X) -> or(X, cst)
pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _)
real(complex(R,I)) -> X
real_dynamic_slice(X, start, limit, strides) -> dynamic_slice(X, start, limit, strides) [if strides, start are constants, limit = start + constant]
real_dynamic_slice(X, start, limit, strides) -> slice(X, start, limit, strides) [if start, limit, strides are all constants]
reduce(X..., dims=[], add) -> X...
reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...]
reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)]
reduce[A](_, _, fn:return A) -> A...
reshape(X, [X.shape]) -> X
reshape(cst, shape) -> cst
reshape(reshape(X, _), [shape]) -> reshape(X, [shape])
select(broadcast(not(p)), t, f) => select(broadcast(p), f, t)
select(not(p), t, f) => select(p, f, t)
shape_of(dynamic_reshape(X, shape)) -> shape
slice(X, [A:A], [B:B], ...) -> X
slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
sort(X) -> sort(X, dim = N) [when dim can be inferred]
sort(X,Y) -> sort(X) [if Y unused and unused in comparator]
subtract(X, 0) -> X
subtract(X, X) -> 0
transpose(X, [iota...]) -> X
transpose(X, [no_mem_layout_change...]) -> reshape(X)
tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X
while -> while (loop invariants as implicit captures)
xor(cst, X) -> xor(X, cst)
op(X : zero_extent_tensor) -> constant([])

GleasonK added a commit that referenced this issue Oct 30, 2024
Porting over / re-organizing useful simplifications from MHLO to
StableHLO for broader community use.

This includes all tests from
[xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir](https://github.com/openxla/xla/blob/2c4b82cab1679273044188da5de780ec8f0eefad/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir#L4)

Closes #2607

Overview of simplifications now in StableHLO:

```
add(X, 0) -> X
add(cst, X) -> add(X, cst)
add(cst,cst) -> cst
and(X, 0) -> 0
and(X, 1) -> X
and(cst, X) -> and(X, cst)
broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank]
broadcast_in_dim(X, [iota...]) -> X
broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel]
broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB))
broadcast_in_dim(splat, _) -> constant(splat)
compare(X, X, [EQ,GE,LE]) -> true
compare(X, X, [NE,GT,LT]) -> false
compare(cst, X, comparator) -> compare(X, cst, inv(comparator))
complex(real(X), imag(X))) -> X
concatenate(X) -> X
concatenate(X, Y, []) -> concatenate(X, Y)
concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z)
convert(X, [X.type]) -> X
dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X)
dynamic_broadcast_in_dim(X, shape_of(X)) -> X
dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB))
dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape)
dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes)
dynamic_iota(shape, dim) -> dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape)
dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) [if low, high, interior are all constants]
dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape)
dynamic_reshape(op(dynamic_reshape(X, shape)), shape) -> op(dynamic_reshape(X, shape)) [if op has same operand and result shape]
dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes)
dynamic_update_slice(X, update : zero_extent)) -> X
dynamic_update_slice(X, update, start_indices : zero)) -> update
gather(X, cst_start_indices) -> slice(X, slice_start, slice_end)
get_dimension_size(X, i) -> X.shape[i]
get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
imag(complex(R,I)) -> I
iota(dim) : multi_rank -> broadcast_in_dim(iota(dim) : array, multi_rank)
iota(dim) : type -> constant(0) : type [if type[dim] == 1]
max(cst, X) -> max(X, cst)
minimum(cst, X) -> minimum(X, cst)
multiply(X, 0i) -> 0i
multiply(X, 1i) -> X
multiply(cst, X) -> multiply(X, cst)
or(X, 0) -> X
or(X, 1) -> 1
or(cst, X) -> or(X, cst)
pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _)
real(complex(R,I)) -> X
real_dynamic_slice(X, start, limit, strides) -> dynamic_slice(X, start, limit, strides) [if strides, start are constants, limit = start + constant]
real_dynamic_slice(X, start, limit, strides) -> slice(X, start, limit, strides) [if start, limit, strides are all constants]
reduce(X..., dims=[], add) -> X...
reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...]
reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)]
reduce[A](_, _, fn:return A) -> A...
reshape(X, [X.shape]) -> X
reshape(cst, shape) -> cst
reshape(reshape(X, _), [shape]) -> reshape(X, [shape])
select(broadcast(not(p)), t, f) => select(broadcast(p), f, t)
select(not(p), t, f) => select(p, f, t)
shape_of(dynamic_reshape(X, shape)) -> shape
slice(X, [A:A], [B:B], ...) -> X
slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
sort(X) -> sort(X, dim = N) [when dim can be inferred]
sort(X,Y) -> sort(X) [if Y unused and unused in comparator]
subtract(X, 0) -> X
subtract(X, X) -> 0
transpose(X, [iota...]) -> X
transpose(X, [no_mem_layout_change...]) -> reshape(X)
tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X
while -> while (loop invariants as implicit captures)
xor(cst, X) -> xor(X, cst)
op(X : zero_extent_tensor) -> constant([])
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant