Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lowering stablehlo.custom_call to linalg #2680

Open
endaiHW opened this issue Dec 30, 2024 · 6 comments
Open

Lowering stablehlo.custom_call to linalg #2680

endaiHW opened this issue Dec 30, 2024 · 6 comments
Labels
Transformations Pertaining to MLIR passes and transformations

Comments

@endaiHW
Copy link

endaiHW commented Dec 30, 2024

Request description

I am trying to lower stablehlo to linalg, but the --stablehlo-legalize-to-linalg pass does not seem to convert stablehlo.custom_call. Is there a suitable pass available to handle this?

@sdasgup3
Copy link
Member

sdasgup3 commented Jan 2, 2025

You're right, the --stablehlo-legalize-to-linalg pass in MLIR doesn't directly handle stablehlo.custom_call operations. My best guess for this is because custom_call ops are inherently implementation-defined and can represent arbitrary functions or external library calls. There's no general way to convert them to standard Linalg operations without knowing their specific implemetation.

In any case, can you please let me know a bit more about your specific scenario. For example, is it always possible to identify target function via call_target_name?

@endaiHW
Copy link
Author

endaiHW commented Jan 6, 2025

I'm lowering StableHLO exported from AF2 to Linalg and encountered an external call to lapack_ssyevd_ffi. Can this be directly lowered to func.func, or are there additional steps needed to handle such external calls properly?
%4:3 = stablehlo.custom_call @lapack_ssyevd_ffi(%3) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[2, 3, 1, 0]> : tensor<4xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[2, 3, 1, 0]> : tensor<4xindex>, dense<[2, 1, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<1x68x4x4xf32>) -> (tensor<1x68x4x4xf32>, tensor<1x68x4xf32>, tensor<1x68xi32>)

@sdasgup3
Copy link
Member

sdasgup3 commented Jan 8, 2025

Hi @endaiHW

Are these custom_call, to external LAPACK library, expected to end up in llvm.call instruction before linking them with the implementation of lapack_ssyevd_ffi? If yes, then .. directly lowered to func.func during StableHLO to Linalg legalization makes sense here.

Something like:

func.func private @lapack_ssyevd_ffi(%arg0: tensor<1x68x4x4xf32>) -> (tensor<1x68x4x4xf32>, tensor<1x68x4xf32>, tensor<1x68xi32>)

// ... other code ...

%4:3 = func.call @lapack_ssyevd_ffi(%3) {
  operand_layouts = [dense<[2, 3, 1, 0]> : tensor<4xindex>],
  result_layouts = [dense<[2, 3, 1, 0]> : tensor<4xindex>, dense<[2, 1, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]
} : (tensor<1x68x4x4xf32>) -> (tensor<1x68x4x4xf32>, tensor<1x68x4xf32>, tensor<1x68xi32>) 

cc @GleasonK

@GleasonK
Copy link
Member

GleasonK commented Jan 8, 2025

Hmm. Once you're in StableHLO with custom_calls, lowering to another dialect isn't typical. In this case it seems that AF2 has some assumption that there is an lapack library to dispatch to for some math, not sure how this would be represented in linalg. I wonder if there's a way to export the model without the lapack dependency and use a reference impl of some sort.

Here's a shorter repro:

import jax
import jax.numpy as jnp
import jax.export

def largest_evec(m):
  _, eigvecs = jnp.linalg.eigh(m)
  return eigvecs[..., -1]

a = jnp.array([[1, -2j], [2j, 1]])

module = jax.export.export(jax.jit(largest_evec), platforms=['cpu'])(a).mlir_module()
print(module)
# ...
# %8:3 = stablehlo.custom_call @lapack_cheevd_ffi(%7)

Could you say more on what your goal is? Are you intending to compile for a different platform which doesn't have lapack, is the ideal case a reference JAX impl?

@GleasonK GleasonK added the Transformations Pertaining to MLIR passes and transformations label Jan 8, 2025
@endaiHW
Copy link
Author

endaiHW commented Jan 10, 2025

Hi @GleasonK,
Apologies for the late reply! Our goal is to run AF2 through the MLIR pathway on CPU for subsequent optimization. We don’t need to compile for a platform without LAPACK support. As @sdasgup3 previously mentioned, lowering to func.call and linking LAPACK should suffice for our use case. Thanks for your help!

@GleasonK
Copy link
Member

Is the conversion failing because it can't convert the custom calls? We could mark it as a legal op to unblock if so.

I'd argue that custom_call has better expressivity for external library calls than func does (i.e. it has operand and result layouts as a part of the op signature). Regardless of the op used, there will still need to be some additional work done to properly call the library, if you have any idea what the end-state IR should look like in your pipeline that would be helpful. I.e. if we represented this as a func.func, why is that better than a custom_call?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Transformations Pertaining to MLIR passes and transformations
Projects
None yet
Development

No branches or pull requests

3 participants