Skip to content

Commit

Permalink
Add a flag for enabling the scatter_determinism_expander on GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
serach24 committed Nov 15, 2024
1 parent d36c8ac commit 678886f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
11 changes: 11 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_enable_fast_math(false);
opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1);
opts.set_xla_pjrt_allow_auto_layout_in_hlo(false);
opts.set_xla_gpu_enable_scatter_determinism_expander(true);
return opts;
}

Expand Down Expand Up @@ -2064,6 +2065,16 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Experimental: Make unset entry computation layout mean auto layout "
"instead of default layout in HLO when run through PjRT. In other cases "
"(StableHLO or non-PjRT) the auto layout is already used."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_scatter_determinism_expander",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_scatter_determinism_expander),
debug_options->xla_gpu_enable_scatter_determinism_expander(),
"Enable the scatter determinism expander, an optimized pass that "
"rewrites scatter operations to ensure deterministic behavior with high "
"performance."
"Note that even when this flag is disabled, scatter operations may still "
"be deterministic, although with additional overhead."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
if (debug_options.xla_gpu_enable_scatter_determinism_expander()) {
pipeline.AddPass<ScatterDeterminismExpander>();
}
pipeline.AddPass<ScatterExpander>(
ScatterExpander::kEliminateIndeterministicScatters);
}
Expand Down
9 changes: 8 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,14 @@ message DebugOptions {

bool xla_pjrt_allow_auto_layout_in_hlo = 344;

// Next id: 345
// Enable the scatter determinism expander, an optimized pass that
// rewrites scatter operations to ensure deterministic behavior with high
// performance.
// Note that even when this flag is disabled, scatter operations may still
// be deterministic, although with additional overhead.
bool xla_gpu_enable_scatter_determinism_expander = 345;

// Next id: 346

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 678886f

Please sign in to comment.