From 985079f4257e632e85162b5525cfd4655ddf555d Mon Sep 17 00:00:00 2001 From: Chenhao Jiang Date: Tue, 12 Nov 2024 21:17:16 +0000 Subject: [PATCH] Add a flag for enabling the scatter_determinism_expander on GPU. --- xla/debug_options_flags.cc | 11 +++++++++++ xla/service/gpu/gpu_compiler.cc | 4 +++- xla/xla.proto | 7 +++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index fb591e28ff311..b29284f8610a8 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_dot_merger_threshold_mb(32); opts.set_xla_enable_fast_math(false); opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1); + opts.set_xla_gpu_enable_scatter_determinism_expander(true); return opts; } @@ -2046,6 +2047,16 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(), "This controls how many in-flight collectives " "latency hiding scheduler can schedule.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_scatter_determinism_expander", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_scatter_determinism_expander), + debug_options->xla_gpu_enable_scatter_determinism_expander(), + "Enable the scatter determinism expander, an optimized pass that " + "rewrites scatter operations to ensure deterministic behavior with high " + "performance." + "Note that even when this flag is disabled, scatter operations may still " + "be deterministic, although with additional overhead.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index a6d0650b0eab8..e1d16392c982d 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses( if (RequireDeterminism(hlo_module->config())) { // Scatter can be indeterministic if indices are not unique or a non // associative combiner function is used. Eliminate these Scatter ops. - pipeline.AddPass(); + if (debug_options.xla_gpu_enable_scatter_determinism_expander()) { + pipeline.AddPass(); + } pipeline.AddPass( ScatterExpander::kEliminateIndeterministicScatters); } diff --git a/xla/xla.proto b/xla/xla.proto index 6ab37cf9b4e73..1e2582930dde8 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1030,6 +1030,13 @@ message DebugOptions { } PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341; + // Enable the scatter determinism expander, an optimized pass that + // rewrites scatter operations to ensure deterministic behavior with high + // performance. + // Note that even when this flag is disabled, scatter operations may still + // be deterministic, although with additional overhead. + bool xla_gpu_enable_scatter_determinism_expander = 342; + // Next id: 343 // Extra options to pass to the compilation backend (e.g. LLVM); specific