Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add @code_hlo macro to get mlir code #39

Merged
merged 2 commits into from
Jul 13, 2024
Merged

Conversation

Pangoraw
Copy link
Collaborator

Example usage:

julia> using Reactant

julia> W = Reactant.ConcreteRArray(randn(Float32, 10, 20))
       x = Reactant.ConcreteRArray(randn(Float32, 20, 5))
       Reactant.@code_hlo W * x
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<20x10xf32>, %arg1: tensor<5x20xf32>) -> tensor<5x10xf32> {
    %0 = stablehlo.dot_general %arg1, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<5x20xf32>, tensor<20x10xf32>) -> tensor<5x10xf32>
    return %0 : tensor<5x10xf32>
  }
}

cc @mofeing

@mofeing mofeing linked an issue Jul 13, 2024 that may be closed by this pull request
@mofeing
Copy link
Collaborator

mofeing commented Jul 13, 2024

Awesome! Did you manage to fix the problem with sum? Also why does transform.with_named_sequence appears there?

EDIT: sum works and the transform.with_named_sequence seems to be a consequence of some pass in Enzyme-JAX.

@mofeing mofeing requested a review from wsmoses July 13, 2024 14:51
@wsmoses
Copy link
Member

wsmoses commented Jul 13, 2024

Oh this is amazing!

One extra thing, would it be possible to have a flag that shows HLO before the pass pipeline of optimizations?

And yeah that's a leftover from the transform dialect we use for optimizations

@wsmoses
Copy link
Member

wsmoses commented Jul 13, 2024

Also @Pangoraw if you'd like, you're more than welcome to add yourself to the contributors [and also if you'd be interested in helping @mofeing and myself out getting some cool things implemented!]

@Pangoraw
Copy link
Collaborator Author

I added a run_pipeline (default=true) option such that the pass pipeline is not run on the module:

julia> W = Reactant.ConcreteRArray(randn(Float32, 10, 20))
       x = Reactant.ConcreteRArray(randn(Float32, 20, 5))
       run_pipeline = false
       Reactant.@code_hlo run_pipeline = run_pipeline W * x
Module:
module {
  func.func @main(%arg0: tensor<5x20xf32>, %arg1: tensor<20x10xf32>) -> (tensor<20x10xf32>, tensor<5x20xf32>, tensor<5x10xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x20xf32>) -> tensor<20x5xf32>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<20x10xf32>) -> tensor<10x20xf32>
    %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x20xf32>, tensor<20x5xf32>) -> tensor<10x5xf32>
    %3 = stablehlo.transpose %1, dims = [1, 0] : (tensor<10x20xf32>) -> tensor<20x10xf32>
    %4 = stablehlo.transpose %0, dims = [1, 0] : (tensor<20x5xf32>) -> tensor<5x20xf32>
    %5 = stablehlo.transpose %2, dims = [1, 0] : (tensor<10x5xf32>) -> tensor<5x10xf32>
    return %3, %4, %5 : tensor<20x10xf32>, tensor<5x20xf32>, tensor<5x10xf32>
  }
}

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mofeing mofeing merged commit d05542c into EnzymeAD:main Jul 13, 2024
6 of 13 checks passed
@Pangoraw Pangoraw deleted the code_hlo_macro branch July 15, 2024 08:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Macro to pretty print the stable hlo code
3 participants