Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
serach24 committed Oct 15, 2024
1 parent 7894ab7 commit 2ddc9e6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 28 deletions.
53 changes: 26 additions & 27 deletions xla/service/scatter_determinism_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License.

#include "xla/service/scatter_determinism_expander.h"
#include <cstdint>
#include <unordered_set>
#include "absl/container/flat_hash_set.h"

#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -60,22 +60,22 @@ HloInstruction* CreateOutOfBoundTensor(HloComputation* parent,
HloInstruction* scatter_indices,
const Shape& scatter_shape) {
if (scatter_indices->shape().rank() == 1) {
CHECK(scatter_shape.dimensions_size() == 1);
CHECK_EQ(scatter_shape.dimensions_size(), 1);
Array<int32_t> out_of_bound_array({scatter_indices->shape().dimensions(0)},
scatter_shape.dimensions(0));
return parent->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromArray(out_of_bound_array)));
}
// More than one dimension in scatter_indices
Array2D<int32_t> out_of_ound_array(scatter_indices->shape().dimensions(0),
scatter_indices->shape().dimensions(1));
Array2D<int32_t> out_of_bound_array(scatter_indices->shape().dimensions(0),
scatter_indices->shape().dimensions(1));
for (int i = 0; i < scatter_indices->shape().dimensions(0); ++i) {
for (int j = 0; j < scatter_indices->shape().dimensions(1); ++j) {
out_of_ound_array(i, j) = scatter_shape.dimensions(j);
out_of_bound_array(i, j) = scatter_shape.dimensions(j);
}
}
return parent->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2FromArray2D<int>(out_of_ound_array)));
LiteralUtil::CreateR2FromArray2D<int>(out_of_bound_array)));
}

// Computation for sorting the scalar scatter indices and updates together
Expand Down Expand Up @@ -139,7 +139,7 @@ static std::vector<HloInstruction*> SortIndicesAndUpdates(

auto* sorting = parent->AddInstruction(HloInstruction::CreateSort(
ShapeUtil::MakeTupleShape(sort_shapes), 0, sort_operands, comparison,
false /*is_stable*/));
/*is_stable=*/false));
auto* sorted_scalar_indices =
parent->AddInstruction(HloInstruction::CreateGetTupleElement(
scalar_indices->shape(), sorting, 0));
Expand Down Expand Up @@ -184,10 +184,7 @@ static StatusOr<HloInstruction*> CreateScanWithIndices(
int64_t num_updates = updates_shape.dimensions(0);

// Calculate the number of iterations needed (log_2(n))
int64_t log_n = static_cast<int64_t>(std::ceil(std::log2(num_updates)));

// Placeholder for offset calculation (2^d)
int64_t offset;
int64_t log_n = Log2Ceiling(static_cast<uint64_t>(num_updates));

// Start to traverse
HloInstruction* prev_updates = updates;
Expand All @@ -198,7 +195,7 @@ static StatusOr<HloInstruction*> CreateScanWithIndices(
std::vector<int64_t> strides = {1};

for (int64_t iteration = 0; iteration < log_n; ++iteration) {
offset = 1 << iteration;
int64_t offset = static_cast<int64_t>(1) << iteration;
std::vector<int64_t> end_indices = {num_updates - offset};

auto shifted_updates_shape = ShapeUtil::MakeShape(
Expand Down Expand Up @@ -384,19 +381,23 @@ StatusOr<HloInstruction*> ScatterDeterminismExpander::ExpandInstruction(
}

namespace {
void RecursivelyGetInputDependencies(
const HloInstruction* instruction,
std::unordered_set<const HloInstruction*>& dependencies) {
void RecursivelyGetInputParamNumbers(
const HloInstruction* instruction, std::vector<int64_t>& param_numbers,
absl::flat_hash_set<const HloInstruction*>& visited) {
if (!visited.emplace(instruction).second) {
return;
}

if (instruction->opcode() == HloOpcode::kParameter) {
dependencies.emplace(instruction);
param_numbers.push_back(instruction->parameter_number());
return;
}
for (HloInstruction* operand : instruction->operands()) {
RecursivelyGetInputDependencies(operand, dependencies);
RecursivelyGetInputParamNumbers(operand, param_numbers, visited);
}
}

// Check if the every output of the computation only depends on the
// Check if every output of the scatter computation only depends on the
// corresponding operand and updates
bool CheckOutputDependency(HloComputation* to_apply, int operand_size) {
HloInstruction* root = to_apply->root_instruction();
Expand All @@ -408,18 +409,16 @@ bool CheckOutputDependency(HloComputation* to_apply, int operand_size) {
// traverse the tuple output of the computation
for (int i = 0; i < operand_size; ++i) {
const HloInstruction* output = root->operand(i);
std::unordered_set<const HloInstruction*> input_dependencies;
RecursivelyGetInputDependencies(output, input_dependencies);
std::vector<int64_t> param_numbers;
absl::flat_hash_set<const HloInstruction*> visited;
RecursivelyGetInputParamNumbers(output, param_numbers, visited);
// The input dependencies can be at most 2
if (input_dependencies.size() > 2) {
if (param_numbers.size() > 2) {
return false;
}
if (input_dependencies.size() == 2) {
for (const HloInstruction* input : input_dependencies) {
int64_t param_number = input->parameter_number();
if (param_number != i && param_number != operand_size + i) {
return false;
}
for (int64_t param_number : param_numbers) {
if (param_number != i && param_number != operand_size + i) {
return false;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion xla/service/scatter_determinism_expander_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) {
)";

RunAndFilecheckHloRewrite(kModuleStr, ScatterDeterminismExpander(),
kExpectedPattern, nullptr, nullptr);
kExpectedPattern);
}

TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) {
Expand Down

0 comments on commit 2ddc9e6

Please sign in to comment.