Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Optimize deterministic scalar scatter #17886

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2554,6 +2554,18 @@ cc_library(
],
)


cc_library(
name = "scatter_utils",
srcs = ["scatter_utils.cc"],
hdrs = ["scatter_utils.h"],
deps = [
":call_inliner",
":hlo_creation_utils",
],
)


cc_library(
name = "scatter_expander",
srcs = ["scatter_expander.cc"],
Expand All @@ -2562,6 +2574,7 @@ cc_library(
":call_inliner",
":hlo_creation_utils",
":op_expander_pass",
":scatter_utils",
":while_util",
"//xla:literal_util",
"//xla/hlo/ir:hlo",
Expand All @@ -2570,6 +2583,25 @@ cc_library(
],
)


cc_library(
name = "scatter_determinism_expander",
srcs = ["scatter_determinism_expander.cc"],
hdrs = ["scatter_determinism_expander.h"],
deps = [
":call_inliner",
":hlo_creation_utils",
":op_expander_pass",
":scatter_utils",
":while_util",
"//xla:literal_util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "scatter_expander_test",
srcs = ["scatter_expander_test.cc"],
Expand All @@ -2587,6 +2619,27 @@ xla_cc_test(
],
)

xla_test(
name = "scatter_determinism_expander_test",
srcs = ["scatter_determinism_expander_test.cc"],
backends = [
"cpu",
"gpu",
],
deps = [
":scatter_determinism_expander",
"//xla:literal",
"//xla:shape_util",
"//xla:test",
"//xla:types",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
],
)

cc_library(
name = "triangular_solve_expander",
srcs = ["triangular_solve_expander.cc"],
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,7 @@ cc_library(
"//xla/service:result_caster",
"//xla/service:rng_bit_generator_expander",
"//xla/service:rng_expander",
"//xla/service:scatter_determinism_expander",
"//xla/service:scatter_expander",
"//xla/service:scatter_simplifier",
"//xla/service:sharding_remover",
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ limitations under the License.
#include "xla/service/rng_bit_generator_expander.h"
#include "xla/service/rng_expander.h"
#include "xla/service/scatter_expander.h"
#include "xla/service/scatter_determinism_expander.h"
#include "xla/service/scatter_simplifier.h"
#include "xla/service/sharding_remover.h"
#include "xla/service/simplify_fp_conversions.h"
Expand Down Expand Up @@ -700,6 +701,7 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you do several canonicalizations that may already be done with GpuScatterExpander. I think it would make sense to move GpuScatterExpander first, check which canonicalizations it already applies, and avoid duplicating those.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I mixed up GpuScatterExpander and ScatterSimplifier. ScatterSimplifier does a bunch of simplifications which seem related to what you do, the normalized scatter has this form (copied from the comment in scatter_simplifier.h):

// The output scatter's attributes will have the following characteristics:
// - scatter_indices is a two-dimensional tensor
// - index_vector_dim is 1
// - inserted_window_dims is []
// - update_window_dims is [0, 1, ...]
// - scatter_dims_to_operand_dims is [0, 1, ...]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are overlaps between the canonicalization process of ScatterSimplifier and ScatterExpander, as pointed out in the comments in the scatter_simplifier.h, above what you pasted:

// It implements the first two steps of the algorithm decribed in
// ScatterExpander::ExpandInstruction (scatter_expander.cc). Additionally, it
// transposes updates and operands to transform scatter_dims_to_operand_dims
// into the identity mapping. This is different from the algorithm in
// ScatterExpander, which instead applies the mapping in scatter_indices.

I was following the exact same canonicalization of ScatterExpander, that is why I extracted those functions into the scatter_utils.cc file, to be reused by both ScatterExpander and ScatterDeterminismExpander. In the gpu_compiler.cc, I was also following the same convention, adding the ScatterDeterminismExpander pass before the ScatterExpander with kEliminateIndeterministicScatters matching

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I understand that it would be a bit harder to rewrite your pass based on the different canonicalization used in ScatterSimplifier. I guess there is still some potential to combine these canonicalizations, but it is somewhat orthogonal to your change.

pipeline.AddPass<ScatterExpander>(
ScatterExpander::kEliminateIndeterministicScatters);
}
Expand Down
Loading