Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #18326: [NVIDIA] Complete the optimization of deterministic scatte…
…r operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 690490783
- Loading branch information