Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching to trace_to_subjaxpr_dynamic. #10775

Merged
merged 1 commit into from
Jul 21, 2022

Conversation

pschuh
Copy link
Collaborator

@pschuh pschuh commented May 20, 2022

This allows the MLIR lowering code to cache call lowerings.

example output:

module @jit_fun.0 {
  func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    %1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    return %1 : tensor<4x8xf32>
  }
  func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
    return %0 : tensor<4x8xf32>
  }
}

If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.

@pschuh pschuh requested review from hawkinsp and mattjj and removed request for hawkinsp May 20, 2022 01:33
@mattjj mattjj assigned pschuh and unassigned mattjj May 23, 2022
This allows the MLIR lowering code to cache call lowerings.

example output:

```
module @jit_fun.0 {
  func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    %1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    return %1 : tensor<4x8xf32>
  }
  func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
    return %0 : tensor<4x8xf32>
  }
}
```

If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 20, 2022
@copybara-service copybara-service bot merged commit be6db2e into jax-ml:main Jul 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants