-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lowering stablehlo.custom_call to linalg #2680
Comments
You're right, the In any case, can you please let me know a bit more about your specific scenario. For example, is it always possible to identify target function via |
I'm lowering StableHLO exported from AF2 to Linalg and encountered an external call to |
Hi @endaiHW Are these Something like:
cc @GleasonK |
Hmm. Once you're in StableHLO with custom_calls, lowering to another dialect isn't typical. In this case it seems that AF2 has some assumption that there is an lapack library to dispatch to for some math, not sure how this would be represented in linalg. I wonder if there's a way to export the model without the lapack dependency and use a reference impl of some sort. Here's a shorter repro: import jax
import jax.numpy as jnp
import jax.export
def largest_evec(m):
_, eigvecs = jnp.linalg.eigh(m)
return eigvecs[..., -1]
a = jnp.array([[1, -2j], [2j, 1]])
module = jax.export.export(jax.jit(largest_evec), platforms=['cpu'])(a).mlir_module()
print(module)
# ...
# %8:3 = stablehlo.custom_call @lapack_cheevd_ffi(%7) Could you say more on what your goal is? Are you intending to compile for a different platform which doesn't have lapack, is the ideal case a reference JAX impl? |
Hi @GleasonK, |
Is the conversion failing because it can't convert the custom calls? We could mark it as a legal op to unblock if so. I'd argue that custom_call has better expressivity for external library calls than func does (i.e. it has operand and result layouts as a part of the op signature). Regardless of the op used, there will still need to be some additional work done to properly call the library, if you have any idea what the end-state IR should look like in your pipeline that would be helpful. I.e. if we represented this as a func.func, why is that better than a custom_call? |
Request description
I am trying to lower stablehlo to linalg, but the --stablehlo-legalize-to-linalg pass does not seem to convert stablehlo.custom_call. Is there a suitable pass available to handle this?
The text was updated successfully, but these errors were encountered: