Skip to content

Commit

Permalink
Add MHLO simplifications to StableHLO (#2608)
Browse files Browse the repository at this point in the history
Porting over / re-organizing useful simplifications from MHLO to
StableHLO for broader community use.

This includes all tests from
[xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir](https://github.com/openxla/xla/blob/2c4b82cab1679273044188da5de780ec8f0eefad/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir#L4)

Closes #2607

Overview of simplifications now in StableHLO:

```
add(X, 0) -> X
add(cst, X) -> add(X, cst)
add(cst,cst) -> cst
and(X, 0) -> 0
and(X, 1) -> X
and(cst, X) -> and(X, cst)
broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank]
broadcast_in_dim(X, [iota...]) -> X
broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel]
broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB))
broadcast_in_dim(splat, _) -> constant(splat)
compare(X, X, [EQ,GE,LE]) -> true
compare(X, X, [NE,GT,LT]) -> false
compare(cst, X, comparator) -> compare(X, cst, inv(comparator))
complex(real(X), imag(X))) -> X
concatenate(X) -> X
concatenate(X, Y, []) -> concatenate(X, Y)
concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z)
convert(X, [X.type]) -> X
dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X)
dynamic_broadcast_in_dim(X, shape_of(X)) -> X
dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB))
dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape)
dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes)
dynamic_iota(shape, dim) -> dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape)
dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) [if low, high, interior are all constants]
dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape)
dynamic_reshape(op(dynamic_reshape(X, shape)), shape) -> op(dynamic_reshape(X, shape)) [if op has same operand and result shape]
dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes)
dynamic_update_slice(X, update : zero_extent)) -> X
dynamic_update_slice(X, update, start_indices : zero)) -> update
gather(X, cst_start_indices) -> slice(X, slice_start, slice_end)
get_dimension_size(X, i) -> X.shape[i]
get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
imag(complex(R,I)) -> I
iota(dim) : multi_rank -> broadcast_in_dim(iota(dim) : array, multi_rank)
iota(dim) : type -> constant(0) : type [if type[dim] == 1]
max(cst, X) -> max(X, cst)
minimum(cst, X) -> minimum(X, cst)
multiply(X, 0i) -> 0i
multiply(X, 1i) -> X
multiply(cst, X) -> multiply(X, cst)
or(X, 0) -> X
or(X, 1) -> 1
or(cst, X) -> or(X, cst)
pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _)
real(complex(R,I)) -> X
real_dynamic_slice(X, start, limit, strides) -> dynamic_slice(X, start, limit, strides) [if strides, start are constants, limit = start + constant]
real_dynamic_slice(X, start, limit, strides) -> slice(X, start, limit, strides) [if start, limit, strides are all constants]
reduce(X..., dims=[], add) -> X...
reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...]
reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)]
reduce[A](_, _, fn:return A) -> A...
reshape(X, [X.shape]) -> X
reshape(cst, shape) -> cst
reshape(reshape(X, _), [shape]) -> reshape(X, [shape])
select(broadcast(not(p)), t, f) => select(broadcast(p), f, t)
select(not(p), t, f) => select(p, f, t)
shape_of(dynamic_reshape(X, shape)) -> shape
slice(X, [A:A], [B:B], ...) -> X
slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
sort(X) -> sort(X, dim = N) [when dim can be inferred]
sort(X,Y) -> sort(X) [if Y unused and unused in comparator]
subtract(X, 0) -> X
subtract(X, X) -> 0
transpose(X, [iota...]) -> X
transpose(X, [no_mem_layout_change...]) -> reshape(X)
tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X
while -> while (loop invariants as implicit captures)
xor(cst, X) -> xor(X, cst)
op(X : zero_extent_tensor) -> constant([])
```
  • Loading branch information
GleasonK authored Oct 30, 2024
1 parent 27c1081 commit 5d15ab0
Show file tree
Hide file tree
Showing 4 changed files with 1,528 additions and 207 deletions.
Loading

0 comments on commit 5d15ab0

Please sign in to comment.