Skip to content

Commit

Permalink
PR #18326: [NVIDIA] Complete the optimization of deterministic scatte…
Browse files Browse the repository at this point in the history
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
  • Loading branch information
serach24 authored and Google-ML-Automation committed Oct 28, 2024
1 parent a7bd679 commit 49b64aa
Show file tree
Hide file tree
Showing 3 changed files with 712 additions and 86 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,7 @@ cc_library(
"//xla/hlo/transforms:op_expander_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
Expand Down
Loading

0 comments on commit 49b64aa

Please sign in to comment.