Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cleanup] Use HloPredicateIs(Not)Op #20825

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions xla/service/all_reduce_reassociate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,23 @@ HloInstruction* LookThroughForAllReduce(HloInstruction* instr,
const Literal& reduction_identity) {
// Match reduce-scatter pattern. Support only the non-formatted case at the
// moment.
if (instr->opcode() == HloOpcode::kDynamicSlice) {
if (HloPredicateIsOp<HloOpcode::kDynamicSlice>(instr)) {
// Dynamic-slice to be matched needs to be immediately using an AllReduce.
if (instr->operand(0)->opcode() != HloOpcode::kAllReduce ||
instr->operand(0)->user_count() != 1 || instr->user_count() != 1) {
return nullptr;
}
return instr;
}
while (instr->opcode() != HloOpcode::kAllReduce) {
while (HloPredicateIsNotOp<HloOpcode::kAllReduce>(instr)) {
if (instr->user_count() != 1) {
return nullptr;
}
if (instr->opcode() != HloOpcode::kReshape &&
instr->opcode() != HloOpcode::kPad &&
instr->opcode() != HloOpcode::kSlice &&
instr->opcode() != HloOpcode::kConvert) {
if (HloPredicateIsNotOp<HloOpcode::kReshape, HloOpcode::kPad,
HloOpcode::kSlice, HloOpcode::kConvert>(instr)) {
return nullptr;
}
if (instr->opcode() == HloOpcode::kPad) {
if (HloPredicateIsOp<HloOpcode::kPad>(instr)) {
if (!instr->operand(1)->IsConstant()) {
return nullptr;
}
Expand Down Expand Up @@ -223,7 +221,7 @@ absl::StatusOr<bool> AllReduceReassociate::Run(
continue;
}
if (lhs->opcode() != rhs->opcode() ||
(lhs->opcode() == HloOpcode::kDynamicSlice &&
(HloPredicateIsOp<HloOpcode::kDynamicSlice>(lhs) &&
!ShapeUtil::Compatible(lhs->operand(0)->shape(),
rhs->operand(0)->shape()))) {
continue;
Expand All @@ -232,7 +230,7 @@ absl::StatusOr<bool> AllReduceReassociate::Run(
HloAllReduceInstruction* ar1 = nullptr;
bool reduce_scatter_pattern_match = false;
// Check Dynamic-slice pattern is identical
if (lhs->opcode() == HloOpcode::kDynamicSlice) {
if (HloPredicateIsOp<HloOpcode::kDynamicSlice>(lhs)) {
HloInstruction* original_rhs_operand = rhs->mutable_operand(0);
TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, lhs->mutable_operand(0)));
if (!lhs->Identical(*rhs)) {
Expand Down
4 changes: 2 additions & 2 deletions xla/service/all_reduce_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ absl::StatusOr<bool> AllReduceSimplifier::Run(
for (auto computation : module->computations(execution_threads)) {
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
// AllGather and ReduceScatter with the same input and output shape
if ((inst->opcode() == HloOpcode::kAllGather ||
inst->opcode() == HloOpcode::kReduceScatter) &&
if ((HloPredicateIsOp<HloOpcode::kAllGather, HloOpcode::kReduceScatter>(
inst)) &&
ShapeUtil::Compatible(inst->shape(), inst->operand(0)->shape())) {
changed = true;
TF_RETURN_IF_ERROR(
Expand Down
6 changes: 3 additions & 3 deletions xla/service/buffer_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ static const HloInstruction* GetEntryParameterInstruction(
for (const auto& p : alloc.assigned_buffers()) {
const HloValue* value = p.first;
const HloInstruction* instr = value->instruction();
if (instr->opcode() == HloOpcode::kParameter &&
if (HloPredicateIsOp<HloOpcode::kParameter>(instr) &&
instr->parent() == instr->GetModule()->entry_computation()) {
return instr;
}
Expand Down Expand Up @@ -1047,7 +1047,7 @@ std::string BufferAssignment::ToVerboseString(
buf_strs.push_back(absl::StrCat(
"\n\t\tOperator: ", xla::OpMetadataToString(instr->metadata())));
}
if (instr->opcode() == HloOpcode::kParameter &&
if (HloPredicateIsOp<HloOpcode::kParameter>(instr) &&
(instr->parent() == instr->GetModule()->entry_computation())) {
// Special case on entry parameters as they sometimes have hundreds of
// indices in their shapes, and overwhelm the output.
Expand Down Expand Up @@ -1483,7 +1483,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer(

const HloInstruction* instruction = value->instruction();
const bool is_entry_parameter =
instruction->opcode() == HloOpcode::kParameter &&
HloPredicateIsOp<HloOpcode::kParameter>(instruction) &&
instruction->parent() == instruction->GetModule()->entry_computation();

if (is_entry_parameter) {
Expand Down
15 changes: 8 additions & 7 deletions xla/service/collective_ops_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,12 @@ absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(

absl::StatusOr<bool> GetCollectiveUseGlobalDeviceIds(
const HloInstruction* hlo) {
const bool is_all_reduce = (hlo->opcode() == HloOpcode::kAllReduce ||
hlo->opcode() == HloOpcode::kAllReduceStart ||
hlo->opcode() == HloOpcode::kReduceScatter);
const bool is_all_gather = (hlo->opcode() == HloOpcode::kAllGather ||
hlo->opcode() == HloOpcode::kAllGatherStart);
const bool is_all_reduce =
(HloPredicateIsOp<HloOpcode::kAllReduce, HloOpcode::kAllReduceStart,
HloOpcode::kReduceScatter>(hlo));
const bool is_all_gather =
(HloPredicateIsOp<HloOpcode::kAllGather, HloOpcode::kAllGatherStart>(
hlo));
if (!is_all_reduce && !is_all_gather) {
return absl::InvalidArgumentError(
"GetReplicaGroupCountAndSize only supports AllReduce and AllGather.");
Expand Down Expand Up @@ -745,7 +746,7 @@ bool IsCollective(const HloInstruction* instruction) {
if (IsNonFusionCollective(instruction)) {
return true;
}
if (instruction->opcode() == HloOpcode::kFusion &&
if (HloPredicateIsOp<HloOpcode::kFusion>(instruction) &&
instruction->IsCustomFusion()) {
for (const auto* inner_inst : instruction->fused_instructions()) {
if (IsCollective(inner_inst)) {
Expand All @@ -757,7 +758,7 @@ bool IsCollective(const HloInstruction* instruction) {
}

HloInstruction* IsOrHasCollectiveWithChannelId(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kFusion) {
if (HloPredicateIsOp<HloOpcode::kFusion>(instruction)) {
for (auto* inner_inst : instruction->fused_instructions()) {
if (IsOrHasCollectiveWithChannelId(inner_inst) != nullptr) {
return inner_inst;
Expand Down
12 changes: 6 additions & 6 deletions xla/service/conditional_code_motion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ ENTRY main {
for (int i = 0; i < on_false->root_instruction()->operand_count(); ++i) {
const HloInstruction* root_operand =
on_false->root_instruction()->operand(i);
if (root_operand->opcode() == HloOpcode::kAdd) {
if (HloPredicateIsOp<HloOpcode::kAdd>(root_operand)) {
on_false_add_idx = i;
} else if (root_operand->opcode() == HloOpcode::kSubtract) {
} else if (HloPredicateIsOp<HloOpcode::kSubtract>(root_operand)) {
on_false_sub_idx = i;
}
}
Expand Down Expand Up @@ -425,9 +425,9 @@ ENTRY main {
for (int i = 0; i < on_false->root_instruction()->operand_count(); ++i) {
const HloInstruction* root_operand =
on_false->root_instruction()->operand(i);
if (root_operand->opcode() == HloOpcode::kGetTupleElement) {
if (HloPredicateIsOp<HloOpcode::kGetTupleElement>(root_operand)) {
on_false_gte_idx = i;
} else if (root_operand->opcode() == HloOpcode::kConstant) {
} else if (HloPredicateIsOp<HloOpcode::kConstant>(root_operand)) {
on_false_const_idx = i;
}
}
Expand Down Expand Up @@ -501,9 +501,9 @@ ENTRY main {
for (int i = 0; i < on_false->root_instruction()->operand_count(); ++i) {
const HloInstruction* root_operand =
on_false->root_instruction()->operand(i);
if (root_operand->opcode() == HloOpcode::kConstant) {
if (HloPredicateIsOp<HloOpcode::kConstant>(root_operand)) {
on_false_const_idx = i;
} else if (root_operand->opcode() == HloOpcode::kGetTupleElement) {
} else if (HloPredicateIsOp<HloOpcode::kGetTupleElement>(root_operand)) {
on_false_gte_idx = i;
}
}
Expand Down
8 changes: 4 additions & 4 deletions xla/service/conditional_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ absl::StatusOr<bool> TryRemoveUnusedConditionalOperands(
for (HloInstruction* user : param->users()) {
// If the user is not a get tuple element, assume it is unsafe to remove
// elements from the tuple.
if (user->opcode() != HloOpcode::kGetTupleElement) {
if (HloPredicateIsNotOp<HloOpcode::kGetTupleElement>(user)) {
return false;
}
tuple_indices_to_keep.insert(user->tuple_index());
Expand Down Expand Up @@ -215,7 +215,7 @@ bool RemoveUnusedTupleElements(HloInstruction* conditional_op) {
std::vector<bool> used_indices(old_tuple_shapes_size, false);
for (const HloInstruction* user : conditional_op->users()) {
// We only deal with the case where all users are GTE instructions.
if (user->opcode() != HloOpcode::kGetTupleElement) {
if (HloPredicateIsNotOp<HloOpcode::kGetTupleElement>(user)) {
VLOG(3) << "Skip RemoveUnusedTupleElements due to non-GTE user:\n"
<< user->ToShortString();
return false;
Expand Down Expand Up @@ -363,7 +363,7 @@ bool MergeDuplicateTupleElements(HloInstruction* conditional) {
}

for (const HloInstruction* user : conditional->users()) {
if (user->opcode() != HloOpcode::kGetTupleElement) {
if (HloPredicateIsNotOp<HloOpcode::kGetTupleElement>(user)) {
VLOG(3) << "Skip MergeDuplicateTupleElements due not all users are "
"kGetTupleElement:\n"
<< conditional->ToShortString();
Expand Down Expand Up @@ -614,7 +614,7 @@ absl::StatusOr<bool> ConditionalSimplifier::Run(
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->computations(execution_threads)) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
if (HloPredicateIsOp<HloOpcode::kConditional>(instr)) {
// Verifier wants a single send/recv with a given channel. This pass
// clones computations which can result in that getting violated.
if (InstructionCallsChannelInstructions(*instr)) {
Expand Down
29 changes: 13 additions & 16 deletions xla/service/dynamic_dimension_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
HloInstruction* hlo) {
bool dimension_is_static = false;
const HloInstruction* size = hlo->operand(1);
if (size->opcode() == HloOpcode::kConstant) {
if (HloPredicateIsOp<HloOpcode::kConstant>(size)) {
// Check if we are setting a dimension size to its static size. If so,
// removes the dynamic dimension.
//
Expand Down Expand Up @@ -1332,7 +1332,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleReshape(
return false;
}
VLOG(3) << "Found " << found_dims.size() << "\n";
if (op->opcode() == HloOpcode::kReshape) {
if (HloPredicateIsOp<HloOpcode::kReshape>(op)) {
for (auto op_dim_index : found_dims) {
auto orig_reshape_pair = find_reshape_group_pair(op, op_dim_index);
if (is_reverse_reshape_group_pair(op, orig_reshape_pair, hlo,
Expand Down Expand Up @@ -2259,7 +2259,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile(
// hlo->while_condition()->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(WhileUtil::MakeInstructionsLiveInResult result,
WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
TF_RET_CHECK(result.replacement_instr->opcode() == HloOpcode::kTuple);
TF_RET_CHECK(HloPredicateIsOp<HloOpcode::kTuple>(result.replacement_instr));
// WhileUtil creates a new while hlo and tuple. Update the dynamic size
// mapping for the newly created tuple.
HloInstruction* new_tuple_operand =
Expand Down Expand Up @@ -2367,7 +2367,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile(
TF_RET_CHECK(!index.empty());
HloInstruction* gte =
result.replacement_instr->mutable_operand(index.front());
TF_RET_CHECK(gte->opcode() == HloOpcode::kGetTupleElement);
TF_RET_CHECK(HloPredicateIsOp<HloOpcode::kGetTupleElement>(gte));
TF_RET_CHECK(gte->operand(0) == hlo);
ShapeUtil::GetMutableSubshape(gte->mutable_shape(),
ShapeIndexView(index).subspan(1))
Expand Down Expand Up @@ -2432,35 +2432,32 @@ absl::StatusOr<bool> DynamicDimensionInferenceVisitor::RequiresPadToStatic(
auto uses =
dataflow_analysis_.GetValueDefinedAt(instr, shape_index).GetUses();
for (const auto& use : uses) {
if (use.instruction->opcode() == HloOpcode::kAsyncStart ||
use.instruction->opcode() == HloOpcode::kAsyncUpdate ||
use.instruction->opcode() == HloOpcode::kAsyncDone ||
use.instruction->opcode() == HloOpcode::kCall ||
use.instruction->opcode() == HloOpcode::kTuple ||
use.instruction->opcode() == HloOpcode::kGetTupleElement ||
use.instruction->opcode() == HloOpcode::kConditional) {
if (HloPredicateIsOp<HloOpcode::kAsyncStart, HloOpcode::kAsyncUpdate,
HloOpcode::kAsyncDone, HloOpcode::kCall,
HloOpcode::kTuple, HloOpcode::kGetTupleElement,
HloOpcode::kConditional>(use.instruction)) {
// These uses do not require padding as they do not operate the data.
continue;
}
if (use.instruction->opcode() == HloOpcode::kWhile) {
if (HloPredicateIsOp<HloOpcode::kWhile>(use.instruction)) {
TF_RET_CHECK(use.operand_number == 0);
HloInstruction* root = use.instruction->while_body()->root_instruction();
if (parent_->HasDynamicDimension(root, use.operand_index)) {
return true;
}
continue;
}
if (use.instruction->opcode() == HloOpcode::kSetDimensionSize) {
if (HloPredicateIsOp<HloOpcode::kSetDimensionSize>(use.instruction)) {
// The dynamic size cannot itself be dynamic.
TF_RET_CHECK(use.operand_number == 0);
// SetDimensionSize will be removed, so the array must be padded if it
// is a user of the array.
return true;
}
if (use.instruction->opcode() == HloOpcode::kGetDimensionSize) {
if (HloPredicateIsOp<HloOpcode::kGetDimensionSize>(use.instruction)) {
return true;
}
if (use.instruction->opcode() != HloOpcode::kCustomCall ||
if (HloPredicateIsNotOp<HloOpcode::kCustomCall>(use.instruction) ||
!use.instruction->IsCustomCall({"PadToStatic", "Sharding",
"SPMDShardToFullShape",
"SPMDFullToShardShape"})) {
Expand Down Expand Up @@ -2863,7 +2860,7 @@ bool DynamicDimensionInference::CanInfer(HloInstruction* hlo) {
// However, if there are called computations, we may need to run inference on
// them. Similarly, custom calls can do anything based on the user callbacks.
if (hlo->shape().is_static() && hlo->called_computations().empty() &&
hlo->opcode() != HloOpcode::kCustomCall) {
HloPredicateIsNotOp<HloOpcode::kCustomCall>(hlo)) {
return false;
}
// The dimensions of all operands must either be 1) not dynamic, or 2) have a
Expand Down
10 changes: 5 additions & 5 deletions xla/service/dynamic_padder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ namespace m = ::xla::match;
namespace op = xla::testing::opcode_matchers;

OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) {
if (hlo->opcode() != HloOpcode::kCustomCall) {
if (HloPredicateIsNotOp<HloOpcode::kCustomCall>(hlo)) {
return OpDynamismSupport::kNoSupport;
}
if (hlo->custom_call_target() == "OpWithDynamicLowering") {
Expand Down Expand Up @@ -591,7 +591,7 @@ ENTRY main {
HloInstruction* while_inst = nullptr;
for (HloInstruction* inst :
module_->entry_computation()->MakeInstructionPostOrder()) {
if (inst->opcode() == HloOpcode::kWhile) {
if (HloPredicateIsOp<HloOpcode::kWhile>(inst)) {
ASSERT_EQ(while_inst, nullptr)
<< "while_inst: " << while_inst->name() << ", inst: " << inst->name();
while_inst = inst;
Expand Down Expand Up @@ -674,7 +674,7 @@ ENTRY main {
module_ = GetHloModule(hlo_text);

auto op_supports_dynamism = [](HloInstruction* hlo) {
if (hlo->opcode() != HloOpcode::kCustomCall) {
if (HloPredicateIsNotOp<HloOpcode::kCustomCall>(hlo)) {
return OpDynamismSupport::kNoSupport;
}
if (hlo->custom_call_target() == "ComputeActivations" ||
Expand All @@ -697,15 +697,15 @@ ENTRY main {

for (HloComputation* computation : module_->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCustomCall) {
if (HloPredicateIsOp<HloOpcode::kCustomCall>(instruction)) {
EXPECT_NE(instruction->custom_call_target(), "PadToStatic");
EXPECT_NE(instruction->custom_call_target(), "SliceToDynamic");
if (instruction->custom_call_target() == "ComputeActivations") {
EXPECT_TRUE(instruction->operand(1)->shape().is_dynamic());
} else if (instruction->custom_call_target() == "ApplyGradients") {
EXPECT_TRUE(instruction->operand(1)->shape().is_dynamic());
}
} else if (instruction->opcode() == HloOpcode::kWhile) {
} else if (HloPredicateIsOp<HloOpcode::kWhile>(instruction)) {
const Shape& shape = instruction->shape();
EXPECT_TRUE(shape.tuple_shapes(1).is_dynamic());
EXPECT_TRUE(shape.tuple_shapes(3).is_dynamic());
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gather_expander_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ ENTRY main {

HloInstruction* while_instr = nullptr;
for (auto* instr : module->entry_computation()->instructions()) {
if (instr->opcode() == HloOpcode::kWhile) {
if (HloPredicateIsOp<HloOpcode::kWhile>(instr)) {
ASSERT_EQ(while_instr, nullptr)
<< "Expected exactly one while instruction in the entry computation "
"after gather expansion";
Expand Down Expand Up @@ -159,7 +159,7 @@ ENTRY main {

HloInstruction* while_instr = nullptr;
for (auto* instr : module->entry_computation()->instructions()) {
if (instr->opcode() == HloOpcode::kWhile) {
if (HloPredicateIsOp<HloOpcode::kWhile>(instr)) {
ASSERT_EQ(while_instr, nullptr)
<< "Expected exactly one while instruction in the entry computation "
"after gather expansion";
Expand Down
2 changes: 1 addition & 1 deletion xla/service/hlo_domain_remover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ absl::StatusOr<int64_t> HloDomainRemover::RemoveExitDomains(
// users vector could be changed during the loop(e.g. ReplaceAllUsesWith).
const std::vector<HloInstruction*> users(instruction->users());
for (HloInstruction* user : users) {
if (user->opcode() == HloOpcode::kDomain &&
if (HloPredicateIsOp<HloOpcode::kDomain>(user) &&
user->user_side_metadata().Kind() == domain_kind &&
user->operand_side_metadata().Kind() == domain_kind) {
VLOG(5) << "Removing exit domain " << user->name();
Expand Down
2 changes: 1 addition & 1 deletion xla/service/hlo_domain_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ absl::Status HloDomainVerifier::RunContext::PopulateDomainKinds(
for (HloComputation* computation :
module_->computations(execution_threads)) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kDomain) {
if (HloPredicateIsOp<HloOpcode::kDomain>(instruction)) {
TF_RET_CHECK(instruction->user_side_metadata().Kind() ==
instruction->operand_side_metadata().Kind())
<< instruction->ToString();
Expand Down
7 changes: 4 additions & 3 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ bool MaxConcurrentCollectivePermutesBelowThreshold(
int max_concurrent_collective_permutes = 0;
int num_concurrent_collective_permutes = 0;
for (HloInstruction* instruction : instruction_sequence) {
if (instruction->opcode() == HloOpcode::kCollectivePermuteStart) {
if (HloPredicateIsOp<HloOpcode::kCollectivePermuteStart>(instruction)) {
num_concurrent_collective_permutes += 1;
max_concurrent_collective_permutes =
std::max(max_concurrent_collective_permutes,
num_concurrent_collective_permutes);
}
if (instruction->opcode() == HloOpcode::kCollectivePermuteDone) {
if (HloPredicateIsOp<HloOpcode::kCollectivePermuteDone>(instruction)) {
num_concurrent_collective_permutes -= 1;
}
}
Expand Down Expand Up @@ -125,7 +125,8 @@ class TestLatencyEstimator : public LatencyEstimator {
? kMediumCost
: kLowCost * ShapeUtil::ElementsIn(instr->shape());
}
if (instr->IsOutputFusion() || instr->opcode() == HloOpcode::kConvolution) {
if (instr->IsOutputFusion() ||
HloPredicateIsOp<HloOpcode::kConvolution>(instr)) {
return instr->shape().IsTuple()
? kHighCost
: kMediumCost * ShapeUtil::ElementsIn(instr->shape());
Expand Down
Loading