diff --git a/xla/service/BUILD b/xla/service/BUILD index d47e3734ef365..688a0d0e174ec 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2558,6 +2558,25 @@ cc_library( ], ) +cc_library( + name = "scatter_utils", + srcs = ["scatter_utils.cc"], + hdrs = ["scatter_utils.h"], + deps = [ + ":call_inliner", + ":hlo_creation_utils", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "scatter_expander", srcs = ["scatter_expander.cc"], @@ -2566,6 +2585,7 @@ cc_library( ":call_inliner", ":hlo_creation_utils", ":op_expander_pass", + ":scatter_utils", ":while_util", "//xla:literal_util", "//xla/hlo/ir:hlo", @@ -2574,6 +2594,29 @@ cc_library( ], ) +cc_library( + name = "scatter_determinism_expander", + srcs = ["scatter_determinism_expander.cc"], + hdrs = ["scatter_determinism_expander.h"], + deps = [ + ":hlo_creation_utils", + ":op_expander_pass", + ":scatter_utils", + "//xla:array", + "//xla:array2d", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "scatter_expander_test", srcs = ["scatter_expander_test.cc"], @@ -2588,6 +2631,24 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "scatter_determinism_expander_test", + srcs = ["scatter_determinism_expander_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":scatter_determinism_expander", + "//xla:literal", + "//xla:test", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 5e86294dd8626..dd8d46345f49d 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1606,6 +1606,7 @@ cc_library( "//xla/service:result_caster", "//xla/service:rng_bit_generator_expander", "//xla/service:rng_expander", + "//xla/service:scatter_determinism_expander", "//xla/service:scatter_expander", "//xla/service:scatter_simplifier", "//xla/service:sharding_remover", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index f25ebff77bd46..9b02c7789a125 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -219,6 +219,7 @@ limitations under the License. #include "xla/service/result_caster.h" #include "xla/service/rng_bit_generator_expander.h" #include "xla/service/rng_expander.h" +#include "xla/service/scatter_determinism_expander.h" #include "xla/service/scatter_expander.h" #include "xla/service/scatter_simplifier.h" #include "xla/service/sharding_remover.h" @@ -700,6 +701,7 @@ absl::Status RunOptimizationPasses( if (RequireDeterminism(hlo_module->config())) { // Scatter can be indeterministic if indices are not unique or a non // associative combiner function is used. Eliminate these Scatter ops. + pipeline.AddPass(); pipeline.AddPass( ScatterExpander::kEliminateIndeterministicScatters); } diff --git a/xla/service/scatter_determinism_expander.cc b/xla/service/scatter_determinism_expander.cc new file mode 100644 index 0000000000000..b938121a107af --- /dev/null +++ b/xla/service/scatter_determinism_expander.cc @@ -0,0 +1,462 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/scatter_determinism_expander.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" +#include "xla/array.h" +#include "xla/array2d.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_creation_utils.h" +#include "xla/service/scatter_utils.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Canonicalizes the scatter_updates in order to keep them uniform while +// performing the scatter operation. +static absl::StatusOr> CanonicalizeScatterUpdates( + const std::vector& scatter_updates, + HloInstruction* scatter_indices, const ScatterDimensionNumbers& dim_numbers, + int64_t scatter_loop_trip_count) { + std::vector adjusted_updates; + adjusted_updates.reserve(scatter_updates.size()); + for (HloInstruction* update : scatter_updates) { + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_update, + PermuteScatterAndWindowDims(update, dim_numbers.update_window_dims())); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_update, + AdjustScatterDims(scatter_indices->shape(), canonical_update, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, adjusted_update->shape().dimensions(0)); + adjusted_updates.push_back(adjusted_update); + } + return adjusted_updates; +} + +// Create the out-of-bound tensor for the scatter operation. +HloInstruction* CreateOutOfBoundTensor(HloComputation* parent, + HloInstruction* scatter_indices, + const Shape& scatter_shape) { + if (scatter_indices->shape().rank() == 1) { + CHECK_EQ(scatter_shape.dimensions_size(), 1); + Array out_of_bound_array({scatter_indices->shape().dimensions(0)}, + scatter_shape.dimensions(0)); + return parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromArray(out_of_bound_array))); + } + // More than one dimension in scatter_indices + Array2D out_of_bound_array(scatter_indices->shape().dimensions(0), + scatter_indices->shape().dimensions(1)); + for (int i = 0; i < scatter_indices->shape().dimensions(0); ++i) { + for (int j = 0; j < scatter_indices->shape().dimensions(1); ++j) { + out_of_bound_array(i, j) = scatter_shape.dimensions(j); + } + } + return parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(out_of_bound_array))); +} + +// Computation for sorting the scalar scatter indices and updates together +HloComputation* ScalarSortingComparison(HloModule* module, + const Shape key_shape, + const Shape update_shape, + int64_t num_updates) { + HloComputation::Builder builder("sorting_computation"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, key_shape, "lhs_key")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, key_shape, "rhs_key")); + const int kExistingParams = 2; + for (int i = 0; i < num_updates; ++i) { + builder.AddInstruction( + HloInstruction::CreateParameter(kExistingParams + i, update_shape, + absl::StrFormat("lhs_update_%d", i))); + builder.AddInstruction( + HloInstruction::CreateParameter(kExistingParams + 1 + i, update_shape, + absl::StrFormat("rhs_update_%d", i))); + } + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + param1, ComparisonDirection::kLt)); + return module->AddEmbeddedComputation(builder.Build()); +} + +static std::vector SortIndicesAndUpdates( + HloInstruction* scatter_indices, + const std::vector& scatter_updates, int64_t num_indices, + HloScatterInstruction* scatter, HloComputation* parent) { + const Shape& indices_shape = scatter_indices->shape(); + const Shape& updates_shape = scatter_updates[0]->shape(); + auto updates_dims = updates_shape.dimensions(); + // Since we canonicalized the scatter updates, the first dim will always be + // the number of updates and the rest will be the shape of each update + + HloInstruction* scalar_indices = scatter_indices; + + std::vector single_update_dimensions(updates_dims.begin() + 1, + updates_dims.end()); + + const Shape update_shape = ShapeUtil::MakeShape(updates_shape.element_type(), + single_update_dimensions); + + const Shape& scalar_index_shape = + ShapeUtil::MakeShape(indices_shape.element_type(), {num_indices}); + + auto* comparison = ScalarSortingComparison( + scatter->GetModule(), + ShapeUtil::MakeShape(indices_shape.element_type(), {}), + ShapeUtil::MakeShape(updates_shape.element_type(), {}), + scatter_updates.size()); + + std::vector sort_operands = {scalar_indices}; + std::vector sort_shapes = {scalar_index_shape}; + for (auto update : scatter_updates) { + sort_operands.push_back(update); + sort_shapes.push_back(update->shape()); + } + + auto* sorting = parent->AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape(sort_shapes), 0, sort_operands, comparison, + /*is_stable=*/false)); + auto* sorted_scalar_indices = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + scalar_indices->shape(), sorting, 0)); + + std::vector sorted_updates(scatter_updates.size()); + for (int i = 0; i < scatter_updates.size(); i++) { + sorted_updates[i] = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + scatter_updates[i]->shape(), sorting, i + 1)); + } + std::vector sorted_tensors = {sorted_scalar_indices}; + sorted_tensors.insert(sorted_tensors.end(), sorted_updates.begin(), + sorted_updates.end()); + return sorted_tensors; +} + +// CreateScanWithIndices performs a prefix scan operation (akin to parallel +// prefix sum) on the updates and indices, to compute the accumulated updates in +// log(n) time. +// +// High-level algorithm: +// +// Iteration through log2(num_updates): +// - For each iteration, the `updates` tensor will be sliced and padded to +// perform shifting by `offset`. +// - Similarly, the `indices` tensor is also sliced and padded. +// - A mask is created that compares each element of shifted `indices` and +// original `indices` are equal (used to avoid combining updates from +// different indices). +// - The `to_apply` function is used to combine the original and shifted +// updates to generate a combined update tensor. +// - Based on the mask, the new update tensor will choose from either the +// combined update or the original update. +// - The result becomes the `new_updates`, which is then used as the +// input for the next iteration. +static absl::StatusOr CreateScanWithIndices( + HloComputation* parent, HloInstruction* updates, HloInstruction* indices, + HloComputation* to_apply) { + const Shape& updates_shape = updates->shape(); + const Shape& indices_shape = indices->shape(); + // Get the length of the input array + int64_t num_updates = updates_shape.dimensions(0); + + // Calculate the number of iterations needed (log_2(n)) + int64_t log_n = Log2Ceiling(static_cast(num_updates)); + + // Start to traverse + HloInstruction* prev_updates = updates; + HloInstruction* prev_indices = indices; + HloInstruction* new_updates = nullptr; + + std::vector start_indices = {0}; + std::vector strides = {1}; + + for (int64_t iteration = 0; iteration < log_n; ++iteration) { + int64_t offset = static_cast(1) << iteration; + std::vector end_indices = {num_updates - offset}; + + auto shifted_updates_shape = ShapeUtil::MakeShape( + updates_shape.element_type(), {num_updates - offset}); + auto padding_updates_shape = + ShapeUtil::MakeShape(updates_shape.element_type(), {offset}); + + auto shifted_indices_shape = ShapeUtil::MakeShape( + indices_shape.element_type(), {num_updates - offset}); + auto padding_indices_shape = + ShapeUtil::MakeShape(indices_shape.element_type(), {offset}); + + auto* shifted_updates = parent->AddInstruction( + HloInstruction::CreateSlice(shifted_updates_shape, prev_updates, + start_indices, end_indices, strides)); + auto* padding_updates = + parent->AddInstruction(HloInstruction::CreateBroadcast( + padding_updates_shape, + parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(updates_shape.element_type(), 0))), + {})); + + auto* shifted_indices = parent->AddInstruction( + HloInstruction::CreateSlice(shifted_indices_shape, prev_indices, + start_indices, end_indices, strides)); + auto* padding_indices = + parent->AddInstruction(HloInstruction::CreateBroadcast( + padding_indices_shape, + parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(indices_shape.element_type(), 0))), + {})); + + auto* concatenated_updates = + parent->AddInstruction(HloInstruction::CreateConcatenate( + updates_shape, {padding_updates, shifted_updates}, 0)); + auto* concatenated_indices = + parent->AddInstruction(HloInstruction::CreateConcatenate( + indices_shape, {padding_indices, shifted_indices}, 0)); + + auto* indices_mask = parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {num_updates}), prev_indices, + concatenated_indices, ComparisonDirection::kEq)); + std::vector map_operands = {prev_updates, + concatenated_updates}; + TF_ASSIGN_OR_RETURN(HloInstruction * reduced_updates, + MakeMapHlo(map_operands, to_apply)); + new_updates = parent->AddInstruction(HloInstruction::CreateTernary( + updates_shape, HloOpcode::kSelect, indices_mask, reduced_updates, + prev_updates)); + prev_updates = new_updates; + } + return new_updates; +} + +absl::StatusOr> ComputePrefixScan( + const std::vector& sorted_updates, + HloInstruction* sorted_scalar_indices, HloScatterInstruction* scatter, + HloComputation* parent) { + std::vector prefix_scans(sorted_updates.size()); + for (int i = 0; i < sorted_updates.size(); i++) { + // TODO(chenhao) change to use the extracted computation + TF_ASSIGN_OR_RETURN( + HloComputation * to_apply, + CallComputationAndGetIthOutputWithBinaryParams(scatter->to_apply(), i)); + TF_ASSIGN_OR_RETURN(prefix_scans[i], + CreateScanWithIndices(parent, sorted_updates[i], + sorted_scalar_indices, to_apply)); + } + return prefix_scans; +} + +static HloInstruction* FindLastOccurrenceIndices( + HloInstruction* scatter_indices, HloInstruction* sorted_scalar_indices, + HloInstruction* scatter, HloComputation* parent, int64_t num_indices) { + int64_t indices_len = sorted_scalar_indices->shape().dimensions(0); + HloInstruction* sorted_indices = sorted_scalar_indices; + auto* sorted_indices_preceding_part = + parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(scatter_indices->shape().element_type(), + {indices_len - 1}), + sorted_scalar_indices, {0}, {indices_len - 1}, {1})); + auto* sorted_indices_following_part = + parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(scatter_indices->shape().element_type(), + {indices_len - 1}), + sorted_scalar_indices, {1}, {indices_len}, {1})); + auto* indices_mask_without_padding = + parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {indices_len - 1}), + sorted_indices_preceding_part, sorted_indices_following_part, + ComparisonDirection::kNe)); + // Pad the comparison with a true value at the end + auto* true_constant = parent->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + auto* padding = parent->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PRED, {1}), true_constant, {})); + std::vector padding_operands = {indices_mask_without_padding, + padding}; + auto* indices_mask = parent->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(PRED, {indices_len}), padding_operands, 0)); + + // Mask the indices + indices_mask = parent->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PRED, scatter_indices->shape().dimensions()), + indices_mask, {0})); + + auto* out_of_bound_tensor = + CreateOutOfBoundTensor(parent, scatter_indices, scatter->shape()); + + auto* masked_indices = parent->AddInstruction(HloInstruction::CreateTernary( + sorted_indices->shape(), HloOpcode::kSelect, indices_mask, sorted_indices, + out_of_bound_tensor)); + return masked_indices; +} + +absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( + HloInstruction* inst) { + auto* scatter = Cast(inst); + auto scatter_operands = scatter->scatter_operands(); + HloInstruction* scatter_indices = scatter->scatter_indices(); + std::vector scatter_updates( + scatter->scatter_updates().begin(), scatter->scatter_updates().end()); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensors are empty, there is no need to update the operands. + // The operands can be forwarded. + if (ShapeUtil::IsZeroElementArray(scatter_updates[0]->shape())) { + if (scatter_operands.size() == 1) { + return scatter_operands[0]; + } + return scatter->parent()->AddInstruction( + HloInstruction::CreateTuple(scatter_operands)); + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + int64_t scatter_indices_count = ScatterIndicesCount(scatter); + if (!IsInt32(scatter_indices_count)) { + // 2147483647 is the maximum value for a 32-bit signed integer (INT32_MAX). + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_indices_count, scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of their most-major + // dimensions must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(scatter_updates, CanonicalizeScatterUpdates( + scatter_updates, scatter_indices, + dim_numbers, scatter_indices_count)); + + HloComputation* parent = scatter->parent(); + + // Sort the scatter indices and updates together based on the scatter indices. + int64_t num_indices = ShapeUtil::ElementsIn(scatter_updates[0]->shape()); + std::vector sorted_tensors = SortIndicesAndUpdates( + scatter_indices, scatter_updates, num_indices, scatter, parent); + HloInstruction* sorted_scalar_indices = sorted_tensors[0]; + std::vector sorted_updates(sorted_tensors.begin() + 1, + sorted_tensors.end()); + + TF_ASSIGN_OR_RETURN(std::vector prefix_scan_updates, + ComputePrefixScan(sorted_updates, sorted_scalar_indices, + scatter, parent)); + + HloInstruction* last_occurrence_indices = FindLastOccurrenceIndices( + scatter_indices, sorted_scalar_indices, scatter, parent, num_indices); + + // Finally, recreate the scatter instruction with unique indices + return parent->AddInstruction(HloInstruction::CreateScatter( + scatter->shape(), scatter_operands, last_occurrence_indices, + prefix_scan_updates, scatter->to_apply(), dim_numbers, + /*indices_are_sorted=*/true, /*unique_indices=*/true)); +} + +namespace { +void RecursivelyGetInputParamNumbers( + const HloInstruction* instruction, std::vector& param_numbers, + absl::flat_hash_set& visited) { + if (!visited.emplace(instruction).second) { + return; + } + + if (instruction->opcode() == HloOpcode::kParameter) { + param_numbers.push_back(instruction->parameter_number()); + return; + } + for (HloInstruction* operand : instruction->operands()) { + RecursivelyGetInputParamNumbers(operand, param_numbers, visited); + } +} + +// Check if every output of the scatter computation only depends on the +// corresponding operand and updates +bool CheckOutputDependency(HloComputation* to_apply, int operand_size) { + HloInstruction* root = to_apply->root_instruction(); + if (!root->shape().IsTuple()) { + return true; + } + CHECK_EQ(operand_size, root->operand_count()); + + // traverse the tuple output of the computation + for (int i = 0; i < operand_size; ++i) { + const HloInstruction* output = root->operand(i); + std::vector param_numbers; + absl::flat_hash_set visited; + RecursivelyGetInputParamNumbers(output, param_numbers, visited); + // The input dependencies can be at most 2 + if (param_numbers.size() > 2) { + return false; + } + for (int64_t param_number : param_numbers) { + if (param_number != i && param_number != operand_size + i) { + return false; + } + } + } + return true; +} + +} // namespace + +bool ScatterDeterminismExpander::InstructionMatchesPattern( + HloInstruction* inst) { + auto* scatter = DynCast(inst); + // Need to check if updates and indices are scalar, as the current pass does + // not expand scatter with multi-dimensional updates or indices. This is + // temporary and will be removed in a future PR soon. + if (scatter == nullptr) { + return false; + } + + const Shape& indices_shape = scatter->scatter_indices()->shape(); + const Shape& updates_shape = scatter->scatter_updates()[0]->shape(); + + // Check if indices and updates are effectively 1D. + bool indices_are_1d = + (indices_shape.rank() == 1 || + (indices_shape.rank() == 2 && indices_shape.dimensions(1) == 1)); + bool updates_are_1d = + (updates_shape.rank() == 1 || + (updates_shape.rank() == 2 && updates_shape.dimensions(1) == 1)); + + return indices_are_1d && updates_are_1d && !IsScatterDeterministic(scatter) && + CheckOutputDependency(scatter->to_apply(), + scatter->scatter_operands().size()); +} + +} // namespace xla diff --git a/xla/service/scatter_determinism_expander.h b/xla/service/scatter_determinism_expander.h new file mode 100644 index 0000000000000..a14ed4482e2ee --- /dev/null +++ b/xla/service/scatter_determinism_expander.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SCATTER_DETERMINISM_EXPANDER_H_ +#define XLA_SERVICE_SCATTER_DETERMINISM_EXPANDER_H_ + +#include "xla/service/op_expander_pass.h" + +namespace xla { + +// This pass rewrites scatter operations into a prefix-scan based algorithm that +// ensures the scatter results to be determininstic. Note that the computation +// after the expansion still contains a scatter operation, but it does not have +// duplicated indices and hence the results are guaranteed to be deterministic. +class ScatterDeterminismExpander : public OpExpanderPass { + public: + explicit ScatterDeterminismExpander() = default; + + absl::string_view name() const override { + return "scatter_determinism_expander"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* inst) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_SCATTER_DETERMINISM_EXPANDER_H_ diff --git a/xla/service/scatter_determinism_expander_test.cc b/xla/service/scatter_determinism_expander_test.cc new file mode 100644 index 0000000000000..23e7e87d4bcce --- /dev/null +++ b/xla/service/scatter_determinism_expander_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/scatter_determinism_expander.h" + +#include +#include +#include + +#include "xla/literal.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +class ScatterDeterminismExpanderTest : public HloTestBase {}; + +TEST_F(ScatterDeterminismExpanderTest, + DoNotEliminateScatterWithAssociativeCombiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = s32[] parameter(1) + arg0.172 = s32[] parameter(0) + ROOT add.48 = s32[] add(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = s32[1,4096] parameter(0) + pad.96 = s32[4096,2] parameter(1) + bitcast.2748 = s32[4096,1,1] parameter(2) + ROOT scatter.48 = s32[1,4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={1,2}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(ScatterDeterminismExpanderTest, + EliminateScatterWithNonAssociativeCombiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = f32[4096] parameter(0) + pad.96 = s32[4096,1] parameter(1) + bitcast.2748 = f32[4096] parameter(2) + ROOT scatter.48 = f32[4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_TRUE(result); +} + +TEST_F(ScatterDeterminismExpanderTest, + DoNotEliminateScatterWithAssociativeFp32Combiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT max.48 = f32[] maximum(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = f32[1,4096] parameter(0) + pad.96 = s32[4096,2] parameter(1) + bitcast.2748 = f32[4096,1,1] parameter(2) + ROOT scatter.48 = f32[1,4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={1,2}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[7,1] constant({{0}, {1}, {2}, {3}, {1}, {1}, {2}}) + updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + std::vector expected_result = {2.0, 16.0, 14.0, 3.0}; + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[2] constant({0, 0}) + indices = s32[3,1] constant({{0}, {1}, {1}}) + updates = f32[3] constant({2, 1, 5}) + ROOT scatter.48 = f32[2] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + const char* const kExpectedPattern = R"( + CHECK: ENTRY %scatter_add_computation () -> f32[2] { + CHECK-DAG: %[[INDICES:.*]] = s32[3,1]{1,0} constant({ {0}, {1}, {1} }) + CHECK-DAG: %[[RESHAPE:.*]] = s32[3]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[OPERAND:.*]] = f32[2]{0} constant({0, 0}) + CHECK-DAG: %[[RESHAPE1:.*]] = s32[3]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[UPDATES:.*]] = f32[3]{0} constant({2, 1, 5}) + CHECK-DAG: %[[TRANSPOSE:.*]] = f32[3]{0} transpose(%[[UPDATES]]), dimensions={0} + CHECK-DAG: %[[RESHAPE2:.*]] = f32[3]{0} reshape(%[[TRANSPOSE]]) + CHECK-DAG: %[[SORT:.*]] = (s32[3]{0}, f32[3]{0}) sort(%[[RESHAPE1]], %[[RESHAPE2]]), dimensions={0}, to_apply=%sorting_computation + CHECK-DAG: %[[GET_TUPLE_ELEMENT:.*]] = s32[3]{0} get-tuple-element(%[[SORT]]), index=0 + CHECK-DAG: %[[SLICE4:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} + CHECK-DAG: %[[SLICE5:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[1:3]} + CHECK-DAG: %[[COMPARE3:.*]] = pred[2]{0} compare(%[[SLICE4]], %[[SLICE5]]), direction=NE + CHECK-DAG: %[[CONSTANT4:.*]] = pred[] constant(true) + CHECK-DAG: %[[BROADCAST4:.*]] = pred[1]{0} broadcast(%[[CONSTANT4]]), dimensions={} + CHECK-DAG: %[[CONCAT_COMPARE4:.*]] = pred[3]{0} concatenate(%[[COMPARE3]], %[[BROADCAST4]]), dimensions={0} + CHECK-DAG: %[[BROADCAST5:.*]] = pred[3]{0} broadcast(%[[CONCAT_COMPARE4]]), dimensions={0} + CHECK-DAG: %[[CONSTANT5:.*]] = s32[3]{0} constant({2, 2, 2}) + CHECK-DAG: %[[SELECT2:.*]] = s32[3]{0} select(%[[BROADCAST5]], %[[GET_TUPLE_ELEMENT]], %[[CONSTANT5]]) + CHECK-DAG: %[[CONSTANT3:.*]] = s32[] constant(0) + CHECK-DAG: %[[BROADCAST3:.*]] = s32[2]{0} broadcast(%[[CONSTANT3]]), dimensions={} + CHECK-DAG: %[[SLICE3:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} + CHECK-DAG: %[[CONCAT3:.*]] = s32[3]{0} concatenate(%[[BROADCAST3]], %[[SLICE3]]), dimensions={0} + CHECK-DAG: %[[COMPARE2:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT3]]), direction=EQ + CHECK-DAG: %[[CONSTANT1:.*]] = s32[] constant(0) + CHECK-DAG: %[[BROADCAST1:.*]] = s32[1]{0} broadcast(%[[CONSTANT1]]), dimensions={} + CHECK-DAG: %[[SLICE1:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} + CHECK-DAG: %[[CONCAT1:.*]] = s32[3]{0} concatenate(%[[BROADCAST1]], %[[SLICE1]]), dimensions={0} + CHECK-DAG: %[[COMPARE1:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT1]]), direction=EQ + CHECK-DAG: %[[GET_TUPLE_ELEMENT1:.*]] = f32[3]{0} get-tuple-element(%[[SORT]]), index=1 + CHECK-DAG: %[[CONSTANT_F32:.*]] = f32[] constant(0) + CHECK-DAG: %[[BROADCAST_F32:.*]] = f32[1]{0} broadcast(%[[CONSTANT_F32]]), dimensions={} + CHECK-DAG: %[[SLICE_F32:.*]] = f32[2]{0} slice(%[[GET_TUPLE_ELEMENT1]]), slice={[0:2]} + CHECK-DAG: %[[CONCAT_F32:.*]] = f32[3]{0} concatenate(%[[BROADCAST_F32]], %[[SLICE_F32]]), dimensions={0} + CHECK-DAG: %[[MAP:.*]] = f32[3]{0} map(%[[GET_TUPLE_ELEMENT1]], %[[CONCAT_F32]]), dimensions={0}, to_apply=%scatter_computation + CHECK-DAG: %[[SELECT:.*]] = f32[3]{0} select(%[[COMPARE1]], %[[MAP]], %[[GET_TUPLE_ELEMENT1]]) + CHECK-DAG: %[[CONSTANT2:.*]] = f32[] constant(0) + CHECK-DAG: %[[BROADCAST2:.*]] = f32[2]{0} broadcast(%[[CONSTANT2]]), dimensions={} + CHECK-DAG: %[[SLICE2:.*]] = f32[1]{0} slice(%[[SELECT]]), slice={[0:1]} + CHECK-DAG: %[[CONCAT2:.*]] = f32[3]{0} concatenate(%[[BROADCAST2]], %[[SLICE2]]), dimensions={0} + CHECK-DAG: %[[MAP1:.*]] = f32[3]{0} map(%[[SELECT]], %[[CONCAT2]]), dimensions={0}, to_apply=%scatter_computation + CHECK-DAG: %[[SELECT1:.*]] = f32[3]{0} select(%[[COMPARE2]], %[[MAP1]], %[[SELECT]]) + CHECK-DAG: ROOT %[[SCATTER:.*]] = f32[2]{0} scatter(%[[OPERAND]], %[[SELECT2]], %[[SELECT1]]), + CHECK-SAME: update_window_dims={}, + CHECK-SAME: inserted_window_dims={0}, + CHECK-SAME: scatter_dims_to_operand_dims={0}, + CHECK-SAME: index_vector_dim=1, + CHECK-SAME: indices_are_sorted=true, + CHECK-SAME: unique_indices=true, + CHECK-SAME: to_apply=%scatter_computation + )"; + + RunAndFilecheckHloRewrite(kModuleStr, ScatterDeterminismExpander(), + kExpectedPattern); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[7,1] constant({{0}, {1}, {5}, {4}, {1}, {1}, {2}}) + updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + std::vector expected_result = {2.0, 16.0, 9.0, 0.0}; + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddReproducibilityTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3] constant({0, 0, 0}) + indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, {0}, {1}, {1}, {2}, {0}, {2}, {1}, {2}, {1}, {2}, {2}, {3}, {2}, {2}, {0}, {3}, {0}, {3}, {2}, {0}, {3}, {3}, {3}, {3}, {3}, {2}, {3}, {3}, {0}, {0}, {3}, {3}, {3}, {2}, {3}, {2}, {3}, {0}, {0}, {2}, {0}, {1}, {3}, {1}, {3}, {2}, {2}, {2}, {1}, {0}, {3}, {1}, {1}, {1}, {1}, {1}, {2}, {2}, {3}, {0}, {2}, {2}, {0}, {2}, {1}, {0}, {2}, {2}, {2}, {0}, {2}, {0}, {1}, {3}, {0}, {2}, {3}, {3}, {2}, {0}, {3}, {3}, {2}, {3}, {2}}) + updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, 0.878675, 0.4102913}) + ROOT scatter.48 = f32[3] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + auto cloned_module = module->Clone(); + Literal first_result_literal = + ExecuteAndTransfer(std::move(cloned_module), {}); + auto first_result_span = first_result_literal.data(); + std::vector first_result(first_result_span.begin(), + first_result_span.end()); + + const int num_trials = 20; + std::vector> results; + + for (int i = 0; i < num_trials; ++i) { + auto cloned_module = module->Clone(); + + Literal result_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, first_result) + << "Results are not reproducible across trials!"; + } +} + +} // namespace +} // namespace xla diff --git a/xla/service/scatter_expander.cc b/xla/service/scatter_expander.cc index cd7f72c64c177..1bd4178afcd3c 100644 --- a/xla/service/scatter_expander.cc +++ b/xla/service/scatter_expander.cc @@ -24,113 +24,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" -#include "xla/service/call_inliner.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/service/scatter_utils.h" #include "xla/service/while_util.h" namespace xla { -// Transposes the given scatter_indices such that the index_vector_dim becomes -// the most-minor dimension. -static absl::StatusOr TransposeIndexVectorDimToLast( - HloInstruction* scatter_indices, int64_t index_vector_dim) { - const Shape& scatter_indices_shape = scatter_indices->shape(); - - if (scatter_indices_shape.dimensions_size() == index_vector_dim) { - return scatter_indices; - } - - if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { - return scatter_indices; - } - - std::vector permutation; - permutation.reserve(scatter_indices_shape.dimensions_size()); - for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { - if (i != index_vector_dim) { - permutation.push_back(i); - } - } - permutation.push_back(index_vector_dim); - return MakeTransposeHlo(scatter_indices, permutation); -} - -// Canonicalizes the scatter_indices tensor in order to keep them uniform while -// performing the scatter operation. -static absl::StatusOr CanonicalizeScatterIndices( - HloInstruction* scatter_indices, int64_t index_vector_dim) { - // Transpose the non-index-vector dimensions to the front. - TF_ASSIGN_OR_RETURN( - HloInstruction * transposed_scatter_indices, - TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); - if (scatter_indices->shape().rank() == index_vector_dim + 1 && - scatter_indices->shape().dimensions(index_vector_dim) == 1) { - auto new_shape = - ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); - TF_ASSIGN_OR_RETURN(scatter_indices, - MakeReshapeHlo(new_shape, scatter_indices)); - } - bool indices_are_scalar = - index_vector_dim == scatter_indices->shape().dimensions_size(); - - // The number of dimensions in scatter_indices that are index dimensions. - const int64_t index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; - - // If there is only one index (i.e. scatter_indices has rank 1 and this - // scatter is really just a dynamic update slice) add a leading degenerate - // dimension for uniformity. Otherwise create a "collapsed" leading dimension - // that subsumes all of the non-index-vector dimensions. - const Shape& shape = transposed_scatter_indices->shape(); - if (shape.dimensions_size() == index_dims_in_scatter_indices) { - return PrependDegenerateDims(transposed_scatter_indices, 1); - } else { - // Collapse all but the dimensions (0 or 1) in scatter_indices containing - // the index vectors. - return CollapseFirstNDims( - transposed_scatter_indices, - shape.dimensions_size() - index_dims_in_scatter_indices); - } -} - -// Permutes the `updates` tensor such that all the scatter dims appear in the -// major dimensions and all the window dimensions appear in the minor -// dimensions. -static absl::StatusOr PermuteScatterAndWindowDims( - HloInstruction* updates, absl::Span update_window_dims) { - std::vector permutation; - const int64_t updates_rank = updates->shape().rank(); - permutation.reserve(updates_rank); - - for (int64_t i = 0; i < updates_rank; ++i) { - bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); - if (is_scatter_dim) { - permutation.push_back(i); - } - } - for (auto window_dim : update_window_dims) { - permutation.push_back(window_dim); - } - - return MakeTransposeHlo(updates, permutation); -} - -// Expands or contracts the scatter indices in the updates tensor. -static absl::StatusOr AdjustScatterDims( - const Shape& scatter_indices_shape, HloInstruction* updates, - int64_t index_vector_dim) { - int64_t num_scatter_dims = scatter_indices_shape.dimensions_size(); - if (index_vector_dim < scatter_indices_shape.dimensions_size()) { - --num_scatter_dims; - } - if (num_scatter_dims == 0) { - // If there are no scatter dims, this must be a dynamic-update-slice kind of - // scatter. In this case, we prepend a degenerate dimension to work - // uniformly in the while loop. - return PrependDegenerateDims(updates, 1); - } - return CollapseFirstNDims(updates, num_scatter_dims); -} - // Expands an index vector from the scatter_indices tensor into a vector that // can be used to dynamic-update-slice to perform the scatter update. static absl::StatusOr ExpandIndexVectorIntoOperandSpace( @@ -218,33 +117,6 @@ static absl::StatusOr CheckIndexValidity( return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); } -static absl::StatusOr CallAndGetOutput( - HloComputation* original, int output_index) { - HloInstruction* original_root = original->root_instruction(); - if (!original_root->shape().IsTuple()) { - return original; - } - HloComputation* new_comp = [&] { - HloComputation::Builder builder( - absl::StrCat(original->name(), ".dup.", output_index)); - for (int i = 0, n = original->num_parameters(); i < n; ++i) { - HloInstruction* original_param = original->parameter_instruction(i); - builder.AddInstruction(HloInstruction::CreateParameter( - i, original_param->shape(), original_param->name())); - } - return original->parent()->AddEmbeddedComputation(builder.Build()); - }(); - HloInstruction* call_original = new_comp->AddInstruction( - HloInstruction::CreateCall(original_root->shape(), - new_comp->parameter_instructions(), original)); - new_comp->set_root_instruction( - new_comp->AddInstruction( - HloInstruction::CreateGetTupleElement(call_original, output_index)), - /*accept_different_shape=*/true); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); - return new_comp; -} - // Body of the while loop that performs the scatter operation using other HLOs. static absl::StatusOr> ScatterLoopBody( HloScatterInstruction* scatter, HloInstruction* induction_var, @@ -377,22 +249,6 @@ static absl::StatusOr> ScatterLoopBody( return updated_loop_state; } -static int64_t ScatterTripCount(const HloScatterInstruction* scatter) { - // Compute the trip count for the while loop to be used for scatter. This - // should be the number of indices we should scatter into the operand. - const HloInstruction* scatter_indices = scatter->scatter_indices(); - const Shape& scatter_indices_shape = scatter_indices->shape(); - const ScatterDimensionNumbers& dim_numbers = - scatter->scatter_dimension_numbers(); - int64_t scatter_loop_trip_count = 1; - for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); - } - } - return scatter_loop_trip_count; -} - // High Level Algorithm. // // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where @@ -431,7 +287,7 @@ absl::StatusOr ScatterExpander::ExpandInstruction( // Compute the trip count for the while loop to be used for scatter. This // should be the number of indices we should scatter into the operand. - int64_t scatter_loop_trip_count = ScatterTripCount(scatter); + int64_t scatter_loop_trip_count = ScatterIndicesCount(scatter); if (!IsInt32(scatter_loop_trip_count)) { return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " @@ -485,48 +341,13 @@ absl::StatusOr ScatterExpander::ExpandInstruction( return MaybeMakeTuple(results); } -namespace { - -bool IsCombinerAssociative(const HloComputation* combiner) { - // Consider simple binary combiner functions only. - if (combiner->instruction_count() != 3) { - return false; - } - switch (combiner->root_instruction()->opcode()) { - // Minimum and Maximum are common associative combiners. - case HloOpcode::kMinimum: - case HloOpcode::kMaximum: - return true; - // Other common combiners are associative at least for integer arithmetic. - case HloOpcode::kAdd: - case HloOpcode::kMultiply: - case HloOpcode::kOr: - case HloOpcode::kXor: - return combiner->root_instruction()->shape().IsInteger(); - default: - return false; - } -} - -bool IsDeterministic(const HloScatterInstruction* scatter) { - if (scatter->unique_indices()) { - return true; - } - if (IsCombinerAssociative(scatter->to_apply())) { - return true; - } - return false; -} - -} // namespace - bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { auto* scatter = DynCast(inst); return (scatter != nullptr) && (mode_ == kEliminateAllScatters || (mode_ == kEliminateSimpleScatters && - ScatterTripCount(scatter) == 1) || + ScatterIndicesCount(scatter) == 1) || (mode_ == kEliminateIndeterministicScatters && - !IsDeterministic(scatter))); + !IsScatterDeterministic(scatter))); } } // namespace xla diff --git a/xla/service/scatter_utils.cc b/xla/service/scatter_utils.cc new file mode 100644 index 0000000000000..fe97c0d3e9f25 --- /dev/null +++ b/xla/service/scatter_utils.cc @@ -0,0 +1,242 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/scatter_utils.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_inliner.h" +#include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +absl::StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64_t index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + if (index_vector_dim >= (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64_t i = 0; i < scatter_indices_shape.dimensions_size(); i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +absl::StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, absl::Span update_window_dims) { + std::vector permutation; + const int64_t updates_rank = updates->shape().rank(); + permutation.reserve(updates_rank); + + for (int64_t i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (int64_t window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +absl::StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64_t index_vector_dim) { + int64_t num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +absl::StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64_t index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + if (scatter_indices->shape().rank() - 1 == index_vector_dim && + scatter_indices->shape().dimensions(index_vector_dim) == 1) { + auto new_shape = + ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); + TF_ASSIGN_OR_RETURN(scatter_indices, + MakeReshapeHlo(new_shape, scatter_indices)); + } + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64_t index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); +} + +absl::StatusOr CallAndGetOutput(HloComputation* original, + int output_index) { + HloInstruction* original_root = original->root_instruction(); + if (!original_root->shape().IsTuple()) { + return original; + } + HloComputation* new_comp = [&] { + HloComputation::Builder builder( + absl::StrCat(original->name(), ".dup.", output_index)); + for (int i = 0, n = original->num_parameters(); i < n; ++i) { + HloInstruction* original_param = original->parameter_instruction(i); + builder.AddInstruction(HloInstruction::CreateParameter( + i, original_param->shape(), original_param->name())); + } + return original->parent()->AddEmbeddedComputation(builder.Build()); + }(); + HloInstruction* call_original = new_comp->AddInstruction( + HloInstruction::CreateCall(original_root->shape(), + new_comp->parameter_instructions(), original)); + new_comp->set_root_instruction( + new_comp->AddInstruction( + HloInstruction::CreateGetTupleElement(call_original, output_index)), + /*accept_different_shape=*/true); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); + return new_comp; +} + +absl::StatusOr CallComputationAndGetIthOutputWithBinaryParams( + HloComputation* original, int output_index) { + HloInstruction* original_root = original->root_instruction(); + if (!original_root->shape().IsTuple()) { + return original; + } + int64_t num_params = original->num_parameters(); + int64_t num_outputs = original_root->shape().tuple_shapes_size(); + + CHECK_EQ(num_params / 2, num_outputs); + HloComputation* new_comp = [&] { + HloComputation::Builder builder( + absl::StrCat(original->name(), ".dup.", output_index)); + HloInstruction* original_param_lhs = + original->parameter_instruction(output_index); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, original_param_lhs->shape(), original_param_lhs->name())); + HloInstruction* original_param_rhs = + original->parameter_instruction(output_index + num_outputs); + builder.AddInstruction(HloInstruction::CreateParameter( + 1, original_param_rhs->shape(), original_param_rhs->name())); + return original->parent()->AddEmbeddedComputation(builder.Build()); + }(); + std::vector operands; + operands.reserve(num_params); + for (int i = 0; i < num_outputs; ++i) { + operands.push_back(new_comp->parameter_instruction(0)); + } + for (int i = 0; i < num_outputs; ++i) { + operands.push_back(new_comp->parameter_instruction(1)); + } + + HloInstruction* call_original = new_comp->AddInstruction( + HloInstruction::CreateCall(original_root->shape(), operands, original)); + new_comp->set_root_instruction( + new_comp->AddInstruction( + HloInstruction::CreateGetTupleElement(call_original, output_index)), + /*accept_different_shape=*/true); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); + return new_comp; +} + +int64_t ScatterIndicesCount(const HloScatterInstruction* scatter) { + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const HloInstruction* scatter_indices = scatter->scatter_indices(); + const Shape& scatter_indices_shape = scatter_indices->shape(); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + int64_t scatter_loop_trip_count = 1; + for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + return scatter_loop_trip_count; +} + +bool IsScatterCombinerAssociative(const HloComputation* combiner) { + // Consider simple binary combiner functions only. + if (combiner->instruction_count() != 3) { + return false; + } + switch (combiner->root_instruction()->opcode()) { + // Minimum and Maximum are common associative combiners. + case HloOpcode::kMinimum: + case HloOpcode::kMaximum: + return true; + // Other common combiners are associative at least for integer arithmetic. + case HloOpcode::kAdd: + case HloOpcode::kMultiply: + case HloOpcode::kOr: + case HloOpcode::kXor: + return combiner->root_instruction()->shape().IsInteger(); + default: + return false; + } +} + +bool IsScatterDeterministic(const HloScatterInstruction* scatter) { + if (scatter->unique_indices()) { + return true; + } + if (IsScatterCombinerAssociative(scatter->to_apply())) { + return true; + } + return false; +} +} // namespace xla diff --git a/xla/service/scatter_utils.h b/xla/service/scatter_utils.h new file mode 100644 index 0000000000000..22209e4fef7bb --- /dev/null +++ b/xla/service/scatter_utils.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SCATTER_UTILS_H_ +#define XLA_SERVICE_SCATTER_UTILS_H_ + +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" + +namespace xla { + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +absl::StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64_t index_vector_dim); + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +absl::StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, absl::Span update_window_dims); + +// Expands or contracts the scatter indices in the updates tensor. +absl::StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64_t index_vector_dim); + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +absl::StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64_t index_vector_dim); + +absl::StatusOr CallAndGetOutput(HloComputation* original, + int output_index); +absl::StatusOr CallComputationAndGetIthOutputWithBinaryParams( + HloComputation* original, int output_index); + +int64_t ScatterIndicesCount(const HloScatterInstruction* scatter); + +// Checks if the combiner is associative. +bool IsScatterCombinerAssociative(const HloComputation* combiner); + +// Checks if the scatter operation is deterministic. +bool IsScatterDeterministic(const HloScatterInstruction* scatter); + +} // namespace xla + +#endif // XLA_SERVICE_SCATTER_UTILS_H_