Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MHLO simplifications to StableHLO (#2608)
Porting over / re-organizing useful simplifications from MHLO to StableHLO for broader community use. This includes all tests from [xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir](https://github.com/openxla/xla/blob/2c4b82cab1679273044188da5de780ec8f0eefad/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir#L4) Closes #2607 Overview of simplifications now in StableHLO: ``` add(X, 0) -> X add(cst, X) -> add(X, cst) add(cst,cst) -> cst and(X, 0) -> 0 and(X, 1) -> X and(cst, X) -> and(X, cst) broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank] broadcast_in_dim(X, [iota...]) -> X broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel] broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB)) broadcast_in_dim(splat, _) -> constant(splat) compare(X, X, [EQ,GE,LE]) -> true compare(X, X, [NE,GT,LT]) -> false compare(cst, X, comparator) -> compare(X, cst, inv(comparator)) complex(real(X), imag(X))) -> X concatenate(X) -> X concatenate(X, Y, []) -> concatenate(X, Y) concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z) convert(X, [X.type]) -> X dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X) dynamic_broadcast_in_dim(X, shape_of(X)) -> X dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB)) dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes) dynamic_iota(shape, dim) -> dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape) dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) [if low, high, interior are all constants] dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) dynamic_reshape(op(dynamic_reshape(X, shape)), shape) -> op(dynamic_reshape(X, shape)) [if op has same operand and result shape] dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes) dynamic_update_slice(X, update : zero_extent)) -> X dynamic_update_slice(X, update, start_indices : zero)) -> update gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) get_dimension_size(X, i) -> X.shape[i] get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i imag(complex(R,I)) -> I iota(dim) : multi_rank -> broadcast_in_dim(iota(dim) : array, multi_rank) iota(dim) : type -> constant(0) : type [if type[dim] == 1] max(cst, X) -> max(X, cst) minimum(cst, X) -> minimum(X, cst) multiply(X, 0i) -> 0i multiply(X, 1i) -> X multiply(cst, X) -> multiply(X, cst) or(X, 0) -> X or(X, 1) -> 1 or(cst, X) -> or(X, cst) pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _) real(complex(R,I)) -> X real_dynamic_slice(X, start, limit, strides) -> dynamic_slice(X, start, limit, strides) [if strides, start are constants, limit = start + constant] real_dynamic_slice(X, start, limit, strides) -> slice(X, start, limit, strides) [if start, limit, strides are all constants] reduce(X..., dims=[], add) -> X... reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] reduce[A](_, _, fn:return A) -> A... reshape(X, [X.shape]) -> X reshape(cst, shape) -> cst reshape(reshape(X, _), [shape]) -> reshape(X, [shape]) select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) select(not(p), t, f) => select(p, f, t) shape_of(dynamic_reshape(X, shape)) -> shape slice(X, [A:A], [B:B], ...) -> X slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) sort(X) -> sort(X, dim = N) [when dim can be inferred] sort(X,Y) -> sort(X) [if Y unused and unused in comparator] subtract(X, 0) -> X subtract(X, X) -> 0 transpose(X, [iota...]) -> X transpose(X, [no_mem_layout_change...]) -> reshape(X) tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X while -> while (loop invariants as implicit captures) xor(cst, X) -> xor(X, cst) op(X : zero_extent_tensor) -> constant([]) ```
- Loading branch information