-
Notifications
You must be signed in to change notification settings - Fork 454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Optimize deterministic scalar scatter #17886
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
d2332e5
to
82f2237
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Is it possible to smash into one commit with a lot more detailed commit message?
- Could you provide microbenchmark results, esp. comparing deterministic and non-deterministic scatter performance? If the performance is comparable, maybe we could even try to make it deterministic by default?
ScatterDeterminismExpander scatter_determinism_expander; | ||
TF_ASSERT_OK_AND_ASSIGN( | ||
bool result, RunHloPass(&scatter_determinism_expander, module.get())); | ||
EXPECT_TRUE(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we
a. FileCheck the result of the rewrite
b. Launch it and verify correctness
c. Verify that it's indeed deterministic by launching multiple times and comparing numerics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test serves one single purpose: to check if we will pattern match when the scatter combiner is non-associative, and followed the same pattern as in scatter_expander_tests
.
The a, b, c that you mentioned are all included in the rest of the tests, specifically:
FileCheck the result of the rewrite -> ScatterAddHloVerificationTest
Launch it and verify correctness -> ScatterAddCorrectnessTest
and ScatterAddOutOfBoundCorrectnessTest
Verify that it's indeed deterministic by launching multiple times and comparing numerics -> ScatterAddReproducibilityTest
I think it is doable, but won't PRs be squashed to merge?
This is provided in the evaluation section of the attached doc. |
)"; | ||
|
||
RunAndFilecheckHloRewrite(kModuleStr, ScatterDeterminismExpander(), | ||
kExpectedPattern, nullptr /*after_pass_checks*/, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we usually annotate parameters like this:
/*after_pass_checks=*/nullptr
But you are passing the default values, so better to just remove them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as requested
xla/service/scatter_utils.cc
Outdated
return scatter_indices; | ||
} | ||
|
||
if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the if above and the if here can be simplified into one if block to:
if (index_vector_dim >= scatter_indices_shape.dimensions_size() - 1) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as requested
@@ -695,6 +696,7 @@ absl::Status RunOptimizationPasses( | |||
if (RequireDeterminism(hlo_module->config())) { | |||
// Scatter can be indeterministic if indices are not unique or a non | |||
// associative combiner function is used. Eliminate these Scatter ops. | |||
pipeline.AddPass<ScatterDeterminismExpander>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that you do several canonicalizations that may already be done with GpuScatterExpander. I think it would make sense to move GpuScatterExpander first, check which canonicalizations it already applies, and avoid duplicating those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I mixed up GpuScatterExpander and ScatterSimplifier. ScatterSimplifier does a bunch of simplifications which seem related to what you do, the normalized scatter has this form (copied from the comment in scatter_simplifier.h):
// The output scatter's attributes will have the following characteristics:
// - scatter_indices is a two-dimensional tensor
// - index_vector_dim is 1
// - inserted_window_dims is []
// - update_window_dims is [0, 1, ...]
// - scatter_dims_to_operand_dims is [0, 1, ...]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there are overlaps between the canonicalization process of ScatterSimplifier and ScatterExpander, as pointed out in the comments in the scatter_simplifier.h, above what you pasted:
// It implements the first two steps of the algorithm decribed in
// ScatterExpander::ExpandInstruction (scatter_expander.cc). Additionally, it
// transposes updates and operands to transform scatter_dims_to_operand_dims
// into the identity mapping. This is different from the algorithm in
// ScatterExpander, which instead applies the mapping in scatter_indices.
I was following the exact same canonicalization of ScatterExpander, that is why I extracted those functions into the scatter_utils.cc file, to be reused by both ScatterExpander and ScatterDeterminismExpander. In the gpu_compiler.cc, I was also following the same convention, adding the ScatterDeterminismExpander pass before the ScatterExpander with kEliminateIndeterministicScatters matching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I understand that it would be a bit harder to rewrite your pass based on the different canonicalization used in ScatterSimplifier. I guess there is still some potential to combine these canonicalizations, but it is somewhat orthogonal to your change.
@@ -230,8 +230,7 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { | |||
)"; | |||
|
|||
RunAndFilecheckHloRewrite(kModuleStr, ScatterDeterminismExpander(), | |||
kExpectedPattern, nullptr /*after_pass_checks*/, | |||
nullptr /*config*/); | |||
kExpectedPattern, nullptr, nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for not being clear, I meant not passing the values for after_pass_checks and config at all. They have default values which are nullptr, no need to explicitly pass a default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand now. Changed as requested
@@ -695,6 +696,7 @@ absl::Status RunOptimizationPasses( | |||
if (RequireDeterminism(hlo_module->config())) { | |||
// Scatter can be indeterministic if indices are not unique or a non | |||
// associative combiner function is used. Eliminate these Scatter ops. | |||
pipeline.AddPass<ScatterDeterminismExpander>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I understand that it would be a bit harder to rewrite your pass based on the different canonicalization used in ScatterSimplifier. I guess there is still some potential to combine these canonicalizations, but it is somewhat orthogonal to your change.
// traverse the tuple output of the computation | ||
for (int i = 0; i < operand_size; ++i) { | ||
const HloInstruction* output = root->operand(i); | ||
std::unordered_set<const HloInstruction*> input_dependencies; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer to use absl::flat_hash_map instead of std::unordered_set because it is faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
} | ||
|
||
namespace { | ||
void RecursivelyGetInputDependencies( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory, this recursion can have exponential runtime if you don't also keep track of which instructions you have already visited. Currently, you do deduplication of parameters. If you have a visited
set instead, you don't need that and can use a vector for dependencies, and also avoid the exponential runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. My initial thought was there isn't really complicated scatter computations so I did not bother to optimize this here.
Changed as suggested.
} | ||
} | ||
|
||
// Check if the every output of the computation only depends on the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "Check if the every" -> "Check if every"
also "scatter computation" instead of just "computation". Makes it clearer that this function does not process arbitrary computations, but just scatter computations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as requested
if (input_dependencies.size() > 2) { | ||
return false; | ||
} | ||
if (input_dependencies.size() == 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if input_dependencies.size() == 1? For example there can be scatter computations that just throw away the initial value, and just use the value from updates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as requested
return false; | ||
} | ||
if (input_dependencies.size() == 2) { | ||
for (const HloInstruction* input : input_dependencies) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If all we care about are the parameter numbers, maybe also just store the parameter numbers instead of the HloInstruction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as requested
9d8eb96
to
92f6051
Compare
|
||
#include "xla/service/scatter_determinism_expander.h" | ||
#include <cstdint> | ||
#include "absl/container/flat_hash_set.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will also need a corresponding BUILD dependency "@com_google_absl//absl/container:flat_hash_set"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
HloInstruction* scatter_indices, | ||
const Shape& scatter_shape) { | ||
if (scatter_indices->shape().rank() == 1) { | ||
CHECK(scatter_shape.dimensions_size() == 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: CHECK_EQ instead of CHECK() with '==' op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
LiteralUtil::CreateFromArray(out_of_bound_array))); | ||
} | ||
// More than one dimension in scatter_indices | ||
Array2D<int32_t> out_of_ound_array(scatter_indices->shape().dimensions(0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: out_of_ound_array -> out_of_bound_array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
|
||
auto* sorting = parent->AddInstruction(HloInstruction::CreateSort( | ||
ShapeUtil::MakeTupleShape(sort_shapes), 0, sort_operands, comparison, | ||
false /*is_stable*/)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we prefer annotations like /is_stable=/false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
int64_t num_updates = updates_shape.dimensions(0); | ||
|
||
// Calculate the number of iterations needed (log_2(n)) | ||
int64_t log_n = static_cast<int64_t>(std::ceil(std::log2(num_updates))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use Log2Ceiling:
Line 562 in b4abe20
constexpr inline int Log2Ceiling(T x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
std::vector<int64_t> strides = {1}; | ||
|
||
for (int64_t iteration = 0; iteration < log_n; ++iteration) { | ||
offset = 1 << iteration; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offset = static_cast<int64_t>(1) << iteration
Unfortunately even if iteration is int64_t, it would still compute the shift with the int32_t value otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
int64_t log_n = static_cast<int64_t>(std::ceil(std::log2(num_updates))); | ||
|
||
// Placeholder for offset calculation (2^d) | ||
int64_t offset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see an advantage to declare the variable here, instead of in the loop where it is assigned. The compiler should be able to do this optimization, and it seems easier to read to move the declaration down.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
2ddc9e6
to
f79da13
Compare
Could we duplicate some of it in the commit message? Google Docs don't tend to live for very long: the access can be pulled at any time, whereas the commit message is there forever. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just a few small things left. Also please address the comment from George and copy the benchmark numbers into the PR description.
scatter_indices, sorted_scalar_indices, scatter, parent, num_indices); | ||
|
||
// Finally, recreate the scatter instruction with unique indices | ||
auto* new_scatter = parent->AddInstruction(HloInstruction::CreateScatter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: directly return without assigning to new_scatter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
auto* new_scatter = parent->AddInstruction(HloInstruction::CreateScatter( | ||
scatter->shape(), scatter_operands, last_occurrence_indices, | ||
prefix_scan_updates, scatter->to_apply(), dim_numbers, | ||
true /*indices_are_sorted*/, true /*unique_indices*/)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: annotate /*indices_are_sorted=*/true
and /*unique_indices=*/true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
int64_t scatter_indices_count = ScatterIndicesCount(scatter); | ||
if (!IsInt32(scatter_indices_count)) { | ||
// 2147483647 is the maximum value for a 32-bit signed integer (INT32_MAX). | ||
return Unimplemented( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether it would better to just not match in this case (so moving the check to InstructionMatchesPattern), so that we can still support it (although very slow)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the show version (scatter_expander.cc
), they have exactly the same check and will abort so I followed that here. Even if we do not match this here, it will still abort in the scatter_expander.cc
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes sense, thanks for the explanation.
f79da13
to
1a524ef
Compare
I squashed the commits into one and included the microbenchmark results there. |
Linter complains with:
|
Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 |
1a524ef
to
42cc615
Compare
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
Imported from GitHub PR openxla/xla#17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at openxla/xla#18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a PiperOrigin-RevId: 686779279
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
Imported from GitHub PR #17886 This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates. The second PR is at #18326 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 42cc615 by Chenhao Jiang <chenhaoj@nvidia.com>: Optimize deterministic scalar scatter Performance Takeaways: - Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes. - Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15. Full Microbenchmark: | Input Size | Index Size | Non-Det | Original Det | New Det | Slowdown (vs Non-det) | Speedup (vs Original Det) | |------------|------------|---------|--------------|----------|-----------------------|---------------------------| | 10 | 10 | 3.96E-05| 7.82E-05 | 4.66E-05 | 1.18 | 1.68 | | 10 | 100 | 3.72E-05| 4.83E-04 | 9.73E-05 | 2.62 | 4.96 | | 10 | 1000 | 3.92E-05| 4.20E-03 | 6.62E-05 | 1.69 | 63.50 | | 10 | 10000 | 4.36E-05| 4.31E-02 | 1.21E-04 | 2.77 | 357.37 | | 10 | 100000 | 1.06E-04| 4.33E-01 | 1.71E-04 | 1.61 | 2536.56 | | 10 | 1000000 | 4.31E-04| 4.17E+00 | 4.45E-04 | 1.03 | 9372.37 | | 100 | 10 | 4.27E-05| 7.76E-05 | 4.71E-05 | 1.10 | 1.65 | | 100 | 100 | 4.01E-05| 4.91E-04 | 5.61E-05 | 1.40 | 8.75 | | 100 | 1000 | 5.17E-05| 4.21E-03 | 1.10E-04 | 2.13 | 38.24 | | 100 | 10000 | 4.08E-05| 4.27E-02 | 1.05E-04 | 2.57 | 407.45 | | 100 | 100000 | 7.60E-05| 4.14E-01 | 1.69E-04 | 2.22 | 2455.08 | | 100 | 1000000 | 2.86E-04| 4.17E+00 | 4.62E-04 | 1.62 | 9009.13 | | 1000 | 10 | 3.95E-05| 7.85E-05 | 4.97E-05 | 1.26 | 1.58 | | 1000 | 100 | 4.16E-05| 4.85E-04 | 5.27E-05 | 1.27 | 9.21 | | 1000 | 1000 | 3.90E-05| 4.25E-03 | 6.35E-05 | 1.63 | 66.86 | | 1000 | 10000 | 4.08E-05| 4.25E-02 | 1.22E-04 | 3.00 | 346.99 | | 1000 | 100000 | 4.26E-05| 4.15E-01 | 1.92E-04 | 4.51 | 2161.72 | | 1000 | 1000000 | 1.73E-04| 4.26E+00 | 4.75E-04 | 2.74 | 8964.91 | | 10000 | 10 | 4.17E-05| 8.00E-05 | 4.76E-05 | 1.14 | 1.68 | | 10000 | 100 | 3.68E-05| 7.16E-04 | 1.10E-04 | 3.00 | 6.49 | | 10000 | 1000 | 4.13E-05| 4.23E-03 | 1.01E-04 | 2.44 | 42.12 | | 10000 | 10000 | 3.71E-05| 4.23E-02 | 1.44E-04 | 3.89 | 293.14 | | 10000 | 100000 | 9.70E-05| 4.28E-01 | 1.72E-04 | 1.77 | 2494.21 | | 10000 | 1000000 | 1.18E-04| 4.17E+00 | 4.91E-04 | 4.15 | 8488.57 | | 100000 | 10 | 3.73E-05| 7.25E-05 | 4.92E-05 | 1.32 | 1.47 | | 100000 | 100 | 4.09E-05| 4.91E-04 | 6.33E-05 | 1.55 | 7.76 | | 100000 | 1000 | 4.10E-05| 4.25E-03 | 6.40E-05 | 1.56 | 66.39 | | 100000 | 10000 | 3.78E-05| 4.22E-02 | 1.26E-04 | 3.34 | 334.38 | | 100000 | 100000 | 4.42E-05| 4.16E-01 | 1.67E-04 | 3.79 | 2486.22 | | 100000 | 1000000 | 5.37E-05| 4.17E+00 | 4.92E-04 | 9.15 | 8474.51 | | 1000000 | 10 | 3.97E-05| 8.10E-05 | 5.12E-05 | 1.29 | 1.58 | | 1000000 | 100 | 4.56E-05| 4.94E-04 | 6.08E-05 | 1.33 | 8.13 | | 1000000 | 1000 | 4.47E-05| 4.29E-03 | 6.17E-05 | 1.38 | 69.44 | | 1000000 | 10000 | 4.48E-05| 4.27E-02 | 1.18E-04 | 2.63 | 362.68 | | 1000000 | 100000 | 4.25E-05| 4.19E-01 | 1.78E-04 | 4.19 | 2352.46 | | 1000000 | 1000000 | 6.59E-05| 4.18E+00 | 5.01E-04 | 7.60 | 8334.87 | Merging this change closes #17886 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615 PiperOrigin-RevId: 686779279
…r operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 690490783
…r operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
…r operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328
… scatter operations Imported from GitHub PR openxla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes openxla#18326 COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
… scatter operations Imported from GitHub PR openxla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes openxla#18326 COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696790875
… scatter operations Imported from GitHub PR openxla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes openxla#18326 COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 691023328
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- b016044 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- fbdb066 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- d36c8ac by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 678886f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- b01604490908fbe43685aed7178d0a66602b7a8c by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- fbdb066fd38a2fadb4322caaabe8c8d1a9fa77e3 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- d36c8ac7260c241c4ca6ed7dc16018f8030c0b80 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 678886f97bd133c4ffa2fbf0365e15c808383a6f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696078761
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696790875
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696956113
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR openxla/xla#19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR openxla/xla#18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of openxla/xla#17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 PiperOrigin-RevId: 696956113
This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in
xla/service/ScatterExpander.cc
. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.The second PR is at #18326
Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
Bugs resolved: jax-ml/jax#17844