Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #17886: [NVIDIA] Optimize deterministic scalar scatter
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
- Loading branch information