diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index fb591e28ff3112..b29284f8610a89 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_dot_merger_threshold_mb(32); opts.set_xla_enable_fast_math(false); opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1); + opts.set_xla_gpu_enable_scatter_determinism_expander(true); return opts; } @@ -2046,6 +2047,16 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(), "This controls how many in-flight collectives " "latency hiding scheduler can schedule.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_scatter_determinism_expander", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_scatter_determinism_expander), + debug_options->xla_gpu_enable_scatter_determinism_expander(), + "Enable the scatter determinism expander, an optimized pass that " + "rewrites scatter operations to ensure deterministic behavior with high " + "performance." + "Note that even when this flag is disabled, scatter operations may still " + "be deterministic, although with additional overhead.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/xla/service/BUILD b/xla/service/BUILD index 84ed6874707697..a9fac60b6398c6 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2121,6 +2121,7 @@ cc_library( "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index a6d0650b0eab89..e1d16392c982d0 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses( if (RequireDeterminism(hlo_module->config())) { // Scatter can be indeterministic if indices are not unique or a non // associative combiner function is used. Eliminate these Scatter ops. - pipeline.AddPass(); + if (debug_options.xla_gpu_enable_scatter_determinism_expander()) { + pipeline.AddPass(); + } pipeline.AddPass( ScatterExpander::kEliminateIndeterministicScatters); } diff --git a/xla/service/scatter_determinism_expander.cc b/xla/service/scatter_determinism_expander.cc index ea462cc3c08fce..e9fcb5bcd6439c 100644 --- a/xla/service/scatter_determinism_expander.cc +++ b/xla/service/scatter_determinism_expander.cc @@ -15,11 +15,14 @@ limitations under the License. #include "xla/service/scatter_determinism_expander.h" +#include #include +#include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "xla/array.h" #include "xla/array2d.h" #include "xla/comparison_util.h" @@ -32,6 +35,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/scatter_utils.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -62,47 +66,128 @@ static absl::StatusOr> CanonicalizeScatterUpdates( return adjusted_updates; } -// Create the out-of-bound tensor for the scatter operation. -HloInstruction* CreateOutOfBoundTensor(HloComputation* parent, - HloInstruction* scatter_indices, - const Shape& scatter_shape) { - if (scatter_indices->shape().rank() == 1) { - CHECK_EQ(scatter_shape.dimensions_size(), 1); - Array out_of_bound_array({scatter_indices->shape().dimensions(0)}, - scatter_shape.dimensions(0)); - return parent->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromArray(out_of_bound_array))); - } - // More than one dimension in scatter_indices - Array2D out_of_bound_array(scatter_indices->shape().dimensions(0), - scatter_indices->shape().dimensions(1)); +template +HloInstruction* CreateBoundTensorGeneric( + HloComputation* parent, HloInstruction* scatter_indices, + absl::Span operand_dims, + absl::Span index_to_operand_map, bool is_out_of_bound = true, + std::optional> window_sizes = std::nullopt) { + CHECK_GT(scatter_indices->shape().dimensions_size(), 1); + Array2D out_of_bound_array(scatter_indices->shape().dimensions(0), + operand_dims.size()); for (int i = 0; i < scatter_indices->shape().dimensions(0); ++i) { - for (int j = 0; j < scatter_indices->shape().dimensions(1); ++j) { - out_of_bound_array(i, j) = scatter_shape.dimensions(j); + for (int j = 0; j < operand_dims.size(); ++j) { + int mapped_index = index_to_operand_map[j]; + out_of_bound_array(i, j) = + is_out_of_bound + ? operand_dims[mapped_index] + : operand_dims[mapped_index] - (*window_sizes)[mapped_index]; + // : operand_dims[j] - (*window_sizes)[j] ; } } return parent->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(out_of_bound_array))); + LiteralUtil::CreateR2FromArray2D(out_of_bound_array))); +} + +// Creates a tensor for the scatter operation based on the value of +// is_out_of_bound. +// +// When is_out_of_bound is true, the tensor is filled with values representing +// the maximum bounds of the scatter shape (out-of-bound values). This is used +// to simulate out-of-bound conditions in the scatter operation. +// +// When is_out_of_bound is false, the tensor is filled with the maximum valid +// indices (calculated as operand_dimensions - window_dimensions). This is used +// to check whether indices are within valid bounds for non-scalar updates. +// +// This function is reusable for both out-of-bound tensor generation and valid +// index checks in scatter operations with non-scalar updates. +absl::StatusOr CreateBoundTensor( + HloComputation* parent, HloInstruction* scatter_indices, + absl::Span operand_dims, + absl::Span index_to_operand_map, bool is_out_of_bound = true, + std::optional> window_sizes = std::nullopt) { + if (!is_out_of_bound && !window_sizes.has_value()) { + return FailedPrecondition( + "window_sizes must be provided when is_out_of_bound is false."); + } + + PrimitiveType type = scatter_indices->shape().element_type(); + if (type == S32) { + return CreateBoundTensorGeneric(parent, scatter_indices, + operand_dims, index_to_operand_map, + is_out_of_bound, window_sizes); + } else if (type == S64) { + return CreateBoundTensorGeneric(parent, scatter_indices, + operand_dims, index_to_operand_map, + is_out_of_bound, window_sizes); + } + return FailedPrecondition("Unexpected type for bound tensor: %s", + PrimitiveType_Name(type)); +} + +// indices shape: (num_indices, num_dims) +// updates shape: (num_indices,) +HloInstruction* FlattenIndices(HloComputation* parent, HloInstruction* indices, + absl::Span operand_dims) { + if (indices->shape().dimensions(1) == 1) { + // Originally scalar indices + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {indices->shape().dimensions(0)}), + indices)); + } + // Step 1: based on the operand_dims, calculate the strides + Array2D strides(operand_dims.size(), 1); + int64_t stride = 1; + for (int i = operand_dims.size() - 1; i >= 0; --i) { + strides(i, 0) = stride; + stride *= operand_dims[i]; + } + auto strides_tensor = parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(strides))); + + // Step 2: calculate the flattened indices + auto dot_shape = ShapeUtil::MakeShape(indices->shape().element_type(), + {indices->shape().dimensions(0), 1}); + DotDimensionNumbers dim_numbers; + dim_numbers.add_lhs_contracting_dimensions(1); + dim_numbers.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + auto flattened_indices = parent->AddInstruction(HloInstruction::CreateDot( + dot_shape, indices, strides_tensor, dim_numbers, precision_config)); + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {indices->shape().dimensions(0)}), + flattened_indices)); } // Computation for sorting the scalar scatter indices and updates together -HloComputation* ScalarSortingComparison(HloModule* module, - const Shape key_shape, - const Shape update_shape, - int64_t num_updates) { +static HloComputation* SortingComparison(HloModule* module, + const PrimitiveType indices_type, + const PrimitiveType updates_type, + int64_t num_updates, + bool has_scalar_indices) { + Shape key_shape = ShapeUtil::MakeShape(indices_type, {}); + Shape update_shape = ShapeUtil::MakeShape(updates_type, {}); HloComputation::Builder builder("sorting_computation"); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, key_shape, "lhs_key")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, key_shape, "rhs_key")); - const int kExistingParams = 2; + int param_count = 2; for (int i = 0; i < num_updates; ++i) { - builder.AddInstruction( - HloInstruction::CreateParameter(kExistingParams + i, update_shape, - absl::StrFormat("lhs_update_%d", i))); - builder.AddInstruction( - HloInstruction::CreateParameter(kExistingParams + 1 + i, update_shape, - absl::StrFormat("rhs_update_%d", i))); + builder.AddInstruction(HloInstruction::CreateParameter( + param_count, update_shape, absl::StrFormat("lhs_update_%d", i))); + builder.AddInstruction(HloInstruction::CreateParameter( + param_count + 1, update_shape, absl::StrFormat("rhs_update_%d", i))); + param_count += 2; + } + if (!has_scalar_indices) { + builder.AddInstruction(HloInstruction::CreateParameter( + param_count, key_shape, "lhs_permutation")); + builder.AddInstruction(HloInstruction::CreateParameter( + param_count + 1, key_shape, "rhs_permutation")); } builder.AddInstruction( HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, @@ -113,14 +198,20 @@ HloComputation* ScalarSortingComparison(HloModule* module, static std::vector SortIndicesAndUpdates( HloInstruction* scatter_indices, const std::vector& scatter_updates, int64_t num_indices, - HloScatterInstruction* scatter, HloComputation* parent) { + HloScatterInstruction* scatter, HloComputation* parent, + absl::Span operand_dims, bool has_scalar_indices) { const Shape& indices_shape = scatter_indices->shape(); const Shape& updates_shape = scatter_updates[0]->shape(); auto updates_dims = updates_shape.dimensions(); // Since we canonicalized the scatter updates, the first dim will always be // the number of updates and the rest will be the shape of each update + HloInstruction* scalar_indices = + FlattenIndices(scatter->parent(), scatter_indices, operand_dims); - HloInstruction* scalar_indices = scatter_indices; + // Create the shape for a single index tuple + // Create [0...num_indices] tensor for permutation in sorting + auto indices_permutation = parent->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(indices_shape.element_type(), {num_indices}), 0)); std::vector single_update_dimensions(updates_dims.begin() + 1, updates_dims.end()); @@ -130,19 +221,23 @@ static std::vector SortIndicesAndUpdates( const Shape& scalar_index_shape = ShapeUtil::MakeShape(indices_shape.element_type(), {num_indices}); + auto* comparison = SortingComparison( + scatter->GetModule(), indices_shape.element_type(), + updates_shape.element_type(), scatter_updates.size(), has_scalar_indices); - auto* comparison = ScalarSortingComparison( - scatter->GetModule(), - ShapeUtil::MakeShape(indices_shape.element_type(), {}), - ShapeUtil::MakeShape(updates_shape.element_type(), {}), - scatter_updates.size()); - + // The sorting operation contains the scalar indices and the updates, and if + // the scatter indices were not scalar, the sorting operation will also + // contain the indices permutation std::vector sort_operands = {scalar_indices}; std::vector sort_shapes = {scalar_index_shape}; for (auto update : scatter_updates) { sort_operands.push_back(update); sort_shapes.push_back(update->shape()); } + if (!has_scalar_indices) { + sort_operands.push_back(indices_permutation); + sort_shapes.push_back(indices_permutation->shape()); + } auto* sorting = parent->AddInstruction(HloInstruction::CreateSort( ShapeUtil::MakeTupleShape(sort_shapes), 0, sort_operands, comparison, @@ -160,6 +255,34 @@ static std::vector SortIndicesAndUpdates( std::vector sorted_tensors = {sorted_scalar_indices}; sorted_tensors.insert(sorted_tensors.end(), sorted_updates.begin(), sorted_updates.end()); + if (has_scalar_indices) { + return sorted_tensors; + } + // When the scatter indices were not scalar, need to return the sorted scatter + // indices + auto* sorted_indices_arg = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + indices_permutation->shape(), sorting, sorted_tensors.size())); + sorted_indices_arg = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(sorted_indices_arg->shape().element_type(), + {num_indices, 1}), + sorted_indices_arg)); + // Use gather of sorted_indices_arg to get the sorted original indices + GatherDimensionNumbers gather_dim_numbers; + gather_dim_numbers.add_offset_dims( + 1); // Preserving the inner dimension (columns) + gather_dim_numbers.add_start_index_map( + 0); // Mapping start_indices to the first dimension of the operand + gather_dim_numbers.add_collapsed_slice_dims(0); + gather_dim_numbers.set_index_vector_dim(1); + std::vector slice_sizes = {1, + scatter_indices->shape().dimensions(1)}; + auto* sorted_expanded_indices = + parent->AddInstruction(HloInstruction::CreateGather( + scatter_indices->shape(), scatter_indices, sorted_indices_arg, + gather_dim_numbers, slice_sizes, + /*indices_are_sorted=*/true)); + sorted_tensors.push_back(sorted_expanded_indices); return sorted_tensors; } @@ -259,7 +382,6 @@ absl::StatusOr> ComputePrefixScan( std::vector prefix_scans(sorted_updates.size()); HloInstruction* prefix_scan_update = nullptr; for (int i = 0; i < sorted_updates.size(); i++) { - // TODO(chenhao) change to use the extracted computation TF_ASSIGN_OR_RETURN( HloComputation * to_apply, CallComputationAndGetIthOutputWithBinaryParams(scatter->to_apply(), i)); @@ -273,19 +395,18 @@ absl::StatusOr> ComputePrefixScan( } static HloInstruction* FindLastOccurrenceIndices( - HloInstruction* scatter_indices, HloInstruction* sorted_scalar_indices, - HloInstruction* scatter, HloComputation* parent, int64_t num_indices) { - int64_t indices_len = sorted_scalar_indices->shape().dimensions(0); - HloInstruction* sorted_indices = sorted_scalar_indices; + HloInstruction* sorted_indices, HloInstruction* sorted_scalar_indices, + HloInstruction* scatter, HloComputation* parent, int64_t num_indices, + HloInstruction* out_of_bound_tensor) { + int64_t indices_len = sorted_indices->shape().dimensions(0); + const PrimitiveType& indices_type = sorted_indices->shape().element_type(); auto* sorted_indices_preceding_part = parent->AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(scatter_indices->shape().element_type(), - {indices_len - 1}), + ShapeUtil::MakeShape(indices_type, {indices_len - 1}), sorted_scalar_indices, {0}, {indices_len - 1}, {1})); auto* sorted_indices_following_part = parent->AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(scatter_indices->shape().element_type(), - {indices_len - 1}), + ShapeUtil::MakeShape(indices_type, {indices_len - 1}), sorted_scalar_indices, {1}, {indices_len}, {1})); auto* indices_mask_without_padding = parent->AddInstruction(HloInstruction::CreateCompare( @@ -304,18 +425,227 @@ static HloInstruction* FindLastOccurrenceIndices( // Mask the indices indices_mask = parent->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(PRED, scatter_indices->shape().dimensions()), + ShapeUtil::MakeShape(PRED, sorted_indices->shape().dimensions()), indices_mask, {0})); - auto* out_of_bound_tensor = - CreateOutOfBoundTensor(parent, scatter_indices, scatter->shape()); - auto* masked_indices = parent->AddInstruction(HloInstruction::CreateTernary( sorted_indices->shape(), HloOpcode::kSelect, indices_mask, sorted_indices, out_of_bound_tensor)); return masked_indices; } +template +HloInstruction* ExpandIndexOffsetsFromUpdateShape( + HloComputation* parent, const Shape& update_shape, + const ScatterDimensionNumbers& dim_num, const Shape& operand_shape, + absl::Span index_to_operand_map, + absl::Span actual_update_window_dims) { + // Calculate the offset tensor for each element of the update tensor. + // The offset tensor is represented in (num_elements_in_update, index_dim). + + int64_t num_elements = ShapeUtil::ElementsIn(update_shape); + int64_t operand_rank = operand_shape.dimensions_size(); + + Array2D offset_tensor(num_elements, operand_rank); + + std::vector is_inserted_window_dims(operand_rank, false); + for (int i = 0; i < dim_num.inserted_window_dims_size(); ++i) { + is_inserted_window_dims[dim_num.inserted_window_dims(i)] = true; + } + + // Compute the inverse of the index_to_operand_map + std::vector operand_to_index_map(operand_rank, -1); + for (int i = 0; i < operand_rank; ++i) { + operand_to_index_map[index_to_operand_map[i]] = i; + } + + for (int64_t linear_index = 0; linear_index < num_elements; ++linear_index) { + // Calculate the multi-dimensional index from the linear index + int64_t current_index = linear_index; + int inserted_window_dim_size = 0; + // Handle 0th to (operand_rank-2)th dimensions + for (int i = operand_rank - 1; i >= 0; --i) { + if (is_inserted_window_dims[i]) { + inserted_window_dim_size++; + offset_tensor(linear_index, operand_to_index_map[i]) = 0; + } else { + int64_t dim_size = actual_update_window_dims[i]; + offset_tensor(linear_index, operand_to_index_map[i]) = + current_index % dim_size; + current_index /= dim_size; + } + } + } + + // Return the offset tensor as an HloInstruction + return parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(offset_tensor))); +} + +// Expand the indices based on index_offset +HloInstruction* ExpandIndices(HloComputation* parent, HloInstruction* indices, + HloInstruction* index_offsets) { + // For each index we need to add the index_offset to the base index + // To do that, we first broadcast the indices and index_offsets to the same + // shape, then add the index_offset to the base index and flatten the + // result Broadcast to be (num_indices, length_of_index_offsets, + // length_of_indices). + bool is_one_dimensional = indices->shape().dimensions_size() == 1; + + int64_t num_indices = indices->shape().dimensions(0); + int64_t num_offsets = index_offsets->shape().dimensions(0); + int64_t index_length = + is_one_dimensional ? 1 : indices->shape().dimensions(1); + + Shape final_shape = + ShapeUtil::MakeShape(indices->shape().element_type(), + {num_indices, num_offsets, index_length}); + auto broadcasted_indices = + parent->AddInstruction(HloInstruction::CreateBroadcast( + final_shape, indices, + is_one_dimensional ? std::vector{0} + : std::vector{0, 2})); + auto broadcasted_offsets = parent->AddInstruction( + HloInstruction::CreateBroadcast(final_shape, index_offsets, {1, 2})); + auto expanded_indices = parent->AddInstruction(HloInstruction::CreateBinary( + final_shape, HloOpcode::kAdd, broadcasted_indices, broadcasted_offsets)); + // Flatten the result to be (num_indices * num_offsets, index_length) + if (is_one_dimensional) { + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {num_indices * num_offsets}), + expanded_indices)); + } + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {num_indices * num_offsets, index_length}), + expanded_indices)); +} + +// Function to create a reduction computation for logical AND +HloComputation* ReduceAndComputation(HloModule* module) { + // Create a computation builder + HloComputation::Builder builder("reduce_logical_and"); + + // Define the scalar shape for boolean operations + const Shape bool_shape = ShapeUtil::MakeShape(PRED, {}); + + // Add parameters for the reduction computation. + // These represent the elements to be combined (lhs and rhs). + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, bool_shape, "lhs")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, bool_shape, "rhs")); + + // Create the logical AND operation between the two parameters + builder.AddInstruction( + HloInstruction::CreateBinary(bool_shape, HloOpcode::kAnd, lhs, rhs)); + + // Build and return the computation object + return module->AddEmbeddedComputation(builder.Build()); +} + +absl::StatusOr CheckValidIndices( + HloComputation* parent, HloInstruction* indices, + absl::Span operand_dims, + absl::Span window_sizes, + absl::Span full_index_to_operand_dims) { + // check if indices and indices with the largest offsets are out of bound + // Essentially we need to do the following: + // 1. Check base indices >= [0, 0, 0, ...] + // 2. Check last indices <= [bounds...] + // 3. For each check, generate a same size tensor, and then do a reduce across + // rows to get a mask of size (n, 1) + auto init_reduce_value = parent->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + auto reduce_computation = ReduceAndComputation(parent->parent()); + + // 1. Check base indices >= [0, 0, 0, ...] + // first generate a zero tensor of the same size as the indices + auto* zero_constant = parent->AddInstruction( + HloInstruction::CreateConstant(indices->shape().element_type() == S64 + ? LiteralUtil::CreateR0(0) + : LiteralUtil::CreateR0(0))); + auto* zero_broadcasted = parent->AddInstruction( + HloInstruction::CreateBroadcast(indices->shape(), zero_constant, {})); + auto* zero_check = parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, indices->shape().dimensions()), indices, + zero_broadcasted, ComparisonDirection::kGe)); + HloInstruction* zero_check_mask; + // Reduce across rows to get a mask (for multi-dimensional indices). + zero_check_mask = parent->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, {indices->shape().dimensions(0)}), zero_check, + init_reduce_value, {1}, reduce_computation)); + // 2. Check last indices <= [bounds...] + // Check if the index is OOB w.r.t. the operand dimensions and window sizes. + TF_ASSIGN_OR_RETURN( + HloInstruction * max_valid_index_constant, + CreateBoundTensor(parent, indices, operand_dims, + full_index_to_operand_dims, false, window_sizes)); + auto oob_check = parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, indices->shape().dimensions()), + max_valid_index_constant, indices, ComparisonDirection::kGe)); + HloInstruction* oob_check_mask; + if (indices->shape().rank() == 1) { + oob_check_mask = oob_check; + } else { + // Reduce across rows to get a mask (for multi-dimensional indices). + oob_check_mask = parent->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, {indices->shape().dimensions(0)}), oob_check, + init_reduce_value, {1}, reduce_computation)); + } + // Combine the results of the two checks above. + auto* valid_index_mask = parent->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {indices->shape().dimensions(0)}), + HloOpcode::kAnd, zero_check_mask, oob_check_mask)); + return parent->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PRED, indices->shape().dimensions()), + valid_index_mask, {0})); +} + +// Add dimensions that are not covered in the indices_to_operand_map to the end +// of indices +absl::StatusOr AddImplicitDimensionsToIndices( + int64_t operand_rank, absl::Span indices_to_operand_map, + HloInstruction* indices) { + const Shape& indices_shape = indices->shape(); + HloComputation* computation = indices->parent(); + + // Get the batch size (N) and S (number of dimensions in index_vector) + int64_t batch_size = indices_shape.dimensions(0); + int64_t num_indices_dims = indices_to_operand_map.size(); + + // Create a tensor of zeros with the target shape [N, operand_rank] + Shape expanded_shape = ShapeUtil::MakeShape(indices_shape.element_type(), + {batch_size, operand_rank}); + + HloInstruction* zero_filled_tensor = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( + Array2D(batch_size, operand_rank - num_indices_dims, 0)))); + // Concatenate the zero-filled tensor with the index_vector + HloInstruction* expanded_indices = + computation->AddInstruction(HloInstruction::CreateConcatenate( + expanded_shape, {indices, zero_filled_tensor}, 1)); + return expanded_indices; +} + +std::vector ComputeFullIndexToOperandDims( + const Shape& operand_shape, ScatterDimensionNumbers& dim_numbers) { + std::vector full_index_to_operand_dims( + dim_numbers.mutable_scatter_dims_to_operand_dims()->begin(), + dim_numbers.mutable_scatter_dims_to_operand_dims()->end()); + // Add the implicit dimensions to the index_to_operand_map + absl::flat_hash_set existing_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + + for (int i = 0; i < operand_shape.dimensions_size(); i++) { + if (existing_dims.find(i) == existing_dims.end()) + full_index_to_operand_dims.push_back(i); + } + return full_index_to_operand_dims; +} + absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( HloInstruction* inst) { auto* scatter = Cast(inst); @@ -323,8 +653,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( HloInstruction* scatter_indices = scatter->scatter_indices(); std::vector scatter_updates( scatter->scatter_updates().begin(), scatter->scatter_updates().end()); - const ScatterDimensionNumbers& dim_numbers = - scatter->scatter_dimension_numbers(); + ScatterDimensionNumbers dim_numbers = scatter->scatter_dimension_numbers(); // If the updates tensors are empty, there is no need to update the operands. // The operands can be forwarded. @@ -349,33 +678,170 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( // Canonicalize the scatter_indices, after which the size of its most-major // dimension must be same as the while loop trip count. + HloInstruction* original_scatter_indices = scatter_indices; TF_ASSIGN_OR_RETURN(scatter_indices, CanonicalizeScatterIndices( scatter_indices, dim_numbers.index_vector_dim())); CHECK_EQ(scatter_indices_count, scatter_indices->shape().dimensions(0)); + // We compromise for maintainability and make the scatter_indices always 2D, + // so that the implementation could be easier, as we do not need to maintain + // two sets of code for 1D and 2D scatter_indices. + if (scatter_indices->shape().dimensions_size() == 1) { + scatter_indices = + scatter->parent()->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(scatter_indices->shape().element_type(), + {scatter_indices->shape().dimensions(0), 1}), + scatter_indices)); + } + CHECK_GT(scatter_indices->shape().dimensions_size(), 1); + bool has_scalar_indices = scatter_indices->shape().dimensions(1) == 1; // Canonicalize the updates, after which the size of their most-major // dimensions must be same as the while loop trip count. - TF_ASSIGN_OR_RETURN(scatter_updates, CanonicalizeScatterUpdates( - scatter_updates, scatter_indices, - dim_numbers, scatter_indices_count)); + TF_ASSIGN_OR_RETURN( + scatter_updates, + CanonicalizeScatterUpdates(scatter_updates, original_scatter_indices, + dim_numbers, scatter_indices_count)); HloComputation* parent = scatter->parent(); + auto updates_shape = scatter_updates[0]->shape(); + auto updates_dims = scatter_updates[0]->shape().dimensions(); + // Since we canonicalized the scatter updates, the first dim will always be + // the number of updates and the rest will be the shape of each update + std::vector one_update_dimensions(updates_dims.begin() + 1, + updates_dims.end()); + const Shape& update_shape = + ShapeUtil::MakeShape(updates_shape.element_type(), one_update_dimensions); + + ScatterDimensionNumbers new_dim_numbers; + // Check if each update is a scalar based on update shape + bool non_scalar_update = scatter_updates[0]->shape().dimensions_size() > 1; + + std::vector full_index_to_operand_dims = + ComputeFullIndexToOperandDims(scatter_operands[0]->shape(), dim_numbers); + + TF_ASSIGN_OR_RETURN( + HloInstruction * out_of_bound_tensor, + CreateBoundTensor(parent, scatter_indices, scatter->shape().dimensions(), + full_index_to_operand_dims)); + + if (non_scalar_update) { + // Extract operand dimensions + const Shape& operand_shape = scatter_operands[0]->shape(); + + int num_operand_dims = operand_shape.dimensions_size(); + std::vector actual_update_window_dims(num_operand_dims); + int update_dim_index = 0; + for (int i = 0; i < num_operand_dims; ++i) { + if (std::find(dim_numbers.inserted_window_dims().begin(), + dim_numbers.inserted_window_dims().end(), + i) != dim_numbers.inserted_window_dims().end()) { + actual_update_window_dims[i] = 1; + } else { + actual_update_window_dims[i] = + update_shape.dimensions(update_dim_index); + update_dim_index++; + } + } + + HloInstruction* index_offsets = + scatter_indices->shape().element_type() == S32 + ? ExpandIndexOffsetsFromUpdateShape( + scatter->parent(), update_shape, dim_numbers, operand_shape, + full_index_to_operand_dims, actual_update_window_dims) + : ExpandIndexOffsetsFromUpdateShape( + scatter->parent(), update_shape, dim_numbers, operand_shape, + full_index_to_operand_dims, actual_update_window_dims); + + // Map scatter_indices into operand space + TF_ASSIGN_OR_RETURN( + scatter_indices, + AddImplicitDimensionsToIndices( + scatter_operands[0]->shape().dimensions_size(), + dim_numbers.scatter_dims_to_operand_dims(), scatter_indices)); + CHECK(scatter_indices->shape().dimensions(0) == scatter_indices_count); + // If any updates are out of bound, we change the corresponding indices to + // be oob_tensor values + TF_ASSIGN_OR_RETURN( + HloInstruction * oob_check_mask, + CheckValidIndices(scatter->parent(), scatter_indices, + scatter_operands[0]->shape().dimensions(), + actual_update_window_dims, + full_index_to_operand_dims)); + + scatter_indices = parent->AddInstruction(HloInstruction::CreateTernary( + scatter_indices->shape(), HloOpcode::kSelect, oob_check_mask, + scatter_indices, out_of_bound_tensor)); + scatter_indices = + ExpandIndices(scatter->parent(), scatter_indices, index_offsets); + + // Check if the number of indices is the same as + // (num of indices before expanding * num of offsets) + CHECK_EQ(scatter_indices->shape().dimensions(0), + scatter_indices_count * ShapeUtil::ElementsIn(update_shape)); + + // Expand the updates + const int64_t num_elements = + ShapeUtil::ElementsIn(scatter_updates[0]->shape()); + for (int i = 0; i < scatter_updates.size(); i++) { + scatter_updates[i] = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(scatter_updates[i]->shape().element_type(), + {num_elements}), + scatter_updates[i])); + } + + // Create a new dimension numbers for the new scatter operation + // As we have scalar updates, there is no update_window_dims + new_dim_numbers.clear_update_window_dims(); + new_dim_numbers.set_index_vector_dim(1); + // Mitigate the missed dimensions + for (int i = 0; i < operand_shape.dimensions_size() - + dim_numbers.input_batching_dims_size(); + i++) { + new_dim_numbers.add_inserted_window_dims(i); + } + for (int i = 0; i < operand_shape.dimensions_size(); i++) { + new_dim_numbers.add_scatter_dims_to_operand_dims( + full_index_to_operand_dims[i]); + } + } else { + new_dim_numbers = dim_numbers; + } // Sort the scatter indices and updates together based on the scatter indices. int64_t num_indices = ShapeUtil::ElementsIn(scatter_updates[0]->shape()); std::vector sorted_tensors = SortIndicesAndUpdates( - scatter_indices, scatter_updates, num_indices, scatter, parent); + scatter_indices, scatter_updates, num_indices, scatter, parent, + scatter_operands[0]->shape().dimensions(), has_scalar_indices); HloInstruction* sorted_scalar_indices = sorted_tensors[0]; - std::vector sorted_updates(sorted_tensors.begin() + 1, - sorted_tensors.end()); + std::vector sorted_updates( + sorted_tensors.begin() + 1, + sorted_tensors.begin() + 1 + scatter_updates.size()); + HloInstruction* sorted_indices; + if (has_scalar_indices) { + sorted_indices = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(sorted_scalar_indices->shape().element_type(), + {num_indices, 1}), + sorted_scalar_indices)); + } else { + sorted_indices = sorted_tensors[sorted_tensors.size() - 1]; + } TF_ASSIGN_OR_RETURN(std::vector prefix_scan_updates, ComputePrefixScan(sorted_updates, sorted_scalar_indices, scatter, parent)); - - HloInstruction* last_occurrence_indices = FindLastOccurrenceIndices( - scatter_indices, sorted_scalar_indices, scatter, parent, num_indices); + if (non_scalar_update) { + // As the indices are expanded, we need to recompute out-of-bound tensor + // with the same shape + TF_ASSIGN_OR_RETURN( + out_of_bound_tensor, + CreateBoundTensor(parent, sorted_indices, + scatter_operands[0]->shape().dimensions(), + full_index_to_operand_dims)); + } + HloInstruction* last_occurrence_indices = + FindLastOccurrenceIndices(sorted_indices, sorted_scalar_indices, scatter, + parent, num_indices, out_of_bound_tensor); CHECK(last_occurrence_indices != nullptr) << "Last occurrence indices should not be nullptr"; @@ -383,7 +849,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( // Finally, recreate the scatter instruction with unique indices return parent->AddInstruction(HloInstruction::CreateScatter( scatter->shape(), scatter_operands, last_occurrence_indices, - prefix_scan_updates, scatter->to_apply(), dim_numbers, + prefix_scan_updates, scatter->to_apply(), new_dim_numbers, /*indices_are_sorted=*/true, /*unique_indices=*/true)); } @@ -437,25 +903,8 @@ bool CheckOutputDependency(HloComputation* to_apply, int operand_size) { bool ScatterDeterminismExpander::InstructionMatchesPattern( HloInstruction* inst) { auto* scatter = DynCast(inst); - // Need to check if updates and indices are scalar, as the current pass does - // not expand scatter with multi-dimensional updates or indices. This is - // temporary and will be removed in a future PR soon. - if (scatter == nullptr) { - return false; - } - - const Shape& indices_shape = scatter->scatter_indices()->shape(); - const Shape& updates_shape = scatter->scatter_updates()[0]->shape(); - - // Check if indices and updates are effectively 1D. - bool indices_are_1d = - (indices_shape.rank() == 1 || - (indices_shape.rank() == 2 && indices_shape.dimensions(1) == 1)); - bool updates_are_1d = - (updates_shape.rank() == 1 || - (updates_shape.rank() == 2 && updates_shape.dimensions(1) == 1)); - return indices_are_1d && updates_are_1d && !IsScatterDeterministic(scatter) && + return (scatter != nullptr) && !IsScatterDeterministic(scatter) && CheckOutputDependency(scatter->to_apply(), scatter->scatter_operands().size()); } diff --git a/xla/service/scatter_determinism_expander_test.cc b/xla/service/scatter_determinism_expander_test.cc index 366dbc768c5d09..eed978dd9e160c 100644 --- a/xla/service/scatter_determinism_expander_test.cc +++ b/xla/service/scatter_determinism_expander_test.cc @@ -89,6 +89,36 @@ TEST_F(ScatterDeterminismExpanderTest, EXPECT_TRUE(result); } +TEST_F(ScatterDeterminismExpanderTest, + EliminateNonScalarScatterWithNonAssociativeCombiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinisic_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = f32[1,4096] parameter(0) + pad.96 = s32[4096,2] parameter(1) + bitcast.2748 = f32[4096,1,1] parameter(2) + ROOT scatter.48 = f32[1,4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={1,2}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_TRUE(result); +} + TEST_F(ScatterDeterminismExpanderTest, DoNotEliminateScatterWithAssociativeFp32Combiner) { const char* const kModuleStr = R"( @@ -148,7 +178,7 @@ TEST_F(ScatterDeterminismExpanderTest, DoNotEliminateScatterWithOneUpdate) { EXPECT_FALSE(result); } -TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { +TEST_F(ScatterDeterminismExpanderTest, ScalarScatterAddCorrectnessTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -170,6 +200,9 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); ScatterDeterminismExpander scatter_determinism_expander; TF_ASSERT_OK_AND_ASSIGN( @@ -177,7 +210,46 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { EXPECT_TRUE(result); - std::vector expected_result = {2.0, 16.0, 14.0, 3.0}; + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScalarScatterAddOutOfBoundCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[7,1] constant({{0}, {1}, {5}, {4}, {1}, {1}, {2}}) + updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); Literal result_literal = ExecuteAndTransfer(std::move(module), {}); @@ -187,6 +259,477 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { EXPECT_EQ(actual_result, expected_result); } +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[3, 2] constant({{0, 0}, {1, 1}, {2,2}}) + updates = f32[3] constant({2, 1, 3}) + ROOT scatter.48 = f32[3,3] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0, 1}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarUpdateCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[3, 1] constant({{1}, {2}, {3}}) + updates = f32[3, 2] constant({{1, 2}, {4, 7}, {10, 13}}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest1) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={1, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest2) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest3) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest4) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest1) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1, 2}, + scatter_dims_to_operand_dims={2, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest2) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1, 2}, + scatter_dims_to_operand_dims={2, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest3) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[2, 2] constant({{0, 0}, {1, 1}}) + updates = f32[2, 2, 2] constant({{{1, 2}, {4, 7}}, {{10, 13}, {21, 27}}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1, 2}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={2, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest4) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[2, 2] constant({{0, 0}, {1, 1}}) + updates = f32[2, 2, 2] constant({{{1, 2}, {4, 7}}, {{10, 13}, {21, 27}}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1, 2}, inserted_window_dims={2}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, ComplicatedMultiDimensionalScatterTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg0 = f32[] parameter(0) + arg1 = f32[] parameter(1) + ROOT add.48 = f32[] add(arg0, arg1) + } + + ENTRY fused_computation { + p0 = f32[1,1,3072,3]{3,2,1,0} parameter(0) + p1 = s32[1,1,128,2,3]{4,3,2,1,0} parameter(1) + p2 = f32[1,1,128,2,3]{4,3,2,1,0} parameter(2) + ROOT scatter.50 = f32[1,1,3072,3]{3,2,1,0} scatter(p0, p1, p2), update_window_dims={4}, inserted_window_dims={0,1,2}, scatter_dims_to_operand_dims={0,1,2}, index_vector_dim=4, to_apply=scatter_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_TRUE(result); +} + TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -199,8 +742,8 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { ENTRY scatter_add_computation { operand = f32[2] constant({0, 0}) - indices = s32[3,1] constant({{0}, {1}, {1}}) - updates = f32[3] constant({2, 1, 5}) + indices = s32[2,1] constant({{1}, {1}}) + updates = f32[2] constant({2, 1}) ROOT scatter.48 = f32[2] scatter(operand, indices, updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, @@ -209,48 +752,41 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { const char* const kExpectedPattern = R"( CHECK: ENTRY %scatter_add_computation () -> f32[2] { - CHECK-DAG: %[[INDICES:.*]] = s32[3,1]{1,0} constant({ {0}, {1}, {1} }) - CHECK-DAG: %[[RESHAPE:.*]] = s32[3]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[INDICES:.*]] = s32[2,1]{1,0} constant({ {1}, {1} }) + CHECK-DAG: %[[RESHAPE:.*]] = s32[2]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[IOTA:.*]] = s32[2]{0} iota(), iota_dimension=0 CHECK-DAG: %[[OPERAND:.*]] = f32[2]{0} constant({0, 0}) - CHECK-DAG: %[[RESHAPE1:.*]] = s32[3]{0} reshape(%[[INDICES]]) - CHECK-DAG: %[[UPDATES:.*]] = f32[3]{0} constant({2, 1, 5}) - CHECK-DAG: %[[TRANSPOSE:.*]] = f32[3]{0} transpose(%[[UPDATES]]), dimensions={0} - CHECK-DAG: %[[RESHAPE2:.*]] = f32[3]{0} reshape(%[[TRANSPOSE]]) - CHECK-DAG: %[[SORT:.*]] = (s32[3]{0}, f32[3]{0}) sort(%[[RESHAPE1]], %[[RESHAPE2]]), dimensions={0}, to_apply=%sorting_computation - CHECK-DAG: %[[GET_TUPLE_ELEMENT:.*]] = s32[3]{0} get-tuple-element(%[[SORT]]), index=0 - CHECK-DAG: %[[SLICE4:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} - CHECK-DAG: %[[SLICE5:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[1:3]} - CHECK-DAG: %[[COMPARE3:.*]] = pred[2]{0} compare(%[[SLICE4]], %[[SLICE5]]), direction=NE - CHECK-DAG: %[[CONSTANT4:.*]] = pred[] constant(true) - CHECK-DAG: %[[BROADCAST4:.*]] = pred[1]{0} broadcast(%[[CONSTANT4]]), dimensions={} - CHECK-DAG: %[[CONCAT_COMPARE4:.*]] = pred[3]{0} concatenate(%[[COMPARE3]], %[[BROADCAST4]]), dimensions={0} - CHECK-DAG: %[[BROADCAST5:.*]] = pred[3]{0} broadcast(%[[CONCAT_COMPARE4]]), dimensions={0} - CHECK-DAG: %[[CONSTANT5:.*]] = s32[3]{0} constant({2, 2, 2}) - CHECK-DAG: %[[SELECT2:.*]] = s32[3]{0} select(%[[BROADCAST5]], %[[GET_TUPLE_ELEMENT]], %[[CONSTANT5]]) - CHECK-DAG: %[[CONSTANT3:.*]] = s32[] constant(0) - CHECK-DAG: %[[BROADCAST3:.*]] = s32[2]{0} broadcast(%[[CONSTANT3]]), dimensions={} - CHECK-DAG: %[[SLICE3:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} - CHECK-DAG: %[[CONCAT3:.*]] = s32[3]{0} concatenate(%[[BROADCAST3]], %[[SLICE3]]), dimensions={0} - CHECK-DAG: %[[COMPARE2:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT3]]), direction=EQ - CHECK-DAG: %[[CONSTANT1:.*]] = s32[] constant(0) - CHECK-DAG: %[[BROADCAST1:.*]] = s32[1]{0} broadcast(%[[CONSTANT1]]), dimensions={} - CHECK-DAG: %[[SLICE1:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} - CHECK-DAG: %[[CONCAT1:.*]] = s32[3]{0} concatenate(%[[BROADCAST1]], %[[SLICE1]]), dimensions={0} - CHECK-DAG: %[[COMPARE1:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT1]]), direction=EQ - CHECK-DAG: %[[GET_TUPLE_ELEMENT1:.*]] = f32[3]{0} get-tuple-element(%[[SORT]]), index=1 - CHECK-DAG: %[[CONSTANT_F32:.*]] = f32[] constant(0) - CHECK-DAG: %[[BROADCAST_F32:.*]] = f32[1]{0} broadcast(%[[CONSTANT_F32]]), dimensions={} - CHECK-DAG: %[[SLICE_F32:.*]] = f32[2]{0} slice(%[[GET_TUPLE_ELEMENT1]]), slice={[0:2]} - CHECK-DAG: %[[CONCAT_F32:.*]] = f32[3]{0} concatenate(%[[BROADCAST_F32]], %[[SLICE_F32]]), dimensions={0} - CHECK-DAG: %[[MAP:.*]] = f32[3]{0} map(%[[GET_TUPLE_ELEMENT1]], %[[CONCAT_F32]]), dimensions={0}, to_apply=%scatter_computation - CHECK-DAG: %[[SELECT:.*]] = f32[3]{0} select(%[[COMPARE1]], %[[MAP]], %[[GET_TUPLE_ELEMENT1]]) - CHECK-DAG: %[[CONSTANT2:.*]] = f32[] constant(0) - CHECK-DAG: %[[BROADCAST2:.*]] = f32[2]{0} broadcast(%[[CONSTANT2]]), dimensions={} - CHECK-DAG: %[[SLICE2:.*]] = f32[1]{0} slice(%[[SELECT]]), slice={[0:1]} - CHECK-DAG: %[[CONCAT2:.*]] = f32[3]{0} concatenate(%[[BROADCAST2]], %[[SLICE2]]), dimensions={0} - CHECK-DAG: %[[MAP1:.*]] = f32[3]{0} map(%[[SELECT]], %[[CONCAT2]]), dimensions={0}, to_apply=%scatter_computation - CHECK-DAG: %[[SELECT1:.*]] = f32[3]{0} select(%[[COMPARE2]], %[[MAP1]], %[[SELECT]]) - CHECK-DAG: ROOT %[[SCATTER:.*]] = f32[2]{0} scatter(%[[OPERAND]], %[[SELECT2]], %[[SELECT1]]), + CHECK-DAG: %[[RESHAPE1:.*]] = s32[2]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[RESHAPE2:.*]] = s32[2,1]{1,0} reshape(%[[RESHAPE1]]) + CHECK-DAG: %[[RESHAPE4:.*]] = s32[2]{0} reshape(%[[RESHAPE2]]) + CHECK-DAG: %[[UPDATES:.*]] = f32[2]{0} constant({2, 1}) + CHECK-DAG: %[[TRANSPOSE:.*]] = f32[2]{0} transpose(%[[UPDATES]]), dimensions={0} + CHECK-DAG: %[[RESHAPE3:.*]] = f32[2]{0} reshape(%[[TRANSPOSE]]) + CHECK-DAG: %[[SORT:.*]] = (s32[2]{0}, f32[2]{0}) sort(%[[RESHAPE4]], %[[RESHAPE3]]), dimensions={0}, to_apply=%sorting_computation + CHECK-DAG: %[[GET_TUPLE_ELEMENT:.*]] = s32[2]{0} get-tuple-element(%[[SORT]]), index=0 + CHECK-DAG: %[[SLICE2:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} + CHECK-DAG: %[[SLICE3:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[1:2]} + CHECK-DAG: %[[COMPARE2:.*]] = pred[1]{0} compare(%[[SLICE2]], %[[SLICE3]]), direction=NE + CHECK-DAG: %[[CONSTANT3:.*]] = pred[] constant(true) + CHECK-DAG: %[[BROADCAST2:.*]] = pred[1]{0} broadcast(%[[CONSTANT3]]), dimensions={} + CHECK-DAG: %[[CONCATENATE2:.*]] = pred[2]{0} concatenate(%[[COMPARE2]], %[[BROADCAST2]]), dimensions={0} + CHECK-DAG: %[[BROADCAST3:.*]] = pred[2,1]{1,0} broadcast(%[[CONCATENATE2]]), dimensions={0} + CHECK-DAG: %[[RESHAPE5:.*]] = s32[2,1]{1,0} reshape(%[[GET_TUPLE_ELEMENT]]) + CHECK-DAG: %[[CONSTANT:.*]] = s32[2,1]{1,0} constant({ {2}, {2} }) + CHECK-DAG: %[[SELECT1:.*]] = s32[2,1]{1,0} select(%[[BROADCAST3]], %[[RESHAPE5]], %[[CONSTANT]]) + CHECK-DAG: %[[CONSTANT2:.*]] = s32[] constant(0) + CHECK-DAG: %[[BROADCAST1:.*]] = s32[1]{0} broadcast(%[[CONSTANT2]]), dimensions={} + CHECK-DAG: %[[SLICE1:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} + CHECK-DAG: %[[CONCATENATE1:.*]] = s32[2]{0} concatenate(%[[BROADCAST1]], %[[SLICE1]]), dimensions={0} + CHECK-DAG: %[[COMPARE1:.*]] = pred[2]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCATENATE1]]), direction=EQ + CHECK-DAG: %[[GET_TUPLE_ELEMENT1:.*]] = f32[2]{0} get-tuple-element(%[[SORT]]), index=1 + CHECK-DAG: %[[CONSTANT1:.*]] = f32[] constant(0) + CHECK-DAG: %[[BROADCAST:.*]] = f32[1]{0} broadcast(%[[CONSTANT1]]), dimensions={} + CHECK-DAG: %[[SLICE:.*]] = f32[1]{0} slice(%[[GET_TUPLE_ELEMENT1]]), slice={[0:1]} + CHECK-DAG: %[[CONCATENATE:.*]] = f32[2]{0} concatenate(%[[BROADCAST]], %[[SLICE]]), dimensions={0} + CHECK-DAG: %[[MAP:.*]] = f32[2]{0} map(%[[GET_TUPLE_ELEMENT1]], %[[CONCATENATE]]), dimensions={0}, to_apply=%scatter_computation + CHECK-DAG: %[[SELECT:.*]] = f32[2]{0} select(%[[COMPARE1]], %[[MAP]], %[[GET_TUPLE_ELEMENT1]]) + CHECK-DAG: ROOT %[[SCATTER:.*]] = f32[2]{0} scatter(%[[OPERAND]], %[[SELECT1]], %[[SELECT]]), CHECK-SAME: update_window_dims={}, CHECK-SAME: inserted_window_dims={0}, CHECK-SAME: scatter_dims_to_operand_dims={0}, @@ -264,7 +800,7 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { kExpectedPattern); } -TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { +TEST_F(ScatterDeterminismExpanderTest, ScalarScatterAddReproducibilityTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -275,10 +811,30 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { } ENTRY scatter_add_computation { - operand = f32[4] constant({0, 0, 0, 0}) - indices = s32[7,1] constant({{0}, {1}, {5}, {4}, {1}, {1}, {2}}) - updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) - ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + operand = f32[3] constant({0, 0, 0}) + indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, + {0}, {1}, {1}, {2}, {0}, {2}, {1}, {2}, {1}, {2}, {2}, {3}, {2}, {2}, {0}, + {3}, {0}, {3}, {2}, {0}, {3}, {3}, {3}, {3}, {3}, {2}, {3}, {3}, {0}, {0}, + {3}, {3}, {3}, {2}, {3}, {2}, {3}, {0}, {0}, {2}, {0}, {1}, {3}, {1}, {3}, + {2}, {2}, {2}, {1}, {0}, {3}, {1}, {1}, {1}, {1}, {1}, {2}, {2}, {3}, {0}, + {2}, {2}, {0}, {2}, {1}, {0}, {2}, {2}, {2}, {0}, {2}, {0}, {1}, {3}, {0}, + {2}, {3}, {3}, {2}, {0}, {3}, {3}, {2}, {3}, {2}}) + updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, + 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, + 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, + 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, + 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, + 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, + 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, + 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, + 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, + 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, + 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, + 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, + 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, + 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, + 0.878675, 0.4102913}) + ROOT scatter.48 = f32[3] scatter(operand, indices, updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=scatter_computation @@ -293,17 +849,30 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { EXPECT_TRUE(result); - std::vector expected_result = {2.0, 16.0, 9.0, 0.0}; + auto cloned_module = module->Clone(); + Literal first_result_literal = + ExecuteAndTransfer(std::move(cloned_module), {}); + auto first_result_span = first_result_literal.data(); + std::vector first_result(first_result_span.begin(), + first_result_span.end()); - Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + const int num_trials = 20; + std::vector> results; - auto result_data = result_literal.data(); - std::vector actual_result(result_data.begin(), result_data.end()); + for (int i = 0; i < num_trials; ++i) { + auto cloned_module = module->Clone(); - EXPECT_EQ(actual_result, expected_result); + Literal result_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, first_result) + << "Results are not reproducible across trials!"; + } } -TEST_F(ScatterDeterminismExpanderTest, ScatterAddReproducibilityTest) { +TEST_F(ScatterDeterminismExpanderTest, NonScalarScatterAddReproducibilityTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -314,12 +883,32 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddReproducibilityTest) { } ENTRY scatter_add_computation { - operand = f32[3] constant({0, 0, 0}) - indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, {0}, {1}, {1}, {2}, {0}, {2}, {1}, {2}, {1}, {2}, {2}, {3}, {2}, {2}, {0}, {3}, {0}, {3}, {2}, {0}, {3}, {3}, {3}, {3}, {3}, {2}, {3}, {3}, {0}, {0}, {3}, {3}, {3}, {2}, {3}, {2}, {3}, {0}, {0}, {2}, {0}, {1}, {3}, {1}, {3}, {2}, {2}, {2}, {1}, {0}, {3}, {1}, {1}, {1}, {1}, {1}, {2}, {2}, {3}, {0}, {2}, {2}, {0}, {2}, {1}, {0}, {2}, {2}, {2}, {0}, {2}, {0}, {1}, {3}, {0}, {2}, {3}, {3}, {2}, {0}, {3}, {3}, {2}, {3}, {2}}) - updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, 0.878675, 0.4102913}) - ROOT scatter.48 = f32[3] scatter(operand, indices, updates), - update_window_dims={}, inserted_window_dims={0}, - scatter_dims_to_operand_dims={0}, index_vector_dim=1, + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[50, 2] constant({{0, 0}, {0, 1}, {1, 1}, {2, 2}, {0, 1}, {1, 0}, {2, 1}, {1, 2}, {0, 2}, {2, 0}, + {1, 1}, {2, 2}, {0, 0}, {0, 1}, {2, 1}, {1, 2}, {2, 0}, {0, 2}, {1, 0}, {1, 1}, + {1, 2}, {2, 1}, {0, 0}, {1, 1}, {0, 2}, {2, 0}, {1, 0}, {2, 2}, {1, 2}, {0, 1}, + {2, 1}, {1, 0}, {0, 2}, {2, 0}, {0, 1}, {2, 1}, {1, 1}, {1, 0}, {2, 2}, {0, 0}, + {0, 1}, {1, 2}, {2, 0}, {1, 1}, {0, 2}, {2, 1}, {1, 2}, {2, 1}, {1, 1}, {0, 2}}) + updates = f32[50, 2] constant({{0.02379167, 0.8527204}, {0.8132185, 0.5140263}, {0.17172801, 0.8026866}, + {0.5124631, 0.34838438}, {0.50526905, 0.3370521}, {0.10868239, 0.10520637}, + {0.83827364, 0.78986526}, {0.34059846, 0.8349273}, {0.24575627, 0.21387374}, + {0.02423227, 0.5617423}, {0.28066766, 0.94366455}, {0.61214995, 0.7383388}, + {0.52419806, 0.65466726}, {0.41012764, 0.24028647}, {0.74443066, 0.03544927}, + {0.851014, 0.02434528}, {0.47239733, 0.72706807}, {0.35055435, 0.6274171}, + {0.61077535, 0.06525731}, {0.8091929, 0.21307838}, {0.6465323, 0.3245015}, + {0.5538883, 0.8849807}, {0.9591211, 0.83856845}, {0.48919427, 0.11810577}, + {0.16933143, 0.83657074}, {0.587505, 0.6867087}, {0.95522237, 0.5797727}, + {0.28024232, 0.34749162}, {0.5199702, 0.9811766}, {0.5645981, 0.2446456}, + {0.68722725, 0.9616587}, {0.480047, 0.88953114}, {0.7083205, 0.948612}, + {0.67764974, 0.44131804}, {0.36789334, 0.95148766}, {0.30909216, 0.70908046}, + {0.8749926, 0.60973287}, {0.60751855, 0.22647333}, {0.5363518, 0.96195626}, + {0.08158326, 0.5266887}, {0.85922587, 0.648262}, {0.4657668, 0.31623375}, + {0.43507564, 0.48351157}, {0.41285944, 0.73501325}, {0.15267539, 0.67055714}, + {0.08459568, 0.04527426}, {0.21078384, 0.4654404}, {0.7363906, 0.23245859}, + {0.22119188, 0.99092937}, {0.878675, 0.4102913}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, to_apply=scatter_computation })"; diff --git a/xla/xla.proto b/xla/xla.proto index 6ab37cf9b4e73d..1e2582930dde80 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1030,6 +1030,13 @@ message DebugOptions { } PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341; + // Enable the scatter determinism expander, an optimized pass that + // rewrites scatter operations to ensure deterministic behavior with high + // performance. + // Note that even when this flag is disabled, scatter operations may still + // be deterministic, although with additional overhead. + bool xla_gpu_enable_scatter_determinism_expander = 342; + // Next id: 343 // Extra options to pass to the compilation backend (e.g. LLVM); specific