Skip to content

Commit

Permalink
Add the scatter indices to operand space mapping
Browse files Browse the repository at this point in the history
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.
  • Loading branch information
serach24 committed Nov 12, 2024
1 parent 3b7b56a commit 126c952
Showing 1 changed file with 80 additions and 17 deletions.
97 changes: 80 additions & 17 deletions xla/service/scatter_determinism_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/scatter_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -443,18 +444,17 @@ static HloInstruction* FindLastOccurrenceIndices(
template <typename T>
HloInstruction* ExpandIndexOffsetsFromUpdateShape(
HloComputation* parent, const Shape& update_shape,
const ScatterDimensionNumbers& dim_num, const Shape& operand_shape) {
const ScatterDimensionNumbers& dim_num, const Shape& operand_shape,
absl::Span<const int64_t> index_to_operand_map) {
// Calculate the offset tensor for each element of the update tensor.
// The offset tensor is represented in (num_elements_in_update, index_dim).

int64_t num_elements = ShapeUtil::ElementsIn(update_shape);
int64_t operand_rank = operand_shape.dimensions_size();

Array2D<T> offset_tensor(num_elements, operand_rank);

std::vector<bool> is_inserted_window_dims(operand_rank, false);
for (int64_t dim : dim_num.inserted_window_dims()) {
is_inserted_window_dims[dim] = true;
}

for (int64_t linear_index = 0; linear_index < num_elements; ++linear_index) {
// Calculate the multi-dimensional index from the linear index
Expand All @@ -471,7 +471,8 @@ HloInstruction* ExpandIndexOffsetsFromUpdateShape(
// inserted window dims.
int64_t dim_size =
update_shape.dimensions(i + 1 - inserted_window_dim_size);
offset_tensor(linear_index, i) = current_index / dim_size;
offset_tensor(linear_index, index_to_operand_map[i]) =
current_index / dim_size;
current_index %= dim_size;
}
}
Expand Down Expand Up @@ -611,15 +612,41 @@ absl::StatusOr<HloInstruction*> CheckValidIndices(
valid_index_mask, {0}));
}

// Add dimensions that are not covered in the indices_to_operand_map to the end
// of indices
absl::StatusOr<HloInstruction*> AddImplicitDimensionsToIndices(
int64_t operand_rank, absl::Span<const int64_t> indices_to_operand_map,
HloInstruction* indices) {
const Shape& indices_shape = indices->shape();

HloComputation* computation = indices->parent();

// Get the batch size (N) and S (number of dimensions in index_vector)
int64_t batch_size = indices_shape.dimensions(0);
int64_t num_indices_dims = indices_to_operand_map.size();

// Create a tensor of zeros with the target shape [N, operand_rank]
Shape expanded_shape = ShapeUtil::MakeShape(indices_shape.element_type(),
{batch_size, operand_rank});

HloInstruction* zero_filled_tensor = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D<int32_t>(
Array2D<int32_t>(batch_size, operand_rank - num_indices_dims, 0))));
// Concatenate the zero-filled tensor with the index_vector
HloInstruction* expanded_indices =
computation->AddInstruction(HloInstruction::CreateConcatenate(
expanded_shape, {indices, zero_filled_tensor}, 1));
return expanded_indices;
}

