Skip to content

Commit

Permalink
Add a flag for enabling the scatter_determinism_expander on GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
serach24 committed Nov 12, 2024
1 parent 1ecb608 commit 985079f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
11 changes: 11 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_dot_merger_threshold_mb(32);
opts.set_xla_enable_fast_math(false);
opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1);
opts.set_xla_gpu_enable_scatter_determinism_expander(true);
return opts;
}

Expand Down Expand Up @@ -2046,6 +2047,16 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(),
"This controls how many in-flight collectives "
"latency hiding scheduler can schedule."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_scatter_determinism_expander",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_scatter_determinism_expander),
debug_options->xla_gpu_enable_scatter_determinism_expander(),
"Enable the scatter determinism expander, an optimized pass that "
"rewrites scatter operations to ensure deterministic behavior with high "
"performance."
"Note that even when this flag is disabled, scatter operations may still "
"be deterministic, although with additional overhead."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
if (debug_options.xla_gpu_enable_scatter_determinism_expander()) {
pipeline.AddPass<ScatterDeterminismExpander>();
}
pipeline.AddPass<ScatterExpander>(
ScatterExpander::kEliminateIndeterministicScatters);
}
Expand Down
7 changes: 7 additions & 0 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,13 @@ message DebugOptions {
}
PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341;

// Enable the scatter determinism expander, an optimized pass that
// rewrites scatter operations to ensure deterministic behavior with high
// performance.
// Note that even when this flag is disabled, scatter operations may still
// be deterministic, although with additional overhead.
bool xla_gpu_enable_scatter_determinism_expander = 342;

// Next id: 343

// Extra options to pass to the compilation backend (e.g. LLVM); specific
Expand Down

0 comments on commit 985079f

Please sign in to comment.