Skip to content

Commit

Permalink
PR openxla#18326: [NVIDIA] Complete the optimization of deterministic…
Browse files Browse the repository at this point in the history
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
  • Loading branch information
serach24 committed Nov 6, 2024
1 parent 0f6331b commit 3ee2848
Show file tree
Hide file tree
Showing 3 changed files with 712 additions and 86 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2116,6 +2116,7 @@ cc_library(
"//xla/hlo/transforms:op_expander_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
Expand Down
Loading

0 comments on commit 3ee2848

Please sign in to comment.