Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MHLO simplifications to StableHLO #2608

Merged
merged 4 commits into from
Oct 30, 2024

Conversation

GleasonK
Copy link
Member

@GleasonK GleasonK commented Oct 29, 2024

Porting over / re-organizing useful simplifications from MHLO to StableHLO for broader community use.

This includes all tests from xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir

Closes #2607

Overview of simplifications now in StableHLO:

add(X, 0) -> X
add(cst, X) -> add(X, cst)
add(cst,cst) -> cst
and(X, 0) -> 0
and(X, 1) -> X
and(cst, X) -> and(X, cst)
broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank]
broadcast_in_dim(X, [iota...]) -> X
broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel]
broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB))
broadcast_in_dim(splat, _) -> constant(splat)
compare(X, X, [EQ,GE,LE]) -> true
compare(X, X, [NE,GT,LT]) -> false
compare(cst, X, comparator) -> compare(X, cst, inv(comparator))
complex(real(X), imag(X))) -> X
concatenate(X) -> X
concatenate(X, Y, []) -> concatenate(X, Y)
concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z)
convert(X, [X.type]) -> X
dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X)
dynamic_broadcast_in_dim(X, shape_of(X)) -> X
dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB))
dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape)
dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes)
dynamic_iota(shape, dim) -> dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape)
dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) [if low, high, interior are all constants]
dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape)
dynamic_reshape(op(dynamic_reshape(X, shape)), shape) -> op(dynamic_reshape(X, shape)) [if op has same operand and result shape]
dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes)
dynamic_update_slice(X, update : zero_extent)) -> X
dynamic_update_slice(X, update, start_indices : zero)) -> update
gather(X, cst_start_indices) -> slice(X, slice_start, slice_end)
get_dimension_size(X, i) -> X.shape[i]
get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
imag(complex(R,I)) -> I
iota(dim) : multi_rank -> broadcast_in_dim(iota(dim) : array, multi_rank)
iota(dim) : type -> constant(0) : type [if type[dim] == 1]
max(cst, X) -> max(X, cst)
minimum(cst, X) -> minimum(X, cst)
multiply(X, 0i) -> 0i
multiply(X, 1i) -> X
multiply(cst, X) -> multiply(X, cst)
or(X, 0) -> X
or(X, 1) -> 1
or(cst, X) -> or(X, cst)
pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _)
real(complex(R,I)) -> X
real_dynamic_slice(X, start, limit, strides) -> dynamic_slice(X, start, limit, strides) [if strides, start are constants, limit = start + constant]
real_dynamic_slice(X, start, limit, strides) -> slice(X, start, limit, strides) [if start, limit, strides are all constants]
reduce(X..., dims=[], add) -> X...
reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...]
reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)]
reduce[A](_, _, fn:return A) -> A...
reshape(X, [X.shape]) -> X
reshape(cst, shape) -> cst
reshape(reshape(X, _), [shape]) -> reshape(X, [shape])
select(broadcast(not(p)), t, f) => select(broadcast(p), f, t)
select(not(p), t, f) => select(p, f, t)
shape_of(dynamic_reshape(X, shape)) -> shape
slice(X, [A:A], [B:B], ...) -> X
slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
sort(X) -> sort(X, dim = N) [when dim can be inferred]
sort(X,Y) -> sort(X) [if Y unused and unused in comparator]
subtract(X, 0) -> X
subtract(X, X) -> 0
transpose(X, [iota...]) -> X
transpose(X, [no_mem_layout_change...]) -> reshape(X)
tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X
while -> while (loop invariants as implicit captures)
xor(cst, X) -> xor(X, cst)
op(X : zero_extent_tensor) -> constant([])

@GleasonK GleasonK requested a review from abhigunj October 29, 2024 21:02
@GleasonK GleasonK added the Transformations Pertaining to MLIR passes and transformations label Oct 29, 2024
@GleasonK GleasonK merged commit 5d15ab0 into openxla:main Oct 30, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Transformations Pertaining to MLIR passes and transformations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add more StableHLO Simplification Patterns
2 participants