absl::StatusOr<HloInstruction*> ScatterDeterminismExpander::ExpandInstruction(
HloInstruction* inst) {
auto* scatter = Cast<HloScatterInstruction>(inst);
auto scatter_operands = scatter->scatter_operands();
HloInstruction* scatter_indices = scatter->scatter_indices();
std::vector<HloInstruction*> scatter_updates(
scatter->scatter_updates().begin(), scatter->scatter_updates().end());
const ScatterDimensionNumbers& dim_numbers =
scatter->scatter_dimension_numbers();
ScatterDimensionNumbers dim_numbers = scatter->scatter_dimension_numbers();

// If the updates tensors are empty, there is no need to update the operands.
// The operands can be forwarded.
Expand All @@ -644,19 +671,23 @@ absl::StatusOr<HloInstruction*> ScatterDeterminismExpander::ExpandInstruction(

// Canonicalize the scatter_indices, after which the size of its most-major
// dimension must be same as the while loop trip count.
HloInstruction* original_scatter_indices = scatter_indices;
TF_ASSIGN_OR_RETURN(scatter_indices,
CanonicalizeScatterIndices(
scatter_indices, dim_numbers.index_vector_dim()));
// TODO(Chenhao) for the maintainability, we should simplify the code so that
// all scatter_indices are always more than 1 D
CHECK_GE(scatter_indices->shape().dimensions_size(), 1);
CHECK_EQ(scatter_indices_count, scatter_indices->shape().dimensions(0));
bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1 ||
scatter_indices->shape().dimensions(1) == 1;

// Canonicalize the updates, after which the size of their most-major
// dimensions must be same as the while loop trip count.
TF_ASSIGN_OR_RETURN(scatter_updates, CanonicalizeScatterUpdates(
scatter_updates, scatter_indices,
dim_numbers, scatter_indices_count));
TF_ASSIGN_OR_RETURN(
scatter_updates,
CanonicalizeScatterUpdates(scatter_updates, original_scatter_indices,
dim_numbers, scatter_indices_count));

HloComputation* parent = scatter->parent();
auto updates_shape = scatter_updates[0]->shape();
Expand All @@ -680,12 +711,26 @@ absl::StatusOr<HloInstruction*> ScatterDeterminismExpander::ExpandInstruction(
// Extract operand dimensions
const Shape& operand_shape = scatter_operands[0]->shape();

// Add the implicit dimensions to the index_to_operand_map
absl::flat_hash_set<int64_t> existing_dims(
dim_numbers.scatter_dims_to_operand_dims().begin(),
dim_numbers.scatter_dims_to_operand_dims().end());
std::vector<int64_t> full_index_to_operand_dims(
dim_numbers.mutable_scatter_dims_to_operand_dims()->begin(),
dim_numbers.mutable_scatter_dims_to_operand_dims()->end());
for (int i = 0; i < operand_shape.dimensions_size(); i++) {
if (existing_dims.find(i) == existing_dims.end())
full_index_to_operand_dims.push_back(i);
}

HloInstruction* index_offsets =
scatter_indices->shape().element_type() == S32
? ExpandIndexOffsetsFromUpdateShape<int32_t>(
scatter->parent(), update_shape, dim_numbers, operand_shape)
scatter->parent(), update_shape, dim_numbers, operand_shape,
full_index_to_operand_dims)
: ExpandIndexOffsetsFromUpdateShape<int64_t>(
scatter->parent(), update_shape, dim_numbers, operand_shape);
scatter->parent(), update_shape, dim_numbers, operand_shape,
full_index_to_operand_dims);

int num_operand_dims = operand_shape.dimensions_size();
std::vector<int64_t> actual_update_window_dims(num_operand_dims);
Expand All @@ -702,6 +747,14 @@ absl::StatusOr<HloInstruction*> ScatterDeterminismExpander::ExpandInstruction(
}
}

// map scatter_indices into operand space
TF_ASSIGN_OR_RETURN(
scatter_indices,
AddImplicitDimensionsToIndices(
scatter_operands[0]->shape().dimensions_size(),
dim_numbers.scatter_dims_to_operand_dims(), scatter_indices));
CHECK(scatter_indices->shape().dimensions(0) == scatter_indices_count);
// TODO(chenhao) check valid indices with the map!
// If any updates are out of bound, we change the corresponding indices to
// be oob_tensor values
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -740,11 +793,16 @@ absl::StatusOr<HloInstruction*> ScatterDeterminismExpander::ExpandInstruction(
i++) {
new_dim_numbers.add_inserted_window_dims(i);
}
// Set the scatter_dims_to_operand_dims
// copy from the original scatter_dims_to_operand_dims
for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); i++) {
new_dim_numbers.add_scatter_dims_to_operand_dims(
dim_numbers.scatter_dims_to_operand_dims(i));
// // Set the scatter_dims_to_operand_dims
// // copy from the original scatter_dims_to_operand_dims
// for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); i++)
// {
// new_dim_numbers.add_scatter_dims_to_operand_dims(
// dim_numbers.scatter_dims_to_operand_dims(i));
// }
// Set the scatter_dims_to_operand_dims to be ordered from 0 to operand_rank
for (int i = 0; i < operand_shape.dimensions_size(); i++) {
new_dim_numbers.add_scatter_dims_to_operand_dims(i);
}
} else {
new_dim_numbers = dim_numbers;
Expand Down Expand Up @@ -838,6 +896,11 @@ bool CheckOutputDependency(HloComputation* to_apply, int operand_size) {
bool ScatterDeterminismExpander::InstructionMatchesPattern(
HloInstruction* inst) {
auto* scatter = DynCast<HloScatterInstruction>(inst);

// TODO(chenhao) there are some tricky cases that we need to avoid
// 1. some weird batch dims
// 2. if the operand rank is not the same as indices rank + inserted window
// dimensions
return (scatter != nullptr) && !IsScatterDeterministic(scatter) &&
CheckOutputDependency(scatter->to_apply(),
scatter->scatter_operands().size());
Expand Down

0 comments on commit 126c952

Please sign in to comment.