Skip to content

Commit

Permalink
PR tensorflow#18326: [NVIDIA] Complete the optimization of determinis…
Browse files Browse the repository at this point in the history
…tic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes tensorflow#18326

PiperOrigin-RevId: 691023328
  • Loading branch information
serach24 authored and tensorflower-gardener committed Oct 29, 2024
1 parent 3c8feff commit acb9905
Show file tree
Hide file tree
Showing 3 changed files with 712 additions and 86 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2113,6 +2113,7 @@ cc_library(
"//xla/hlo/transforms:op_expander_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
Expand Down
Loading

0 comments on commit acb9905

Please sign in to comment.