From 3d1f6b230cb29cbada4965126b406b8cd4cc31f6 Mon Sep 17 00:00:00 2001 From: Chenhao Jiang Date: Fri, 15 Nov 2024 11:50:24 -0800 Subject: [PATCH] PR #19275: [NVIDIA] Add fixes for supporting determinism expander for high-dimensional scatter operation and a flag to disable it Imported from GitHub PR https://github.com/openxla/xla/pull/19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of https://github.com/openxla/xla/pull/17886, and has fixed issues reported in the reverted PR https://github.com/openxla/xla/pull/18326. The issue was that the changes in https://github.com/openxla/xla/pull/18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: https://github.com/jax-ml/jax/issues/17844 Copybara import of the project: -- 3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang : PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR https://github.com/openxla/xla/pull/18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of https://github.com/openxla/xla/pull/17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: https://github.com/jax-ml/jax/issues/17844 Copybara import of the project: -- de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang : Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang : Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang : Fix the scatter determinism expander for various dimension numbers -- 985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang : Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52 PiperOrigin-RevId: 696956113 --- xla/debug_options_flags.cc | 11 + xla/service/BUILD | 1 + xla/service/gpu/gpu_compiler.cc | 4 +- xla/service/scatter_determinism_expander.cc | 607 +++++++++++++-- .../scatter_determinism_expander_test.cc | 713 ++++++++++++++++-- xla/xla.proto | 9 +- 6 files changed, 1202 insertions(+), 143 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index ee233faf5ff0b..0ce60d0bc6e57 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_fast_math(false); opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1); opts.set_xla_pjrt_allow_auto_layout_in_hlo(false); + opts.set_xla_gpu_enable_scatter_determinism_expander(true); return opts; } @@ -2064,6 +2065,16 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Experimental: Make unset entry computation layout mean auto layout " "instead of default layout in HLO when run through PjRT. In other cases " "(StableHLO or non-PjRT) the auto layout is already used.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_scatter_determinism_expander", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_scatter_determinism_expander), + debug_options->xla_gpu_enable_scatter_determinism_expander(), + "Enable the scatter determinism expander, an optimized pass that " + "rewrites scatter operations to ensure deterministic behavior with high " + "performance." + "Note that even when this flag is disabled, scatter operations may still " + "be deterministic, although with additional overhead.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/xla/service/BUILD b/xla/service/BUILD index 80fc0ccf0a893..f3a02c1d66961 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2121,6 +2121,7 @@ cc_library( "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 41a463f9ed5d2..7409231544c98 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses( if (RequireDeterminism(hlo_module->config())) { // Scatter can be indeterministic if indices are not unique or a non // associative combiner function is used. Eliminate these Scatter ops. - pipeline.AddPass(); + if (debug_options.xla_gpu_enable_scatter_determinism_expander()) { + pipeline.AddPass(); + } pipeline.AddPass( ScatterExpander::kEliminateIndeterministicScatters); } diff --git a/xla/service/scatter_determinism_expander.cc b/xla/service/scatter_determinism_expander.cc index ea462cc3c08fc..e9fcb5bcd6439 100644 --- a/xla/service/scatter_determinism_expander.cc +++ b/xla/service/scatter_determinism_expander.cc @@ -15,11 +15,14 @@ limitations under the License. #include "xla/service/scatter_determinism_expander.h" +#include #include +#include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "xla/array.h" #include "xla/array2d.h" #include "xla/comparison_util.h" @@ -32,6 +35,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/scatter_utils.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -62,47 +66,128 @@ static absl::StatusOr> CanonicalizeScatterUpdates( return adjusted_updates; } -// Create the out-of-bound tensor for the scatter operation. -HloInstruction* CreateOutOfBoundTensor(HloComputation* parent, - HloInstruction* scatter_indices, - const Shape& scatter_shape) { - if (scatter_indices->shape().rank() == 1) { - CHECK_EQ(scatter_shape.dimensions_size(), 1); - Array out_of_bound_array({scatter_indices->shape().dimensions(0)}, - scatter_shape.dimensions(0)); - return parent->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromArray(out_of_bound_array))); - } - // More than one dimension in scatter_indices - Array2D out_of_bound_array(scatter_indices->shape().dimensions(0), - scatter_indices->shape().dimensions(1)); +template +HloInstruction* CreateBoundTensorGeneric( + HloComputation* parent, HloInstruction* scatter_indices, + absl::Span operand_dims, + absl::Span index_to_operand_map, bool is_out_of_bound = true, + std::optional> window_sizes = std::nullopt) { + CHECK_GT(scatter_indices->shape().dimensions_size(), 1); + Array2D out_of_bound_array(scatter_indices->shape().dimensions(0), + operand_dims.size()); for (int i = 0; i < scatter_indices->shape().dimensions(0); ++i) { - for (int j = 0; j < scatter_indices->shape().dimensions(1); ++j) { - out_of_bound_array(i, j) = scatter_shape.dimensions(j); + for (int j = 0; j < operand_dims.size(); ++j) { + int mapped_index = index_to_operand_map[j]; + out_of_bound_array(i, j) = + is_out_of_bound + ? operand_dims[mapped_index] + : operand_dims[mapped_index] - (*window_sizes)[mapped_index]; + // : operand_dims[j] - (*window_sizes)[j] ; } } return parent->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(out_of_bound_array))); + LiteralUtil::CreateR2FromArray2D(out_of_bound_array))); +} + +// Creates a tensor for the scatter operation based on the value of +// is_out_of_bound. +// +// When is_out_of_bound is true, the tensor is filled with values representing +// the maximum bounds of the scatter shape (out-of-bound values). This is used +// to simulate out-of-bound conditions in the scatter operation. +// +// When is_out_of_bound is false, the tensor is filled with the maximum valid +// indices (calculated as operand_dimensions - window_dimensions). This is used +// to check whether indices are within valid bounds for non-scalar updates. +// +// This function is reusable for both out-of-bound tensor generation and valid +// index checks in scatter operations with non-scalar updates. +absl::StatusOr CreateBoundTensor( + HloComputation* parent, HloInstruction* scatter_indices, + absl::Span operand_dims, + absl::Span index_to_operand_map, bool is_out_of_bound = true, + std::optional> window_sizes = std::nullopt) { + if (!is_out_of_bound && !window_sizes.has_value()) { + return FailedPrecondition( + "window_sizes must be provided when is_out_of_bound is false."); + } + + PrimitiveType type = scatter_indices->shape().element_type(); + if (type == S32) { + return CreateBoundTensorGeneric(parent, scatter_indices, + operand_dims, index_to_operand_map, + is_out_of_bound, window_sizes); + } else if (type == S64) { + return CreateBoundTensorGeneric(parent, scatter_indices, + operand_dims, index_to_operand_map, + is_out_of_bound, window_sizes); + } + return FailedPrecondition("Unexpected type for bound tensor: %s", + PrimitiveType_Name(type)); +} + +// indices shape: (num_indices, num_dims) +// updates shape: (num_indices,) +HloInstruction* FlattenIndices(HloComputation* parent, HloInstruction* indices, + absl::Span operand_dims) { + if (indices->shape().dimensions(1) == 1) { + // Originally scalar indices + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {indices->shape().dimensions(0)}), + indices)); + } + // Step 1: based on the operand_dims, calculate the strides + Array2D strides(operand_dims.size(), 1); + int64_t stride = 1; + for (int i = operand_dims.size() - 1; i >= 0; --i) { + strides(i, 0) = stride; + stride *= operand_dims[i]; + } + auto strides_tensor = parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(strides))); + + // Step 2: calculate the flattened indices + auto dot_shape = ShapeUtil::MakeShape(indices->shape().element_type(), + {indices->shape().dimensions(0), 1}); + DotDimensionNumbers dim_numbers; + dim_numbers.add_lhs_contracting_dimensions(1); + dim_numbers.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + auto flattened_indices = parent->AddInstruction(HloInstruction::CreateDot( + dot_shape, indices, strides_tensor, dim_numbers, precision_config)); + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {indices->shape().dimensions(0)}), + flattened_indices)); } // Computation for sorting the scalar scatter indices and updates together -HloComputation* ScalarSortingComparison(HloModule* module, - const Shape key_shape, - const Shape update_shape, - int64_t num_updates) { +static HloComputation* SortingComparison(HloModule* module, + const PrimitiveType indices_type, + const PrimitiveType updates_type, + int64_t num_updates, + bool has_scalar_indices) { + Shape key_shape = ShapeUtil::MakeShape(indices_type, {}); + Shape update_shape = ShapeUtil::MakeShape(updates_type, {}); HloComputation::Builder builder("sorting_computation"); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, key_shape, "lhs_key")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, key_shape, "rhs_key")); - const int kExistingParams = 2; + int param_count = 2; for (int i = 0; i < num_updates; ++i) { - builder.AddInstruction( - HloInstruction::CreateParameter(kExistingParams + i, update_shape, - absl::StrFormat("lhs_update_%d", i))); - builder.AddInstruction( - HloInstruction::CreateParameter(kExistingParams + 1 + i, update_shape, - absl::StrFormat("rhs_update_%d", i))); + builder.AddInstruction(HloInstruction::CreateParameter( + param_count, update_shape, absl::StrFormat("lhs_update_%d", i))); + builder.AddInstruction(HloInstruction::CreateParameter( + param_count + 1, update_shape, absl::StrFormat("rhs_update_%d", i))); + param_count += 2; + } + if (!has_scalar_indices) { + builder.AddInstruction(HloInstruction::CreateParameter( + param_count, key_shape, "lhs_permutation")); + builder.AddInstruction(HloInstruction::CreateParameter( + param_count + 1, key_shape, "rhs_permutation")); } builder.AddInstruction( HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, @@ -113,14 +198,20 @@ HloComputation* ScalarSortingComparison(HloModule* module, static std::vector SortIndicesAndUpdates( HloInstruction* scatter_indices, const std::vector& scatter_updates, int64_t num_indices, - HloScatterInstruction* scatter, HloComputation* parent) { + HloScatterInstruction* scatter, HloComputation* parent, + absl::Span operand_dims, bool has_scalar_indices) { const Shape& indices_shape = scatter_indices->shape(); const Shape& updates_shape = scatter_updates[0]->shape(); auto updates_dims = updates_shape.dimensions(); // Since we canonicalized the scatter updates, the first dim will always be // the number of updates and the rest will be the shape of each update + HloInstruction* scalar_indices = + FlattenIndices(scatter->parent(), scatter_indices, operand_dims); - HloInstruction* scalar_indices = scatter_indices; + // Create the shape for a single index tuple + // Create [0...num_indices] tensor for permutation in sorting + auto indices_permutation = parent->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(indices_shape.element_type(), {num_indices}), 0)); std::vector single_update_dimensions(updates_dims.begin() + 1, updates_dims.end()); @@ -130,19 +221,23 @@ static std::vector SortIndicesAndUpdates( const Shape& scalar_index_shape = ShapeUtil::MakeShape(indices_shape.element_type(), {num_indices}); + auto* comparison = SortingComparison( + scatter->GetModule(), indices_shape.element_type(), + updates_shape.element_type(), scatter_updates.size(), has_scalar_indices); - auto* comparison = ScalarSortingComparison( - scatter->GetModule(), - ShapeUtil::MakeShape(indices_shape.element_type(), {}), - ShapeUtil::MakeShape(updates_shape.element_type(), {}), - scatter_updates.size()); - + // The sorting operation contains the scalar indices and the updates, and if + // the scatter indices were not scalar, the sorting operation will also + // contain the indices permutation std::vector sort_operands = {scalar_indices}; std::vector sort_shapes = {scalar_index_shape}; for (auto update : scatter_updates) { sort_operands.push_back(update); sort_shapes.push_back(update->shape()); } + if (!has_scalar_indices) { + sort_operands.push_back(indices_permutation); + sort_shapes.push_back(indices_permutation->shape()); + } auto* sorting = parent->AddInstruction(HloInstruction::CreateSort( ShapeUtil::MakeTupleShape(sort_shapes), 0, sort_operands, comparison, @@ -160,6 +255,34 @@ static std::vector SortIndicesAndUpdates( std::vector sorted_tensors = {sorted_scalar_indices}; sorted_tensors.insert(sorted_tensors.end(), sorted_updates.begin(), sorted_updates.end()); + if (has_scalar_indices) { + return sorted_tensors; + } + // When the scatter indices were not scalar, need to return the sorted scatter + // indices + auto* sorted_indices_arg = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + indices_permutation->shape(), sorting, sorted_tensors.size())); + sorted_indices_arg = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(sorted_indices_arg->shape().element_type(), + {num_indices, 1}), + sorted_indices_arg)); + // Use gather of sorted_indices_arg to get the sorted original indices + GatherDimensionNumbers gather_dim_numbers; + gather_dim_numbers.add_offset_dims( + 1); // Preserving the inner dimension (columns) + gather_dim_numbers.add_start_index_map( + 0); // Mapping start_indices to the first dimension of the operand + gather_dim_numbers.add_collapsed_slice_dims(0); + gather_dim_numbers.set_index_vector_dim(1); + std::vector slice_sizes = {1, + scatter_indices->shape().dimensions(1)}; + auto* sorted_expanded_indices = + parent->AddInstruction(HloInstruction::CreateGather( + scatter_indices->shape(), scatter_indices, sorted_indices_arg, + gather_dim_numbers, slice_sizes, + /*indices_are_sorted=*/true)); + sorted_tensors.push_back(sorted_expanded_indices); return sorted_tensors; } @@ -259,7 +382,6 @@ absl::StatusOr> ComputePrefixScan( std::vector prefix_scans(sorted_updates.size()); HloInstruction* prefix_scan_update = nullptr; for (int i = 0; i < sorted_updates.size(); i++) { - // TODO(chenhao) change to use the extracted computation TF_ASSIGN_OR_RETURN( HloComputation * to_apply, CallComputationAndGetIthOutputWithBinaryParams(scatter->to_apply(), i)); @@ -273,19 +395,18 @@ absl::StatusOr> ComputePrefixScan( } static HloInstruction* FindLastOccurrenceIndices( - HloInstruction* scatter_indices, HloInstruction* sorted_scalar_indices, - HloInstruction* scatter, HloComputation* parent, int64_t num_indices) { - int64_t indices_len = sorted_scalar_indices->shape().dimensions(0); - HloInstruction* sorted_indices = sorted_scalar_indices; + HloInstruction* sorted_indices, HloInstruction* sorted_scalar_indices, + HloInstruction* scatter, HloComputation* parent, int64_t num_indices, + HloInstruction* out_of_bound_tensor) { + int64_t indices_len = sorted_indices->shape().dimensions(0); + const PrimitiveType& indices_type = sorted_indices->shape().element_type(); auto* sorted_indices_preceding_part = parent->AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(scatter_indices->shape().element_type(), - {indices_len - 1}), + ShapeUtil::MakeShape(indices_type, {indices_len - 1}), sorted_scalar_indices, {0}, {indices_len - 1}, {1})); auto* sorted_indices_following_part = parent->AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(scatter_indices->shape().element_type(), - {indices_len - 1}), + ShapeUtil::MakeShape(indices_type, {indices_len - 1}), sorted_scalar_indices, {1}, {indices_len}, {1})); auto* indices_mask_without_padding = parent->AddInstruction(HloInstruction::CreateCompare( @@ -304,18 +425,227 @@ static HloInstruction* FindLastOccurrenceIndices( // Mask the indices indices_mask = parent->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(PRED, scatter_indices->shape().dimensions()), + ShapeUtil::MakeShape(PRED, sorted_indices->shape().dimensions()), indices_mask, {0})); - auto* out_of_bound_tensor = - CreateOutOfBoundTensor(parent, scatter_indices, scatter->shape()); - auto* masked_indices = parent->AddInstruction(HloInstruction::CreateTernary( sorted_indices->shape(), HloOpcode::kSelect, indices_mask, sorted_indices, out_of_bound_tensor)); return masked_indices; } +template +HloInstruction* ExpandIndexOffsetsFromUpdateShape( + HloComputation* parent, const Shape& update_shape, + const ScatterDimensionNumbers& dim_num, const Shape& operand_shape, + absl::Span index_to_operand_map, + absl::Span actual_update_window_dims) { + // Calculate the offset tensor for each element of the update tensor. + // The offset tensor is represented in (num_elements_in_update, index_dim). + + int64_t num_elements = ShapeUtil::ElementsIn(update_shape); + int64_t operand_rank = operand_shape.dimensions_size(); + + Array2D offset_tensor(num_elements, operand_rank); + + std::vector is_inserted_window_dims(operand_rank, false); + for (int i = 0; i < dim_num.inserted_window_dims_size(); ++i) { + is_inserted_window_dims[dim_num.inserted_window_dims(i)] = true; + } + + // Compute the inverse of the index_to_operand_map + std::vector operand_to_index_map(operand_rank, -1); + for (int i = 0; i < operand_rank; ++i) { + operand_to_index_map[index_to_operand_map[i]] = i; + } + + for (int64_t linear_index = 0; linear_index < num_elements; ++linear_index) { + // Calculate the multi-dimensional index from the linear index + int64_t current_index = linear_index; + int inserted_window_dim_size = 0; + // Handle 0th to (operand_rank-2)th dimensions + for (int i = operand_rank - 1; i >= 0; --i) { + if (is_inserted_window_dims[i]) { + inserted_window_dim_size++; + offset_tensor(linear_index, operand_to_index_map[i]) = 0; + } else { + int64_t dim_size = actual_update_window_dims[i]; + offset_tensor(linear_index, operand_to_index_map[i]) = + current_index % dim_size; + current_index /= dim_size; + } + } + } + + // Return the offset tensor as an HloInstruction + return parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(offset_tensor))); +} + +// Expand the indices based on index_offset +HloInstruction* ExpandIndices(HloComputation* parent, HloInstruction* indices, + HloInstruction* index_offsets) { + // For each index we need to add the index_offset to the base index + // To do that, we first broadcast the indices and index_offsets to the same + // shape, then add the index_offset to the base index and flatten the + // result Broadcast to be (num_indices, length_of_index_offsets, + // length_of_indices). + bool is_one_dimensional = indices->shape().dimensions_size() == 1; + + int64_t num_indices = indices->shape().dimensions(0); + int64_t num_offsets = index_offsets->shape().dimensions(0); + int64_t index_length = + is_one_dimensional ? 1 : indices->shape().dimensions(1); + + Shape final_shape = + ShapeUtil::MakeShape(indices->shape().element_type(), + {num_indices, num_offsets, index_length}); + auto broadcasted_indices = + parent->AddInstruction(HloInstruction::CreateBroadcast( + final_shape, indices, + is_one_dimensional ? std::vector{0} + : std::vector{0, 2})); + auto broadcasted_offsets = parent->AddInstruction( + HloInstruction::CreateBroadcast(final_shape, index_offsets, {1, 2})); + auto expanded_indices = parent->AddInstruction(HloInstruction::CreateBinary( + final_shape, HloOpcode::kAdd, broadcasted_indices, broadcasted_offsets)); + // Flatten the result to be (num_indices * num_offsets, index_length) + if (is_one_dimensional) { + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {num_indices * num_offsets}), + expanded_indices)); + } + return parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(indices->shape().element_type(), + {num_indices * num_offsets, index_length}), + expanded_indices)); +} + +// Function to create a reduction computation for logical AND +HloComputation* ReduceAndComputation(HloModule* module) { + // Create a computation builder + HloComputation::Builder builder("reduce_logical_and"); + + // Define the scalar shape for boolean operations + const Shape bool_shape = ShapeUtil::MakeShape(PRED, {}); + + // Add parameters for the reduction computation. + // These represent the elements to be combined (lhs and rhs). + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, bool_shape, "lhs")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, bool_shape, "rhs")); + + // Create the logical AND operation between the two parameters + builder.AddInstruction( + HloInstruction::CreateBinary(bool_shape, HloOpcode::kAnd, lhs, rhs)); + + // Build and return the computation object + return module->AddEmbeddedComputation(builder.Build()); +} + +absl::StatusOr CheckValidIndices( + HloComputation* parent, HloInstruction* indices, + absl::Span operand_dims, + absl::Span window_sizes, + absl::Span full_index_to_operand_dims) { + // check if indices and indices with the largest offsets are out of bound + // Essentially we need to do the following: + // 1. Check base indices >= [0, 0, 0, ...] + // 2. Check last indices <= [bounds...] + // 3. For each check, generate a same size tensor, and then do a reduce across + // rows to get a mask of size (n, 1) + auto init_reduce_value = parent->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + auto reduce_computation = ReduceAndComputation(parent->parent()); + + // 1. Check base indices >= [0, 0, 0, ...] + // first generate a zero tensor of the same size as the indices + auto* zero_constant = parent->AddInstruction( + HloInstruction::CreateConstant(indices->shape().element_type() == S64 + ? LiteralUtil::CreateR0(0) + : LiteralUtil::CreateR0(0))); + auto* zero_broadcasted = parent->AddInstruction( + HloInstruction::CreateBroadcast(indices->shape(), zero_constant, {})); + auto* zero_check = parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, indices->shape().dimensions()), indices, + zero_broadcasted, ComparisonDirection::kGe)); + HloInstruction* zero_check_mask; + // Reduce across rows to get a mask (for multi-dimensional indices). + zero_check_mask = parent->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, {indices->shape().dimensions(0)}), zero_check, + init_reduce_value, {1}, reduce_computation)); + // 2. Check last indices <= [bounds...] + // Check if the index is OOB w.r.t. the operand dimensions and window sizes. + TF_ASSIGN_OR_RETURN( + HloInstruction * max_valid_index_constant, + CreateBoundTensor(parent, indices, operand_dims, + full_index_to_operand_dims, false, window_sizes)); + auto oob_check = parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, indices->shape().dimensions()), + max_valid_index_constant, indices, ComparisonDirection::kGe)); + HloInstruction* oob_check_mask; + if (indices->shape().rank() == 1) { + oob_check_mask = oob_check; + } else { + // Reduce across rows to get a mask (for multi-dimensional indices). + oob_check_mask = parent->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, {indices->shape().dimensions(0)}), oob_check, + init_reduce_value, {1}, reduce_computation)); + } + // Combine the results of the two checks above. + auto* valid_index_mask = parent->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {indices->shape().dimensions(0)}), + HloOpcode::kAnd, zero_check_mask, oob_check_mask)); + return parent->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PRED, indices->shape().dimensions()), + valid_index_mask, {0})); +} + +// Add dimensions that are not covered in the indices_to_operand_map to the end +// of indices +absl::StatusOr AddImplicitDimensionsToIndices( + int64_t operand_rank, absl::Span indices_to_operand_map, + HloInstruction* indices) { + const Shape& indices_shape = indices->shape(); + HloComputation* computation = indices->parent(); + + // Get the batch size (N) and S (number of dimensions in index_vector) + int64_t batch_size = indices_shape.dimensions(0); + int64_t num_indices_dims = indices_to_operand_map.size(); + + // Create a tensor of zeros with the target shape [N, operand_rank] + Shape expanded_shape = ShapeUtil::MakeShape(indices_shape.element_type(), + {batch_size, operand_rank}); + + HloInstruction* zero_filled_tensor = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( + Array2D(batch_size, operand_rank - num_indices_dims, 0)))); + // Concatenate the zero-filled tensor with the index_vector + HloInstruction* expanded_indices = + computation->AddInstruction(HloInstruction::CreateConcatenate( + expanded_shape, {indices, zero_filled_tensor}, 1)); + return expanded_indices; +} + +std::vector ComputeFullIndexToOperandDims( + const Shape& operand_shape, ScatterDimensionNumbers& dim_numbers) { + std::vector full_index_to_operand_dims( + dim_numbers.mutable_scatter_dims_to_operand_dims()->begin(), + dim_numbers.mutable_scatter_dims_to_operand_dims()->end()); + // Add the implicit dimensions to the index_to_operand_map + absl::flat_hash_set existing_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + + for (int i = 0; i < operand_shape.dimensions_size(); i++) { + if (existing_dims.find(i) == existing_dims.end()) + full_index_to_operand_dims.push_back(i); + } + return full_index_to_operand_dims; +} + absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( HloInstruction* inst) { auto* scatter = Cast(inst); @@ -323,8 +653,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( HloInstruction* scatter_indices = scatter->scatter_indices(); std::vector scatter_updates( scatter->scatter_updates().begin(), scatter->scatter_updates().end()); - const ScatterDimensionNumbers& dim_numbers = - scatter->scatter_dimension_numbers(); + ScatterDimensionNumbers dim_numbers = scatter->scatter_dimension_numbers(); // If the updates tensors are empty, there is no need to update the operands. // The operands can be forwarded. @@ -349,33 +678,170 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( // Canonicalize the scatter_indices, after which the size of its most-major // dimension must be same as the while loop trip count. + HloInstruction* original_scatter_indices = scatter_indices; TF_ASSIGN_OR_RETURN(scatter_indices, CanonicalizeScatterIndices( scatter_indices, dim_numbers.index_vector_dim())); CHECK_EQ(scatter_indices_count, scatter_indices->shape().dimensions(0)); + // We compromise for maintainability and make the scatter_indices always 2D, + // so that the implementation could be easier, as we do not need to maintain + // two sets of code for 1D and 2D scatter_indices. + if (scatter_indices->shape().dimensions_size() == 1) { + scatter_indices = + scatter->parent()->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(scatter_indices->shape().element_type(), + {scatter_indices->shape().dimensions(0), 1}), + scatter_indices)); + } + CHECK_GT(scatter_indices->shape().dimensions_size(), 1); + bool has_scalar_indices = scatter_indices->shape().dimensions(1) == 1; // Canonicalize the updates, after which the size of their most-major // dimensions must be same as the while loop trip count. - TF_ASSIGN_OR_RETURN(scatter_updates, CanonicalizeScatterUpdates( - scatter_updates, scatter_indices, - dim_numbers, scatter_indices_count)); + TF_ASSIGN_OR_RETURN( + scatter_updates, + CanonicalizeScatterUpdates(scatter_updates, original_scatter_indices, + dim_numbers, scatter_indices_count)); HloComputation* parent = scatter->parent(); + auto updates_shape = scatter_updates[0]->shape(); + auto updates_dims = scatter_updates[0]->shape().dimensions(); + // Since we canonicalized the scatter updates, the first dim will always be + // the number of updates and the rest will be the shape of each update + std::vector one_update_dimensions(updates_dims.begin() + 1, + updates_dims.end()); + const Shape& update_shape = + ShapeUtil::MakeShape(updates_shape.element_type(), one_update_dimensions); + + ScatterDimensionNumbers new_dim_numbers; + // Check if each update is a scalar based on update shape + bool non_scalar_update = scatter_updates[0]->shape().dimensions_size() > 1; + + std::vector full_index_to_operand_dims = + ComputeFullIndexToOperandDims(scatter_operands[0]->shape(), dim_numbers); + + TF_ASSIGN_OR_RETURN( + HloInstruction * out_of_bound_tensor, + CreateBoundTensor(parent, scatter_indices, scatter->shape().dimensions(), + full_index_to_operand_dims)); + + if (non_scalar_update) { + // Extract operand dimensions + const Shape& operand_shape = scatter_operands[0]->shape(); + + int num_operand_dims = operand_shape.dimensions_size(); + std::vector actual_update_window_dims(num_operand_dims); + int update_dim_index = 0; + for (int i = 0; i < num_operand_dims; ++i) { + if (std::find(dim_numbers.inserted_window_dims().begin(), + dim_numbers.inserted_window_dims().end(), + i) != dim_numbers.inserted_window_dims().end()) { + actual_update_window_dims[i] = 1; + } else { + actual_update_window_dims[i] = + update_shape.dimensions(update_dim_index); + update_dim_index++; + } + } + + HloInstruction* index_offsets = + scatter_indices->shape().element_type() == S32 + ? ExpandIndexOffsetsFromUpdateShape( + scatter->parent(), update_shape, dim_numbers, operand_shape, + full_index_to_operand_dims, actual_update_window_dims) + : ExpandIndexOffsetsFromUpdateShape( + scatter->parent(), update_shape, dim_numbers, operand_shape, + full_index_to_operand_dims, actual_update_window_dims); + + // Map scatter_indices into operand space + TF_ASSIGN_OR_RETURN( + scatter_indices, + AddImplicitDimensionsToIndices( + scatter_operands[0]->shape().dimensions_size(), + dim_numbers.scatter_dims_to_operand_dims(), scatter_indices)); + CHECK(scatter_indices->shape().dimensions(0) == scatter_indices_count); + // If any updates are out of bound, we change the corresponding indices to + // be oob_tensor values + TF_ASSIGN_OR_RETURN( + HloInstruction * oob_check_mask, + CheckValidIndices(scatter->parent(), scatter_indices, + scatter_operands[0]->shape().dimensions(), + actual_update_window_dims, + full_index_to_operand_dims)); + + scatter_indices = parent->AddInstruction(HloInstruction::CreateTernary( + scatter_indices->shape(), HloOpcode::kSelect, oob_check_mask, + scatter_indices, out_of_bound_tensor)); + scatter_indices = + ExpandIndices(scatter->parent(), scatter_indices, index_offsets); + + // Check if the number of indices is the same as + // (num of indices before expanding * num of offsets) + CHECK_EQ(scatter_indices->shape().dimensions(0), + scatter_indices_count * ShapeUtil::ElementsIn(update_shape)); + + // Expand the updates + const int64_t num_elements = + ShapeUtil::ElementsIn(scatter_updates[0]->shape()); + for (int i = 0; i < scatter_updates.size(); i++) { + scatter_updates[i] = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(scatter_updates[i]->shape().element_type(), + {num_elements}), + scatter_updates[i])); + } + + // Create a new dimension numbers for the new scatter operation + // As we have scalar updates, there is no update_window_dims + new_dim_numbers.clear_update_window_dims(); + new_dim_numbers.set_index_vector_dim(1); + // Mitigate the missed dimensions + for (int i = 0; i < operand_shape.dimensions_size() - + dim_numbers.input_batching_dims_size(); + i++) { + new_dim_numbers.add_inserted_window_dims(i); + } + for (int i = 0; i < operand_shape.dimensions_size(); i++) { + new_dim_numbers.add_scatter_dims_to_operand_dims( + full_index_to_operand_dims[i]); + } + } else { + new_dim_numbers = dim_numbers; + } // Sort the scatter indices and updates together based on the scatter indices. int64_t num_indices = ShapeUtil::ElementsIn(scatter_updates[0]->shape()); std::vector sorted_tensors = SortIndicesAndUpdates( - scatter_indices, scatter_updates, num_indices, scatter, parent); + scatter_indices, scatter_updates, num_indices, scatter, parent, + scatter_operands[0]->shape().dimensions(), has_scalar_indices); HloInstruction* sorted_scalar_indices = sorted_tensors[0]; - std::vector sorted_updates(sorted_tensors.begin() + 1, - sorted_tensors.end()); + std::vector sorted_updates( + sorted_tensors.begin() + 1, + sorted_tensors.begin() + 1 + scatter_updates.size()); + HloInstruction* sorted_indices; + if (has_scalar_indices) { + sorted_indices = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(sorted_scalar_indices->shape().element_type(), + {num_indices, 1}), + sorted_scalar_indices)); + } else { + sorted_indices = sorted_tensors[sorted_tensors.size() - 1]; + } TF_ASSIGN_OR_RETURN(std::vector prefix_scan_updates, ComputePrefixScan(sorted_updates, sorted_scalar_indices, scatter, parent)); - - HloInstruction* last_occurrence_indices = FindLastOccurrenceIndices( - scatter_indices, sorted_scalar_indices, scatter, parent, num_indices); + if (non_scalar_update) { + // As the indices are expanded, we need to recompute out-of-bound tensor + // with the same shape + TF_ASSIGN_OR_RETURN( + out_of_bound_tensor, + CreateBoundTensor(parent, sorted_indices, + scatter_operands[0]->shape().dimensions(), + full_index_to_operand_dims)); + } + HloInstruction* last_occurrence_indices = + FindLastOccurrenceIndices(sorted_indices, sorted_scalar_indices, scatter, + parent, num_indices, out_of_bound_tensor); CHECK(last_occurrence_indices != nullptr) << "Last occurrence indices should not be nullptr"; @@ -383,7 +849,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( // Finally, recreate the scatter instruction with unique indices return parent->AddInstruction(HloInstruction::CreateScatter( scatter->shape(), scatter_operands, last_occurrence_indices, - prefix_scan_updates, scatter->to_apply(), dim_numbers, + prefix_scan_updates, scatter->to_apply(), new_dim_numbers, /*indices_are_sorted=*/true, /*unique_indices=*/true)); } @@ -437,25 +903,8 @@ bool CheckOutputDependency(HloComputation* to_apply, int operand_size) { bool ScatterDeterminismExpander::InstructionMatchesPattern( HloInstruction* inst) { auto* scatter = DynCast(inst); - // Need to check if updates and indices are scalar, as the current pass does - // not expand scatter with multi-dimensional updates or indices. This is - // temporary and will be removed in a future PR soon. - if (scatter == nullptr) { - return false; - } - - const Shape& indices_shape = scatter->scatter_indices()->shape(); - const Shape& updates_shape = scatter->scatter_updates()[0]->shape(); - - // Check if indices and updates are effectively 1D. - bool indices_are_1d = - (indices_shape.rank() == 1 || - (indices_shape.rank() == 2 && indices_shape.dimensions(1) == 1)); - bool updates_are_1d = - (updates_shape.rank() == 1 || - (updates_shape.rank() == 2 && updates_shape.dimensions(1) == 1)); - return indices_are_1d && updates_are_1d && !IsScatterDeterministic(scatter) && + return (scatter != nullptr) && !IsScatterDeterministic(scatter) && CheckOutputDependency(scatter->to_apply(), scatter->scatter_operands().size()); } diff --git a/xla/service/scatter_determinism_expander_test.cc b/xla/service/scatter_determinism_expander_test.cc index 366dbc768c5d0..eed978dd9e160 100644 --- a/xla/service/scatter_determinism_expander_test.cc +++ b/xla/service/scatter_determinism_expander_test.cc @@ -89,6 +89,36 @@ TEST_F(ScatterDeterminismExpanderTest, EXPECT_TRUE(result); } +TEST_F(ScatterDeterminismExpanderTest, + EliminateNonScalarScatterWithNonAssociativeCombiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinisic_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = f32[1,4096] parameter(0) + pad.96 = s32[4096,2] parameter(1) + bitcast.2748 = f32[4096,1,1] parameter(2) + ROOT scatter.48 = f32[1,4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={1,2}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_TRUE(result); +} + TEST_F(ScatterDeterminismExpanderTest, DoNotEliminateScatterWithAssociativeFp32Combiner) { const char* const kModuleStr = R"( @@ -148,7 +178,7 @@ TEST_F(ScatterDeterminismExpanderTest, DoNotEliminateScatterWithOneUpdate) { EXPECT_FALSE(result); } -TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { +TEST_F(ScatterDeterminismExpanderTest, ScalarScatterAddCorrectnessTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -170,6 +200,9 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); ScatterDeterminismExpander scatter_determinism_expander; TF_ASSERT_OK_AND_ASSIGN( @@ -177,7 +210,46 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { EXPECT_TRUE(result); - std::vector expected_result = {2.0, 16.0, 14.0, 3.0}; + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScalarScatterAddOutOfBoundCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[7,1] constant({{0}, {1}, {5}, {4}, {1}, {1}, {2}}) + updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); Literal result_literal = ExecuteAndTransfer(std::move(module), {}); @@ -187,6 +259,477 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { EXPECT_EQ(actual_result, expected_result); } +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[3, 2] constant({{0, 0}, {1, 1}, {2,2}}) + updates = f32[3] constant({2, 1, 3}) + ROOT scatter.48 = f32[3,3] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0, 1}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarUpdateCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[3, 1] constant({{1}, {2}, {3}}) + updates = f32[3, 2] constant({{1, 2}, {4, 7}, {10, 13}}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest1) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={1, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest2) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest3) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness2DTest4) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest1) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1, 2}, + scatter_dims_to_operand_dims={2, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest2) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) + updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={1, 2}, + scatter_dims_to_operand_dims={2, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest3) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[2, 2] constant({{0, 0}, {1, 1}}) + updates = f32[2, 2, 2] constant({{{1, 2}, {4, 7}}, {{10, 13}, {21, 27}}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1, 2}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={2, 0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, + ScatterAddWithNonScalarIndexAndUpdateCorrectness3DTest4) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}}) + indices = s32[2, 2] constant({{0, 0}, {1, 1}}) + updates = f32[2, 2, 2] constant({{{1, 2}, {4, 7}}, {{10, 13}, {21, 27}}}) + ROOT scatter.48 = f32[3, 3, 3] scatter(operand, indices, updates), + update_window_dims={1, 2}, inserted_window_dims={2}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto cloned_module = module->Clone(); + Literal expected_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + auto expected_result = expected_literal.data(); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, ComplicatedMultiDimensionalScatterTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg0 = f32[] parameter(0) + arg1 = f32[] parameter(1) + ROOT add.48 = f32[] add(arg0, arg1) + } + + ENTRY fused_computation { + p0 = f32[1,1,3072,3]{3,2,1,0} parameter(0) + p1 = s32[1,1,128,2,3]{4,3,2,1,0} parameter(1) + p2 = f32[1,1,128,2,3]{4,3,2,1,0} parameter(2) + ROOT scatter.50 = f32[1,1,3072,3]{3,2,1,0} scatter(p0, p1, p2), update_window_dims={4}, inserted_window_dims={0,1,2}, scatter_dims_to_operand_dims={0,1,2}, index_vector_dim=4, to_apply=scatter_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_TRUE(result); +} + TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -199,8 +742,8 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { ENTRY scatter_add_computation { operand = f32[2] constant({0, 0}) - indices = s32[3,1] constant({{0}, {1}, {1}}) - updates = f32[3] constant({2, 1, 5}) + indices = s32[2,1] constant({{1}, {1}}) + updates = f32[2] constant({2, 1}) ROOT scatter.48 = f32[2] scatter(operand, indices, updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, @@ -209,48 +752,41 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { const char* const kExpectedPattern = R"( CHECK: ENTRY %scatter_add_computation () -> f32[2] { - CHECK-DAG: %[[INDICES:.*]] = s32[3,1]{1,0} constant({ {0}, {1}, {1} }) - CHECK-DAG: %[[RESHAPE:.*]] = s32[3]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[INDICES:.*]] = s32[2,1]{1,0} constant({ {1}, {1} }) + CHECK-DAG: %[[RESHAPE:.*]] = s32[2]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[IOTA:.*]] = s32[2]{0} iota(), iota_dimension=0 CHECK-DAG: %[[OPERAND:.*]] = f32[2]{0} constant({0, 0}) - CHECK-DAG: %[[RESHAPE1:.*]] = s32[3]{0} reshape(%[[INDICES]]) - CHECK-DAG: %[[UPDATES:.*]] = f32[3]{0} constant({2, 1, 5}) - CHECK-DAG: %[[TRANSPOSE:.*]] = f32[3]{0} transpose(%[[UPDATES]]), dimensions={0} - CHECK-DAG: %[[RESHAPE2:.*]] = f32[3]{0} reshape(%[[TRANSPOSE]]) - CHECK-DAG: %[[SORT:.*]] = (s32[3]{0}, f32[3]{0}) sort(%[[RESHAPE1]], %[[RESHAPE2]]), dimensions={0}, to_apply=%sorting_computation - CHECK-DAG: %[[GET_TUPLE_ELEMENT:.*]] = s32[3]{0} get-tuple-element(%[[SORT]]), index=0 - CHECK-DAG: %[[SLICE4:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} - CHECK-DAG: %[[SLICE5:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[1:3]} - CHECK-DAG: %[[COMPARE3:.*]] = pred[2]{0} compare(%[[SLICE4]], %[[SLICE5]]), direction=NE - CHECK-DAG: %[[CONSTANT4:.*]] = pred[] constant(true) - CHECK-DAG: %[[BROADCAST4:.*]] = pred[1]{0} broadcast(%[[CONSTANT4]]), dimensions={} - CHECK-DAG: %[[CONCAT_COMPARE4:.*]] = pred[3]{0} concatenate(%[[COMPARE3]], %[[BROADCAST4]]), dimensions={0} - CHECK-DAG: %[[BROADCAST5:.*]] = pred[3]{0} broadcast(%[[CONCAT_COMPARE4]]), dimensions={0} - CHECK-DAG: %[[CONSTANT5:.*]] = s32[3]{0} constant({2, 2, 2}) - CHECK-DAG: %[[SELECT2:.*]] = s32[3]{0} select(%[[BROADCAST5]], %[[GET_TUPLE_ELEMENT]], %[[CONSTANT5]]) - CHECK-DAG: %[[CONSTANT3:.*]] = s32[] constant(0) - CHECK-DAG: %[[BROADCAST3:.*]] = s32[2]{0} broadcast(%[[CONSTANT3]]), dimensions={} - CHECK-DAG: %[[SLICE3:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} - CHECK-DAG: %[[CONCAT3:.*]] = s32[3]{0} concatenate(%[[BROADCAST3]], %[[SLICE3]]), dimensions={0} - CHECK-DAG: %[[COMPARE2:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT3]]), direction=EQ - CHECK-DAG: %[[CONSTANT1:.*]] = s32[] constant(0) - CHECK-DAG: %[[BROADCAST1:.*]] = s32[1]{0} broadcast(%[[CONSTANT1]]), dimensions={} - CHECK-DAG: %[[SLICE1:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} - CHECK-DAG: %[[CONCAT1:.*]] = s32[3]{0} concatenate(%[[BROADCAST1]], %[[SLICE1]]), dimensions={0} - CHECK-DAG: %[[COMPARE1:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT1]]), direction=EQ - CHECK-DAG: %[[GET_TUPLE_ELEMENT1:.*]] = f32[3]{0} get-tuple-element(%[[SORT]]), index=1 - CHECK-DAG: %[[CONSTANT_F32:.*]] = f32[] constant(0) - CHECK-DAG: %[[BROADCAST_F32:.*]] = f32[1]{0} broadcast(%[[CONSTANT_F32]]), dimensions={} - CHECK-DAG: %[[SLICE_F32:.*]] = f32[2]{0} slice(%[[GET_TUPLE_ELEMENT1]]), slice={[0:2]} - CHECK-DAG: %[[CONCAT_F32:.*]] = f32[3]{0} concatenate(%[[BROADCAST_F32]], %[[SLICE_F32]]), dimensions={0} - CHECK-DAG: %[[MAP:.*]] = f32[3]{0} map(%[[GET_TUPLE_ELEMENT1]], %[[CONCAT_F32]]), dimensions={0}, to_apply=%scatter_computation - CHECK-DAG: %[[SELECT:.*]] = f32[3]{0} select(%[[COMPARE1]], %[[MAP]], %[[GET_TUPLE_ELEMENT1]]) - CHECK-DAG: %[[CONSTANT2:.*]] = f32[] constant(0) - CHECK-DAG: %[[BROADCAST2:.*]] = f32[2]{0} broadcast(%[[CONSTANT2]]), dimensions={} - CHECK-DAG: %[[SLICE2:.*]] = f32[1]{0} slice(%[[SELECT]]), slice={[0:1]} - CHECK-DAG: %[[CONCAT2:.*]] = f32[3]{0} concatenate(%[[BROADCAST2]], %[[SLICE2]]), dimensions={0} - CHECK-DAG: %[[MAP1:.*]] = f32[3]{0} map(%[[SELECT]], %[[CONCAT2]]), dimensions={0}, to_apply=%scatter_computation - CHECK-DAG: %[[SELECT1:.*]] = f32[3]{0} select(%[[COMPARE2]], %[[MAP1]], %[[SELECT]]) - CHECK-DAG: ROOT %[[SCATTER:.*]] = f32[2]{0} scatter(%[[OPERAND]], %[[SELECT2]], %[[SELECT1]]), + CHECK-DAG: %[[RESHAPE1:.*]] = s32[2]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[RESHAPE2:.*]] = s32[2,1]{1,0} reshape(%[[RESHAPE1]]) + CHECK-DAG: %[[RESHAPE4:.*]] = s32[2]{0} reshape(%[[RESHAPE2]]) + CHECK-DAG: %[[UPDATES:.*]] = f32[2]{0} constant({2, 1}) + CHECK-DAG: %[[TRANSPOSE:.*]] = f32[2]{0} transpose(%[[UPDATES]]), dimensions={0} + CHECK-DAG: %[[RESHAPE3:.*]] = f32[2]{0} reshape(%[[TRANSPOSE]]) + CHECK-DAG: %[[SORT:.*]] = (s32[2]{0}, f32[2]{0}) sort(%[[RESHAPE4]], %[[RESHAPE3]]), dimensions={0}, to_apply=%sorting_computation + CHECK-DAG: %[[GET_TUPLE_ELEMENT:.*]] = s32[2]{0} get-tuple-element(%[[SORT]]), index=0 + CHECK-DAG: %[[SLICE2:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} + CHECK-DAG: %[[SLICE3:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[1:2]} + CHECK-DAG: %[[COMPARE2:.*]] = pred[1]{0} compare(%[[SLICE2]], %[[SLICE3]]), direction=NE + CHECK-DAG: %[[CONSTANT3:.*]] = pred[] constant(true) + CHECK-DAG: %[[BROADCAST2:.*]] = pred[1]{0} broadcast(%[[CONSTANT3]]), dimensions={} + CHECK-DAG: %[[CONCATENATE2:.*]] = pred[2]{0} concatenate(%[[COMPARE2]], %[[BROADCAST2]]), dimensions={0} + CHECK-DAG: %[[BROADCAST3:.*]] = pred[2,1]{1,0} broadcast(%[[CONCATENATE2]]), dimensions={0} + CHECK-DAG: %[[RESHAPE5:.*]] = s32[2,1]{1,0} reshape(%[[GET_TUPLE_ELEMENT]]) + CHECK-DAG: %[[CONSTANT:.*]] = s32[2,1]{1,0} constant({ {2}, {2} }) + CHECK-DAG: %[[SELECT1:.*]] = s32[2,1]{1,0} select(%[[BROADCAST3]], %[[RESHAPE5]], %[[CONSTANT]]) + CHECK-DAG: %[[CONSTANT2:.*]] = s32[] constant(0) + CHECK-DAG: %[[BROADCAST1:.*]] = s32[1]{0} broadcast(%[[CONSTANT2]]), dimensions={} + CHECK-DAG: %[[SLICE1:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} + CHECK-DAG: %[[CONCATENATE1:.*]] = s32[2]{0} concatenate(%[[BROADCAST1]], %[[SLICE1]]), dimensions={0} + CHECK-DAG: %[[COMPARE1:.*]] = pred[2]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCATENATE1]]), direction=EQ + CHECK-DAG: %[[GET_TUPLE_ELEMENT1:.*]] = f32[2]{0} get-tuple-element(%[[SORT]]), index=1 + CHECK-DAG: %[[CONSTANT1:.*]] = f32[] constant(0) + CHECK-DAG: %[[BROADCAST:.*]] = f32[1]{0} broadcast(%[[CONSTANT1]]), dimensions={} + CHECK-DAG: %[[SLICE:.*]] = f32[1]{0} slice(%[[GET_TUPLE_ELEMENT1]]), slice={[0:1]} + CHECK-DAG: %[[CONCATENATE:.*]] = f32[2]{0} concatenate(%[[BROADCAST]], %[[SLICE]]), dimensions={0} + CHECK-DAG: %[[MAP:.*]] = f32[2]{0} map(%[[GET_TUPLE_ELEMENT1]], %[[CONCATENATE]]), dimensions={0}, to_apply=%scatter_computation + CHECK-DAG: %[[SELECT:.*]] = f32[2]{0} select(%[[COMPARE1]], %[[MAP]], %[[GET_TUPLE_ELEMENT1]]) + CHECK-DAG: ROOT %[[SCATTER:.*]] = f32[2]{0} scatter(%[[OPERAND]], %[[SELECT1]], %[[SELECT]]), CHECK-SAME: update_window_dims={}, CHECK-SAME: inserted_window_dims={0}, CHECK-SAME: scatter_dims_to_operand_dims={0}, @@ -264,7 +800,7 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { kExpectedPattern); } -TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { +TEST_F(ScatterDeterminismExpanderTest, ScalarScatterAddReproducibilityTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -275,10 +811,30 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { } ENTRY scatter_add_computation { - operand = f32[4] constant({0, 0, 0, 0}) - indices = s32[7,1] constant({{0}, {1}, {5}, {4}, {1}, {1}, {2}}) - updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) - ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + operand = f32[3] constant({0, 0, 0}) + indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, + {0}, {1}, {1}, {2}, {0}, {2}, {1}, {2}, {1}, {2}, {2}, {3}, {2}, {2}, {0}, + {3}, {0}, {3}, {2}, {0}, {3}, {3}, {3}, {3}, {3}, {2}, {3}, {3}, {0}, {0}, + {3}, {3}, {3}, {2}, {3}, {2}, {3}, {0}, {0}, {2}, {0}, {1}, {3}, {1}, {3}, + {2}, {2}, {2}, {1}, {0}, {3}, {1}, {1}, {1}, {1}, {1}, {2}, {2}, {3}, {0}, + {2}, {2}, {0}, {2}, {1}, {0}, {2}, {2}, {2}, {0}, {2}, {0}, {1}, {3}, {0}, + {2}, {3}, {3}, {2}, {0}, {3}, {3}, {2}, {3}, {2}}) + updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, + 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, + 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, + 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, + 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, + 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, + 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, + 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, + 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, + 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, + 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, + 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, + 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, + 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, + 0.878675, 0.4102913}) + ROOT scatter.48 = f32[3] scatter(operand, indices, updates), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=scatter_computation @@ -293,17 +849,30 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { EXPECT_TRUE(result); - std::vector expected_result = {2.0, 16.0, 9.0, 0.0}; + auto cloned_module = module->Clone(); + Literal first_result_literal = + ExecuteAndTransfer(std::move(cloned_module), {}); + auto first_result_span = first_result_literal.data(); + std::vector first_result(first_result_span.begin(), + first_result_span.end()); - Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + const int num_trials = 20; + std::vector> results; - auto result_data = result_literal.data(); - std::vector actual_result(result_data.begin(), result_data.end()); + for (int i = 0; i < num_trials; ++i) { + auto cloned_module = module->Clone(); - EXPECT_EQ(actual_result, expected_result); + Literal result_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, first_result) + << "Results are not reproducible across trials!"; + } } -TEST_F(ScatterDeterminismExpanderTest, ScatterAddReproducibilityTest) { +TEST_F(ScatterDeterminismExpanderTest, NonScalarScatterAddReproducibilityTest) { const char* const kModuleStr = R"( HloModule scatter_determinism_expander @@ -314,12 +883,32 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddReproducibilityTest) { } ENTRY scatter_add_computation { - operand = f32[3] constant({0, 0, 0}) - indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, {0}, {1}, {1}, {2}, {0}, {2}, {1}, {2}, {1}, {2}, {2}, {3}, {2}, {2}, {0}, {3}, {0}, {3}, {2}, {0}, {3}, {3}, {3}, {3}, {3}, {2}, {3}, {3}, {0}, {0}, {3}, {3}, {3}, {2}, {3}, {2}, {3}, {0}, {0}, {2}, {0}, {1}, {3}, {1}, {3}, {2}, {2}, {2}, {1}, {0}, {3}, {1}, {1}, {1}, {1}, {1}, {2}, {2}, {3}, {0}, {2}, {2}, {0}, {2}, {1}, {0}, {2}, {2}, {2}, {0}, {2}, {0}, {1}, {3}, {0}, {2}, {3}, {3}, {2}, {0}, {3}, {3}, {2}, {3}, {2}}) - updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, 0.878675, 0.4102913}) - ROOT scatter.48 = f32[3] scatter(operand, indices, updates), - update_window_dims={}, inserted_window_dims={0}, - scatter_dims_to_operand_dims={0}, index_vector_dim=1, + operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) + indices = s32[50, 2] constant({{0, 0}, {0, 1}, {1, 1}, {2, 2}, {0, 1}, {1, 0}, {2, 1}, {1, 2}, {0, 2}, {2, 0}, + {1, 1}, {2, 2}, {0, 0}, {0, 1}, {2, 1}, {1, 2}, {2, 0}, {0, 2}, {1, 0}, {1, 1}, + {1, 2}, {2, 1}, {0, 0}, {1, 1}, {0, 2}, {2, 0}, {1, 0}, {2, 2}, {1, 2}, {0, 1}, + {2, 1}, {1, 0}, {0, 2}, {2, 0}, {0, 1}, {2, 1}, {1, 1}, {1, 0}, {2, 2}, {0, 0}, + {0, 1}, {1, 2}, {2, 0}, {1, 1}, {0, 2}, {2, 1}, {1, 2}, {2, 1}, {1, 1}, {0, 2}}) + updates = f32[50, 2] constant({{0.02379167, 0.8527204}, {0.8132185, 0.5140263}, {0.17172801, 0.8026866}, + {0.5124631, 0.34838438}, {0.50526905, 0.3370521}, {0.10868239, 0.10520637}, + {0.83827364, 0.78986526}, {0.34059846, 0.8349273}, {0.24575627, 0.21387374}, + {0.02423227, 0.5617423}, {0.28066766, 0.94366455}, {0.61214995, 0.7383388}, + {0.52419806, 0.65466726}, {0.41012764, 0.24028647}, {0.74443066, 0.03544927}, + {0.851014, 0.02434528}, {0.47239733, 0.72706807}, {0.35055435, 0.6274171}, + {0.61077535, 0.06525731}, {0.8091929, 0.21307838}, {0.6465323, 0.3245015}, + {0.5538883, 0.8849807}, {0.9591211, 0.83856845}, {0.48919427, 0.11810577}, + {0.16933143, 0.83657074}, {0.587505, 0.6867087}, {0.95522237, 0.5797727}, + {0.28024232, 0.34749162}, {0.5199702, 0.9811766}, {0.5645981, 0.2446456}, + {0.68722725, 0.9616587}, {0.480047, 0.88953114}, {0.7083205, 0.948612}, + {0.67764974, 0.44131804}, {0.36789334, 0.95148766}, {0.30909216, 0.70908046}, + {0.8749926, 0.60973287}, {0.60751855, 0.22647333}, {0.5363518, 0.96195626}, + {0.08158326, 0.5266887}, {0.85922587, 0.648262}, {0.4657668, 0.31623375}, + {0.43507564, 0.48351157}, {0.41285944, 0.73501325}, {0.15267539, 0.67055714}, + {0.08459568, 0.04527426}, {0.21078384, 0.4654404}, {0.7363906, 0.23245859}, + {0.22119188, 0.99092937}, {0.878675, 0.4102913}}) + ROOT scatter.48 = f32[3, 3] scatter(operand, indices, updates), + update_window_dims={1}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0, 1}, index_vector_dim=1, to_apply=scatter_computation })"; diff --git a/xla/xla.proto b/xla/xla.proto index 1778db2085783..b7683f9f99e84 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1044,7 +1044,14 @@ message DebugOptions { bool xla_pjrt_allow_auto_layout_in_hlo = 344; - // Next id: 345 + // Enable the scatter determinism expander, an optimized pass that + // rewrites scatter operations to ensure deterministic behavior with high + // performance. + // Note that even when this flag is disabled, scatter operations may still + // be deterministic, although with additional overhead. + bool xla_gpu_enable_scatter_determinism_expander = 345; + + // Next id: 346 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.