Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #17886: [NVIDIA] Optimize deterministic scalar scatter #18419

Merged
merged 1 commit into from
Oct 17, 2024

Conversation

copybara-service[bot]
Copy link

PR #17886: [NVIDIA] Optimize deterministic scalar scatter

Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang chenhaoj@nvidia.com:

Optimize deterministic scalar scatter

Performance Takeaways:

  • Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
  • Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:

Input Size Index Size Non-Det Original Det New Det Slowdown (vs Non-det) Speedup (vs Original Det)
10 10 3.96E-05 7.82E-05 4.66E-05 1.18 1.68
10 100 3.72E-05 4.83E-04 9.73E-05 2.62 4.96
10 1000 3.92E-05 4.20E-03 6.62E-05 1.69 63.50
10 10000 4.36E-05 4.31E-02 1.21E-04 2.77 357.37
10 100000 1.06E-04 4.33E-01 1.71E-04 1.61 2536.56
10 1000000 4.31E-04 4.17E+00 4.45E-04 1.03 9372.37
100 10 4.27E-05 7.76E-05 4.71E-05 1.10 1.65
100 100 4.01E-05 4.91E-04 5.61E-05 1.40 8.75
100 1000 5.17E-05 4.21E-03 1.10E-04 2.13 38.24
100 10000 4.08E-05 4.27E-02 1.05E-04 2.57 407.45
100 100000 7.60E-05 4.14E-01 1.69E-04 2.22 2455.08
100 1000000 2.86E-04 4.17E+00 4.62E-04 1.62 9009.13
1000 10 3.95E-05 7.85E-05 4.97E-05 1.26 1.58
1000 100 4.16E-05 4.85E-04 5.27E-05 1.27 9.21
1000 1000 3.90E-05 4.25E-03 6.35E-05 1.63 66.86
1000 10000 4.08E-05 4.25E-02 1.22E-04 3.00 346.99
1000 100000 4.26E-05 4.15E-01 1.92E-04 4.51 2161.72
1000 1000000 1.73E-04 4.26E+00 4.75E-04 2.74 8964.91
10000 10 4.17E-05 8.00E-05 4.76E-05 1.14 1.68
10000 100 3.68E-05 7.16E-04 1.10E-04 3.00 6.49
10000 1000 4.13E-05 4.23E-03 1.01E-04 2.44 42.12
10000 10000 3.71E-05 4.23E-02 1.44E-04 3.89 293.14
10000 100000 9.70E-05 4.28E-01 1.72E-04 1.77 2494.21
10000 1000000 1.18E-04 4.17E+00 4.91E-04 4.15 8488.57
100000 10 3.73E-05 7.25E-05 4.92E-05 1.32 1.47
100000 100 4.09E-05 4.91E-04 6.33E-05 1.55 7.76
100000 1000 4.10E-05 4.25E-03 6.40E-05 1.56 66.39
100000 10000 3.78E-05 4.22E-02 1.26E-04 3.34 334.38
100000 100000 4.42E-05 4.16E-01 1.67E-04 3.79 2486.22
100000 1000000 5.37E-05 4.17E+00 4.92E-04 9.15 8474.51
1000000 10 3.97E-05 8.10E-05 5.12E-05 1.29 1.58
1000000 100 4.56E-05 4.94E-04 6.08E-05 1.33 8.13
1000000 1000 4.47E-05 4.29E-03 6.17E-05 1.38 69.44
1000000 10000 4.48E-05 4.27E-02 1.18E-04 2.63 362.68
1000000 100000 4.25E-05 4.19E-01 1.78E-04 4.19 2352.46
1000000 1000000 6.59E-05 4.18E+00 5.01E-04 7.60 8334.87

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615

@copybara-service copybara-service bot force-pushed the test_686779279 branch 3 times, most recently from a0f7808 to 6c68a10 Compare October 17, 2024 11:45
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686871951
@copybara-service copybara-service bot merged commit 9b59f66 into main Oct 17, 2024
@copybara-service copybara-service bot deleted the test_686779279 branch October 17, 2024 12:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vmap with scatter_add extremely slow when using xla_gpu_deterministic_ops
1 participant