-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more StableHLO Simplification Patterns #2607
Comments
Ported everything from xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir except for the following tests, since they aren't pure StableHLO->StableHLO simplifications. Some of these probably have precedence to be added to the StableHLO canonicalizations, but others should likely exist at a later phase of compilation.
|
Current list of simplifications in #2608:
|
GleasonK
added a commit
that referenced
this issue
Oct 30, 2024
Porting over / re-organizing useful simplifications from MHLO to StableHLO for broader community use. This includes all tests from [xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir](https://github.com/openxla/xla/blob/2c4b82cab1679273044188da5de780ec8f0eefad/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir#L4) Closes #2607 Overview of simplifications now in StableHLO: ``` add(X, 0) -> X add(cst, X) -> add(X, cst) add(cst,cst) -> cst and(X, 0) -> 0 and(X, 1) -> X and(cst, X) -> and(X, cst) broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank] broadcast_in_dim(X, [iota...]) -> X broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel] broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB)) broadcast_in_dim(splat, _) -> constant(splat) compare(X, X, [EQ,GE,LE]) -> true compare(X, X, [NE,GT,LT]) -> false compare(cst, X, comparator) -> compare(X, cst, inv(comparator)) complex(real(X), imag(X))) -> X concatenate(X) -> X concatenate(X, Y, []) -> concatenate(X, Y) concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z) convert(X, [X.type]) -> X dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X) dynamic_broadcast_in_dim(X, shape_of(X)) -> X dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB)) dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes) dynamic_iota(shape, dim) -> dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape) dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) [if low, high, interior are all constants] dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) dynamic_reshape(op(dynamic_reshape(X, shape)), shape) -> op(dynamic_reshape(X, shape)) [if op has same operand and result shape] dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes) dynamic_update_slice(X, update : zero_extent)) -> X dynamic_update_slice(X, update, start_indices : zero)) -> update gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) get_dimension_size(X, i) -> X.shape[i] get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i imag(complex(R,I)) -> I iota(dim) : multi_rank -> broadcast_in_dim(iota(dim) : array, multi_rank) iota(dim) : type -> constant(0) : type [if type[dim] == 1] max(cst, X) -> max(X, cst) minimum(cst, X) -> minimum(X, cst) multiply(X, 0i) -> 0i multiply(X, 1i) -> X multiply(cst, X) -> multiply(X, cst) or(X, 0) -> X or(X, 1) -> 1 or(cst, X) -> or(X, cst) pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _) real(complex(R,I)) -> X real_dynamic_slice(X, start, limit, strides) -> dynamic_slice(X, start, limit, strides) [if strides, start are constants, limit = start + constant] real_dynamic_slice(X, start, limit, strides) -> slice(X, start, limit, strides) [if start, limit, strides are all constants] reduce(X..., dims=[], add) -> X... reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] reduce[A](_, _, fn:return A) -> A... reshape(X, [X.shape]) -> X reshape(cst, shape) -> cst reshape(reshape(X, _), [shape]) -> reshape(X, [shape]) select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) select(not(p), t, f) => select(p, f, t) shape_of(dynamic_reshape(X, shape)) -> shape slice(X, [A:A], [B:B], ...) -> X slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) sort(X) -> sort(X, dim = N) [when dim can be inferred] sort(X,Y) -> sort(X) [if Y unused and unused in comparator] subtract(X, 0) -> X subtract(X, X) -> 0 transpose(X, [iota...]) -> X transpose(X, [no_mem_layout_change...]) -> reshape(X) tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X while -> while (loop invariants as implicit captures) xor(cst, X) -> xor(X, cst) op(X : zero_extent_tensor) -> constant([]) ```
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Request description
There a loads of simplification patterns in MHLO, most of which would be very useful to have in StableHLO. This ticket will track the porting of these simplification patterns into the StablehloAggressiveSimplification pass.
The text was updated successfully, but these errors were encountered: